From 8f68f1ce824cadf937697b45ac16a3587f6cb909 Mon Sep 17 00:00:00 2001 From: mayuehit Date: Mon, 17 Nov 2025 20:10:34 +0800 Subject: [PATCH 1/3] sync code --- BUILD.bazel | 23 + WORKSPACE | 41 +- api/cpp/BUILD.bazel | 149 +- api/cpp/example/BUILD.bazel | 21 +- api/cpp/example/faas_example.cpp | 57 + api/cpp/example/instance_example.cpp | 11 + api/cpp/example/kv_example.cpp | 1 + api/cpp/example/object_example.cpp | 1 + api/cpp/example/runtime_env_example.cpp | 36 + api/cpp/example/runtime_env_example1.cpp | 87 + api/cpp/example/ssl_example.cpp | 1 + api/cpp/example/stream_example.cpp | 62 + api/cpp/include/faas/Constant.h | 51 + api/cpp/include/faas/Context.h | 78 + api/cpp/include/faas/Function.h | 76 + api/cpp/include/faas/FunctionError.h | 37 + api/cpp/include/faas/FunctionLogger.h | 51 + api/cpp/include/faas/ObjectRef.h | 42 + api/cpp/include/faas/Runtime.h | 52 + api/cpp/include/faas/RuntimeHandler.h | 40 + api/cpp/include/yr/api/affinity.h | 36 + api/cpp/include/yr/api/buffer.h | 23 +- api/cpp/include/yr/api/config.h | 55 +- api/cpp/include/yr/api/constant.h | 9 + api/cpp/include/yr/api/function_handler.h | 10 +- api/cpp/include/yr/api/hetero_manager.h | 8 +- api/cpp/include/yr/api/instance_creator.h | 4 +- api/cpp/include/yr/api/invoke_arg.h | 3 + api/cpp/include/yr/api/invoke_options.h | 54 +- api/cpp/include/yr/api/kv.h | 31 + api/cpp/include/yr/api/kv_manager.h | 20 + api/cpp/include/yr/api/local_mode_runtime.h | 6 +- api/cpp/include/yr/api/local_state_store.h | 2 + api/cpp/include/yr/api/mutable_buffer.h | 33 + api/cpp/include/yr/api/named_instance.h | 34 + api/cpp/include/yr/api/node.h | 28 + api/cpp/include/yr/api/object_ref.h | 6 +- api/cpp/include/yr/api/object_store.h | 17 +- api/cpp/include/yr/api/runtime.h | 41 +- api/cpp/include/yr/api/runtime_env.h | 137 + api/cpp/include/yr/api/serdes.h | 15 +- api/cpp/include/yr/api/stream.h | 209 + api/cpp/include/yr/api/wait_request_manager.h | 2 +- api/cpp/include/yr/api/yr_core.h | 70 + .../yr/parallel/detail/parallel_for_local.h | 12 +- api/cpp/include/yr/yr.h | 202 +- api/cpp/src/cluster_mode_runtime.cpp | 528 +- api/cpp/src/cluster_mode_runtime.h | 41 +- api/cpp/src/config_manager.cpp | 62 +- api/cpp/src/config_manager.h | 24 +- api/cpp/src/datasystem_buffer.cpp | 46 + api/cpp/src/datasystem_buffer.h | 37 + api/cpp/src/executor/executor_holder.cpp | 5 + api/cpp/src/executor/executor_holder.h | 2 + api/cpp/src/faas/context_env.cpp | 109 + api/cpp/src/faas/context_env.h | 55 + api/cpp/src/faas/context_impl.cpp | 186 + api/cpp/src/faas/context_impl.h | 103 + api/cpp/src/faas/context_invoke_params.cpp | 108 + api/cpp/src/faas/context_invoke_params.h | 61 + api/cpp/src/faas/faas_executor.cpp | 327 + api/cpp/src/faas/faas_executor.h | 60 + api/cpp/src/faas/function.cpp | 215 + api/cpp/src/faas/function_error.cpp | 45 + api/cpp/src/faas/function_logger.cpp | 122 + api/cpp/src/faas/object_ref.cpp | 75 + api/cpp/src/faas/register_runtime_handler.cpp | 77 + api/cpp/src/faas/register_runtime_handler.h | 55 + api/cpp/src/faas/runtime.cpp | 88 + api/cpp/src/local_mode_runtime.cpp | 5 + api/cpp/src/local_state_store.cpp | 15 + api/cpp/src/mutable_buffer.cpp | 35 + api/cpp/src/object_store.cpp | 10 +- api/cpp/src/read_only_buffer.h | 4 +- api/cpp/src/runtime_env.cpp | 54 + api/cpp/src/runtime_env_parse.cpp | 211 + api/cpp/src/runtime_env_parse.h | 25 + api/cpp/src/stream_pubsub.cpp | 122 + api/cpp/src/stream_pubsub.h | 144 + api/cpp/src/wait_request_manager.cpp | 5 +- api/cpp/src/yr.cpp | 93 +- api/go/BUILD.bazel | 95 + api/go/README.md | 12 + api/go/example/actor_example.go | 102 + api/go/example/actor_example_exception.go | 149 + api/go/example/entrance_example.go | 41 + api/go/example/kv_example.go | 68 + api/go/example/put_get_wait_example.go | 90 + api/go/example/stream_example.go | 87 + api/go/example/task_example.go | 60 + api/go/example/test_exception.go | 78 + api/go/faassdk/common/alarm/logalarm.go | 242 + api/go/faassdk/common/alarm/logalarm_test.go | 123 + api/go/faassdk/common/aliasroute/alias.go | 358 + .../faassdk/common/aliasroute/alias_test.go | 274 + .../faassdk/common/aliasroute/expression.go | 113 + .../common/aliasroute/expression_test.go | 194 + api/go/faassdk/common/constants/constants.go | 116 + api/go/faassdk/common/constants/error.go | 79 + api/go/faassdk/common/faasscheduler/proxy.go | 115 + .../common/faasscheduler/proxy_test.go | 131 + .../common/functionlog/function_log.go | 437 ++ .../common/functionlog/function_log_test.go | 209 + .../faassdk/common/functionlog/log_fields.go | 91 + .../common/functionlog/log_fields_test.go | 38 + .../faassdk/common/functionlog/log_recoder.go | 497 ++ .../common/functionlog/log_recoder_test.go | 585 ++ .../common/functionlog/log_recorder_option.go | 55 + .../functionlog/log_recorder_option_test.go | 49 + api/go/faassdk/common/functionlog/std_log.go | 85 + .../common/functionlog/std_log_test.go | 97 + api/go/faassdk/common/functionlog/write.go | 99 + api/go/faassdk/common/loadbalance/hash.go | 379 + .../faassdk/common/loadbalance/hash_test.go | 91 + .../faassdk/common/loadbalance/hashcache.go | 43 + .../faassdk/common/loadbalance/loadbalance.go | 66 + .../common/loadbalance/loadbalance_test.go | 229 + .../loadbalance/nolockconsistenthash.go | 126 + .../loadbalance/nolockconsistenthash_test.go | 118 + .../faassdk/common/loadbalance/roundrobin.go | 258 + .../common/loadbalance/roundrobin_test.go | 290 + .../faassdk/common/monitor/monitor_manager.go | 60 + .../common/monitor/monitor_manager_test.go | 28 + .../common/monitor/oom/memory_manager.go | 104 + .../common/monitor/oom/memory_manager_test.go | 84 + api/go/faassdk/common/monitor/oom/parser.go | 98 + .../common/tokentosecret/secret_mgr.go | 72 + .../common/tokentosecret/secret_mgr_test.go | 64 + api/go/faassdk/config/config.go | 112 + api/go/faassdk/config/config_test.go | 33 + api/go/faassdk/entrance.go | 44 + api/go/faassdk/entrance_test.go | 56 + .../faas-sdk/go-api/context/client_context.go | 67 + .../faas-sdk/go-api/function/function.go | 22 + .../faas-sdk/pkg/runtime/context/context.go | 52 + .../pkg/runtime/context/contextenv.go | 142 + .../pkg/runtime/context/contextenv_test.go | 161 + .../pkg/runtime/context/runtime_context.go | 104 + .../runtime/context/runtime_context_test.go | 68 + .../pkg/runtime/userlog/user_logger.go | 26 + api/go/faassdk/faashandler.go | 207 + api/go/faassdk/faashandler_test.go | 302 + api/go/faassdk/handler/event/future.go | 124 + api/go/faassdk/handler/event/future_test.go | 214 + api/go/faassdk/handler/event/handler.go | 340 + api/go/faassdk/handler/event/handler_test.go | 804 ++ api/go/faassdk/handler/handler.go | 48 + api/go/faassdk/handler/handler_test.go | 113 + api/go/faassdk/handler/http/apig.go | 175 + api/go/faassdk/handler/http/apig_test.go | 63 + api/go/faassdk/handler/http/basic_handler.go | 676 ++ .../handler/http/basic_handler_test.go | 435 ++ .../http/crossclusterinvoke/httpclient.go | 43 + .../http/crossclusterinvoke/invoker.go | 370 + .../http/crossclusterinvoke/invoker_test.go | 370 + .../handler/http/custom_container_handler.go | 1150 +++ .../http/custom_container_handler_test.go | 1937 +++++ .../handler/http/custom_container_log.go | 677 ++ .../handler/http/custom_container_log_test.go | 664 ++ api/go/faassdk/handler/http/http_handler.go | 149 + .../faassdk/handler/http/http_handler_test.go | 386 + api/go/faassdk/handler/http/state.go | 232 + api/go/faassdk/handler/http/state_test.go | 290 + api/go/faassdk/handler/mock_utils_test.go | 302 + api/go/faassdk/runtime.go | 86 + api/go/faassdk/runtime_test.go | 43 + api/go/faassdk/sts/sts.go | 65 + api/go/faassdk/sts/sts_test.go | 22 + api/go/faassdk/types/types.go | 321 + api/go/faassdk/types/types_test.go | 94 + api/go/faassdk/utils/handle_response.go | 120 + api/go/faassdk/utils/handle_response_test.go | 71 + api/go/faassdk/utils/signer/sksigner.go | 86 + api/go/faassdk/utils/signer/sksigner_test.go | 90 + api/go/faassdk/utils/urnutils/gadgets.go | 91 + api/go/faassdk/utils/urnutils/gadgets_test.go | 159 + api/go/faassdk/utils/urnutils/urn_utils.go | 124 + .../faassdk/utils/urnutils/urn_utils_test.go | 145 + api/go/faassdk/utils/utils.go | 149 + api/go/faassdk/utils/utils_test.go | 117 + api/go/go.mod | 57 + api/go/libruntime/api/api.go | 100 + api/go/libruntime/api/types.go | 480 ++ api/go/libruntime/api/types_test.go | 158 + api/go/libruntime/clibruntime/clibruntime.go | 2539 +++++++ .../clibruntime/clibruntime_test.go | 1213 +++ api/go/libruntime/common/config.go | 180 + api/go/libruntime/common/config_test.go | 113 + api/go/libruntime/common/config_types.go | 41 + .../libruntime/common/constants/constants.go | 50 + .../common/constants/status_code.go | 33 + .../libruntime/common/faas/logger/logger.go | 64 + .../common/faas/logger/logger_test.go | 39 + .../common/faas/logger/user_logger.go | 109 + .../common/faas/logger/user_logger_test.go | 80 + .../libruntime/common/logger/async/writer.go | 176 + .../common/logger/async/writer_test.go | 94 + .../libruntime/common/logger/config/config.go | 132 + .../common/logger/config/config_test.go | 65 + .../common/logger/custom_encoder.go | 412 + .../common/logger/custom_encoder_test.go | 447 ++ .../common/logger/interface_encoder.go | 349 + .../common/logger/interface_encoder_test.go | 433 ++ .../common/logger/interfacelogger.go | 100 + .../common/logger/interfacelogger_test.go | 97 + api/go/libruntime/common/logger/log/log.go | 243 + .../libruntime/common/logger/log/log_test.go | 241 + api/go/libruntime/common/logger/rollinglog.go | 269 + .../common/logger/rollinglog_test.go | 152 + api/go/libruntime/common/logger/zap/zaplog.go | 162 + .../common/logger/zap/zaplog_test.go | 159 + api/go/libruntime/common/signal.go | 54 + api/go/libruntime/common/token_mgr.go | 100 + api/go/libruntime/common/token_mgr_test.go | 79 + api/go/libruntime/common/types.go | 38 + api/go/libruntime/common/utils/utils.go | 42 + api/go/libruntime/common/utils/utils_test.go | 50 + api/go/libruntime/common/uuid/uuid.go | 93 + api/go/libruntime/common/uuid/uuid_test.go | 50 + api/go/libruntime/config/config.go | 159 + api/go/libruntime/cpplibruntime/BUILD.bazel | 53 + api/go/libruntime/cpplibruntime/clibruntime.h | 493 ++ .../cpplibruntime/cpplibruntime.cpp | 1761 +++++ .../cpplibruntime/mock/mock_cpplibruntime.cpp | 446 ++ api/go/libruntime/execution/execution.go | 54 + api/go/libruntime/libruntime.go | 78 + api/go/libruntime/libruntime_test.go | 148 + .../libruntimesdkimpl/libruntimesdkimpl.go | 252 + .../libruntimesdkimpl_test.go | 337 + api/go/libruntime/pool/pool.go | 41 + api/go/libruntime/pool/pool_test.go | 38 + api/go/posixsdk/posixhandler.go | 345 + api/go/posixsdk/posixhandler_test.go | 170 + api/go/posixsdk/runtime.go | 127 + api/go/posixsdk/runtime_test.go | 55 + api/go/runtime/system_function_bootstrap | 21 + api/go/runtime/yr_runtime_main.go | 76 + api/go/runtime/yr_runtime_main_test.go | 35 + api/go/yr/actorhandler.go | 327 + api/go/yr/actorhandler_test.go | 235 + api/go/yr/cluster_mode_runtime.go | 257 + api/go/yr/cluster_mode_runtime_test.go | 303 + api/go/yr/config.go | 90 + api/go/yr/config_manager.go | 137 + api/go/yr/config_manager_test.go | 90 + api/go/yr/config_test.go | 38 + api/go/yr/function_handler.go | 95 + api/go/yr/function_handler_test.go | 57 + api/go/yr/instance_creator.go | 90 + api/go/yr/instance_creator_test.go | 71 + api/go/yr/instance_function_handler.go | 61 + api/go/yr/instance_function_handler_test.go | 50 + api/go/yr/named_instance.go | 57 + api/go/yr/named_instance_test.go | 66 + api/go/yr/object_ref.go | 49 + api/go/yr/object_ref_test.go | 54 + api/go/yr/runtime.go | 64 + api/go/yr/runtime_holder.go | 52 + api/go/yr/runtime_holder_test.go | 37 + api/go/yr/runtime_test.go | 41 + api/go/yr/stacktrace.go | 170 + api/go/yr/stacktrace_test.go | 128 + api/go/yr/stream.go | 129 + api/go/yr/stream_test.go | 132 + api/go/yr/utils.go | 82 + api/go/yr/utils_test.go | 133 + api/go/yr/yr.go | 266 + api/go/yr/yr_test.go | 354 + api/java/example/GoInstanceExample.java | 72 + api/java/example/InstanceExample.java | 6 +- api/java/example/OptionsExample.java | 2 +- api/java/example/VoidFunctionExample.java | 4 +- api/java/faas-function-sdk/pom.xml | 91 + .../main/java/com/function/CreateOptions.java | 96 + .../src/main/java/com/function/Function.java | 174 + .../src/main/java/com/function/ObjectRef.java | 222 + .../com/function/common/RspErrorCode.java | 61 + .../main/java/com/function/common/Util.java | 179 + .../runtime/exception/InvokeException.java | 106 + .../java/com/function/TestCreateOptions.java | 42 + .../test/java/com/function/TestFunction.java | 164 + .../test/java/com/function/TestObjectRef.java | 139 + .../java/com/function/common/ContextMock.java | 235 + .../java/com/function/common/TestUtil.java | 137 + .../exception/TestInvokeException.java | 31 + .../main/cpp/com_yuanrong_jni_Consumer.cpp | 105 + .../src/main/cpp/com_yuanrong_jni_Consumer.h | 35 + .../main/cpp/com_yuanrong_jni_LibRuntime.cpp | 419 +- .../main/cpp/com_yuanrong_jni_LibRuntime.h | 20 +- .../main/cpp/com_yuanrong_jni_Producer.cpp | 151 + .../src/main/cpp/com_yuanrong_jni_Producer.h | 44 + .../src/main/cpp/jni_errorinfo.cpp | 16 +- .../src/main/cpp/jni_function_meta.cpp | 7 + .../function-common/src/main/cpp/jni_init.cpp | 10 + .../src/main/cpp/jni_stacktrace_element.cpp | 6 +- .../src/main/cpp/jni_stacktrace_info.cpp | 12 +- .../src/main/cpp/jni_types.cpp | 247 +- .../function-common/src/main/cpp/jni_types.h | 87 + .../com/services/enums/FaasErrorCode.java | 73 + .../com/services/exception/FaaSException.java | 85 + .../services/logger/UserFunctionLogger.java | 22 +- .../java/com/services/model/Response.java | 2 + .../java/com/services/runtime/Context.java | 57 + .../services/runtime/action/ContextImpl.java | 40 + .../runtime/action/ContextInvokeParams.java | 6 + .../runtime/action/DelegateDecrypt.java | 8 + .../runtime/action/ExtendedMetaData.java | 3 + .../com/services/runtime/action/PreStop.java | 35 + .../java/com/services/runtime/utils/Util.java | 48 + .../main/java/com/yuanrong/InvokeOptions.java | 17 +- .../java/com/yuanrong/affinity/Affinity.java | 28 + .../com/yuanrong/affinity/AffinityScope.java | 4 +- .../com/yuanrong/affinity/LabelOperator.java | 3 +- .../src/main/java/com/yuanrong/api/Node.java | 52 + .../com/yuanrong/errorcode/ErrorCode.java | 17 + .../handler/traceback/StackTraceUtils.java | 77 +- .../java/com/yuanrong/jni/JniConsumer.java | 67 + .../java/com/yuanrong/jni/JniProducer.java | 91 + .../java/com/yuanrong/jni/LibRuntime.java | 146 +- .../com/yuanrong/jni/LibRuntimeConfig.java | 3 + .../main/java/com/yuanrong/jni/LoadUtil.java | 6 + .../yuanrong/runtime/client/ObjectRef.java | 15 +- .../com/yuanrong/runtime/util/Constants.java | 11 + .../runtime/util/ExtClasspathLoader.java | 3 +- .../runtime/util/FuncClassLoader.java | 3 +- .../java/com/yuanrong/runtime/util/Utils.java | 161 +- .../serialization/strategy/Strategy.java | 3 +- .../java/com/yuanrong/stream/Consumer.java | 69 + .../com/yuanrong/stream/ConsumerImpl.java | 151 + .../java/com/yuanrong/stream/Element.java | 80 + .../java/com/yuanrong/stream/Producer.java | 66 + .../com/yuanrong/stream/ProducerConfig.java | 219 + .../com/yuanrong/stream/ProducerImpl.java | 160 + .../yuanrong/stream/SubscriptionConfig.java | 125 + .../com/yuanrong/stream/SubscriptionType.java | 26 + .../com/services/model/TestFaaSModel.java | 56 + .../runtime/action/TestContextImpl.java | 16 + .../java/com/yuanrong/TestInvokeOptions.java | 26 +- .../com/yuanrong/affinity/TestAffinity.java | 13 + .../affinity/TestInstanceAffinity.java | 116 +- .../affinity/TestResourceAffinity.java | 70 + .../com/yuanrong/affinity/TestSelector.java | 41 + .../yuanrong/affinity/TestSubCondition.java | 48 +- .../java/com/yuanrong/runtime/TestUtils.java | 30 + .../com/yuanrong/stream/TestConsumerImpl.java | 145 + .../java/com/yuanrong/stream/TestElement.java | 40 + .../yuanrong/stream/TestProducerConfig.java | 36 + .../com/yuanrong/stream/TestProducerImpl.java | 161 + .../stream/TestSubscriptionConfig.java | 32 + api/java/yr-api-sdk/resource/sdkpom.xml | 2 +- .../src/main/java/com/yuanrong/Config.java | 202 +- .../main/java/com/yuanrong/ConfigManager.java | 14 +- .../src/main/java/com/yuanrong/YRCall.java | 52 + .../com/yuanrong/api/JobExecutorCaller.java | 288 + .../src/main/java/com/yuanrong/api/YR.java | 173 + .../com/yuanrong/call/CppFunctionHandler.java | 4 +- .../com/yuanrong/call/CppInstanceCreator.java | 5 +- .../com/yuanrong/call/CppInstanceHandler.java | 10 + .../com/yuanrong/call/GoFunctionHandler.java | 89 + .../com/yuanrong/call/GoInstanceCreator.java | 121 + .../call/GoInstanceFunctionHandler.java | 104 + .../com/yuanrong/call/GoInstanceHandler.java | 207 + .../com/yuanrong/call/InstanceCreator.java | 9 +- .../call/InstanceFunctionHandler.java | 4 +- .../com/yuanrong/call/InstanceHandler.java | 16 +- .../yuanrong/call/JavaFunctionHandler.java | 4 +- .../yuanrong/call/JavaInstanceCreator.java | 5 +- .../yuanrong/call/JavaInstanceHandler.java | 10 + .../call/VoidInstanceFunctionHandler.java | 4 +- .../com/yuanrong/function/GoFunction.java | 68 + .../yuanrong/function/GoInstanceClass.java | 48 + .../yuanrong/function/GoInstanceMethod.java | 75 + .../com/yuanrong/jobexecutor/JobExecutor.java | 362 + .../com/yuanrong/jobexecutor/OBSoptions.java | 152 + .../com/yuanrong/jobexecutor/RuntimeEnv.java | 161 + .../com/yuanrong/jobexecutor/YRJobInfo.java | 163 + .../com/yuanrong/jobexecutor/YRJobParam.java | 471 ++ .../com/yuanrong/jobexecutor/YRJobStatus.java | 48 + .../yuanrong/runtime/ClusterModeRuntime.java | 386 +- .../java/com/yuanrong/runtime/Runtime.java | 118 +- .../yuanrong/runtime/client/KVManager.java | 18 +- .../java/com/yuanrong/utils/SdkUtils.java | 6 + .../test/java/com/yuanrong/TestConfig.java | 80 +- .../java/com/yuanrong/TestConfigManager.java | 7 +- .../src/test/java/com/yuanrong/TestGroup.java | 3 +- .../yuanrong/api/TestJobExecutorCaller.java | 214 + .../yuanrong/call/TestCppFunctionHandler.java | 4 +- .../yuanrong/call/TestCppInstanceCreator.java | 4 +- .../call/TestCppInstanceFunctionHandler.java | 4 +- .../yuanrong/call/TestCppInstanceHandler.java | 8 +- .../yuanrong/call/TestFunctionHandler.java | 6 +- .../yuanrong/call/TestGoFunctionHandler.java | 80 + .../yuanrong/call/TestGoInstanceCreator.java | 86 + .../call/TestGoInstanceFunctionHandler.java | 95 + .../yuanrong/call/TestGoInstanceHandler.java | 119 + .../yuanrong/call/TestInstanceHandler.java | 5 +- .../call/TestJavaFunctionHandler.java | 4 +- .../call/TestJavaInstanceCreator.java | 4 +- .../call/TestJavaInstanceFunctionHandler.java | 6 +- .../call/TestJavaInstanceHandler.java | 5 +- .../yuanrong/jobexecutor/TestJobExecutor.java | 65 + .../yuanrong/jobexecutor/TestOBSoptions.java | 66 + .../yuanrong/jobexecutor/TestRuntimeEnv.java | 101 + .../yuanrong/jobexecutor/TestYRJobInfo.java | 60 + .../yuanrong/jobexecutor/TestYRJobParam.java | 254 + .../runtime/TestClusterModeRuntime.java | 326 +- .../java/com/yuanrong/runtime/TestYR.java | 154 +- .../runtime/client/TestKVManager.java | 14 +- .../main/java/com/yuanrong/Entrypoint.java | 7 + .../yuanrong/codemanager/CodeExecutor.java | 7 +- .../com/yuanrong/codemanager/CodeLoader.java | 3 +- .../com/yuanrong/executor/FaaSHandler.java | 745 ++ .../yuanrong/executor/FunctionHandler.java | 4 + .../runtime/server/RuntimeLogger.java | 12 +- .../yr-runtime/src/main/resources/log4j2.xml | 41 +- .../yuanrong/executor/MockFailedClass.java | 45 + .../com/yuanrong/executor/MockNoneClass.java | 19 + .../yuanrong/executor/TestFaaSHandler.java | 652 ++ .../executor/TestFunctionHandler.java | 97 +- .../com/yuanrong/executor/TestReturnType.java | 3 +- .../yuanrong/executor/UserTestHandler.java | 20 + .../runtime/server/TestRuntimeLogger.java | 4 +- api/python/BUILD.bazel | 4 +- api/python/functionsdk.py | 32 + api/python/requirements.txt | 2 +- api/python/requirements_for_py37.txt | 2 +- api/python/yr/__init__.py | 53 +- api/python/yr/affinity.py | 41 +- api/python/yr/apis.py | 125 +- api/python/yr/cluster_mode_runtime.py | 88 +- api/python/yr/code_manager.py | 112 +- api/python/yr/common/constants.py | 3 + api/python/yr/compiled_dag_ref.py | 225 + api/python/yr/config.py | 68 +- api/python/yr/config/python-runtime-log.json | 2 +- api/python/yr/config_manager.py | 15 +- api/python/yr/decorator/function_proxy.py | 4 + api/python/yr/decorator/instance_proxy.py | 23 + api/python/yr/exception.py | 52 + api/python/yr/executor/executor.py | 31 +- api/python/yr/executor/faas_executor.py | 256 + api/python/yr/executor/faas_handler.py | 45 + api/python/yr/fcc.py | 2 +- api/python/yr/fnruntime.pyx | 654 +- api/python/yr/functionsdk/__init__.py | 17 + api/python/yr/functionsdk/context.py | 408 + api/python/yr/functionsdk/error_code.py | 53 + api/python/yr/functionsdk/function.py | 304 + api/python/yr/functionsdk/logger.py | 133 + api/python/yr/functionsdk/logger_manager.py | 179 + api/python/yr/functionsdk/utils.py | 237 + api/python/yr/generator.py | 150 + api/python/yr/includes/affinity.pxd | 4 + api/python/yr/includes/affinity.pxi | 6 + api/python/yr/includes/libruntime.pxd | 65 +- api/python/yr/includes/serialization.pxi | 5 + .../yr/local_mode/local_mode_runtime.py | 44 + api/python/yr/log.py | 16 +- api/python/yr/main/yr_runtime_main.py | 3 +- api/python/yr/object_ref.py | 17 +- api/python/yr/runtime.py | 47 + api/python/yr/runtime_env.py | 175 + api/python/yr/serialization/__init__.py | 3 +- api/python/yr/serialization/serialization.py | 36 +- api/python/yr/stream.py | 110 + api/python/yr/tests/BUILD.bazel | 19 + api/python/yr/tests/test_apis.py | 25 + api/python/yr/tests/test_apis_get.py | 36 + api/python/yr/tests/test_apis_put.py | 106 + .../yr/tests/test_cluster_mode_runtime.py | 53 + api/python/yr/tests/test_code_manager.py | 9 +- api/python/yr/tests/test_executor.py | 26 +- api/python/yr/tests/test_faas_handler.py | 142 + api/python/yr/tests/test_functionsdk.py | 293 + api/python/yr/tests/test_generator.py | 135 + api/python/yr/tests/test_local_mode.py | 15 + api/python/yr/tests/test_runtime_env.py | 160 + api/python/yr/tests/test_serialization.py | 13 + bazel/local_patched_repository.bzl | 6 +- bazel/metrics_sdk.bzl | 1 - bazel/openssl.bazel | 18 +- bazel/preload_opentelemetry.bzl | 27 + bazel/yr_go.bzl | 51 + build.sh | 46 +- .../zh_cn/C++/FunctionHandler-Options.rst | 2 +- .../C++/InstanceFunctionHandler-Options.rst | 2 +- .../zh_cn/C++/struct-InvokeOptions.rst | 2 +- .../development_guide/data_object/KV.md | 2 +- ...e-namespace-and-library-name-with-yr.patch | 6696 +++++++++++++++++ scripts/package_yuanrong.sh | 3 +- src/dto/acquire_options.h | 12 + src/dto/affinity.h | 32 +- src/dto/config.h | 23 +- src/dto/constant.h | 2 + src/dto/data_object.h | 13 + src/dto/debug_config.h | 39 + src/dto/invoke_options.h | 49 +- src/dto/resource_unit.h | 1 + src/dto/stream_conf.h | 75 + .../clientsmanager/clients_manager.cpp | 81 +- .../clientsmanager/clients_manager.h | 14 +- .../driverlog/driverlog_receiver.cpp | 202 + src/libruntime/driverlog/driverlog_receiver.h | 56 + src/libruntime/err_type.h | 71 + src/libruntime/fmclient/fm_client.cpp | 146 +- src/libruntime/fmclient/fm_client.h | 16 +- src/libruntime/fsclient/fs_client.cpp | 14 + src/libruntime/fsclient/fs_client.h | 2 + src/libruntime/fsclient/fs_intf.cpp | 16 +- src/libruntime/fsclient/fs_intf.h | 7 +- src/libruntime/fsclient/fs_intf_impl.cpp | 186 +- src/libruntime/fsclient/fs_intf_impl.h | 14 +- src/libruntime/fsclient/fs_intf_manager.cpp | 2 +- .../fsclient/fs_intf_reader_writer.h | 3 +- .../fs_intf_grpc_client_reader_writer.cpp | 70 +- .../grpc/fs_intf_grpc_client_reader_writer.h | 4 + .../grpc/fs_intf_grpc_reader_writer.h | 1 + .../fs_intf_grpc_server_reader_writer.cpp | 13 +- .../grpc/fs_intf_grpc_server_reader_writer.h | 1 + .../fsclient/grpc/grpc_posix_service.cpp | 69 +- .../fsclient/grpc/grpc_posix_service.h | 1 + .../fsclient/grpc/posix_auth_interceptor.cpp | 183 + .../fsclient/grpc/posix_auth_interceptor.h | 129 + .../fsclient/protobuf/bus_service.proto | 11 + src/libruntime/fsclient/protobuf/common.proto | 9 + .../fsclient/protobuf/core_service.proto | 131 +- src/libruntime/generator/generator_id_map.h | 84 + src/libruntime/generator/generator_notifier.h | 40 + src/libruntime/generator/generator_receiver.h | 33 + .../generator/stream_generator_notifier.cpp | 359 + .../generator/stream_generator_notifier.h | 94 + .../generator/stream_generator_receiver.cpp | 254 + .../generator/stream_generator_receiver.h | 72 + .../groupmanager/function_group.cpp | 11 +- src/libruntime/groupmanager/group.cpp | 6 +- src/libruntime/groupmanager/named_group.cpp | 8 +- src/libruntime/groupmanager/range_group.cpp | 8 +- src/libruntime/gwclient/gw_client.cpp | 1104 +++ src/libruntime/gwclient/gw_client.h | 364 + .../gwclient/gw_datasystem_client_wrapper.h | 67 + .../gwclient/http/async_http_client.cpp | 86 +- .../gwclient/http/async_http_client.h | 6 +- .../gwclient/http/async_https_client.cpp | 68 +- .../gwclient/http/async_https_client.h | 6 +- .../gwclient/http/client_manager.cpp | 124 +- src/libruntime/gwclient/http/client_manager.h | 21 +- src/libruntime/gwclient/http/http_client.h | 129 +- .../heterostore/datasystem_hetero_store.cpp | 9 +- .../heterostore/datasystem_hetero_store.h | 6 +- src/libruntime/heterostore/hetero_future.cpp | 28 +- src/libruntime/heterostore/hetero_future.h | 2 + src/libruntime/heterostore/hetero_store.h | 38 +- src/libruntime/invoke_order_manager.cpp | 46 +- src/libruntime/invoke_order_manager.h | 3 + src/libruntime/invoke_spec.cpp | 118 +- src/libruntime/invoke_spec.h | 73 +- .../invokeadaptor/alias_element.cpp | 88 + src/libruntime/invokeadaptor/alias_element.h | 61 + .../invokeadaptor/alias_routing.cpp | 290 + src/libruntime/invokeadaptor/alias_routing.h | 68 + .../invokeadaptor/faas_instance_manager.cpp | 1042 +++ .../invokeadaptor/faas_instance_manager.h | 83 + .../invokeadaptor/instance_manager.cpp | 294 +- .../invokeadaptor/instance_manager.h | 54 +- .../invokeadaptor/invoke_adaptor.cpp | 508 +- src/libruntime/invokeadaptor/invoke_adaptor.h | 63 +- .../invokeadaptor/limiter_consistant_hash.cpp | 183 + .../invokeadaptor/limiter_consistant_hash.h | 72 + .../invokeadaptor/load_balancer.cpp | 242 + src/libruntime/invokeadaptor/load_balancer.h | 48 + .../invokeadaptor/normal_instance_manager.cpp | 79 +- .../invokeadaptor/normal_instance_manager.h | 2 +- .../invokeadaptor/request_manager.cpp | 6 + .../invokeadaptor/request_manager.h | 1 + .../invokeadaptor/request_queue.cpp | 31 + src/libruntime/invokeadaptor/request_queue.h | 14 + .../invokeadaptor/scheduler_instance_info.cpp | 52 + .../invokeadaptor/scheduler_instance_info.h | 48 + .../invokeadaptor/task_scheduler.cpp | 35 +- src/libruntime/invokeadaptor/task_scheduler.h | 35 +- .../invokeadaptor/task_submitter.cpp | 577 +- src/libruntime/invokeadaptor/task_submitter.h | 45 +- src/libruntime/libruntime.cpp | 386 +- src/libruntime/libruntime.h | 185 +- src/libruntime/libruntime_config.h | 26 +- src/libruntime/libruntime_manager.cpp | 102 +- .../metricsadaptor/metrics_adaptor.cpp | 47 +- .../metricsadaptor/metrics_adaptor.h | 11 +- .../objectstore/datasystem_object_store.cpp | 49 +- .../objectstore/datasystem_object_store.h | 9 +- src/libruntime/objectstore/memory_store.cpp | 198 +- src/libruntime/objectstore/memory_store.h | 6 +- src/libruntime/objectstore/object_store.h | 8 +- .../statestore/datasystem_state_store.cpp | 133 +- .../statestore/datasystem_state_store.h | 26 +- src/libruntime/statestore/state_store.h | 31 +- .../streamstore/datasystem_stream_store.cpp | 233 + .../streamstore/datasystem_stream_store.h | 107 + .../streamstore/stream_producer_consumer.cpp | 132 + .../streamstore/stream_producer_consumer.h | 59 + src/libruntime/streamstore/stream_store.h | 49 + .../exporter/log_file_exporter.cpp | 89 + .../traceadaptor/exporter/log_file_exporter.h | 53 + .../exporter/log_file_exporter_factory.cpp | 29 + .../exporter/log_file_exporter_factory.h | 28 + src/libruntime/traceadaptor/trace_adapter.cpp | 175 + src/libruntime/traceadaptor/trace_adapter.h | 78 + src/libruntime/traceadaptor/trace_struct.h | 40 + src/libruntime/utils/grpc_utils.cpp | 122 + src/libruntime/utils/grpc_utils.h | 40 + src/libruntime/utils/hash_utils.cpp | 56 + src/libruntime/utils/hash_utils.h | 28 + src/libruntime/utils/http_utils.cpp | 101 + src/libruntime/utils/http_utils.h | 48 + src/libruntime/utils/security.cpp | 115 +- src/libruntime/utils/security.h | 44 + src/libruntime/utils/utils.cpp | 42 + src/libruntime/utils/utils.h | 5 + src/libruntime/waiting_object_manager.cpp | 4 +- src/proto/libruntime.proto | 4 +- src/scene/downgrade.cpp | 226 + src/scene/downgrade.h | 107 + src/utility/file_watcher.cpp | 109 + src/utility/file_watcher.h | 56 + src/utility/logger/common.h | 5 +- src/utility/logger/log_handler.cpp | 4 +- src/utility/logger/spd_logger.cpp | 64 +- src/utility/logger/spd_logger.h | 13 +- src/utility/memory.cpp | 14 +- src/utility/memory.h | 14 +- src/utility/timer_worker.cpp | 21 +- src/utility/timer_worker.h | 7 +- test/BUILD.bazel | 232 +- test/api/api_test.cpp | 167 +- test/api/cluster_mode_runtime_test.cpp | 247 +- test/api/config_manager_test.cpp | 55 + test/api/function_manager_test.cpp | 28 +- test/api/local_mode_test.cpp | 57 + test/api/object_ref_test.cpp | 5 + test/api/runtime_env_parse_test.cpp | 227 + test/api/runtime_env_test.cpp | 169 + test/api/stream_pub_sub_test.cpp | 242 + test/clibruntime/clibruntime_test.cpp | 1034 +++ test/common/mock_libruntime.h | 43 +- test/data/cert/ca.crt | 19 + test/data/cert/client.crt | 17 + test/data/cert/client.key | 30 + test/data/cert/server.crt | 18 + test/data/cert/server.key | 30 + test/dto/config_test.cpp | 17 +- test/faas/faas_executor_test.cpp | 435 ++ test/faas/function_test.cpp | 149 + test/libruntime/alias_routing_test.cpp | 212 + test/libruntime/auto_init_test.cpp | 16 + test/libruntime/clients_manager_test.cpp | 8 +- test/libruntime/driverlog_test.cpp | 169 + test/libruntime/execution_manager_test.cpp | 21 + .../libruntime/faas_instance_manager_test.cpp | 547 ++ test/libruntime/fm_client_test.cpp | 135 +- test/libruntime/fs_client_test.cpp | 29 +- test/libruntime/fs_intf_grpc_rw_test.cpp | 30 + test/libruntime/fs_intf_impl_test.cpp | 145 +- test/libruntime/fs_intf_manager_test.cpp | 3 +- test/libruntime/function_group_test.cpp | 6 +- test/libruntime/generator_test.cpp | 620 ++ test/libruntime/grpc_utils_test.cpp | 93 + test/libruntime/gw_client_test.cpp | 1019 +++ test/libruntime/hash_util_test.cpp | 39 + test/libruntime/hetero_future_test.cpp | 23 + test/libruntime/hetero_store_test.cpp | 8 +- test/libruntime/http_utils_test.cpp | 71 + test/libruntime/https_client_test.cpp | 374 + test/libruntime/instance_manager_test.cpp | 6 +- test/libruntime/invoke_adaptor_test.cpp | 370 +- test/libruntime/invoke_order_manager_test.cpp | 18 + test/libruntime/invoke_spec_test.cpp | 56 + test/libruntime/kv_state_store_test.cpp | 11 +- test/libruntime/libruntime_config_test.cpp | 14 + test/libruntime/libruntime_test.cpp | 191 +- .../limiter_consistant_hash_test.cpp | 252 + test/libruntime/load_balancer_test.cpp | 177 + test/libruntime/metrics_adaptor_test.cpp | 166 +- test/libruntime/mock/mock_datasystem.h | 60 +- .../mock/mock_datasystem_client.cpp | 133 +- test/libruntime/mock/mock_fs_intf.h | 3 +- test/libruntime/mock/mock_fs_intf_rw.h | 1 + .../mock/mock_fs_intf_with_callback.h | 30 +- test/libruntime/mock/mock_invoke_adaptor.h | 15 + test/libruntime/mock/mock_task_submitter.h | 8 + .../normal_instance_manager_test.cpp | 3 +- test/libruntime/object_store_test.cpp | 24 + test/libruntime/request_queue_test.cpp | 78 + test/libruntime/resource_group_test.cpp | 2 +- test/libruntime/rt_direct_call_test.cpp | 38 +- .../scheduler_instance_info_test.cpp | 74 + test/libruntime/security_test.cpp | 138 +- test/libruntime/stream_store_test.cpp | 165 + test/libruntime/task_submitter_test.cpp | 246 +- test/libruntime/trace_adapter_test.cpp | 115 + test/libruntime/utils_test.cpp | 36 + test/scene/downgrade_test.cpp | 120 + test/st/cpp/src/base/actor_test.cpp | 20 +- test/st/cpp/src/base/always_local_mode.cpp | 14 +- test/st/cpp/src/base/ds_test.cpp | 16 +- test/st/cpp/src/base/init_test.cpp | 14 +- test/st/cpp/src/base/task_test.cpp | 14 +- test/st/cpp/src/base/utils.h | 14 +- test/st/cpp/src/main.cpp | 14 +- test/st/cpp/src/user_common_func.cpp | 14 +- test/st/cpp/src/user_common_func.h | 14 +- test/st/cpp/src/utils.cpp | 14 +- test/st/cpp/src/utils.h | 14 +- test/st/others/rpc_retry_test/src/main.cpp | 14 +- test/st/python/test_yr_api.py | 3 +- test/test_goruntime_start.sh | 30 + test/utility/file_watcher_test.cpp | 121 + test/utility/logger/logger_test.cpp | 8 +- tools/download_dependency.sh | 3 + tools/openSource.txt | 2 +- yuanrong/build/build.sh | 71 + yuanrong/build/build_function.sh | 124 + yuanrong/build/compile_functions.sh | 36 + .../dashboard/config/dashboard_config.json | 22 + .../build/dashboard/config/dashboard_log.json | 13 + yuanrong/cmd/collector/main.go | 26 + yuanrong/cmd/collector/process/process.go | 48 + yuanrong/cmd/dashboard/main.go | 27 + yuanrong/cmd/dashboard/process/process.go | 85 + yuanrong/cmd/faas/faascontroller/main.go | 163 + yuanrong/cmd/faas/faascontroller/main_test.go | 323 + yuanrong/cmd/faas/faasmanager/main.go | 139 + .../cmd/faas/faasscheduler/function_main.go | 162 + .../faas/faasscheduler/function_main_test.go | 286 + .../cmd/faas/faasscheduler/module_main.go | 210 + yuanrong/go.mod | 56 + yuanrong/pkg/collector/common/connection.go | 66 + .../pkg/collector/common/connection_test.go | 62 + yuanrong/pkg/collector/common/flags.go | 177 + yuanrong/pkg/collector/common/flags_test.go | 141 + yuanrong/pkg/collector/logcollector/common.go | 101 + .../pkg/collector/logcollector/common_test.go | 91 + .../collector/logcollector/log_reporter.go | 225 + .../logcollector/log_reporter_test.go | 300 + .../pkg/collector/logcollector/register.go | 102 + .../collector/logcollector/register_test.go | 70 + .../pkg/collector/logcollector/service.go | 174 + .../collector/logcollector/service_test.go | 152 + .../pkg/common/constants/constant_test.go | 1 + yuanrong/pkg/common/constants/constants.go | 395 + yuanrong/pkg/common/crypto/crypto.go | 267 + yuanrong/pkg/common/crypto/crypto_test.go | 143 + yuanrong/pkg/common/crypto/pem_crypto.go | 180 + yuanrong/pkg/common/crypto/scc_constants.go | 47 + yuanrong/pkg/common/crypto/scc_crypto.go | 115 + yuanrong/pkg/common/crypto/scc_crypto_fake.go | 50 + yuanrong/pkg/common/crypto/scc_crypto_test.go | 83 + yuanrong/pkg/common/crypto/types.go | 300 + yuanrong/pkg/common/crypto/types_test.go | 47 + yuanrong/pkg/common/engine/etcd/etcd.go | 219 + yuanrong/pkg/common/engine/etcd/stream.go | 71 + .../pkg/common/engine/etcd/transaction.go | 309 + .../common/engine/etcd/transaction_test.go | 175 + yuanrong/pkg/common/engine/interface.go | 122 + yuanrong/pkg/common/etcd3/config.go | 169 + yuanrong/pkg/common/etcd3/config_test.go | 162 + yuanrong/pkg/common/etcd3/event.go | 81 + yuanrong/pkg/common/etcd3/event_test.go | 81 + yuanrong/pkg/common/etcd3/scc_config.go | 147 + yuanrong/pkg/common/etcd3/scc_watcher.go | 55 + .../pkg/common/etcd3/scc_watcher_no_scc.go | 59 + yuanrong/pkg/common/etcd3/watcher.go | 215 + yuanrong/pkg/common/etcd3/watcher_test.go | 203 + yuanrong/pkg/common/etcdkey/etcdkey.go | 196 + yuanrong/pkg/common/etcdkey/etcdkey_test.go | 264 + .../pkg/common/faas_common/alarm/config.go | 32 + .../pkg/common/faas_common/alarm/logalarm.go | 242 + .../common/faas_common/alarm/logalarm_test.go | 85 + .../common/faas_common/aliasroute/alias.go | 503 ++ .../faas_common/aliasroute/alias_test.go | 515 ++ .../common/faas_common/aliasroute/event.go | 38 + .../faas_common/aliasroute/expression.go | 101 + .../faas_common/aliasroute/expression_test.go | 172 + .../common/faas_common/autogc/algorithm.go | 60 + .../faas_common/autogc/algorithm_test.go | 60 + .../pkg/common/faas_common/autogc/autogc.go | 108 + .../common/faas_common/autogc/autogc_test.go | 49 + .../pkg/common/faas_common/autogc/util.go | 75 + .../common/faas_common/autogc/util_test.go | 44 + .../pkg/common/faas_common/config/config.go | 25 + .../pkg/common/faas_common/constant/app.go | 72 + .../common/faas_common/constant/constant.go | 524 ++ .../common/faas_common/constant/delegate.go | 96 + .../faas_common/constant/functiongraph.go | 156 + .../common/faas_common/constant/wisecloud.go | 42 + .../faas_common/crypto/cryptoapi_mock.go | 42 + .../common/faas_common/crypto/scc_crypto.go | 166 + .../faas_common/crypto/scc_crypto_test.go | 85 + .../pkg/common/faas_common/etcd3/cache.go | 408 + .../common/faas_common/etcd3/cache_test.go | 362 + .../pkg/common/faas_common/etcd3/client.go | 485 ++ .../common/faas_common/etcd3/client_test.go | 363 + .../pkg/common/faas_common/etcd3/config.go | 278 + .../common/faas_common/etcd3/config_test.go | 469 ++ .../pkg/common/faas_common/etcd3/event.go | 106 + .../common/faas_common/etcd3/event_test.go | 73 + .../faas_common/etcd3/instance_register.go | 147 + .../etcd3/instance_register_test.go | 267 + .../pkg/common/faas_common/etcd3/lease.go | 55 + .../common/faas_common/etcd3/lease_test.go | 69 + yuanrong/pkg/common/faas_common/etcd3/lock.go | 282 + .../pkg/common/faas_common/etcd3/lock_test.go | 341 + yuanrong/pkg/common/faas_common/etcd3/type.go | 109 + .../pkg/common/faas_common/etcd3/utils.go | 170 + .../common/faas_common/etcd3/utils_test.go | 155 + .../pkg/common/faas_common/etcd3/watcher.go | 357 + .../common/faas_common/etcd3/watcher_test.go | 456 ++ .../pkg/common/faas_common/instance/util.go | 58 + .../common/faas_common/instance/util_test.go | 68 + .../common/faas_common/instanceconfig/util.go | 122 + .../faas_common/instanceconfig/util_test.go | 68 + .../pkg/common/faas_common/k8sclient/tools.go | 311 + .../faas_common/k8sclient/tools_test.go | 524 ++ .../kernelrpc/connection/connection.go | 41 + .../kernelrpc/connection/stream_connection.go | 450 ++ .../connection/stream_connection_test.go | 410 + .../rpcclient/basic_stream_client.go | 298 + .../rpcclient/basic_stream_client_test.go | 75 + .../faas_common/kernelrpc/rpcclient/client.go | 167 + .../kernelrpc/rpcclient/client_test.go | 110 + .../faas_common/kernelrpc/rpcserver/server.go | 49 + .../rpcserver/simplified_stream_server.go | 217 + .../simplified_stream_server_test.go | 116 + .../faas_common/kernelrpc/utils/utils.go | 27 + .../faas_common/kernelrpc/utils/utils_test.go | 30 + .../common/faas_common/loadbalance/hash.go | 454 ++ .../faas_common/loadbalance/hash_test.go | 83 + .../faas_common/loadbalance/hashcache.go | 43 + .../faas_common/loadbalance/loadbalance.go | 68 + .../loadbalance/loadbalance_test.go | 238 + .../loadbalance/nolockconsistenthash.go | 126 + .../loadbalance/nolockconsistenthash_test.go | 96 + .../faas_common/loadbalance/roundrobin.go | 129 + .../loadbalance/roundrobin_test.go | 95 + .../common/faas_common/localauth/authcache.go | 195 + .../faas_common/localauth/authcache_test.go | 220 + .../common/faas_common/localauth/authcheck.go | 407 + .../faas_common/localauth/authcheck_test.go | 292 + .../common/faas_common/localauth/crypto.go | 98 + .../faas_common/localauth/crypto_test.go | 52 + .../pkg/common/faas_common/localauth/env.go | 36 + .../common/faas_common/localauth/env_test.go | 73 + .../common/faas_common/logger/async/writer.go | 176 + .../faas_common/logger/async/writer_test.go | 125 + .../faas_common/logger/config/config.go | 121 + .../faas_common/logger/config/config_test.go | 315 + .../faas_common/logger/custom_encoder.go | 391 + .../faas_common/logger/custom_encoder_test.go | 113 + .../faas_common/logger/healthlog/healthlog.go | 46 + .../logger/healthlog/healthlog_test.go | 49 + .../faas_common/logger/interface_encoder.go | 346 + .../logger/interface_encoder_test.go | 122 + .../faas_common/logger/interfacelogger.go | 100 + .../logger/interfacelogger_test.go | 57 + .../common/faas_common/logger/log/logger.go | 263 + .../faas_common/logger/log/logger_test.go | 98 + .../common/faas_common/logger/rollinglog.go | 276 + .../faas_common/logger/rollinglog_test.go | 147 + .../common/faas_common/logger/zap/zaplog.go | 160 + .../faas_common/logger/zap/zaplog_test.go | 183 + .../faas_common/monitor/defaultfilewatcher.go | 176 + .../monitor/defaultfilewatcher_test.go | 90 + .../common/faas_common/monitor/filewatcher.go | 77 + .../faas_common/monitor/filewatcher_test.go | 116 + .../pkg/common/faas_common/monitor/memory.go | 353 + .../common/faas_common/monitor/memory_test.go | 164 + .../faas_common/monitor/mockfilewatcher.go | 42 + .../monitor/mockfilewatcher_test.go | 51 + .../pkg/common/faas_common/monitor/parser.go | 110 + .../common/faas_common/monitor/parser_test.go | 95 + .../pkg/common/faas_common/queue/fifoqueue.go | 152 + .../faas_common/queue/fifoqueue_test.go | 62 + .../common/faas_common/queue/priorityqueue.go | 410 + .../faas_common/queue/priorityqueue_test.go | 274 + .../pkg/common/faas_common/queue/queue.go | 50 + .../faas_common/redisclient/redisclient.go | 510 ++ .../redisclient/redisclient_test.go | 457 ++ .../pkg/common/faas_common/resspeckey/type.go | 120 + .../pkg/common/faas_common/resspeckey/util.go | 103 + .../faas_common/resspeckey/util_test.go | 52 + .../pkg/common/faas_common/signals/signal.go | 55 + .../common/faas_common/signals/signal_test.go | 43 + .../pkg/common/faas_common/snerror/snerror.go | 86 + .../faas_common/snerror/snerror_test.go | 40 + .../pkg/common/faas_common/state/observer.go | 109 + .../common/faas_common/state/observer_test.go | 72 + .../faas_common/statuscode/statuscode.go | 490 ++ .../faas_common/statuscode/statuscode_test.go | 86 + .../pkg/common/faas_common/sts/cert/cert.go | 107 + .../common/faas_common/sts/cert/cert_test.go | 238 + yuanrong/pkg/common/faas_common/sts/common.go | 214 + .../pkg/common/faas_common/sts/common_test.go | 118 + .../pkg/common/faas_common/sts/raw/crypto.go | 84 + .../common/faas_common/sts/raw/crypto_test.go | 64 + .../pkg/common/faas_common/sts/raw/raw.go | 43 + yuanrong/pkg/common/faas_common/sts/sts.go | 88 + .../faas_common/timewheel/simpletimewheel.go | 242 + .../timewheel/simpletimewheel_test.go | 192 + .../common/faas_common/timewheel/timewheel.go | 38 + yuanrong/pkg/common/faas_common/tls/https.go | 401 + .../pkg/common/faas_common/tls/https_test.go | 256 + yuanrong/pkg/common/faas_common/tls/option.go | 113 + .../pkg/common/faas_common/tls/option_test.go | 146 + yuanrong/pkg/common/faas_common/tls/tls.go | 74 + .../pkg/common/faas_common/tls/tls_test.go | 79 + .../faas_common/trafficlimit/trafficlimit.go | 107 + .../trafficlimit/trafficlimit_test.go | 42 + .../pkg/common/faas_common/types/serve.go | 236 + .../common/faas_common/types/serve_test.go | 198 + .../pkg/common/faas_common/types/types.go | 895 +++ .../common/faas_common/urnutils/gadgets.go | 98 + .../faas_common/urnutils/gadgets_test.go | 152 + .../common/faas_common/urnutils/urn_utils.go | 561 ++ .../faas_common/urnutils/urn_utils_test.go | 475 ++ .../common/faas_common/urnutils/urnconv.go | 56 + .../faas_common/urnutils/urnconv_test.go | 48 + .../faas_common/utils/component_util.go | 52 + .../pkg/common/faas_common/utils/file_test.go | 53 + .../faas_common/utils/func_meta_util.go | 193 + .../faas_common/utils/func_meta_util_test.go | 87 + .../pkg/common/faas_common/utils/helper.go | 170 + .../common/faas_common/utils/helper_test.go | 196 + .../faas_common/utils/libruntimeapi_mock.go | 305 + .../utils/libruntimeapi_mock_test.go | 143 + .../common/faas_common/utils/memory_test.go | 22 + .../common/faas_common/utils/mock_utils.go | 138 + .../common/faas_common/utils/resourcepath.go | 59 + .../faas_common/utils/scheduler_option.go | 108 + .../utils/scheduler_option_test.go | 97 + .../pkg/common/faas_common/utils/tools.go | 656 ++ .../common/faas_common/utils/tools_test.go | 649 ++ .../faas_common/wisecloudtool/pod_operator.go | 185 + .../wisecloudtool/pod_operator_test.go | 116 + .../wisecloudtool/prometheus_metrics.go | 285 + .../wisecloudtool/prometheus_metrics_test.go | 313 + .../wisecloudtool/serviceaccount/jwtsign.go | 203 + .../serviceaccount/jwtsign_test.go | 19 + .../wisecloudtool/serviceaccount/parse.go | 79 + .../serviceaccount/parse_test.go | 75 + .../wisecloudtool/serviceaccount/token.go | 80 + .../faas_common/wisecloudtool/types/types.go | 74 + yuanrong/pkg/common/go.mod | 156 + .../pkg/common/httputil/config/adminconfig.go | 32 + .../pkg/common/httputil/http/client/client.go | 116 + .../httputil/http/client/fast/client.go | 188 + .../httputil/http/client/fast/client_test.go | 56 + yuanrong/pkg/common/httputil/http/const.go | 32 + yuanrong/pkg/common/httputil/http/type.go | 35 + yuanrong/pkg/common/httputil/utils/file.go | 52 + .../pkg/common/httputil/utils/file_test.go | 44 + yuanrong/pkg/common/httputil/utils/utils.go | 33 + .../pkg/common/httputil/utils/utils_test.go | 48 + yuanrong/pkg/common/job/config.go | 63 + yuanrong/pkg/common/job/handler.go | 297 + yuanrong/pkg/common/job/handler_test.go | 587 ++ yuanrong/pkg/common/protobuf/adaptor.proto | 65 + yuanrong/pkg/common/protobuf/bus.proto | 132 + .../pkg/common/protobuf/callMessage.proto | 43 + yuanrong/pkg/common/protobuf/deadlock.proto | 55 + yuanrong/pkg/common/protobuf/error.proto | 26 + yuanrong/pkg/common/protobuf/filter.proto | 34 + yuanrong/pkg/common/protobuf/get.proto | 50 + .../protobuf/health/health_service.proto | 34 + yuanrong/pkg/common/protobuf/invoke.proto | 84 + yuanrong/pkg/common/protobuf/readstate.proto | 36 + .../pkg/common/protobuf/rpc/bus_service.proto | 58 + yuanrong/pkg/common/protobuf/rpc/common.proto | 53 + .../common/protobuf/rpc/core_service.proto | 158 + .../common/protobuf/rpc/inner_service.proto | 119 + .../pkg/common/protobuf/rpc/runtime_rpc.proto | 112 + .../common/protobuf/rpc/runtime_service.proto | 134 + yuanrong/pkg/common/protobuf/savestate.proto | 34 + .../scheduler/domainscheduler_service.proto | 78 + .../scheduler/globalscheduler_service.proto | 50 + .../scheduler/localscheduler_service.proto | 43 + .../protobuf/scheduler/scheduler_common.proto | 109 + .../scheduler/worker_agent_service.proto | 178 + yuanrong/pkg/common/protobuf/settimeout.proto | 37 + yuanrong/pkg/common/protobuf/specialize.proto | 41 + yuanrong/pkg/common/protobuf/terminate.proto | 36 + yuanrong/pkg/common/protobuf/wait.proto | 40 + yuanrong/pkg/common/reader/reader.go | 69 + yuanrong/pkg/common/reader/reader_test.go | 63 + yuanrong/pkg/common/tls/https.go | 369 + yuanrong/pkg/common/tls/https_test.go | 312 + yuanrong/pkg/common/tls/option.go | 238 + yuanrong/pkg/common/tls/option_scc.go | 93 + yuanrong/pkg/common/tls/option_scc_fake.go | 86 + yuanrong/pkg/common/tls/option_test.go | 157 + yuanrong/pkg/common/tls/tls.go | 130 + yuanrong/pkg/common/uuid/uuid.go | 153 + yuanrong/pkg/common/uuid/uuid_test.go | 132 + yuanrong/pkg/dashboard/client/index.html | 29 + yuanrong/pkg/dashboard/client/package.json | 37 + yuanrong/pkg/dashboard/client/public/logo.png | Bin 0 -> 1691 bytes yuanrong/pkg/dashboard/client/src/api/api.ts | 50 + .../pkg/dashboard/client/src/api/index.ts | 34 + .../src/components/breadcrumb-component.vue | 42 + .../client/src/components/chart-config.ts | 36 + .../client/src/components/common-card.vue | 52 + .../src/components/log-content-template.vue | 134 + .../src/components/progress-bar-template.ts | 81 + .../client/src/components/warning-notify.ts | 34 + .../pkg/dashboard/client/src/i18n/index.ts | 34 + yuanrong/pkg/dashboard/client/src/index.css | 67 + yuanrong/pkg/dashboard/client/src/main.ts | 82 + .../src/pages/cluster/cluster-chart.vue | 304 + .../src/pages/cluster/cluster-layout.vue | 28 + .../components/empty-log-card.vue | 37 + .../components/instance-info.vue | 185 + .../instance-details-layout.vue | 62 + .../src/pages/instances/instances-chart.vue | 189 + .../pages/job-details/components/job-info.vue | 165 + .../pages/job-details/job-details-layout.vue | 60 + .../client/src/pages/jobs/jobs-chart.vue | 122 + .../pkg/dashboard/client/src/pages/layout.vue | 81 + .../src/pages/log-pages/logs-content.vue | 30 + .../client/src/pages/log-pages/logs-files.vue | 80 + .../client/src/pages/log-pages/logs-nodes.vue | 78 + .../overview/components/cluster-card.vue | 70 + .../overview/components/instances-card.vue | 74 + .../overview/components/resources-card.vue | 145 + .../src/pages/overview/overview-layout.vue | 45 + .../pkg/dashboard/client/src/types/api.d.ts | 150 + .../dashboard/client/src/utils/dayFormat.ts | 22 + .../dashboard/client/src/utils/handleNum.ts | 27 + .../pkg/dashboard/client/src/utils/sort.ts | 25 + .../pkg/dashboard/client/src/utils/swr.ts | 34 + .../components/progress-bar-template.spec.ts | 34 + .../pkg/dashboard/client/tests/main.spec.ts | 24 + .../tests/pages/cluster/cluster-chart.spec.ts | 100 + .../pages/cluster/cluster-layout.spec.ts | 42 + .../components/empty-log-card.spec.ts | 34 + .../components/instance-info.spec.ts | 78 + .../instance-details-layout.spec.ts | 63 + .../pages/instances/instances-chart.spec.ts | 108 + .../job-details/components/job-info.spec.ts | 84 + .../job-details/job-details-layout.spec.ts | 61 + .../tests/pages/jobs/jobs-chart.spec.ts | 48 + .../client/tests/pages/layout.spec.ts | 42 + .../pages/log-pages/logs-content.spec.ts | 35 + .../tests/pages/log-pages/logs-files.spec.ts | 59 + .../tests/pages/log-pages/logs-nodes.spec.ts | 69 + .../overview/components/cluster-card.spec.ts | 78 + .../components/instances-card.spec.ts | 78 + .../components/resources-card.spec.ts | 43 + .../pages/overview/overview-layout.spec.ts | 33 + yuanrong/pkg/dashboard/client/vite.config.js | 78 + .../pkg/dashboard/etcdcache/instance_cache.go | 211 + .../etcdcache/instance_cache_test.go | 49 + yuanrong/pkg/dashboard/flags/flags.go | 229 + yuanrong/pkg/dashboard/flags/flags_test.go | 114 + yuanrong/pkg/dashboard/getinfo/client_pool.go | 95 + .../pkg/dashboard/getinfo/frontend_app.go | 79 + .../pkg/dashboard/getinfo/get_instances.go | 39 + .../pkg/dashboard/getinfo/get_resources.go | 41 + .../handlers/cluster_status_handler.go | 121 + .../handlers/cluster_status_handler_test.go | 165 + .../components_componentid_handler.go | 37 + .../components_componentid_handler_test.go | 44 + .../dashboard/handlers/components_handler.go | 132 + .../handlers/components_handler_test.go | 112 + yuanrong/pkg/dashboard/handlers/err_code.go | 31 + .../dashboard/handlers/instances_handler.go | 146 + .../handlers/instances_handler_test.go | 101 + .../handlers/instances_instanceid_handler.go | 36 + .../instances_instanceid_handler_test.go | 56 + .../handlers/instances_parentid_handler.go | 79 + .../instances_parentid_handler_test.go | 73 + .../handlers/instances_summary_handler.go | 85 + .../instances_summary_handler_test.go | 98 + .../pkg/dashboard/handlers/job_handler.go | 99 + .../dashboard/handlers/job_handler_test.go | 363 + .../dashboard/handlers/prometheus_handler.go | 62 + .../handlers/prometheus_handler_test.go | 51 + .../dashboard/handlers/resources_handler.go | 36 + .../handlers/resources_handler_test.go | 44 + .../handlers/resources_summary_handler.go | 129 + .../resources_summary_handler_test.go | 161 + .../handlers/resources_unitid_handler.go | 37 + .../handlers/resources_unitid_handler_test.go | 44 + .../pkg/dashboard/handlers/serve_handler.go | 328 + .../dashboard/handlers/serve_handler_test.go | 163 + .../dashboard/logmanager/collector_client.go | 104 + .../logmanager/collector_client_test.go | 194 + .../pkg/dashboard/logmanager/http_handlers.go | 188 + .../logmanager/http_handlers_test.go | 254 + yuanrong/pkg/dashboard/logmanager/log_db.go | 135 + .../pkg/dashboard/logmanager/log_db_test.go | 163 + .../pkg/dashboard/logmanager/log_entry.go | 166 + .../dashboard/logmanager/log_entry_test.go | 150 + .../pkg/dashboard/logmanager/log_index.go | 75 + .../dashboard/logmanager/log_index_test.go | 136 + .../pkg/dashboard/logmanager/log_manager.go | 152 + .../dashboard/logmanager/log_manager_test.go | 135 + yuanrong/pkg/dashboard/logmanager/service.go | 56 + .../pkg/dashboard/logmanager/service_test.go | 109 + .../pkg/dashboard/models/common_response.go | 31 + .../pkg/dashboard/models/serve_api_models.go | 58 + yuanrong/pkg/dashboard/routers/cors.go | 40 + yuanrong/pkg/dashboard/routers/router.go | 107 + yuanrong/pkg/dashboard/routers/router_test.go | 45 + yuanrong/pkg/functionmanager/config/config.go | 80 + .../pkg/functionmanager/config/config_test.go | 133 + .../pkg/functionmanager/constant/constant.go | 25 + yuanrong/pkg/functionmanager/faasmanager.go | 720 ++ .../pkg/functionmanager/faasmanager_test.go | 945 +++ yuanrong/pkg/functionmanager/queue.go | 120 + yuanrong/pkg/functionmanager/state/state.go | 133 + .../pkg/functionmanager/state/state_test.go | 146 + yuanrong/pkg/functionmanager/types/types.go | 135 + yuanrong/pkg/functionmanager/utils/utils.go | 123 + .../pkg/functionmanager/utils/utils_test.go | 212 + .../pkg/functionmanager/vpcmanager/plugin.go | 157 + .../functionmanager/vpcmanager/plugin_test.go | 330 + .../functionmanager/vpcmanager/pulltrigger.go | 400 + .../vpcmanager/pulltrigger_test.go | 732 ++ .../pkg/functionmanager/vpcmanager/types.go | 50 + .../vpcmanager/volumeconstant.go | 119 + yuanrong/pkg/functionscaler/config/config.go | 218 + .../pkg/functionscaler/config/config_test.go | 250 + .../functionscaler/config/hotload_config.go | 119 + .../config/hotload_config_test.go | 225 + .../dynamicconfigmanager.go | 113 + .../dynamicconfigmanager_test.go | 40 + yuanrong/pkg/functionscaler/faasscheduler.go | 1151 +++ .../pkg/functionscaler/faasscheduler_test.go | 1794 +++++ .../functionscaler/healthcheck/healthcheck.go | 127 + .../healthcheck/healthcheck_test.go | 158 + .../functionscaler/httpserver/httpserver.go | 177 + .../httpserver/httpserver_test.go | 272 + .../instancepool/componentset.go | 62 + .../instancepool/container_tool.go | 77 + .../instance_operation_adapter.go | 98 + .../instance_operation_adapter_test.go | 203 + .../instancepool/instance_operation_fg.go | 90 + .../instance_operation_fg_test.go | 129 + .../instancepool/instance_operation_kernel.go | 1691 +++++ .../instance_operation_kernel_test.go | 1828 +++++ .../instancepool/instancepool.go | 1509 ++++ .../instancepool/instancepool_test.go | 1351 ++++ .../pkg/functionscaler/instancepool/log.go | 106 + .../functionscaler/instancepool/log_test.go | 66 + .../instancepool/min_instance_alarm.go | 197 + .../instancepool/miscellaneous.go | 310 + .../instancepool/miscellaneous_test.go | 175 + .../instancepool/operatekube.go | 98 + .../instancepool/operatekube_test.go | 75 + .../instancepool/poolmanager.go | 573 ++ .../instancepool/poolmanager_test.go | 1181 +++ .../instancepool/rasp_sidecar.go | 162 + .../functionscaler/instancepool/stateroute.go | 382 + .../instancepool/stateroute_test.go | 334 + .../instancequeue/create_request_queue.go | 127 + .../create_request_queue_test.go | 41 + .../instancequeue/instance_queue.go | 97 + .../instancequeue/instance_queue_builder.go | 246 + .../instance_queue_builder_test.go | 204 + .../instancequeue/instance_queue_test.go | 67 + .../instancequeue/ondemand_instance_queue.go | 200 + .../ondemand_instance_queue_test.go | 183 + .../instancequeue/scaled_instance_queue.go | 509 ++ .../scaled_instance_queue_test.go | 755 ++ .../lease/generic_lease_manager.go | 271 + .../lease/generic_lease_manager_test.go | 128 + yuanrong/pkg/functionscaler/lease/lease.go | 32 + .../functionscaler/registry/agentregistry.go | 195 + .../registry/agentregistry_test.go | 167 + .../functionscaler/registry/aliasregistry.go | 111 + .../registry/faasfrontendregistry.go | 157 + .../registry/faasmanagerregistry.go | 129 + .../registry/faasschedulerregistry.go | 415 + .../registry/faasschedulerregistry_test.go | 252 + .../registry/functionavailableregistry.go | 121 + .../registry/functionregistry.go | 274 + .../registry/instanceconfigregistry.go | 129 + .../registry/instanceregistry.go | 284 + .../registry/instanceregistry_fg.go | 172 + .../registry/instanceregistry_fg_test.go | 130 + .../registry/instanceregistry_test.go | 36 + .../pkg/functionscaler/registry/registry.go | 311 + .../functionscaler/registry/registry_test.go | 1792 +++++ .../registry/rolloutregistry.go | 189 + .../registry/rolloutregistry_test.go | 161 + .../registry/tenantquotaregistry.go | 188 + .../registry/useragencyregistry.go | 126 + .../requestqueue/instance_request_queue.go | 372 + .../instance_request_queue_test.go | 291 + .../functionscaler/rollout/rollouthandler.go | 218 + .../rollout/rollouthandler_test.go | 142 + .../pkg/functionscaler/scaler/autoscaler.go | 394 + .../functionscaler/scaler/autoscaler_test.go | 117 + .../functionscaler/scaler/instance_scaler.go | 54 + .../scaler/instance_scaler_test.go | 563 ++ .../functionscaler/scaler/predictscaler.go | 589 ++ .../functionscaler/scaler/replicascaler.go | 245 + .../scaler/replicascaler_test.go | 81 + .../functionscaler/scaler/wisecloudscaler.go | 227 + .../basic_concurrency_scheduler.go | 1509 ++++ .../basic_concurrency_scheduler_test.go | 1375 ++++ .../grayinstanceallocator.go | 215 + .../reserved_concurrency_scheduler.go | 213 + .../reserved_concurrency_scheduler_test.go | 241 + .../scaled_concurrency_scheduler.go | 393 + .../scaled_concurrency_scheduler_test.go | 274 + .../scheduler/instance_scheduler.go | 89 + .../microservice_scheduler.go | 302 + .../microservice_scheduler_test.go | 307 + .../roundrobin_scheduler.go | 412 + .../roundrobin_scheduler_test.go | 440 ++ .../pkg/functionscaler/selfregister/proxy.go | 177 + .../functionscaler/selfregister/proxy_test.go | 137 + .../selfregister/rolloutregister.go | 265 + .../selfregister/rolloutregister_test.go | 255 + .../selfregister/selfregister.go | 258 + .../selfregister/selfregister_test.go | 253 + .../signalmanager/signalmanager.go | 247 + .../signalmanager/signalmanager_test.go | 317 + yuanrong/pkg/functionscaler/state/state.go | 236 + .../pkg/functionscaler/state/state_test.go | 192 + .../pkg/functionscaler/stateinstance/lease.go | 161 + .../stateinstance/lease_test.go | 122 + .../pkg/functionscaler/sts/sensitiveconfig.go | 114 + .../sts/sensitiveconfig_test.go | 76 + yuanrong/pkg/functionscaler/sts/sts.go | 81 + yuanrong/pkg/functionscaler/sts/sts_test.go | 34 + yuanrong/pkg/functionscaler/sts/types.go | 127 + .../functionscaler/tenantquota/tenantcache.go | 164 + .../tenantquota/tenantcache_test.go | 149 + .../functionscaler/tenantquota/tenantetcd.go | 181 + .../tenantquota/tenantetcd_test.go | 173 + .../pkg/functionscaler/types/constants.go | 137 + yuanrong/pkg/functionscaler/types/types.go | 691 ++ .../pkg/functionscaler/types/types_test.go | 99 + .../functionscaler/utils/configmap_util.go | 81 + .../utils/configmap_util_test.go | 73 + yuanrong/pkg/functionscaler/utils/utils.go | 402 + .../pkg/functionscaler/utils/utils_test.go | 274 + .../pkg/functionscaler/workermanager/lease.go | 113 + .../workermanager/lease_test.go | 118 + .../workermanager/workermanager_client.go | 105 + .../workermanager_client_test.go | 114 + .../workermanager/workermanager_request.go | 329 + .../workermanager_request_test.go | 509 ++ .../config/config.go | 215 + .../config/config_test.go | 308 + .../constant/constant.go | 88 + .../faascontroller/fasscontroller.go | 277 + .../faascontroller/fasscontroller_test.go | 356 + .../faasfrontendmanager.go | 855 +++ .../frontendmanager_test.go | 917 +++ .../faasfunctionmanager.go | 763 ++ .../faasfunctionmanager_test.go | 905 +++ .../faasschedulermanager.go | 882 +++ .../schedulermanager_test.go | 780 ++ .../instancemanager/instancemanager.go | 116 + .../instancemanager/instancemanager_test.go | 337 + .../registry/faasfrontendregistry.go | 160 + .../registry/faasmanagerregistry.go | 160 + .../registry/faasschedulerregistry.go | 160 + .../registry/registry.go | 88 + .../registry/registry_test.go | 711 ++ .../service/frontendservice.go | 80 + .../service/frontendservice_test.go | 28 + .../system_function_controller/state/state.go | 178 + .../state/state_test.go | 130 + .../system_function_controller/types/types.go | 282 + .../system_function_controller/utils/utils.go | 139 + .../utils/utils_test.go | 92 + yuanrong/proto/CMakeLists.txt | 27 + yuanrong/proto/pb/message_pb.h | 177 + yuanrong/proto/pb/posix_pb.h | 102 + yuanrong/proto/posix/affinity.proto | 98 + yuanrong/proto/posix/bus_adapter.proto | 101 + yuanrong/proto/posix/bus_service.proto | 62 + yuanrong/proto/posix/common.proto | 249 + yuanrong/proto/posix/core_service.proto | 200 + yuanrong/proto/posix/inner_service.proto | 118 + yuanrong/proto/posix/log_service.proto | 105 + yuanrong/proto/posix/message.proto | 978 +++ yuanrong/proto/posix/resource.proto | 567 ++ yuanrong/proto/posix/runtime_rpc.proto | 119 + yuanrong/proto/posix/runtime_service.proto | 111 + yuanrong/test/collector/test.sh | 59 + yuanrong/test/common/test.sh | 50 + yuanrong/test/dashboard/test.sh | 53 + yuanrong/test/test.sh | 87 + 1296 files changed, 196690 insertions(+), 2708 deletions(-) create mode 100644 api/cpp/example/faas_example.cpp create mode 100644 api/cpp/example/runtime_env_example.cpp create mode 100644 api/cpp/example/runtime_env_example1.cpp create mode 100644 api/cpp/example/stream_example.cpp create mode 100644 api/cpp/include/faas/Constant.h create mode 100644 api/cpp/include/faas/Context.h create mode 100644 api/cpp/include/faas/Function.h create mode 100644 api/cpp/include/faas/FunctionError.h create mode 100644 api/cpp/include/faas/FunctionLogger.h create mode 100644 api/cpp/include/faas/ObjectRef.h create mode 100644 api/cpp/include/faas/Runtime.h create mode 100644 api/cpp/include/faas/RuntimeHandler.h create mode 100644 api/cpp/include/yr/api/kv.h create mode 100644 api/cpp/include/yr/api/mutable_buffer.h create mode 100644 api/cpp/include/yr/api/node.h create mode 100644 api/cpp/include/yr/api/runtime_env.h create mode 100644 api/cpp/include/yr/api/stream.h create mode 100644 api/cpp/include/yr/api/yr_core.h create mode 100644 api/cpp/src/datasystem_buffer.cpp create mode 100644 api/cpp/src/datasystem_buffer.h create mode 100644 api/cpp/src/faas/context_env.cpp create mode 100644 api/cpp/src/faas/context_env.h create mode 100644 api/cpp/src/faas/context_impl.cpp create mode 100644 api/cpp/src/faas/context_impl.h create mode 100644 api/cpp/src/faas/context_invoke_params.cpp create mode 100644 api/cpp/src/faas/context_invoke_params.h create mode 100644 api/cpp/src/faas/faas_executor.cpp create mode 100644 api/cpp/src/faas/faas_executor.h create mode 100644 api/cpp/src/faas/function.cpp create mode 100644 api/cpp/src/faas/function_error.cpp create mode 100644 api/cpp/src/faas/function_logger.cpp create mode 100644 api/cpp/src/faas/object_ref.cpp create mode 100644 api/cpp/src/faas/register_runtime_handler.cpp create mode 100644 api/cpp/src/faas/register_runtime_handler.h create mode 100644 api/cpp/src/faas/runtime.cpp create mode 100644 api/cpp/src/mutable_buffer.cpp create mode 100644 api/cpp/src/runtime_env.cpp create mode 100644 api/cpp/src/runtime_env_parse.cpp create mode 100644 api/cpp/src/runtime_env_parse.h create mode 100644 api/cpp/src/stream_pubsub.cpp create mode 100644 api/cpp/src/stream_pubsub.h create mode 100644 api/go/BUILD.bazel create mode 100644 api/go/README.md create mode 100644 api/go/example/actor_example.go create mode 100644 api/go/example/actor_example_exception.go create mode 100644 api/go/example/entrance_example.go create mode 100644 api/go/example/kv_example.go create mode 100644 api/go/example/put_get_wait_example.go create mode 100644 api/go/example/stream_example.go create mode 100644 api/go/example/task_example.go create mode 100644 api/go/example/test_exception.go create mode 100644 api/go/faassdk/common/alarm/logalarm.go create mode 100644 api/go/faassdk/common/alarm/logalarm_test.go create mode 100644 api/go/faassdk/common/aliasroute/alias.go create mode 100644 api/go/faassdk/common/aliasroute/alias_test.go create mode 100644 api/go/faassdk/common/aliasroute/expression.go create mode 100644 api/go/faassdk/common/aliasroute/expression_test.go create mode 100644 api/go/faassdk/common/constants/constants.go create mode 100644 api/go/faassdk/common/constants/error.go create mode 100644 api/go/faassdk/common/faasscheduler/proxy.go create mode 100644 api/go/faassdk/common/faasscheduler/proxy_test.go create mode 100644 api/go/faassdk/common/functionlog/function_log.go create mode 100644 api/go/faassdk/common/functionlog/function_log_test.go create mode 100644 api/go/faassdk/common/functionlog/log_fields.go create mode 100644 api/go/faassdk/common/functionlog/log_fields_test.go create mode 100644 api/go/faassdk/common/functionlog/log_recoder.go create mode 100644 api/go/faassdk/common/functionlog/log_recoder_test.go create mode 100644 api/go/faassdk/common/functionlog/log_recorder_option.go create mode 100644 api/go/faassdk/common/functionlog/log_recorder_option_test.go create mode 100644 api/go/faassdk/common/functionlog/std_log.go create mode 100644 api/go/faassdk/common/functionlog/std_log_test.go create mode 100644 api/go/faassdk/common/functionlog/write.go create mode 100644 api/go/faassdk/common/loadbalance/hash.go create mode 100644 api/go/faassdk/common/loadbalance/hash_test.go create mode 100644 api/go/faassdk/common/loadbalance/hashcache.go create mode 100644 api/go/faassdk/common/loadbalance/loadbalance.go create mode 100644 api/go/faassdk/common/loadbalance/loadbalance_test.go create mode 100644 api/go/faassdk/common/loadbalance/nolockconsistenthash.go create mode 100644 api/go/faassdk/common/loadbalance/nolockconsistenthash_test.go create mode 100644 api/go/faassdk/common/loadbalance/roundrobin.go create mode 100644 api/go/faassdk/common/loadbalance/roundrobin_test.go create mode 100644 api/go/faassdk/common/monitor/monitor_manager.go create mode 100644 api/go/faassdk/common/monitor/monitor_manager_test.go create mode 100644 api/go/faassdk/common/monitor/oom/memory_manager.go create mode 100644 api/go/faassdk/common/monitor/oom/memory_manager_test.go create mode 100644 api/go/faassdk/common/monitor/oom/parser.go create mode 100644 api/go/faassdk/common/tokentosecret/secret_mgr.go create mode 100644 api/go/faassdk/common/tokentosecret/secret_mgr_test.go create mode 100644 api/go/faassdk/config/config.go create mode 100644 api/go/faassdk/config/config_test.go create mode 100644 api/go/faassdk/entrance.go create mode 100644 api/go/faassdk/entrance_test.go create mode 100644 api/go/faassdk/faas-sdk/go-api/context/client_context.go create mode 100644 api/go/faassdk/faas-sdk/go-api/function/function.go create mode 100644 api/go/faassdk/faas-sdk/pkg/runtime/context/context.go create mode 100644 api/go/faassdk/faas-sdk/pkg/runtime/context/contextenv.go create mode 100644 api/go/faassdk/faas-sdk/pkg/runtime/context/contextenv_test.go create mode 100644 api/go/faassdk/faas-sdk/pkg/runtime/context/runtime_context.go create mode 100644 api/go/faassdk/faas-sdk/pkg/runtime/context/runtime_context_test.go create mode 100644 api/go/faassdk/faas-sdk/pkg/runtime/userlog/user_logger.go create mode 100644 api/go/faassdk/faashandler.go create mode 100644 api/go/faassdk/faashandler_test.go create mode 100644 api/go/faassdk/handler/event/future.go create mode 100644 api/go/faassdk/handler/event/future_test.go create mode 100644 api/go/faassdk/handler/event/handler.go create mode 100644 api/go/faassdk/handler/event/handler_test.go create mode 100644 api/go/faassdk/handler/handler.go create mode 100644 api/go/faassdk/handler/handler_test.go create mode 100644 api/go/faassdk/handler/http/apig.go create mode 100644 api/go/faassdk/handler/http/apig_test.go create mode 100644 api/go/faassdk/handler/http/basic_handler.go create mode 100644 api/go/faassdk/handler/http/basic_handler_test.go create mode 100644 api/go/faassdk/handler/http/crossclusterinvoke/httpclient.go create mode 100644 api/go/faassdk/handler/http/crossclusterinvoke/invoker.go create mode 100644 api/go/faassdk/handler/http/crossclusterinvoke/invoker_test.go create mode 100644 api/go/faassdk/handler/http/custom_container_handler.go create mode 100644 api/go/faassdk/handler/http/custom_container_handler_test.go create mode 100644 api/go/faassdk/handler/http/custom_container_log.go create mode 100644 api/go/faassdk/handler/http/custom_container_log_test.go create mode 100644 api/go/faassdk/handler/http/http_handler.go create mode 100644 api/go/faassdk/handler/http/http_handler_test.go create mode 100644 api/go/faassdk/handler/http/state.go create mode 100644 api/go/faassdk/handler/http/state_test.go create mode 100644 api/go/faassdk/handler/mock_utils_test.go create mode 100644 api/go/faassdk/runtime.go create mode 100644 api/go/faassdk/runtime_test.go create mode 100644 api/go/faassdk/sts/sts.go create mode 100644 api/go/faassdk/sts/sts_test.go create mode 100644 api/go/faassdk/types/types.go create mode 100644 api/go/faassdk/types/types_test.go create mode 100644 api/go/faassdk/utils/handle_response.go create mode 100644 api/go/faassdk/utils/handle_response_test.go create mode 100644 api/go/faassdk/utils/signer/sksigner.go create mode 100644 api/go/faassdk/utils/signer/sksigner_test.go create mode 100644 api/go/faassdk/utils/urnutils/gadgets.go create mode 100644 api/go/faassdk/utils/urnutils/gadgets_test.go create mode 100644 api/go/faassdk/utils/urnutils/urn_utils.go create mode 100644 api/go/faassdk/utils/urnutils/urn_utils_test.go create mode 100644 api/go/faassdk/utils/utils.go create mode 100644 api/go/faassdk/utils/utils_test.go create mode 100644 api/go/go.mod create mode 100644 api/go/libruntime/api/api.go create mode 100644 api/go/libruntime/api/types.go create mode 100644 api/go/libruntime/api/types_test.go create mode 100644 api/go/libruntime/clibruntime/clibruntime.go create mode 100644 api/go/libruntime/clibruntime/clibruntime_test.go create mode 100644 api/go/libruntime/common/config.go create mode 100644 api/go/libruntime/common/config_test.go create mode 100644 api/go/libruntime/common/config_types.go create mode 100644 api/go/libruntime/common/constants/constants.go create mode 100644 api/go/libruntime/common/constants/status_code.go create mode 100644 api/go/libruntime/common/faas/logger/logger.go create mode 100644 api/go/libruntime/common/faas/logger/logger_test.go create mode 100644 api/go/libruntime/common/faas/logger/user_logger.go create mode 100644 api/go/libruntime/common/faas/logger/user_logger_test.go create mode 100644 api/go/libruntime/common/logger/async/writer.go create mode 100644 api/go/libruntime/common/logger/async/writer_test.go create mode 100644 api/go/libruntime/common/logger/config/config.go create mode 100644 api/go/libruntime/common/logger/config/config_test.go create mode 100644 api/go/libruntime/common/logger/custom_encoder.go create mode 100644 api/go/libruntime/common/logger/custom_encoder_test.go create mode 100644 api/go/libruntime/common/logger/interface_encoder.go create mode 100644 api/go/libruntime/common/logger/interface_encoder_test.go create mode 100644 api/go/libruntime/common/logger/interfacelogger.go create mode 100644 api/go/libruntime/common/logger/interfacelogger_test.go create mode 100644 api/go/libruntime/common/logger/log/log.go create mode 100644 api/go/libruntime/common/logger/log/log_test.go create mode 100644 api/go/libruntime/common/logger/rollinglog.go create mode 100644 api/go/libruntime/common/logger/rollinglog_test.go create mode 100644 api/go/libruntime/common/logger/zap/zaplog.go create mode 100644 api/go/libruntime/common/logger/zap/zaplog_test.go create mode 100644 api/go/libruntime/common/signal.go create mode 100644 api/go/libruntime/common/token_mgr.go create mode 100644 api/go/libruntime/common/token_mgr_test.go create mode 100644 api/go/libruntime/common/types.go create mode 100644 api/go/libruntime/common/utils/utils.go create mode 100644 api/go/libruntime/common/utils/utils_test.go create mode 100644 api/go/libruntime/common/uuid/uuid.go create mode 100644 api/go/libruntime/common/uuid/uuid_test.go create mode 100644 api/go/libruntime/config/config.go create mode 100644 api/go/libruntime/cpplibruntime/BUILD.bazel create mode 100644 api/go/libruntime/cpplibruntime/clibruntime.h create mode 100644 api/go/libruntime/cpplibruntime/cpplibruntime.cpp create mode 100644 api/go/libruntime/cpplibruntime/mock/mock_cpplibruntime.cpp create mode 100644 api/go/libruntime/execution/execution.go create mode 100644 api/go/libruntime/libruntime.go create mode 100644 api/go/libruntime/libruntime_test.go create mode 100644 api/go/libruntime/libruntimesdkimpl/libruntimesdkimpl.go create mode 100644 api/go/libruntime/libruntimesdkimpl/libruntimesdkimpl_test.go create mode 100644 api/go/libruntime/pool/pool.go create mode 100644 api/go/libruntime/pool/pool_test.go create mode 100644 api/go/posixsdk/posixhandler.go create mode 100644 api/go/posixsdk/posixhandler_test.go create mode 100644 api/go/posixsdk/runtime.go create mode 100644 api/go/posixsdk/runtime_test.go create mode 100644 api/go/runtime/system_function_bootstrap create mode 100644 api/go/runtime/yr_runtime_main.go create mode 100644 api/go/runtime/yr_runtime_main_test.go create mode 100644 api/go/yr/actorhandler.go create mode 100644 api/go/yr/actorhandler_test.go create mode 100644 api/go/yr/cluster_mode_runtime.go create mode 100644 api/go/yr/cluster_mode_runtime_test.go create mode 100644 api/go/yr/config.go create mode 100644 api/go/yr/config_manager.go create mode 100644 api/go/yr/config_manager_test.go create mode 100644 api/go/yr/config_test.go create mode 100644 api/go/yr/function_handler.go create mode 100644 api/go/yr/function_handler_test.go create mode 100644 api/go/yr/instance_creator.go create mode 100644 api/go/yr/instance_creator_test.go create mode 100644 api/go/yr/instance_function_handler.go create mode 100644 api/go/yr/instance_function_handler_test.go create mode 100644 api/go/yr/named_instance.go create mode 100644 api/go/yr/named_instance_test.go create mode 100644 api/go/yr/object_ref.go create mode 100644 api/go/yr/object_ref_test.go create mode 100644 api/go/yr/runtime.go create mode 100644 api/go/yr/runtime_holder.go create mode 100644 api/go/yr/runtime_holder_test.go create mode 100644 api/go/yr/runtime_test.go create mode 100644 api/go/yr/stacktrace.go create mode 100644 api/go/yr/stacktrace_test.go create mode 100644 api/go/yr/stream.go create mode 100644 api/go/yr/stream_test.go create mode 100644 api/go/yr/utils.go create mode 100644 api/go/yr/utils_test.go create mode 100644 api/go/yr/yr.go create mode 100644 api/go/yr/yr_test.go create mode 100644 api/java/example/GoInstanceExample.java create mode 100644 api/java/faas-function-sdk/pom.xml create mode 100644 api/java/faas-function-sdk/src/main/java/com/function/CreateOptions.java create mode 100644 api/java/faas-function-sdk/src/main/java/com/function/Function.java create mode 100644 api/java/faas-function-sdk/src/main/java/com/function/ObjectRef.java create mode 100644 api/java/faas-function-sdk/src/main/java/com/function/common/RspErrorCode.java create mode 100644 api/java/faas-function-sdk/src/main/java/com/function/common/Util.java create mode 100644 api/java/faas-function-sdk/src/main/java/com/function/runtime/exception/InvokeException.java create mode 100644 api/java/faas-function-sdk/src/test/java/com/function/TestCreateOptions.java create mode 100644 api/java/faas-function-sdk/src/test/java/com/function/TestFunction.java create mode 100644 api/java/faas-function-sdk/src/test/java/com/function/TestObjectRef.java create mode 100644 api/java/faas-function-sdk/src/test/java/com/function/common/ContextMock.java create mode 100644 api/java/faas-function-sdk/src/test/java/com/function/common/TestUtil.java create mode 100644 api/java/faas-function-sdk/src/test/java/com/function/runtime/exception/TestInvokeException.java create mode 100644 api/java/function-common/src/main/cpp/com_yuanrong_jni_Consumer.cpp create mode 100644 api/java/function-common/src/main/cpp/com_yuanrong_jni_Consumer.h create mode 100644 api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.cpp create mode 100644 api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.h create mode 100644 api/java/function-common/src/main/java/com/services/enums/FaasErrorCode.java create mode 100644 api/java/function-common/src/main/java/com/services/exception/FaaSException.java create mode 100644 api/java/function-common/src/main/java/com/services/runtime/action/PreStop.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/api/Node.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/jni/JniConsumer.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/jni/JniProducer.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/stream/Consumer.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/stream/ConsumerImpl.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/stream/Element.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/stream/Producer.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/stream/ProducerConfig.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/stream/ProducerImpl.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/stream/SubscriptionConfig.java create mode 100644 api/java/function-common/src/main/java/com/yuanrong/stream/SubscriptionType.java create mode 100644 api/java/function-common/src/test/java/com/services/model/TestFaaSModel.java create mode 100644 api/java/function-common/src/test/java/com/yuanrong/affinity/TestSelector.java create mode 100644 api/java/function-common/src/test/java/com/yuanrong/stream/TestConsumerImpl.java create mode 100644 api/java/function-common/src/test/java/com/yuanrong/stream/TestElement.java create mode 100644 api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerConfig.java create mode 100644 api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerImpl.java create mode 100644 api/java/function-common/src/test/java/com/yuanrong/stream/TestSubscriptionConfig.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/api/JobExecutorCaller.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoFunctionHandler.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceCreator.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceFunctionHandler.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceHandler.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoFunction.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoInstanceClass.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoInstanceMethod.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/JobExecutor.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/OBSoptions.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/RuntimeEnv.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobInfo.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobParam.java create mode 100644 api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobStatus.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/api/TestJobExecutorCaller.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoFunctionHandler.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceCreator.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceFunctionHandler.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceHandler.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestJobExecutor.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestOBSoptions.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestRuntimeEnv.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestYRJobInfo.java create mode 100644 api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestYRJobParam.java create mode 100644 api/java/yr-runtime/src/main/java/com/yuanrong/executor/FaaSHandler.java create mode 100644 api/java/yr-runtime/src/test/java/com/yuanrong/executor/MockFailedClass.java create mode 100644 api/java/yr-runtime/src/test/java/com/yuanrong/executor/MockNoneClass.java create mode 100644 api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestFaaSHandler.java create mode 100644 api/python/functionsdk.py create mode 100644 api/python/yr/compiled_dag_ref.py create mode 100644 api/python/yr/executor/faas_executor.py create mode 100644 api/python/yr/executor/faas_handler.py create mode 100644 api/python/yr/functionsdk/__init__.py create mode 100644 api/python/yr/functionsdk/context.py create mode 100644 api/python/yr/functionsdk/error_code.py create mode 100644 api/python/yr/functionsdk/function.py create mode 100644 api/python/yr/functionsdk/logger.py create mode 100644 api/python/yr/functionsdk/logger_manager.py create mode 100644 api/python/yr/functionsdk/utils.py create mode 100644 api/python/yr/generator.py create mode 100644 api/python/yr/runtime_env.py create mode 100644 api/python/yr/stream.py create mode 100644 api/python/yr/tests/test_apis_get.py create mode 100644 api/python/yr/tests/test_apis_put.py create mode 100644 api/python/yr/tests/test_faas_handler.py create mode 100644 api/python/yr/tests/test_functionsdk.py create mode 100644 api/python/yr/tests/test_generator.py create mode 100644 api/python/yr/tests/test_runtime_env.py create mode 100644 bazel/preload_opentelemetry.bzl create mode 100644 bazel/yr_go.bzl create mode 100644 patch/spdlog-change-namespace-and-library-name-with-yr.patch create mode 100644 src/dto/debug_config.h create mode 100644 src/dto/stream_conf.h create mode 100644 src/libruntime/driverlog/driverlog_receiver.cpp create mode 100644 src/libruntime/driverlog/driverlog_receiver.h create mode 100644 src/libruntime/fsclient/grpc/posix_auth_interceptor.h create mode 100644 src/libruntime/generator/generator_id_map.h create mode 100644 src/libruntime/generator/generator_notifier.h create mode 100644 src/libruntime/generator/generator_receiver.h create mode 100644 src/libruntime/generator/stream_generator_notifier.cpp create mode 100644 src/libruntime/generator/stream_generator_notifier.h create mode 100644 src/libruntime/generator/stream_generator_receiver.cpp create mode 100644 src/libruntime/generator/stream_generator_receiver.h create mode 100644 src/libruntime/gwclient/gw_client.cpp create mode 100644 src/libruntime/gwclient/gw_client.h create mode 100644 src/libruntime/gwclient/gw_datasystem_client_wrapper.h create mode 100644 src/libruntime/invokeadaptor/alias_element.cpp create mode 100644 src/libruntime/invokeadaptor/alias_element.h create mode 100644 src/libruntime/invokeadaptor/alias_routing.cpp create mode 100644 src/libruntime/invokeadaptor/alias_routing.h create mode 100644 src/libruntime/invokeadaptor/faas_instance_manager.cpp create mode 100644 src/libruntime/invokeadaptor/faas_instance_manager.h create mode 100644 src/libruntime/invokeadaptor/limiter_consistant_hash.cpp create mode 100644 src/libruntime/invokeadaptor/limiter_consistant_hash.h create mode 100644 src/libruntime/invokeadaptor/load_balancer.cpp create mode 100644 src/libruntime/invokeadaptor/load_balancer.h create mode 100644 src/libruntime/invokeadaptor/scheduler_instance_info.cpp create mode 100644 src/libruntime/invokeadaptor/scheduler_instance_info.h create mode 100644 src/libruntime/streamstore/datasystem_stream_store.cpp create mode 100644 src/libruntime/streamstore/datasystem_stream_store.h create mode 100644 src/libruntime/streamstore/stream_producer_consumer.cpp create mode 100644 src/libruntime/streamstore/stream_producer_consumer.h create mode 100644 src/libruntime/streamstore/stream_store.h create mode 100644 src/libruntime/traceadaptor/exporter/log_file_exporter.cpp create mode 100644 src/libruntime/traceadaptor/exporter/log_file_exporter.h create mode 100644 src/libruntime/traceadaptor/exporter/log_file_exporter_factory.cpp create mode 100644 src/libruntime/traceadaptor/exporter/log_file_exporter_factory.h create mode 100644 src/libruntime/traceadaptor/trace_adapter.cpp create mode 100644 src/libruntime/traceadaptor/trace_adapter.h create mode 100644 src/libruntime/traceadaptor/trace_struct.h create mode 100644 src/libruntime/utils/grpc_utils.cpp create mode 100644 src/libruntime/utils/grpc_utils.h create mode 100644 src/libruntime/utils/hash_utils.cpp create mode 100644 src/libruntime/utils/hash_utils.h create mode 100644 src/libruntime/utils/http_utils.cpp create mode 100644 src/libruntime/utils/http_utils.h create mode 100644 src/scene/downgrade.cpp create mode 100644 src/scene/downgrade.h create mode 100644 src/utility/file_watcher.cpp create mode 100644 src/utility/file_watcher.h create mode 100644 test/api/runtime_env_parse_test.cpp create mode 100644 test/api/runtime_env_test.cpp create mode 100644 test/api/stream_pub_sub_test.cpp create mode 100644 test/clibruntime/clibruntime_test.cpp create mode 100644 test/data/cert/ca.crt create mode 100644 test/data/cert/client.crt create mode 100644 test/data/cert/client.key create mode 100644 test/data/cert/server.crt create mode 100644 test/data/cert/server.key create mode 100644 test/faas/faas_executor_test.cpp create mode 100644 test/faas/function_test.cpp create mode 100644 test/libruntime/alias_routing_test.cpp create mode 100644 test/libruntime/driverlog_test.cpp create mode 100644 test/libruntime/faas_instance_manager_test.cpp create mode 100644 test/libruntime/generator_test.cpp create mode 100644 test/libruntime/grpc_utils_test.cpp create mode 100644 test/libruntime/gw_client_test.cpp create mode 100644 test/libruntime/hash_util_test.cpp create mode 100644 test/libruntime/http_utils_test.cpp create mode 100644 test/libruntime/https_client_test.cpp create mode 100644 test/libruntime/limiter_consistant_hash_test.cpp create mode 100644 test/libruntime/load_balancer_test.cpp create mode 100644 test/libruntime/request_queue_test.cpp create mode 100644 test/libruntime/scheduler_instance_info_test.cpp create mode 100644 test/libruntime/stream_store_test.cpp create mode 100644 test/libruntime/trace_adapter_test.cpp create mode 100644 test/scene/downgrade_test.cpp create mode 100644 test/test_goruntime_start.sh create mode 100644 test/utility/file_watcher_test.cpp create mode 100644 yuanrong/build/build.sh create mode 100644 yuanrong/build/build_function.sh create mode 100644 yuanrong/build/compile_functions.sh create mode 100644 yuanrong/build/dashboard/config/dashboard_config.json create mode 100644 yuanrong/build/dashboard/config/dashboard_log.json create mode 100644 yuanrong/cmd/collector/main.go create mode 100644 yuanrong/cmd/collector/process/process.go create mode 100644 yuanrong/cmd/dashboard/main.go create mode 100644 yuanrong/cmd/dashboard/process/process.go create mode 100644 yuanrong/cmd/faas/faascontroller/main.go create mode 100644 yuanrong/cmd/faas/faascontroller/main_test.go create mode 100644 yuanrong/cmd/faas/faasmanager/main.go create mode 100644 yuanrong/cmd/faas/faasscheduler/function_main.go create mode 100644 yuanrong/cmd/faas/faasscheduler/function_main_test.go create mode 100644 yuanrong/cmd/faas/faasscheduler/module_main.go create mode 100644 yuanrong/go.mod create mode 100644 yuanrong/pkg/collector/common/connection.go create mode 100644 yuanrong/pkg/collector/common/connection_test.go create mode 100644 yuanrong/pkg/collector/common/flags.go create mode 100644 yuanrong/pkg/collector/common/flags_test.go create mode 100644 yuanrong/pkg/collector/logcollector/common.go create mode 100644 yuanrong/pkg/collector/logcollector/common_test.go create mode 100644 yuanrong/pkg/collector/logcollector/log_reporter.go create mode 100644 yuanrong/pkg/collector/logcollector/log_reporter_test.go create mode 100644 yuanrong/pkg/collector/logcollector/register.go create mode 100644 yuanrong/pkg/collector/logcollector/register_test.go create mode 100644 yuanrong/pkg/collector/logcollector/service.go create mode 100644 yuanrong/pkg/collector/logcollector/service_test.go create mode 100644 yuanrong/pkg/common/constants/constant_test.go create mode 100644 yuanrong/pkg/common/constants/constants.go create mode 100644 yuanrong/pkg/common/crypto/crypto.go create mode 100644 yuanrong/pkg/common/crypto/crypto_test.go create mode 100644 yuanrong/pkg/common/crypto/pem_crypto.go create mode 100644 yuanrong/pkg/common/crypto/scc_constants.go create mode 100644 yuanrong/pkg/common/crypto/scc_crypto.go create mode 100644 yuanrong/pkg/common/crypto/scc_crypto_fake.go create mode 100644 yuanrong/pkg/common/crypto/scc_crypto_test.go create mode 100644 yuanrong/pkg/common/crypto/types.go create mode 100644 yuanrong/pkg/common/crypto/types_test.go create mode 100644 yuanrong/pkg/common/engine/etcd/etcd.go create mode 100644 yuanrong/pkg/common/engine/etcd/stream.go create mode 100644 yuanrong/pkg/common/engine/etcd/transaction.go create mode 100644 yuanrong/pkg/common/engine/etcd/transaction_test.go create mode 100644 yuanrong/pkg/common/engine/interface.go create mode 100644 yuanrong/pkg/common/etcd3/config.go create mode 100644 yuanrong/pkg/common/etcd3/config_test.go create mode 100644 yuanrong/pkg/common/etcd3/event.go create mode 100644 yuanrong/pkg/common/etcd3/event_test.go create mode 100644 yuanrong/pkg/common/etcd3/scc_config.go create mode 100644 yuanrong/pkg/common/etcd3/scc_watcher.go create mode 100644 yuanrong/pkg/common/etcd3/scc_watcher_no_scc.go create mode 100644 yuanrong/pkg/common/etcd3/watcher.go create mode 100644 yuanrong/pkg/common/etcd3/watcher_test.go create mode 100644 yuanrong/pkg/common/etcdkey/etcdkey.go create mode 100644 yuanrong/pkg/common/etcdkey/etcdkey_test.go create mode 100644 yuanrong/pkg/common/faas_common/alarm/config.go create mode 100644 yuanrong/pkg/common/faas_common/alarm/logalarm.go create mode 100644 yuanrong/pkg/common/faas_common/alarm/logalarm_test.go create mode 100644 yuanrong/pkg/common/faas_common/aliasroute/alias.go create mode 100644 yuanrong/pkg/common/faas_common/aliasroute/alias_test.go create mode 100644 yuanrong/pkg/common/faas_common/aliasroute/event.go create mode 100644 yuanrong/pkg/common/faas_common/aliasroute/expression.go create mode 100644 yuanrong/pkg/common/faas_common/aliasroute/expression_test.go create mode 100644 yuanrong/pkg/common/faas_common/autogc/algorithm.go create mode 100644 yuanrong/pkg/common/faas_common/autogc/algorithm_test.go create mode 100644 yuanrong/pkg/common/faas_common/autogc/autogc.go create mode 100644 yuanrong/pkg/common/faas_common/autogc/autogc_test.go create mode 100644 yuanrong/pkg/common/faas_common/autogc/util.go create mode 100644 yuanrong/pkg/common/faas_common/autogc/util_test.go create mode 100644 yuanrong/pkg/common/faas_common/config/config.go create mode 100644 yuanrong/pkg/common/faas_common/constant/app.go create mode 100644 yuanrong/pkg/common/faas_common/constant/constant.go create mode 100644 yuanrong/pkg/common/faas_common/constant/delegate.go create mode 100644 yuanrong/pkg/common/faas_common/constant/functiongraph.go create mode 100644 yuanrong/pkg/common/faas_common/constant/wisecloud.go create mode 100644 yuanrong/pkg/common/faas_common/crypto/cryptoapi_mock.go create mode 100644 yuanrong/pkg/common/faas_common/crypto/scc_crypto.go create mode 100644 yuanrong/pkg/common/faas_common/crypto/scc_crypto_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/cache.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/cache_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/client.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/client_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/config.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/config_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/event.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/event_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/instance_register.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/instance_register_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/lease.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/lease_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/lock.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/lock_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/type.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/utils.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/utils_test.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/watcher.go create mode 100644 yuanrong/pkg/common/faas_common/etcd3/watcher_test.go create mode 100644 yuanrong/pkg/common/faas_common/instance/util.go create mode 100644 yuanrong/pkg/common/faas_common/instance/util_test.go create mode 100644 yuanrong/pkg/common/faas_common/instanceconfig/util.go create mode 100644 yuanrong/pkg/common/faas_common/instanceconfig/util_test.go create mode 100644 yuanrong/pkg/common/faas_common/k8sclient/tools.go create mode 100644 yuanrong/pkg/common/faas_common/k8sclient/tools_test.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/connection/connection.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/connection/stream_connection.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/connection/stream_connection_test.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/basic_stream_client.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/basic_stream_client_test.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/client.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/client_test.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/server.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/simplified_stream_server.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/simplified_stream_server_test.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/utils/utils.go create mode 100644 yuanrong/pkg/common/faas_common/kernelrpc/utils/utils_test.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/hash.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/hash_test.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/hashcache.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/loadbalance.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/loadbalance_test.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/nolockconsistenthash.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/nolockconsistenthash_test.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/roundrobin.go create mode 100644 yuanrong/pkg/common/faas_common/loadbalance/roundrobin_test.go create mode 100644 yuanrong/pkg/common/faas_common/localauth/authcache.go create mode 100644 yuanrong/pkg/common/faas_common/localauth/authcache_test.go create mode 100644 yuanrong/pkg/common/faas_common/localauth/authcheck.go create mode 100644 yuanrong/pkg/common/faas_common/localauth/authcheck_test.go create mode 100644 yuanrong/pkg/common/faas_common/localauth/crypto.go create mode 100644 yuanrong/pkg/common/faas_common/localauth/crypto_test.go create mode 100644 yuanrong/pkg/common/faas_common/localauth/env.go create mode 100644 yuanrong/pkg/common/faas_common/localauth/env_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/async/writer.go create mode 100644 yuanrong/pkg/common/faas_common/logger/async/writer_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/config/config.go create mode 100644 yuanrong/pkg/common/faas_common/logger/config/config_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/custom_encoder.go create mode 100644 yuanrong/pkg/common/faas_common/logger/custom_encoder_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/healthlog/healthlog.go create mode 100644 yuanrong/pkg/common/faas_common/logger/healthlog/healthlog_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/interface_encoder.go create mode 100644 yuanrong/pkg/common/faas_common/logger/interface_encoder_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/interfacelogger.go create mode 100644 yuanrong/pkg/common/faas_common/logger/interfacelogger_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/log/logger.go create mode 100644 yuanrong/pkg/common/faas_common/logger/log/logger_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/rollinglog.go create mode 100644 yuanrong/pkg/common/faas_common/logger/rollinglog_test.go create mode 100644 yuanrong/pkg/common/faas_common/logger/zap/zaplog.go create mode 100644 yuanrong/pkg/common/faas_common/logger/zap/zaplog_test.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/defaultfilewatcher.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/defaultfilewatcher_test.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/filewatcher.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/filewatcher_test.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/memory.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/memory_test.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/mockfilewatcher.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/mockfilewatcher_test.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/parser.go create mode 100644 yuanrong/pkg/common/faas_common/monitor/parser_test.go create mode 100644 yuanrong/pkg/common/faas_common/queue/fifoqueue.go create mode 100644 yuanrong/pkg/common/faas_common/queue/fifoqueue_test.go create mode 100644 yuanrong/pkg/common/faas_common/queue/priorityqueue.go create mode 100644 yuanrong/pkg/common/faas_common/queue/priorityqueue_test.go create mode 100644 yuanrong/pkg/common/faas_common/queue/queue.go create mode 100644 yuanrong/pkg/common/faas_common/redisclient/redisclient.go create mode 100644 yuanrong/pkg/common/faas_common/redisclient/redisclient_test.go create mode 100644 yuanrong/pkg/common/faas_common/resspeckey/type.go create mode 100644 yuanrong/pkg/common/faas_common/resspeckey/util.go create mode 100644 yuanrong/pkg/common/faas_common/resspeckey/util_test.go create mode 100644 yuanrong/pkg/common/faas_common/signals/signal.go create mode 100644 yuanrong/pkg/common/faas_common/signals/signal_test.go create mode 100644 yuanrong/pkg/common/faas_common/snerror/snerror.go create mode 100644 yuanrong/pkg/common/faas_common/snerror/snerror_test.go create mode 100644 yuanrong/pkg/common/faas_common/state/observer.go create mode 100644 yuanrong/pkg/common/faas_common/state/observer_test.go create mode 100644 yuanrong/pkg/common/faas_common/statuscode/statuscode.go create mode 100644 yuanrong/pkg/common/faas_common/statuscode/statuscode_test.go create mode 100644 yuanrong/pkg/common/faas_common/sts/cert/cert.go create mode 100644 yuanrong/pkg/common/faas_common/sts/cert/cert_test.go create mode 100644 yuanrong/pkg/common/faas_common/sts/common.go create mode 100644 yuanrong/pkg/common/faas_common/sts/common_test.go create mode 100644 yuanrong/pkg/common/faas_common/sts/raw/crypto.go create mode 100644 yuanrong/pkg/common/faas_common/sts/raw/crypto_test.go create mode 100644 yuanrong/pkg/common/faas_common/sts/raw/raw.go create mode 100644 yuanrong/pkg/common/faas_common/sts/sts.go create mode 100644 yuanrong/pkg/common/faas_common/timewheel/simpletimewheel.go create mode 100644 yuanrong/pkg/common/faas_common/timewheel/simpletimewheel_test.go create mode 100644 yuanrong/pkg/common/faas_common/timewheel/timewheel.go create mode 100644 yuanrong/pkg/common/faas_common/tls/https.go create mode 100644 yuanrong/pkg/common/faas_common/tls/https_test.go create mode 100644 yuanrong/pkg/common/faas_common/tls/option.go create mode 100644 yuanrong/pkg/common/faas_common/tls/option_test.go create mode 100644 yuanrong/pkg/common/faas_common/tls/tls.go create mode 100644 yuanrong/pkg/common/faas_common/tls/tls_test.go create mode 100644 yuanrong/pkg/common/faas_common/trafficlimit/trafficlimit.go create mode 100644 yuanrong/pkg/common/faas_common/trafficlimit/trafficlimit_test.go create mode 100644 yuanrong/pkg/common/faas_common/types/serve.go create mode 100644 yuanrong/pkg/common/faas_common/types/serve_test.go create mode 100644 yuanrong/pkg/common/faas_common/types/types.go create mode 100644 yuanrong/pkg/common/faas_common/urnutils/gadgets.go create mode 100644 yuanrong/pkg/common/faas_common/urnutils/gadgets_test.go create mode 100644 yuanrong/pkg/common/faas_common/urnutils/urn_utils.go create mode 100644 yuanrong/pkg/common/faas_common/urnutils/urn_utils_test.go create mode 100644 yuanrong/pkg/common/faas_common/urnutils/urnconv.go create mode 100644 yuanrong/pkg/common/faas_common/urnutils/urnconv_test.go create mode 100644 yuanrong/pkg/common/faas_common/utils/component_util.go create mode 100644 yuanrong/pkg/common/faas_common/utils/file_test.go create mode 100644 yuanrong/pkg/common/faas_common/utils/func_meta_util.go create mode 100644 yuanrong/pkg/common/faas_common/utils/func_meta_util_test.go create mode 100644 yuanrong/pkg/common/faas_common/utils/helper.go create mode 100644 yuanrong/pkg/common/faas_common/utils/helper_test.go create mode 100644 yuanrong/pkg/common/faas_common/utils/libruntimeapi_mock.go create mode 100644 yuanrong/pkg/common/faas_common/utils/libruntimeapi_mock_test.go create mode 100644 yuanrong/pkg/common/faas_common/utils/memory_test.go create mode 100644 yuanrong/pkg/common/faas_common/utils/mock_utils.go create mode 100644 yuanrong/pkg/common/faas_common/utils/resourcepath.go create mode 100644 yuanrong/pkg/common/faas_common/utils/scheduler_option.go create mode 100644 yuanrong/pkg/common/faas_common/utils/scheduler_option_test.go create mode 100644 yuanrong/pkg/common/faas_common/utils/tools.go create mode 100644 yuanrong/pkg/common/faas_common/utils/tools_test.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/pod_operator.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/pod_operator_test.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/prometheus_metrics.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/prometheus_metrics_test.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/jwtsign.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/jwtsign_test.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/parse.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/parse_test.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/token.go create mode 100644 yuanrong/pkg/common/faas_common/wisecloudtool/types/types.go create mode 100644 yuanrong/pkg/common/go.mod create mode 100644 yuanrong/pkg/common/httputil/config/adminconfig.go create mode 100644 yuanrong/pkg/common/httputil/http/client/client.go create mode 100644 yuanrong/pkg/common/httputil/http/client/fast/client.go create mode 100644 yuanrong/pkg/common/httputil/http/client/fast/client_test.go create mode 100644 yuanrong/pkg/common/httputil/http/const.go create mode 100644 yuanrong/pkg/common/httputil/http/type.go create mode 100644 yuanrong/pkg/common/httputil/utils/file.go create mode 100644 yuanrong/pkg/common/httputil/utils/file_test.go create mode 100644 yuanrong/pkg/common/httputil/utils/utils.go create mode 100644 yuanrong/pkg/common/httputil/utils/utils_test.go create mode 100644 yuanrong/pkg/common/job/config.go create mode 100644 yuanrong/pkg/common/job/handler.go create mode 100644 yuanrong/pkg/common/job/handler_test.go create mode 100644 yuanrong/pkg/common/protobuf/adaptor.proto create mode 100644 yuanrong/pkg/common/protobuf/bus.proto create mode 100644 yuanrong/pkg/common/protobuf/callMessage.proto create mode 100644 yuanrong/pkg/common/protobuf/deadlock.proto create mode 100644 yuanrong/pkg/common/protobuf/error.proto create mode 100644 yuanrong/pkg/common/protobuf/filter.proto create mode 100644 yuanrong/pkg/common/protobuf/get.proto create mode 100644 yuanrong/pkg/common/protobuf/health/health_service.proto create mode 100644 yuanrong/pkg/common/protobuf/invoke.proto create mode 100644 yuanrong/pkg/common/protobuf/readstate.proto create mode 100644 yuanrong/pkg/common/protobuf/rpc/bus_service.proto create mode 100644 yuanrong/pkg/common/protobuf/rpc/common.proto create mode 100644 yuanrong/pkg/common/protobuf/rpc/core_service.proto create mode 100644 yuanrong/pkg/common/protobuf/rpc/inner_service.proto create mode 100644 yuanrong/pkg/common/protobuf/rpc/runtime_rpc.proto create mode 100644 yuanrong/pkg/common/protobuf/rpc/runtime_service.proto create mode 100644 yuanrong/pkg/common/protobuf/savestate.proto create mode 100644 yuanrong/pkg/common/protobuf/scheduler/domainscheduler_service.proto create mode 100644 yuanrong/pkg/common/protobuf/scheduler/globalscheduler_service.proto create mode 100644 yuanrong/pkg/common/protobuf/scheduler/localscheduler_service.proto create mode 100644 yuanrong/pkg/common/protobuf/scheduler/scheduler_common.proto create mode 100644 yuanrong/pkg/common/protobuf/scheduler/worker_agent_service.proto create mode 100644 yuanrong/pkg/common/protobuf/settimeout.proto create mode 100644 yuanrong/pkg/common/protobuf/specialize.proto create mode 100644 yuanrong/pkg/common/protobuf/terminate.proto create mode 100644 yuanrong/pkg/common/protobuf/wait.proto create mode 100644 yuanrong/pkg/common/reader/reader.go create mode 100644 yuanrong/pkg/common/reader/reader_test.go create mode 100644 yuanrong/pkg/common/tls/https.go create mode 100644 yuanrong/pkg/common/tls/https_test.go create mode 100644 yuanrong/pkg/common/tls/option.go create mode 100644 yuanrong/pkg/common/tls/option_scc.go create mode 100644 yuanrong/pkg/common/tls/option_scc_fake.go create mode 100644 yuanrong/pkg/common/tls/option_test.go create mode 100644 yuanrong/pkg/common/tls/tls.go create mode 100644 yuanrong/pkg/common/uuid/uuid.go create mode 100644 yuanrong/pkg/common/uuid/uuid_test.go create mode 100644 yuanrong/pkg/dashboard/client/index.html create mode 100644 yuanrong/pkg/dashboard/client/package.json create mode 100644 yuanrong/pkg/dashboard/client/public/logo.png create mode 100644 yuanrong/pkg/dashboard/client/src/api/api.ts create mode 100644 yuanrong/pkg/dashboard/client/src/api/index.ts create mode 100644 yuanrong/pkg/dashboard/client/src/components/breadcrumb-component.vue create mode 100644 yuanrong/pkg/dashboard/client/src/components/chart-config.ts create mode 100644 yuanrong/pkg/dashboard/client/src/components/common-card.vue create mode 100644 yuanrong/pkg/dashboard/client/src/components/log-content-template.vue create mode 100644 yuanrong/pkg/dashboard/client/src/components/progress-bar-template.ts create mode 100644 yuanrong/pkg/dashboard/client/src/components/warning-notify.ts create mode 100644 yuanrong/pkg/dashboard/client/src/i18n/index.ts create mode 100644 yuanrong/pkg/dashboard/client/src/index.css create mode 100644 yuanrong/pkg/dashboard/client/src/main.ts create mode 100644 yuanrong/pkg/dashboard/client/src/pages/cluster/cluster-chart.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/cluster/cluster-layout.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/instance-details/components/empty-log-card.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/instance-details/components/instance-info.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/instance-details/instance-details-layout.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/instances/instances-chart.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/job-details/components/job-info.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/job-details/job-details-layout.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/jobs/jobs-chart.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/layout.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-content.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-files.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-nodes.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/overview/components/cluster-card.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/overview/components/instances-card.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/overview/components/resources-card.vue create mode 100644 yuanrong/pkg/dashboard/client/src/pages/overview/overview-layout.vue create mode 100644 yuanrong/pkg/dashboard/client/src/types/api.d.ts create mode 100644 yuanrong/pkg/dashboard/client/src/utils/dayFormat.ts create mode 100644 yuanrong/pkg/dashboard/client/src/utils/handleNum.ts create mode 100644 yuanrong/pkg/dashboard/client/src/utils/sort.ts create mode 100644 yuanrong/pkg/dashboard/client/src/utils/swr.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/components/progress-bar-template.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/main.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/cluster/cluster-chart.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/cluster/cluster-layout.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/instance-details/components/empty-log-card.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/instance-details/components/instance-info.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/instance-details/instance-details-layout.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/instances/instances-chart.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/job-details/components/job-info.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/job-details/job-details-layout.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/jobs/jobs-chart.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/layout.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-content.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-files.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-nodes.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/overview/components/cluster-card.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/overview/components/instances-card.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/overview/components/resources-card.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/tests/pages/overview/overview-layout.spec.ts create mode 100644 yuanrong/pkg/dashboard/client/vite.config.js create mode 100644 yuanrong/pkg/dashboard/etcdcache/instance_cache.go create mode 100644 yuanrong/pkg/dashboard/etcdcache/instance_cache_test.go create mode 100644 yuanrong/pkg/dashboard/flags/flags.go create mode 100644 yuanrong/pkg/dashboard/flags/flags_test.go create mode 100644 yuanrong/pkg/dashboard/getinfo/client_pool.go create mode 100644 yuanrong/pkg/dashboard/getinfo/frontend_app.go create mode 100644 yuanrong/pkg/dashboard/getinfo/get_instances.go create mode 100644 yuanrong/pkg/dashboard/getinfo/get_resources.go create mode 100644 yuanrong/pkg/dashboard/handlers/cluster_status_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/cluster_status_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/components_componentid_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/components_componentid_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/components_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/components_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/err_code.go create mode 100644 yuanrong/pkg/dashboard/handlers/instances_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/instances_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/instances_instanceid_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/instances_instanceid_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/instances_parentid_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/instances_parentid_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/instances_summary_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/instances_summary_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/job_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/job_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/prometheus_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/prometheus_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/resources_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/resources_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/resources_summary_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/resources_summary_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/resources_unitid_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/resources_unitid_handler_test.go create mode 100644 yuanrong/pkg/dashboard/handlers/serve_handler.go create mode 100644 yuanrong/pkg/dashboard/handlers/serve_handler_test.go create mode 100644 yuanrong/pkg/dashboard/logmanager/collector_client.go create mode 100644 yuanrong/pkg/dashboard/logmanager/collector_client_test.go create mode 100644 yuanrong/pkg/dashboard/logmanager/http_handlers.go create mode 100644 yuanrong/pkg/dashboard/logmanager/http_handlers_test.go create mode 100644 yuanrong/pkg/dashboard/logmanager/log_db.go create mode 100644 yuanrong/pkg/dashboard/logmanager/log_db_test.go create mode 100644 yuanrong/pkg/dashboard/logmanager/log_entry.go create mode 100644 yuanrong/pkg/dashboard/logmanager/log_entry_test.go create mode 100644 yuanrong/pkg/dashboard/logmanager/log_index.go create mode 100644 yuanrong/pkg/dashboard/logmanager/log_index_test.go create mode 100644 yuanrong/pkg/dashboard/logmanager/log_manager.go create mode 100644 yuanrong/pkg/dashboard/logmanager/log_manager_test.go create mode 100644 yuanrong/pkg/dashboard/logmanager/service.go create mode 100644 yuanrong/pkg/dashboard/logmanager/service_test.go create mode 100644 yuanrong/pkg/dashboard/models/common_response.go create mode 100644 yuanrong/pkg/dashboard/models/serve_api_models.go create mode 100644 yuanrong/pkg/dashboard/routers/cors.go create mode 100644 yuanrong/pkg/dashboard/routers/router.go create mode 100644 yuanrong/pkg/dashboard/routers/router_test.go create mode 100644 yuanrong/pkg/functionmanager/config/config.go create mode 100644 yuanrong/pkg/functionmanager/config/config_test.go create mode 100644 yuanrong/pkg/functionmanager/constant/constant.go create mode 100644 yuanrong/pkg/functionmanager/faasmanager.go create mode 100644 yuanrong/pkg/functionmanager/faasmanager_test.go create mode 100644 yuanrong/pkg/functionmanager/queue.go create mode 100644 yuanrong/pkg/functionmanager/state/state.go create mode 100644 yuanrong/pkg/functionmanager/state/state_test.go create mode 100644 yuanrong/pkg/functionmanager/types/types.go create mode 100644 yuanrong/pkg/functionmanager/utils/utils.go create mode 100644 yuanrong/pkg/functionmanager/utils/utils_test.go create mode 100644 yuanrong/pkg/functionmanager/vpcmanager/plugin.go create mode 100644 yuanrong/pkg/functionmanager/vpcmanager/plugin_test.go create mode 100644 yuanrong/pkg/functionmanager/vpcmanager/pulltrigger.go create mode 100644 yuanrong/pkg/functionmanager/vpcmanager/pulltrigger_test.go create mode 100644 yuanrong/pkg/functionmanager/vpcmanager/types.go create mode 100644 yuanrong/pkg/functionmanager/vpcmanager/volumeconstant.go create mode 100644 yuanrong/pkg/functionscaler/config/config.go create mode 100644 yuanrong/pkg/functionscaler/config/config_test.go create mode 100644 yuanrong/pkg/functionscaler/config/hotload_config.go create mode 100644 yuanrong/pkg/functionscaler/config/hotload_config_test.go create mode 100644 yuanrong/pkg/functionscaler/dynamicconfigmanager/dynamicconfigmanager.go create mode 100644 yuanrong/pkg/functionscaler/dynamicconfigmanager/dynamicconfigmanager_test.go create mode 100644 yuanrong/pkg/functionscaler/faasscheduler.go create mode 100644 yuanrong/pkg/functionscaler/faasscheduler_test.go create mode 100644 yuanrong/pkg/functionscaler/healthcheck/healthcheck.go create mode 100644 yuanrong/pkg/functionscaler/healthcheck/healthcheck_test.go create mode 100644 yuanrong/pkg/functionscaler/httpserver/httpserver.go create mode 100644 yuanrong/pkg/functionscaler/httpserver/httpserver_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/componentset.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/container_tool.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/instance_operation_adapter.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/instance_operation_adapter_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/instance_operation_fg.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/instance_operation_fg_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/instance_operation_kernel.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/instance_operation_kernel_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/instancepool.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/instancepool_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/log.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/log_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/min_instance_alarm.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/miscellaneous.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/miscellaneous_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/operatekube.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/operatekube_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/poolmanager.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/poolmanager_test.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/rasp_sidecar.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/stateroute.go create mode 100644 yuanrong/pkg/functionscaler/instancepool/stateroute_test.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/create_request_queue.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/create_request_queue_test.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/instance_queue.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/instance_queue_builder.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/instance_queue_builder_test.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/instance_queue_test.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/ondemand_instance_queue.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/ondemand_instance_queue_test.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/scaled_instance_queue.go create mode 100644 yuanrong/pkg/functionscaler/instancequeue/scaled_instance_queue_test.go create mode 100644 yuanrong/pkg/functionscaler/lease/generic_lease_manager.go create mode 100644 yuanrong/pkg/functionscaler/lease/generic_lease_manager_test.go create mode 100644 yuanrong/pkg/functionscaler/lease/lease.go create mode 100644 yuanrong/pkg/functionscaler/registry/agentregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/agentregistry_test.go create mode 100644 yuanrong/pkg/functionscaler/registry/aliasregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/faasfrontendregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/faasmanagerregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/faasschedulerregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/faasschedulerregistry_test.go create mode 100644 yuanrong/pkg/functionscaler/registry/functionavailableregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/functionregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/instanceconfigregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/instanceregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/instanceregistry_fg.go create mode 100644 yuanrong/pkg/functionscaler/registry/instanceregistry_fg_test.go create mode 100644 yuanrong/pkg/functionscaler/registry/instanceregistry_test.go create mode 100644 yuanrong/pkg/functionscaler/registry/registry.go create mode 100644 yuanrong/pkg/functionscaler/registry/registry_test.go create mode 100644 yuanrong/pkg/functionscaler/registry/rolloutregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/rolloutregistry_test.go create mode 100644 yuanrong/pkg/functionscaler/registry/tenantquotaregistry.go create mode 100644 yuanrong/pkg/functionscaler/registry/useragencyregistry.go create mode 100644 yuanrong/pkg/functionscaler/requestqueue/instance_request_queue.go create mode 100644 yuanrong/pkg/functionscaler/requestqueue/instance_request_queue_test.go create mode 100644 yuanrong/pkg/functionscaler/rollout/rollouthandler.go create mode 100644 yuanrong/pkg/functionscaler/rollout/rollouthandler_test.go create mode 100644 yuanrong/pkg/functionscaler/scaler/autoscaler.go create mode 100644 yuanrong/pkg/functionscaler/scaler/autoscaler_test.go create mode 100644 yuanrong/pkg/functionscaler/scaler/instance_scaler.go create mode 100644 yuanrong/pkg/functionscaler/scaler/instance_scaler_test.go create mode 100644 yuanrong/pkg/functionscaler/scaler/predictscaler.go create mode 100644 yuanrong/pkg/functionscaler/scaler/replicascaler.go create mode 100644 yuanrong/pkg/functionscaler/scaler/replicascaler_test.go create mode 100644 yuanrong/pkg/functionscaler/scaler/wisecloudscaler.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/basic_concurrency_scheduler.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/basic_concurrency_scheduler_test.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/grayinstanceallocator.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/reserved_concurrency_scheduler.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/reserved_concurrency_scheduler_test.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/scaled_concurrency_scheduler.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/scaled_concurrency_scheduler_test.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/instance_scheduler.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/microservicescheduler/microservice_scheduler.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/microservicescheduler/microservice_scheduler_test.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler/roundrobin_scheduler.go create mode 100644 yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler/roundrobin_scheduler_test.go create mode 100644 yuanrong/pkg/functionscaler/selfregister/proxy.go create mode 100644 yuanrong/pkg/functionscaler/selfregister/proxy_test.go create mode 100644 yuanrong/pkg/functionscaler/selfregister/rolloutregister.go create mode 100644 yuanrong/pkg/functionscaler/selfregister/rolloutregister_test.go create mode 100644 yuanrong/pkg/functionscaler/selfregister/selfregister.go create mode 100644 yuanrong/pkg/functionscaler/selfregister/selfregister_test.go create mode 100644 yuanrong/pkg/functionscaler/signalmanager/signalmanager.go create mode 100644 yuanrong/pkg/functionscaler/signalmanager/signalmanager_test.go create mode 100644 yuanrong/pkg/functionscaler/state/state.go create mode 100644 yuanrong/pkg/functionscaler/state/state_test.go create mode 100644 yuanrong/pkg/functionscaler/stateinstance/lease.go create mode 100644 yuanrong/pkg/functionscaler/stateinstance/lease_test.go create mode 100644 yuanrong/pkg/functionscaler/sts/sensitiveconfig.go create mode 100644 yuanrong/pkg/functionscaler/sts/sensitiveconfig_test.go create mode 100644 yuanrong/pkg/functionscaler/sts/sts.go create mode 100644 yuanrong/pkg/functionscaler/sts/sts_test.go create mode 100644 yuanrong/pkg/functionscaler/sts/types.go create mode 100644 yuanrong/pkg/functionscaler/tenantquota/tenantcache.go create mode 100644 yuanrong/pkg/functionscaler/tenantquota/tenantcache_test.go create mode 100644 yuanrong/pkg/functionscaler/tenantquota/tenantetcd.go create mode 100644 yuanrong/pkg/functionscaler/tenantquota/tenantetcd_test.go create mode 100644 yuanrong/pkg/functionscaler/types/constants.go create mode 100644 yuanrong/pkg/functionscaler/types/types.go create mode 100644 yuanrong/pkg/functionscaler/types/types_test.go create mode 100644 yuanrong/pkg/functionscaler/utils/configmap_util.go create mode 100644 yuanrong/pkg/functionscaler/utils/configmap_util_test.go create mode 100644 yuanrong/pkg/functionscaler/utils/utils.go create mode 100644 yuanrong/pkg/functionscaler/utils/utils_test.go create mode 100644 yuanrong/pkg/functionscaler/workermanager/lease.go create mode 100644 yuanrong/pkg/functionscaler/workermanager/lease_test.go create mode 100644 yuanrong/pkg/functionscaler/workermanager/workermanager_client.go create mode 100644 yuanrong/pkg/functionscaler/workermanager/workermanager_client_test.go create mode 100644 yuanrong/pkg/functionscaler/workermanager/workermanager_request.go create mode 100644 yuanrong/pkg/functionscaler/workermanager/workermanager_request_test.go create mode 100644 yuanrong/pkg/system_function_controller/config/config.go create mode 100644 yuanrong/pkg/system_function_controller/config/config_test.go create mode 100644 yuanrong/pkg/system_function_controller/constant/constant.go create mode 100644 yuanrong/pkg/system_function_controller/faascontroller/fasscontroller.go create mode 100644 yuanrong/pkg/system_function_controller/faascontroller/fasscontroller_test.go create mode 100644 yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager/faasfrontendmanager.go create mode 100644 yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager/frontendmanager_test.go create mode 100644 yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager/faasfunctionmanager.go create mode 100644 yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager/faasfunctionmanager_test.go create mode 100644 yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager/faasschedulermanager.go create mode 100644 yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager/schedulermanager_test.go create mode 100644 yuanrong/pkg/system_function_controller/instancemanager/instancemanager.go create mode 100644 yuanrong/pkg/system_function_controller/instancemanager/instancemanager_test.go create mode 100644 yuanrong/pkg/system_function_controller/registry/faasfrontendregistry.go create mode 100644 yuanrong/pkg/system_function_controller/registry/faasmanagerregistry.go create mode 100644 yuanrong/pkg/system_function_controller/registry/faasschedulerregistry.go create mode 100644 yuanrong/pkg/system_function_controller/registry/registry.go create mode 100644 yuanrong/pkg/system_function_controller/registry/registry_test.go create mode 100644 yuanrong/pkg/system_function_controller/service/frontendservice.go create mode 100644 yuanrong/pkg/system_function_controller/service/frontendservice_test.go create mode 100644 yuanrong/pkg/system_function_controller/state/state.go create mode 100644 yuanrong/pkg/system_function_controller/state/state_test.go create mode 100644 yuanrong/pkg/system_function_controller/types/types.go create mode 100644 yuanrong/pkg/system_function_controller/utils/utils.go create mode 100644 yuanrong/pkg/system_function_controller/utils/utils_test.go create mode 100644 yuanrong/proto/CMakeLists.txt create mode 100644 yuanrong/proto/pb/message_pb.h create mode 100644 yuanrong/proto/pb/posix_pb.h create mode 100644 yuanrong/proto/posix/affinity.proto create mode 100644 yuanrong/proto/posix/bus_adapter.proto create mode 100644 yuanrong/proto/posix/bus_service.proto create mode 100644 yuanrong/proto/posix/common.proto create mode 100644 yuanrong/proto/posix/core_service.proto create mode 100644 yuanrong/proto/posix/inner_service.proto create mode 100644 yuanrong/proto/posix/log_service.proto create mode 100644 yuanrong/proto/posix/message.proto create mode 100644 yuanrong/proto/posix/resource.proto create mode 100644 yuanrong/proto/posix/runtime_rpc.proto create mode 100644 yuanrong/proto/posix/runtime_service.proto create mode 100644 yuanrong/test/collector/test.sh create mode 100644 yuanrong/test/common/test.sh create mode 100644 yuanrong/test/dashboard/test.sh create mode 100644 yuanrong/test/test.sh diff --git a/BUILD.bazel b/BUILD.bazel index ad9d6fa..0645c88 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -26,6 +26,8 @@ cc_library( "src/libruntime/objectstore/*.h", "src/libruntime/statestore/*.cpp", "src/libruntime/statestore/*.h", + "src/libruntime/streamstore/*.cpp", + "src/libruntime/streamstore/*.h", "src/libruntime/heterostore/*.h", "src/libruntime/heterostore/*.cpp", "src/libruntime/utils/*.cpp", @@ -34,10 +36,14 @@ cc_library( "src/libruntime/utils/crypto/*.h", "src/libruntime/invokeadaptor/*.h", "src/libruntime/invokeadaptor/*.cpp", + "src/scene/*.cpp", + "src/scene/*.h", "src/libruntime/fsclient/*.cpp", "src/libruntime/fsclient/*.h", "src/libruntime/fsclient/grpc/*.cpp", "src/libruntime/fsclient/grpc/*.h", + "src/libruntime/gwclient/*.cpp", + "src/libruntime/gwclient/*.h", "src/libruntime/gwclient/http/*.h", "src/libruntime/gwclient/http/*.cpp", "src/libruntime/connect/*.h", @@ -51,10 +57,18 @@ cc_library( "src/libruntime/stacktrace/*.cpp", "src/libruntime/metricsadaptor/*.h", "src/libruntime/metricsadaptor/*.cpp", + "src/libruntime/generator/*.h", + "src/libruntime/generator/*.cpp", + "src/libruntime/driverlog/*.h", + "src/libruntime/driverlog/*.cpp", "src/libruntime/fmclient/*.cpp", "src/libruntime/fmclient/*.h", "src/libruntime/rgroupmanager/*.cpp", "src/libruntime/rgroupmanager/*.h", + "src/libruntime/traceadaptor/*.cpp", + "src/libruntime/traceadaptor/*.h", + "src/libruntime/traceadaptor/exporter/*.cpp", + "src/libruntime/traceadaptor/exporter/*.h", ]) + glob( ["src/libruntime/*.h"], exclude = [ @@ -95,6 +109,11 @@ cc_library( "@nlohmann_json", "@securec", "@com_googlesource_code_re2//:re2", + "@opentelemetry_cpp//api", + "@opentelemetry_cpp//exporters/otlp:otlp_grpc_exporter", + "@opentelemetry_cpp//exporters/ostream:ostream_log_record_exporter", + "@opentelemetry_cpp//sdk/src/trace", + "@opentelemetry_cpp//sdk/src/resource", ], alwayslink = True, ) @@ -105,22 +124,26 @@ cc_library( [ "src/libruntime/objectstore/*.h", "src/libruntime/statestore/*.h", + "src/libruntime/streamstore/*.h", "src/libruntime/heterostore/*.h", "src/libruntime/utils/*.h", "src/libruntime/utils/crypto/*.h", "src/libruntime/invokeadaptor/*.h", "src/libruntime/fsclient/*.h", "src/libruntime/fsclient/grpc/*.h", + "src/libruntime/gwclient/*.h", "src/libruntime/gwclient/http/*.h", "src/libruntime/connect/*.h", "src/libruntime/clientsmanager/*.h", "src/libruntime/groupmanager/*.h", "src/libruntime/stacktrace/*.h", "src/libruntime/metricsadaptor/*.h", + "src/libruntime/generator/*.h", "src/libruntime/fmclient/*.h", "src/libruntime/driverlog/*.h", "src/libruntime/rgroupmanager/*.h", "src/libruntime/*.h", + "src/scene/*.h", ], exclude = [ "src/libruntime/libruntime.h", diff --git a/WORKSPACE b/WORKSPACE index e5074a1..72dcfb8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -35,16 +35,16 @@ load("@rules_jvm_external//:defs.bzl", "maven_install") maven_install( artifacts = [ - "com.google.code.gson:gson:2.10.1", - "org.apache.commons:commons-lang3:3.14.0", + "com.google.code.gson:gson:2.11.0", + "org.apache.commons:commons-lang3:3.18.0", "org.apache.maven.plugins:maven-assembly-plugin:3.4.2", "org.apache.maven.plugins:maven-compiler-plugin:3.10.1", "commons-io:commons-io:2.16.1", "org.json:json:20230227", "org.msgpack:jackson-dataformat-msgpack:0.9.3", "org.msgpack:msgpack-core:0.9.3", - "com.fasterxml.jackson.core:jackson-core:2.16.2", - "com.fasterxml.jackson.core:jackson-databind:2.16.2", + "com.fasterxml.jackson.core:jackson-core:2.18.2", + "com.fasterxml.jackson.core:jackson-databind:2.18.2", "org.apache.logging.log4j:log4j-slf4j-impl:2.23.1", "org.apache.logging.log4j:log4j-api:2.23.1", "org.apache.logging.log4j:log4j-core:2.23.1", @@ -53,7 +53,7 @@ maven_install( "org.powermock:powermock-api-mockito2:2.0.4", "junit:junit:4.11", "org.jacoco:org.jacoco.agent:0.8.8", - "org.projectlombok:lombok:1.18.22", + "org.projectlombok:lombok:1.18.36", "org.ow2.asm:asm:9.7", ], repositories = [ @@ -61,10 +61,13 @@ maven_install( ], ) -new_local_repository( +local_patched_repository( name = "spdlog", - build_file = "@//bazel:spdlog.bzl", path = "../thirdparty/spdlog/", + build_file = "@//bazel:spdlog.bzl", + patch_files = [ + "@yuanrong_multi_language_runtime//patch:spdlog-change-namespace-and-library-name-with-yr.patch", + ] ) http_archive( @@ -94,6 +97,23 @@ load("//bazel:preload_grpc.bzl", "preload_grpc") preload_grpc() +load("//bazel:preload_opentelemetry.bzl", "preload_opentelemetry") + +preload_opentelemetry() + +http_archive( + name = "opentelemetry_cpp", + sha256 = "7735cc56507149686e6019e06f588317099d4522480be5f38a2a09ec69af1706", + strip_prefix = "opentelemetry-cpp-1.13.0", + urls = ["https://github.com/open-telemetry/opentelemetry-cpp/archive/refs/tags/v1.13.0.tar.gz"], +) + +load("@opentelemetry_cpp//bazel:repository.bzl", "opentelemetry_cpp_deps") +opentelemetry_cpp_deps() + +load("@opentelemetry_cpp//bazel:extra_deps.bzl", "opentelemetry_extra_deps") +opentelemetry_extra_deps() + load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") grpc_deps() @@ -113,6 +133,13 @@ http_archive( url = "https://gitee.com/mirrors/msgpack-c/repository/archive/cpp-5.0.0.zip", ) +http_archive( + name = "yaml-cpp", + sha256 = "6a05c681872d9465b8e2040b5211b1aa5cf30151dc4f3d7ed23ac75ce0fd9944", + strip_prefix = "yaml-cpp-0.8.0", + url = "https://gitee.com/mirrors/yaml-cpp/repository/archive/0.8.0.zip", +) + new_local_repository( name = "datasystem_sdk", build_file = "@//bazel:datasystem_sdk.bzl", diff --git a/api/cpp/BUILD.bazel b/api/cpp/BUILD.bazel index 6f14e44..8d9ea9f 100644 --- a/api/cpp/BUILD.bazel +++ b/api/cpp/BUILD.bazel @@ -41,6 +41,8 @@ cc_library( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@securec", + "@nlohmann_json", + "@yaml-cpp", ], alwayslink = True, ) @@ -67,6 +69,58 @@ cc_library( ]), strip_include_prefix = "include", visibility = ["//visibility:public"], + deps = [ + "@nlohmann_json", + ], +) + +cc_library( + name = "functionsdk_lib", + srcs = glob([ + "src/faas/*.cpp", + "src/faas/*.h", + ]) + [ + "src/executor/executor.h", + "src/executor/executor_holder.h", + "src/utils/utils.h", + "src/utils/version.h", + ], + hdrs = glob([ + "include/faas/*.h", + ]), + copts = COPTS, + linkopts = ["-lstdc++fs"], + linkstatic = True, + strip_include_prefix = "include/faas", + visibility = ["//visibility:public"], + deps = [ + "yr_cpp_lib", + "//:runtime_lib_hdrs", + "@msgpack", + "@nlohmann_json", + ], + alwayslink = True, +) + +cc_binary( + name = "libfunctionsdk.so", + copts = COPTS, + dynamic_deps = ["//:grpc_dynamic"], + linkopts = LOPTS, + linkshared = True, + linkstatic = True, + visibility = ["//visibility:public"], + deps = [":functionsdk_lib"], +) + +cc_library( + name = "yr_faas_lib", + srcs = ["libfunctionsdk.so"], + hdrs = glob([ + "include/faas/*.h", + ]), + strip_include_prefix = "include/faas", + visibility = ["//visibility:public"], ) cc_binary( @@ -87,6 +141,7 @@ cc_binary( cc_strip( name = "cpp_strip", srcs = [ + "libfunctionsdk.so", "libyr-api.so", "runtime", ], @@ -104,60 +159,62 @@ genrule( srcs = [ ":cpp_strip", "//:grpc_strip", - "@boringssl//:gen_dir", + "@boringssl//:shared", ":cpp_include", ], outs = ["yr_cpp_pkg.out"], cmd = """ - BASE_DIR="$$(pwd)" && - CPP_SDK_DIR=$$BASE_DIR/build/output/runtime/sdk/cpp && - CPP_SERVICE_DIR=$$BASE_DIR/build/output/runtime/service/cpp && - rm -rf $$CPP_SDK_DIR $$CPP_SERVICE_DIR && - mkdir -p $$CPP_SDK_DIR/include $$CPP_SDK_DIR/include/boost $$CPP_SDK_DIR/lib $$CPP_SDK_DIR/bin $$CPP_SERVICE_DIR && - cp -rf $$BASE_DIR/external/msgpack/include $$CPP_SDK_DIR && - cp -rf $(locations @boringssl//:gen_dir)/include/openssl $$CPP_SDK_DIR/include && - cp -rf $$BASE_DIR/external/boost/boost/asio $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/any $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/assert $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/bind $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/callable_traits $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/chrono $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/config $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/container_hash $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/core $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/date_time $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/detail $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/exception $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/fiber $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/function $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/functional $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/fusion $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/integer $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/io $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/move $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/mpl $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/numeric $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/optional $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/predef $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/preprocessor $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/ratio $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/regex $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/smart_ptr $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/system $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/type_index $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/type_traits $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/utility $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/variant $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/winapi $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/*.hpp $$CPP_SDK_DIR/include/boost && - cp -rf $$BASE_DIR/external/boost/boost/*.h $$CPP_SDK_DIR/include/boost && - cp -rf $(locations @boringssl//:gen_dir)/lib/libssl.so.1.1 $$CPP_SDK_DIR/lib - cp -rf $(locations @boringssl//:gen_dir)/lib/libcrypto.so.1.1 $$CPP_SDK_DIR/lib + BASE_DIR="$$(pwd)" + CPP_SDK_DIR=$$BASE_DIR/build/output/runtime/sdk/cpp + CPP_SERVICE_DIR=$$BASE_DIR/build/output/runtime/service/cpp + rm -rf $$CPP_SDK_DIR $$CPP_SERVICE_DIR + mkdir -p $$CPP_SDK_DIR/include $$CPP_SDK_DIR/include/boost $$CPP_SDK_DIR/lib $$CPP_SDK_DIR/bin $$CPP_SERVICE_DIR + cp -rf $$BASE_DIR/external/msgpack/include $$CPP_SDK_DIR + if [ "$$BOOST_VERSION" != "1.72.0" ];then + cp -rf $$BASE_DIR/external/boost/boost/any $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/assert $$CPP_SDK_DIR/include/boost + fi + cp -rf $$BASE_DIR/external/boringssl/install/include/openssl $$CPP_SDK_DIR/include + cp -rf $$BASE_DIR/external/boost/boost/asio $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/align $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/bind $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/callable_traits $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/chrono $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/config $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/container_hash $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/core $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/date_time $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/detail $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/exception $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/fiber $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/function $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/functional $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/fusion $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/integer $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/io $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/move $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/mpl $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/numeric $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/optional $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/predef $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/preprocessor $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/ratio $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/regex $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/smart_ptr $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/system $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/type_index $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/type_traits $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/utility $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/variant $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/winapi $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/*.hpp $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/boost/boost/*.h $$CPP_SDK_DIR/include/boost + cp -rf $$BASE_DIR/external/nlohmann_json/single_include/nlohmann/json.hpp $$CPP_SDK_DIR/include chmod +w $(locations :cpp_strip) $(locations //:grpc_strip) && chrpath -d $(locations :cpp_strip) $(locations //:grpc_strip) && cp -rf $(locations :cpp_strip) $(locations //:grpc_strip) $$CPP_SDK_DIR/lib/ && cp -rf $(locations :cpp_strip) $$CPP_SDK_DIR/bin/ && - rm -rf $$CPP_SDK_DIR/bin/libyr-api.so + rm -rf $$CPP_SDK_DIR/bin/libyr-api.so $$CPP_SDK_DIR/bin/libfunctionsdk.so cp -rf $$BASE_DIR/api/cpp/include $$CPP_SDK_DIR/ && DATASYSTEM_DIR=$$BASE_DIR/external/datasystem_sdk/cpp && cp -rf $$DATASYSTEM_DIR/include $$CPP_SDK_DIR && diff --git a/api/cpp/example/BUILD.bazel b/api/cpp/example/BUILD.bazel index ecc05ac..aebb85f 100644 --- a/api/cpp/example/BUILD.bazel +++ b/api/cpp/example/BUILD.bazel @@ -26,6 +26,18 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "yr_faas", + srcs = [ + "3rd/lib/libfunctionsdk.so", + ], + hdrs = glob([ + "3rd/include/faas/*.h", + ]), + strip_include_prefix = "3rd/include/faas", + visibility = ["//visibility:public"], +) + [ cc_binary( name = example, @@ -47,9 +59,16 @@ cc_library( ] +cc_binary( + name = "faas_example", + srcs = ["faas_example.cpp"], + linkstatic = True, + deps = [":yr_faas"], +) + genrule( name = "example_all", - srcs = [example for example in examples] + [example + ".so" for example in examples], + srcs = [example for example in examples] + [example + ".so" for example in examples] + [ "faas_example" ], outs = ["example_all.out"], cmd = "echo ok > $@", visibility = ["//visibility:public"], diff --git a/api/cpp/example/faas_example.cpp b/api/cpp/example/faas_example.cpp new file mode 100644 index 0000000..05c56a2 --- /dev/null +++ b/api/cpp/example/faas_example.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "Function.h" +#include "Runtime.h" +bool flags = false; +std::string HandleRequest(const std::string &request, Function::Context &context) +{ + Function::FunctionLogger logger = context.GetLogger(); + logger.setLevel("INFO"); + logger.Info("hello cpp %s ", "user info log"); + logger.Error("hello cpp %s ", "user error log"); + logger.Warn("hello cpp %s ", "user warn log"); + logger.Debug("hello cpp %s ", "user debug log"); + logger.Error("hello cpp %s ", context.GetFunctionName().c_str()); + logger.Error("hello cpp %s ", context.GetUserData("b").c_str()); + logger.Error("hello cpp %s ", context.GetUserData("key1").c_str()); + return request; +} + +void InitState(const std::string &request, Function::Context &context) +{ + context.SetState(request); +} + +void Initializer(Function::Context &context) +{ + flags = true; +} + +const std::string DEFAULT_PORT = "31552"; +int main(int argc, char *argv[]) +{ + Function::Runtime rt; + rt.RegisterHandler(HandleRequest); + // 有状态函数 + rt.InitState(InitState); + // 初始化函数入口 + rt.RegisterInitializerFunction(Initializer); + rt.Start(argc, argv); + return 0; +} diff --git a/api/cpp/example/instance_example.cpp b/api/cpp/example/instance_example.cpp index e88bcca..105e1d7 100644 --- a/api/cpp/example/instance_example.cpp +++ b/api/cpp/example/instance_example.cpp @@ -61,6 +61,7 @@ YR_INVOKE(Counter::FactoryCreate, &Counter::Add, &Counter::Get); int main(void) { YR::Config conf; + conf.inCluster = true; YR::Init(conf); { @@ -81,6 +82,16 @@ int main(void) //! [terminate instance sync] } + { + //! [terminate instance async] + auto counter = YR::Instance(Counter::FactoryCreate).Invoke(1); + auto c = counter.Function(&Counter::Add).Invoke(1); + std::cout << "counter is " << *YR::Get(c) << std::endl; + auto f = counter.AsyncTerminate(true); + f.get(); + //! [terminate instance async] + } + { auto counter = YR::Instance(Counter::FactoryCreate, "name_1").Invoke(1); auto c = counter.Function(&Counter::Add).Invoke(1); diff --git a/api/cpp/example/kv_example.cpp b/api/cpp/example/kv_example.cpp index 31df90f..2507e5e 100644 --- a/api/cpp/example/kv_example.cpp +++ b/api/cpp/example/kv_example.cpp @@ -21,6 +21,7 @@ int main() { YR::Config conf; + conf.inCluster = true; conf.functionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest"; conf.serverAddr = ""; // bus proxy IP:port conf.dataSystemAddr = ""; // datasystem worker IP:port diff --git a/api/cpp/example/object_example.cpp b/api/cpp/example/object_example.cpp index 7bb817b..8f8e0d7 100644 --- a/api/cpp/example/object_example.cpp +++ b/api/cpp/example/object_example.cpp @@ -21,6 +21,7 @@ int main(int argc, char **argv) { YR::Config conf; + conf.inCluster = true; YR::Init(conf); auto ref = YR::Put(123); std::cout << "get result is " << *YR::Get(ref) << std::endl; diff --git a/api/cpp/example/runtime_env_example.cpp b/api/cpp/example/runtime_env_example.cpp new file mode 100644 index 0000000..07ca186 --- /dev/null +++ b/api/cpp/example/runtime_env_example.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! [runtime env demo] +#include +#include "yr/api/runtime_env.h" +#include "yr/yr.h" +int main(int argc, char **argv) { + std::string pyFunctionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:0-yr-mypython:$latest"; + + YR::Config conf; + YR::Init(conf, argc, argv); + YR::InvokeOptions opts; + YR::RuntimeEnv runtimeEnv; + runtimeEnv.Set("conda", "pytorch_p39"); + runtimeEnv.Set>("env_vars", {{"OMP_NUM_THREADS", "32"}, {"TF_WARNINGS", "none"}, {"YR_CONDA_HOME", "/home/snuser/.conda"}}); + opts.runtimeEnv = runtimeEnv; + auto resFutureSquare = YR::PyFunction("calculator", "square").SetUrn(pyFunctionUrn).Options(opts).Invoke(2); + auto resSquare = *YR::Get(resFutureSquare); + std::cout << resSquare << std::endl; + return 0; +} +//! [runtime env demo] \ No newline at end of file diff --git a/api/cpp/example/runtime_env_example1.cpp b/api/cpp/example/runtime_env_example1.cpp new file mode 100644 index 0000000..d25a8b7 --- /dev/null +++ b/api/cpp/example/runtime_env_example1.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "yr/api/runtime_env.h" +#include "yr/yr.h" +int main(int argc, char **argv) +{ + std::string pyFunctionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:0-yr-mypython:$latest"; + + YR::Config conf; + YR::Init(conf, argc, argv); + YR::InvokeOptions opts; + YR::RuntimeEnv runtimeEnv; + runtimeEnv.Set("conda", "pytorch_p39"); + { + //! [set env_vars] + YR::RuntimeEnv runtimeEnv; + runtimeEnv.Set>("env_vars", {{"OMP_NUM_THREADS", "32"}, {"TF_WARNINGS", "none"}}); + //! [set env_vars] + } + { + //! [set pip] + YR::RuntimeEnv env; + env.Set>("pip", {"numpy=2.3.0", "pandas"}); + //! [set pip] + } + { + //! [set working_dir] + YR::RuntimeEnv env; + runtimeEnv.Set("working_dir", "file:/opt/mycode/cpp-invoke-python/calculator.zip"); + //! [set working_dir] + } + { + //! [set existed conda environ] + YR::RuntimeEnv env; + runtimeEnv.Set("conda", "pytorch_p39"); + //! [set existed conda environ] + } + { + //! [set conda environ with dependency] + // If name is not specified, the name will start with 'virtual_env-' followed by a randomly generated suffix. + runtimeEnv.Set("conda", {{"name", "pytorch_p39"},{"channels", {"conda-forge"}}, {"dependencies", {"python=3.9", "matplotlib", "msgpack-python=1.0.5", "protobuf", "libgcc-ng", "numpy", "pandas", "cloudpickle=2.0.0", "cython=3.0.10", "pyyaml=6.0.2"}}}); + runtimeEnv.Set>("env_vars", {{"OMP_NUM_THREADS", "32"}, {"TF_WARNINGS", "none"}, {"YR_CONDA_HOME", "/home/snuser/.conda"}}); + //! [set conda environ with dependency] + } + { + //! [set conda environ with yaml file] + YR::RuntimeEnv runtimeEnv; + /* yaml file demo + * name: myenv3 + * channels: + * - conda-forge + * - defaults + * dependencies: + * - python=3.9 + * - numpy + * - pandas + * - cloudpickle=2.2.1 + * - msgpack-python=1.0.5 + * - protobuf + * - cython=3.0.10 + * - pyyaml=6.0.2 + */ + runtimeEnv.Set("conda", "/opt/conda/env-xpf.yaml"); + //! [set conda environ with yaml file] + } + runtimeEnv.Set>("env_vars", {{"OMP_NUM_THREADS", "32"}, {"TF_WARNINGS", "none"}, {"YR_CONDA_HOME", "/home/snuser/.conda"}}); + opts.runtimeEnv = runtimeEnv; + auto resFutureSquare = YR::PyFunction("calculator", "square").SetUrn(pyFunctionUrn).Options(opts).Invoke(2); + auto resSquare = *YR::Get(resFutureSquare); + std::cout << resSquare << std::endl; + return 0; +} diff --git a/api/cpp/example/ssl_example.cpp b/api/cpp/example/ssl_example.cpp index 6a7acee..2116b6d 100644 --- a/api/cpp/example/ssl_example.cpp +++ b/api/cpp/example/ssl_example.cpp @@ -31,6 +31,7 @@ int main() conf.certificateFilePath = tls_file_path + "/module.crt"; conf.verifyFilePath = tls_file_path + "/ca.crt"; conf.privateKeyPath = tls_file_path + "/module.key"; + std::strcpy(conf.privateKeyPaaswd, "paaswd"); conf.serverName = "serverName"; YR::Init(conf); diff --git a/api/cpp/example/stream_example.cpp b/api/cpp/example/stream_example.cpp new file mode 100644 index 0000000..56918d8 --- /dev/null +++ b/api/cpp/example/stream_example.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "yr/yr.h" + +int main() +{ + YR::Config conf; + YR::Init(conf); + //! [create producer] + std::string streamName = "streamName"; + // create stream producer + YR::ProducerConf producerConf{}; + std::shared_ptr producer = YR::CreateProducer(streamName, producerConf); + //! [create producer] + // create stream consumer + //! [create consumer] + YR::SubscriptionConfig config("subName", YR::SubscriptionType::STREAM); + std::shared_ptr consumer = YR::Subscribe(streamName, config); + //! [create consumer] + //! [producer send] + // producer send data + std::string str = "hello"; + YR::Element element((uint8_t *)(str.c_str()), str.size()); + producer->Send(element); + producer->Flush(); + //! [producer send] + //! [consumer recv] + // consumer receive data + std::vector elements; + consumer->Receive(1, 6000, elements); // timeout 6s + consumer->Ack(elements[0].id); + std::string actualData0(reinterpret_cast(elements[0].ptr), elements[0].size); + std::cout << "receive: " << actualData0 << std::endl; + //! [consumer recv] + //! [close producer] + producer->Close(); + //! [close producer] + //! [close consumer] + consumer->Close(); + //! [close consumer] + //! [delete stream] + // delete stream + YR::DeleteStream(streamName); + //! [delete stream] + return 0; +} \ No newline at end of file diff --git a/api/cpp/include/faas/Constant.h b/api/cpp/include/faas/Constant.h new file mode 100644 index 0000000..54a8a06 --- /dev/null +++ b/api/cpp/include/faas/Constant.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace Function { +enum ErrorCode { + OK = 0, + ERROR = 1, + + ILLEGAL_ACCESS = 4001, + FUNCTION_EXCEPTION = 4002, + USER_STATE_LARGE_ERROR = 4003, + ILLEGAL_RETURN = 4004, + USER_STATE_UNDEFINED_ERROR = 4005, + USER_INITIALIZATION_FUNCTION_EXCEPTION = 4009, + USER_LOAD_FUNCTION_EXCEPTION = 4014, + NO_SUCH_INSTANCE_NAME_ERROR_CODE = 4026, + INVALID_PARAMETER = 4040, + NO_SUCH_STATE_ERROR_CODE = 4041, + INTERNAL_ERROR = 110500, +}; +// UserErrorMax is the maximum value of user errors +const int USER_ERROR_MAX = 10000; + +const std::string SUCCESS_RESPONSE = "OK"; +const std::string ILLEGAL_ACCESS_MESSAGE = "function entry cannot be found"; +const std::string FUNCTION_EXCEPTION_MESSAGE = "function invocation exception: "; +const std::string USER_STATE_LARGE_ERROR_MESSAGE = "state content is too large"; +const std::string ILLEGAL_RETURN_MESSAGE = "function return value is too large"; +const std::string USER_STATE_UNDEFINED_ERROR_MESSAGE = "state is undefined"; + +const std::string INTERNAL_ERROR_MESSAGE = "internal system error"; +const uint32_t MAX_USER_EXCEPTION_LENGTH = 1024; +const uint32_t MAX_USER_STATE_LENGTH = 3584 * 1024; +} // namespace Function \ No newline at end of file diff --git a/api/cpp/include/faas/Context.h b/api/cpp/include/faas/Context.h new file mode 100644 index 0000000..aee33ed --- /dev/null +++ b/api/cpp/include/faas/Context.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "FunctionLogger.h" + +namespace Function { +class Context { +public: + Context() = default; + + virtual ~Context() = default; + + virtual const std::string GetAccessKey() const = 0; + + virtual const std::string GetSecretKey() const = 0; + + virtual const std::string GetSecurityAccessKey() const = 0; + + virtual const std::string GetSecuritySecretKey() const = 0; + + virtual const std::string GetToken() const = 0; + + virtual const std::string GetAlias() const = 0; + + virtual const std::string GetTraceId() const = 0; + + virtual const std::string GetInvokeId() const = 0; + + virtual const FunctionLogger &GetLogger() = 0; + + virtual const std::string GetState() const = 0; + + virtual const std::string GetInstanceId() const = 0; + + virtual const std::string GetInstanceLabel() const = 0; + + virtual void SetState(const std::string &state) = 0; + + virtual const std::string GetInvokeProperty() const = 0; + + virtual const std::string GetRequestID() const = 0; + + virtual const std::string GetUserData(std::string key) const = 0; + + virtual const std::string GetFunctionName() const = 0; + + virtual int GetRemainingTimeInMilliSeconds() const = 0; + + virtual int GetRunningTimeInSeconds() const = 0; + + virtual const std::string GetVersion() const = 0; + + virtual int GetMemorySize() const = 0; + + virtual int GetCPUNumber() const = 0; + + virtual const std::string GetProjectID() const = 0; + + virtual const std::string GetPackage() const = 0; +}; + +} // namespace Function diff --git a/api/cpp/include/faas/Function.h b/api/cpp/include/faas/Function.h new file mode 100644 index 0000000..5dde789 --- /dev/null +++ b/api/cpp/include/faas/Function.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include "Constant.h" +#include "Context.h" +#include "ObjectRef.h" + +namespace Function { + +struct InvokeOptions { + // unit is ms + int cpu = 0; + // unit is MB + int memory = 0; + + std::unordered_map aliasParams; +}; + +class Function { +public: + explicit Function(Context &context); + explicit Function(Context &context, const std::string &funcName); + explicit Function(Context &context, const std::string &funcName, const std::string &instanceName); + + virtual ~Function() = default; + + Function(const Function &) = delete; + + Function &operator=(const Function &) = delete; + + ObjectRef Invoke(const std::string &payload); + + Function &Options(const InvokeOptions &opt); + + const std::string GetObjectRef(ObjectRef &objectRef); + + void GetInstance(const std::string &functionName, const std::string &instanceName); + + void GetLocalInstance(const std::string &functionName, const std::string &instanceName); + + ObjectRef Terminate(); + + void SaveState(); + + const std::shared_ptr GetContext() const; + + std::string GetInstanceId() const; + +private: + std::shared_ptr context_; + std::string funcName_; + std::string instanceName_; + std::string instanceID_; + InvokeOptions options_; +}; +} // namespace Function diff --git a/api/cpp/include/faas/FunctionError.h b/api/cpp/include/faas/FunctionError.h new file mode 100644 index 0000000..fa1f99b --- /dev/null +++ b/api/cpp/include/faas/FunctionError.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include "Constant.h" + +namespace Function { +class FunctionError : public std::exception { +public: + FunctionError(int code, const std::string message) : errCode((ErrorCode)code), errMsg(message) {} + virtual ~FunctionError() = default; + const char *what() const noexcept override; + ErrorCode GetErrorCode() const; + const std::string GetMessage() const; + const std::string GetJsonString() const; + +private: + ErrorCode errCode; + std::string errMsg; +}; +} // namespace Function diff --git a/api/cpp/include/faas/FunctionLogger.h b/api/cpp/include/faas/FunctionLogger.h new file mode 100644 index 0000000..92a45ce --- /dev/null +++ b/api/cpp/include/faas/FunctionLogger.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace Function { +class FunctionLogger { +public: + FunctionLogger(){}; + virtual ~FunctionLogger(){}; + FunctionLogger(const std::string &traceId, const std::string &invokeId) : traceId(traceId), invokeId(invokeId) + { + } + + void setLevel(const std::string &level); + + void Info(std::string message, ...); + + void Warn(std::string message, ...); + + void Debug(std::string message, ...); + + void Error(std::string message, ...); + + void SetInvokeID(std::string invokeID); + + void SetTraceID(std::string traceID); + +private: + std::string traceId; + std::string invokeId; + std::string logLevel = "INFO"; + void Log(const std::string &level, const std::string &logMessage); + bool sendEmptyLog(const std::string &message, const std::string &level); +}; +} // namespace Function diff --git a/api/cpp/include/faas/ObjectRef.h b/api/cpp/include/faas/ObjectRef.h new file mode 100644 index 0000000..bed47a3 --- /dev/null +++ b/api/cpp/include/faas/ObjectRef.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace Function { +class ObjectRef { +public: + ObjectRef(std::string &futureId, std::string &instanceId) + : objectRefId_(futureId), instanceId_(instanceId), isResultExist_(false) + { + } + + virtual ~ObjectRef() = default; + + const std::string GetObjectRefId() const; + const std::string GetResult() const; + const std::string Get(); + bool GetResultFlag() const; + +private: + std::string objectRefId_; + std::string instanceId_; + std::string result_; + bool isResultExist_; +}; +} // namespace Function diff --git a/api/cpp/include/faas/Runtime.h b/api/cpp/include/faas/Runtime.h new file mode 100644 index 0000000..3a1a142 --- /dev/null +++ b/api/cpp/include/faas/Runtime.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include "Constant.h" +#include "Context.h" +#include "RuntimeHandler.h" + +namespace Function { +class Runtime { +public: + Runtime(); + virtual ~Runtime(); + + void RegisterHandler( + std::function handleRequestFunc); + + void RegisterInitializerFunction(std::function initializerFunc); + + void RegisterPreStopFunction(std::function preStopFunc); + + void InitState(std::function initStateFunc); + + void Start(int argc, char *argv[]); + +private: + std::function handleRequest; + std::function initializerFunction; + std::function preStopFunction; + std::function initStateFunction; + void InitRuntimeLogger() const; + void ReleaseRuntimeLogger() const; + void BuildRegisterRuntimeHandler() const; +}; +} // namespace Function diff --git a/api/cpp/include/faas/RuntimeHandler.h b/api/cpp/include/faas/RuntimeHandler.h new file mode 100644 index 0000000..05c3881 --- /dev/null +++ b/api/cpp/include/faas/RuntimeHandler.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "Context.h" + +namespace Function { +/* + * Customer implement this interface to handle runtime message + */ +class RuntimeHandler { +public: + RuntimeHandler() = default; + virtual ~RuntimeHandler() = default; + + virtual std::string HandleRequest(const std::string &request, Context &context) = 0; + + virtual void InitState(const std::string &request, Context &context) = 0; + + virtual void PreStop(Context &context) = 0; + + virtual void Initializer(Context &context) = 0; +}; +} // namespace Function + diff --git a/api/cpp/include/yr/api/affinity.h b/api/cpp/include/yr/api/affinity.h index dab6dda..d98138b 100644 --- a/api/cpp/include/yr/api/affinity.h +++ b/api/cpp/include/yr/api/affinity.h @@ -178,6 +178,23 @@ public: { } + /*! + * @brief Construct an affinity operation object. + * + * @param kind The affinity kind mainly include `RESOURCE` and `INSTANCE`. `RESOURCE` refers to predefined resource + * label affinity, while `INSTANCE` refers to dynamic instance label affinity. + * @param type The affinity type mainly include `PREFERRED`, `PREFERRED_ANTI`, `REQUIRED` and `REQUIRED_ANTI`, + * which represent weak affinity, weak anti-affinity, strong affinity, and strong anti-affinity, + * respectively. + * @param operators Label operation list, see LabelOperator and its subclasses for details. + * @param affinityScope The instance affinity scope mainly includes two types: `AFFINITYSCOPE_POD` + * and `AFFINITYSCOPE_NODE`. + */ + Affinity(const std::string &kind, const std::string &type, const std::list &operators, + const std::string &affinityScope) + : affinityKind(kind), affinityType(type), labelOperators(operators), affinityScope(affinityScope) { + } + /*! * @brief Default Destructor. */ @@ -213,10 +230,29 @@ public: return this->labelOperators; } + /*! + * @brief Get the scope of instance affinity. + * + * @return The scope of instance affinity. + */ + std::string GetAffinityScope() const + { + return this->affinityScope; + } + + /*! + * @brief Set the scope of instance affinity. + */ + void SetAffinityScope(const std::string &affinityScope) + { + this->affinityScope = affinityScope; + } + private: std::string affinityKind; std::string affinityType; std::list labelOperators; + std::string affinityScope; }; /*! @class ResourcePreferredAffinity affinity.h "include/yr/api/affinity.h" diff --git a/api/cpp/include/yr/api/buffer.h b/api/cpp/include/yr/api/buffer.h index aaf494b..dc99e65 100644 --- a/api/cpp/include/yr/api/buffer.h +++ b/api/cpp/include/yr/api/buffer.h @@ -19,7 +19,26 @@ namespace YR { class Buffer { public: - virtual uint64_t GetSize() const = 0; - virtual const void *ImmutableData() const = 0; + Buffer() = default; + + Buffer(void *data, uint64_t size) + { + data_ = data; + size_ = size; + } + + virtual uint64_t GetSize() const + { + return size_; + } + + virtual const void *ImmutableData() const + { + return data_; + } + +private: + void *data_ = nullptr; + uint64_t size_ = 0; }; } // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/api/config.h b/api/cpp/include/yr/api/config.h index 9ce1e7c..71c16fa 100644 --- a/api/cpp/include/yr/api/config.h +++ b/api/cpp/include/yr/api/config.h @@ -87,6 +87,9 @@ struct Config { * defaults to CPU cores count. Default is `10`. */ uint32_t localThreadPoolSize = 10; + + bool inCluster = true; + /** * @brief Maximum idle time for instances. Instances will be terminated if idle beyond this duration. Unit: seconds. * Valid range: `1~3000`. Defaults to `2` if not configured. @@ -108,39 +111,21 @@ struct Config { * @brief Server certificate file path. */ std::string verifyFilePath = ""; - /** - * @brief Server name for TLS. - */ + char privateKeyPaaswd[MAX_PASSWD_LENGTH] = {0}; + std::string encryptPrivateKeyPasswd; std::string serverName = ""; - /** - * @brief Enable data system authentication. Default is `false`. - */ + std::shared_ptr tlsContext; + uint32_t httpIocThreadsNum = DEFAULT_HTTP_IOC_THREADS_NUM; bool enableDsAuth = false; - /** - * @brief `true`: Enable data system encryption (requires public/private key configs). Default is `false`. - */ bool enableDsEncrypt = false; - /** - * @brief The path of worker public key for data system tls authentication, if enableDsEncrypt is ``true`` and the - * dsPublicKeyContextPath is empty, an exception will be thrown. - */ - std::string dsPublicKeyContextPath = ""; - /** - * @brief The path of client public key for data system tls authentication, if enableDsEncrypt is ``true`` and the - * runtimePublicKeyContextPath is empty, an exception will be thrown. - */ - std::string runtimePublicKeyContextPath = ""; - /** - * @brief The path of client private key for data system tls authentication, if enableDsEncrypt is ``true`` and the - * runtimePrivateKeyContextPath is empty, an exception will be thrown. - */ - std::string runtimePrivateKeyContextPath = ""; + std::string dsPublicKeyContext = ""; + std::string runtimePublicKeyContext = ""; + std::string runtimePrivateKeyContext = ""; + std::string encryptDsPublicKeyContext; + std::string encryptRuntimePublicKeyContext; + std::string encryptRuntimePrivateKeyContext; std::string primaryKeyStoreFile; std::string standbyKeyStoreFile; - /** - * @brief Limits the maximum number of stateless function instances. Valid range: `1~65536`. Defaults to `-1` if - * unconfigured. The `Init` interface throws exceptions for invalid values. - */ int maxTaskInstanceNum = -1; /** * @brief Custom path for metrics logs. The corresponding environment variable is `YR_METRICS_LOG_PATH`. @@ -150,7 +135,7 @@ struct Config { * @brief Whether to enable metrics collection. `false` means disabled, `true` means enabled. Only effective within * the cluster. Default is `false`. The corresponding environment variable is `YR_ENABLE_METRICS`. */ - bool enableMetrics = false; + bool enableMetrics = true; uint32_t defaultGetTimeoutSec = 300; // 0 means never time out bool isDriver = true; // internal use only, user do not set it. /** @@ -202,15 +187,11 @@ struct Config { * @brief Custom environment variables for runtime (only `LD_LIBRARY_PATH` supported). */ std::unordered_map customEnvs; - /** - * @brief Enable low-reliability mode for stateless instances (improves creation performance in large-scale - * scenarios). - */ + std::string httpVersion = ""; + bool autodeploy = false; + std::string tenantId = ""; bool isLowReliabilityTask = false; - /** - * @brief Attach `libruntime` instance to existing instances during initialization (only supports KV APIs). - * Default is `false`. - */ bool attach = false; + bool launchUserBinary = false; }; } // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/api/constant.h b/api/cpp/include/yr/api/constant.h index b8c2ba5..13b0b71 100644 --- a/api/cpp/include/yr/api/constant.h +++ b/api/cpp/include/yr/api/constant.h @@ -66,6 +66,15 @@ const std::string LABEL_EXISTS = "LabelExists"; * @brief Label operation type, indicating that there is no corresponding label. */ const std::string LABEL_DOES_NOT_EXIST = "LabelDoesNotExist"; +/*! @var AFFINITYSCOPE_POD + * @brief Instance affinity scope, indicating Pod-level affinity. + */ +const std::string AFFINITYSCOPE_POD = "POD"; +/*! @var AFFINITYSCOPE_NODE + * @brief Instance affinity scope, indicating Node-level affinity. + */ +const std::string AFFINITYSCOPE_NODE = "NODE"; +const uint32_t DEFAULT_HTTP_IOC_THREADS_NUM = 200; const size_t MAX_OPTIONS_RETRY_TIME = 10; const int DEFAULT_RECYCLETIME = 2; const std::string FUNCTION_NOT_REGISTERED_ERROR_MSG = "Function to be invoked is not registered by using YR_INVOKE"; diff --git a/api/cpp/include/yr/api/function_handler.h b/api/cpp/include/yr/api/function_handler.h index 23ea54a..49e9274 100644 --- a/api/cpp/include/yr/api/function_handler.h +++ b/api/cpp/include/yr/api/function_handler.h @@ -19,6 +19,8 @@ #include #include +#include + #include "yr/api/args_check.h" #include "yr/api/cross_lang.h" #include "yr/api/function_manager.h" @@ -131,8 +133,12 @@ static void PackInvokeArgsImpl(YR::internal::FunctionLanguage language, std::vec AddPythonPlaceholder(language, invokeArgs); InvokeArg invokeArg{}; localNestedObjList.clear(); - invokeArg.buf = std::move(Serialize(arg)); // Serialize add nested objects to localNestedObjList - invokeArg.nestedObjects.swap(localNestedObjList); + if constexpr (boost::is_same::value) { + invokeArg.yrBuf = std::move(arg); + } else { + invokeArg.buf = std::move(Serialize(arg)); // Serialize add nested objects to localNestedObjList + invokeArg.nestedObjects.swap(localNestedObjList); + } localNestedObjList.clear(); invokeArg.isRef = false; invokeArgs.emplace_back(std::move(invokeArg)); diff --git a/api/cpp/include/yr/api/hetero_manager.h b/api/cpp/include/yr/api/hetero_manager.h index 2239845..9d898b0 100644 --- a/api/cpp/include/yr/api/hetero_manager.h +++ b/api/cpp/include/yr/api/hetero_manager.h @@ -57,13 +57,13 @@ public: } @endcode */ - static void Delete(const std::vector &objectIds, std::vector &failedObjectIds) + static void DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) { CheckInitialized(); if (YR::internal::IsLocalMode()) { throw HeteroException::IncorrectFunctionUsageException("Delete is not supported in local mode"); } - YR::internal::GetRuntime()->Delete(objectIds, failedObjectIds); + YR::internal::GetRuntime()->DevDelete(objectIds, failedObjectIds); } /*! @@ -93,13 +93,13 @@ public: } @endcode */ - static void LocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) + static void DevLocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) { CheckInitialized(); if (YR::internal::IsLocalMode()) { throw HeteroException::IncorrectFunctionUsageException("LocalDelete is not supported in local mode"); } - YR::internal::GetRuntime()->LocalDelete(objectIds, failedObjectIds); + YR::internal::GetRuntime()->DevLocalDelete(objectIds, failedObjectIds); } /*! diff --git a/api/cpp/include/yr/api/instance_creator.h b/api/cpp/include/yr/api/instance_creator.h index 5d4d861..a516656 100644 --- a/api/cpp/include/yr/api/instance_creator.h +++ b/api/cpp/include/yr/api/instance_creator.h @@ -183,8 +183,8 @@ public: handler.SetClassName(funcMeta.className); handler.SetFunctionUrn(funcMeta.funcUrn); handler.SetNeedOrder(opts.needOrder); - handler.SetName(funcMeta.name.value_or("")); - handler.SetNs(funcMeta.ns.value_or("")); + handler.SetName(funcMeta.name); + handler.SetNs(funcMeta.ns); if (InstanceRangeEnabled(this->opts.instanceRange)) { handler.SetGroupName(this->opts.groupName); } diff --git a/api/cpp/include/yr/api/invoke_arg.h b/api/cpp/include/yr/api/invoke_arg.h index 0d6f3e5..38a1a61 100644 --- a/api/cpp/include/yr/api/invoke_arg.h +++ b/api/cpp/include/yr/api/invoke_arg.h @@ -25,6 +25,7 @@ struct InvokeArg { { if (&rhs != this) { buf = std::move(rhs.buf); + yrBuf = std::move(rhs.yrBuf); isRef = rhs.isRef; objId = std::move(rhs.objId); nestedObjects = std::move(rhs.nestedObjects); @@ -35,6 +36,7 @@ struct InvokeArg { { if (&rhs != this) { buf = std::move(rhs.buf); + yrBuf = std::move(rhs.yrBuf); isRef = rhs.isRef; objId = std::move(rhs.objId); nestedObjects = std::move(rhs.nestedObjects); @@ -46,6 +48,7 @@ struct InvokeArg { InvokeArg &operator=(InvokeArg const &) = delete; msgpack::sbuffer buf; + Buffer yrBuf; bool isRef = false; std::string objId; // objId records objs for dependency resolver std::unordered_set nestedObjects; // nestedObjects records objs for ref count diff --git a/api/cpp/include/yr/api/invoke_options.h b/api/cpp/include/yr/api/invoke_options.h index b2399b5..608229e 100644 --- a/api/cpp/include/yr/api/invoke_options.h +++ b/api/cpp/include/yr/api/invoke_options.h @@ -21,6 +21,7 @@ #include "yr/api/affinity.h" #include "yr/api/exception.h" +#include "yr/api/runtime_env.h" namespace YR { /*! @@ -109,6 +110,21 @@ struct InstanceRange { RangeOptions rangeOpts; }; +/** + * @struct DebugConfig + * @brief 结构体 DebugConfig 是所有 debug 相关的配置,用于指定初始化 debug 实例的配置。 + * @note 当指定 enable 为 true 时,runtime 进程会关闭心跳检测,且可以被远程 debug 工具连接调试 + */ +struct DebugConfig { + DebugConfig() = default; + ~DebugConfig() = default; + /** + * @brief 指定是否初始化 debug 实例。当指定 enable 为 true 时,runtime 进程会关闭心跳检测,且可以被远 + * 程 debug 工具连接调试 + */ + bool enable = false; +}; + /*! * @struct InvokeOptions invoke_options.h "include/yr/api/invoke_options.h" * @brief used to set the invoke options. @@ -200,6 +216,12 @@ struct InvokeOptions { */ size_t retryTimes = 0; + /*! + * @var size_t maxRetryTime + * @brief 定义无限重试最大重试次数,默认为-1,表示无限重试。 + */ + int maxRetryTime = -1; + /*! * @var std::function retryChecker * @brief 无状态函数的重试判断钩子,默认为空。当 retryTimes = 0 时,本参数不生效。 @@ -248,6 +270,8 @@ struct InvokeOptions { */ std::unordered_map envVars; + DebugConfig debug; + /*! * @var std::sring traceId * @brief 设置函数调用的 traceId,用于链路追踪。 @@ -260,6 +284,32 @@ struct InvokeOptions { */ int timeout = -1; + /*! + * @var bool preemptedAllowed + * @brief 实例是否可以被抢占,仅在优先级场景下(元戎部署的 maxPriority 配置项大于 0 的场景)生效。默认为 false + */ + bool preemptedAllowed = false; + + /*! + * @var int instancePriority + * @brief 实例的优先级,数值越大优先级越高,高优先级的实例可以抢占低优先级且被配置为(preemptedAllowed = true) + * 的实例。仅在优先级场景下(元戎部署的 maxPriority 配置项大于 0 的场景)生效。instancePriority 的最小值为 0,最 + * 大值为元戎部署的 maxPriority 配置。默认为 0 + */ + int instancePriority = 0; + + /*! + * @var size_t scheduleTimeoutMs + * @brief 实例的调度超时时间,单位毫秒,取值[-1, int64_t 类型最大值]。默认为 30000 + */ + int64_t scheduleTimeoutMs = 30000; + + /*! + * @var RuntimeEnv runtimeEnv + * @brief 设置python的virtual env + */ + RuntimeEnv runtimeEnv; + void CheckOptionsValid() { // check retry time @@ -313,8 +363,8 @@ struct FuncMeta { std::string funcUrn; std::string className; FunctionLanguage language; - std::optional name; - std::optional ns; + std::string name; + std::string ns; bool isAsync; bool isGenerator; }; diff --git a/api/cpp/include/yr/api/kv.h b/api/cpp/include/yr/api/kv.h new file mode 100644 index 0000000..92eb65c --- /dev/null +++ b/api/cpp/include/yr/api/kv.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "yr/api/check_initialized.h" +#include "yr/api/kv_manager.h" + +namespace YR { +/** + * @brief Interface for key-value storage. + * @return the kv manager + */ +inline KVManager &KV() +{ + CheckInitialized(); + return KVManager::Singleton(); +} +} // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/api/kv_manager.h b/api/cpp/include/yr/api/kv_manager.h index 8b387da..7ae114e 100644 --- a/api/cpp/include/yr/api/kv_manager.h +++ b/api/cpp/include/yr/api/kv_manager.h @@ -1067,6 +1067,26 @@ public: : internal::GetRuntime()->KVDel(keys, delParam); } + /** + * @brief Check if multiple specified keys exist. + * + * @param keys A list of keys to check. The maximum number of keys is 10000。 + * @return std::vector A vector containing the existence of the corresponding key. The order of values + * corresponds to the order of keys. + * + * @throw Exception Thrown in the following cases: + * - **1001**: Invalid input parameters (e.g., empty keys or invalid characters). + * - **1002**: Internal communication errors. + * - **4299**: DataSystem failed. + */ + static std::vector Exist(const std::vector &keys) + { + CheckInitialized(); + return internal::IsLocalMode() + ? internal::GetLocalModeRuntime()->KVExist(keys) + : internal::GetRuntime()->KVExist(keys); + } + private: template static void CheckMSetTxParams(const std::vector &keys, const std::vector &vals, diff --git a/api/cpp/include/yr/api/local_mode_runtime.h b/api/cpp/include/yr/api/local_mode_runtime.h index fa73f38..61a4bd9 100644 --- a/api/cpp/include/yr/api/local_mode_runtime.h +++ b/api/cpp/include/yr/api/local_mode_runtime.h @@ -55,12 +55,14 @@ public: std::shared_ptr KVRead(const std::string &key, int timeoutMs); std::vector> KVRead(const std::vector &keys, int timeoutMs, - bool allowPartial = false); + bool allowPartial = false); void KVDel(const std::string &key); std::vector KVDel(const std::vector &keys); + std::vector KVExist(const std::vector &keys); + template bool IsAllFail(const std::vector> &results); @@ -80,7 +82,7 @@ private: std::shared_ptr waitRequestManager_; std::shared_ptr stateStore_; std::shared_ptr pool_; - std::atomic initPool_ = false; + std::atomic initPool_{false}; }; template diff --git a/api/cpp/include/yr/api/local_state_store.h b/api/cpp/include/yr/api/local_state_store.h index b06f3e1..db8b0cf 100644 --- a/api/cpp/include/yr/api/local_state_store.h +++ b/api/cpp/include/yr/api/local_state_store.h @@ -47,6 +47,8 @@ public: std::vector Del(const std::vector &keys); + std::vector Exist(const std::vector &keys); + // clear all key-values in kv_map void Clear() noexcept; diff --git a/api/cpp/include/yr/api/mutable_buffer.h b/api/cpp/include/yr/api/mutable_buffer.h new file mode 100644 index 0000000..d306437 --- /dev/null +++ b/api/cpp/include/yr/api/mutable_buffer.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace YR { +template +class ObjectRef; +class MutableBuffer { +public: + virtual ~MutableBuffer() = default; + + virtual void *MutableData(); + + virtual ObjectRef Publish(); + + virtual int64_t GetSize(); +}; +} // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/api/named_instance.h b/api/cpp/include/yr/api/named_instance.h index ba52e48..1cc6e7b 100644 --- a/api/cpp/include/yr/api/named_instance.h +++ b/api/cpp/include/yr/api/named_instance.h @@ -175,6 +175,23 @@ public: */ void Terminate(bool isSync); + /** + * @brief Supports synchronous or asynchronous termination. For an instance handle, this indicates deleting an + * already created function instance. When synchronous termination is not enabled, the default timeout for the + * current kill request is 30 seconds. In scenarios such as high disk load or etcd failures, the kill request + * processing time may exceed 30 seconds, causing the interface to throw a timeout exception. Since the kill request + * has a retry mechanism, users can choose to ignore the timeout exception or retry after capturing it. When + * synchronous termination is enabled, this interface will block until the instance completely exits. + * @param isSync Whether to enable synchronous termination. If `true`, it sends a kill request with the signal + * `killInstanceSync` to the function-proxy, and the kernel synchronously kills the instance. If `false`, it sends a + * kill request with the signal `killInstance` to the function-proxy, and the kernel asynchronously kills the + * instance. + * @return A future for the kill result. + * + * @snippet{trimleft} instance_example.cpp terminate instance sync + */ + std::shared_future AsyncTerminate(bool isSync); + /** * @brief Specifies the timeout in seconds for waiting until a set of instances in Range scheduling are scheduled * and their instance IDs are returned, generating a list of `NamedInstance` objects. This parameter is optional. If @@ -453,6 +470,23 @@ void NamedInstance::Terminate(bool isSync) YR::internal::GetRuntime()->TerminateInstance(instanceId); } } + +template +std::shared_future NamedInstance::AsyncTerminate(bool isSync) +{ + CheckInitialized(); + if (internal::IsLocalMode() || this->alwaysLocalMode) { + YR::internal::LocalInstanceManager::Singleton().DelLocalInstance(instanceId); + auto promise = std::make_shared>(); + promise->set_value(); + return promise->get_future().share(); + } + if (!groupName.empty()) { + throw Exception::IncorrectFunctionUsageException("range instance does not support async terminate"); + } + return YR::internal::GetRuntime()->TerminateInstanceAsync(instanceId, isSync); +} + } // namespace YR namespace msgpack { diff --git a/api/cpp/include/yr/api/node.h b/api/cpp/include/yr/api/node.h new file mode 100644 index 0000000..81d7882 --- /dev/null +++ b/api/cpp/include/yr/api/node.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +namespace YR { +struct Node { + std::string id; + bool alive; + std::unordered_map resources; + std::unordered_map> labels; +}; +} // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/api/object_ref.h b/api/cpp/include/yr/api/object_ref.h index b96a586..1a260f9 100644 --- a/api/cpp/include/yr/api/object_ref.h +++ b/api/cpp/include/yr/api/object_ref.h @@ -53,7 +53,7 @@ public: if (internal::IsLocalMode()) { this->isLocal = true; } else if (needIncre) { - YR::internal::GetRuntime()->IncreGlobalReference({id}); + YR::internal::GetRuntime()->IncreGlobalReference({id}, false); } } @@ -63,7 +63,7 @@ public: if (isLocal) { future_ = rhs.future_; } else { - YR::internal::GetRuntime()->IncreGlobalReference({rhs.objId}); + YR::internal::GetRuntime()->IncreGlobalReference({rhs.objId}, false); } } @@ -80,7 +80,7 @@ public: if (isLocal) { future_ = rhs.future_; } else { - YR::internal::GetRuntime()->IncreGlobalReference({rhs.objId}); + YR::internal::GetRuntime()->IncreGlobalReference({rhs.objId}, false); if (!objId.empty()) { YR::internal::GetRuntime()->DecreGlobalReference({objId}); } diff --git a/api/cpp/include/yr/api/object_store.h b/api/cpp/include/yr/api/object_store.h index 5405576..17e3ce7 100644 --- a/api/cpp/include/yr/api/object_store.h +++ b/api/cpp/include/yr/api/object_store.h @@ -19,6 +19,9 @@ #include #include #include + +#include + #include "yr/api/buffer.h" #include "yr/api/err_type.h" #include "yr/api/object_ref.h" @@ -55,13 +58,17 @@ void CheckIfObjectRefsHomogeneous(const std::vector> &objs) } template -void CheckObjsAndTimeout(const std::vector> &objs, int timeoutSec) +void CheckObjs(const std::vector> &objs) { CheckInitialized(); if (objs.empty()) { throw Exception::InvalidParamException("Get does not accept empty object list"); } CheckIfObjectRefsHomogeneous(objs); +} + +inline void CheckTimeout(int timeoutSec) +{ if (timeoutSec < NO_TIMEOUT) { std::string msg = "get config timeout (" + std::to_string(timeoutSec) + " s) is invalid"; throw YR::Exception::InvalidParamException(msg); @@ -79,8 +86,12 @@ void ExtractSuccessObjects(std::vector &remainIds, for (size_t i = 0; i < remainIds.size(); i++) { if ((i < remainBuffers.size()) && remainBuffers[i]) { auto &indices = idToIndices[remainIds[i]]; - auto obj = YR::internal::Deserialize>(remainBuffers[i]); - returnObjects[indices.front()] = obj; + if constexpr (boost::is_same::value) { + returnObjects[indices.front()] = remainBuffers[i]; + } else { + auto obj = YR::internal::Deserialize>(remainBuffers[i]); + returnObjects[indices.front()] = obj; + } indices.pop_front(); } else { newRemainIds.emplace_back(std::move(remainIds[i])); diff --git a/api/cpp/include/yr/api/runtime.h b/api/cpp/include/yr/api/runtime.h index 7ee4d36..e2a69ab 100644 --- a/api/cpp/include/yr/api/runtime.h +++ b/api/cpp/include/yr/api/runtime.h @@ -29,6 +29,9 @@ #include "yr/api/future.h" #include "yr/api/invoke_arg.h" #include "yr/api/invoke_options.h" +#include "yr/api/mutable_buffer.h" +#include "yr/api/node.h" +#include "yr/api/stream.h" #include "yr/api/wait_result.h" namespace YR { @@ -316,6 +319,8 @@ public: virtual std::string Put(std::shared_ptr data, const std::unordered_set &nestedObjectIds) = 0; + virtual std::string Put(std::shared_ptr data, const std::unordered_set &nestedObjectIds) = 0; + virtual void Put(const std::string &objId, std::shared_ptr data, const std::unordered_set &nestedId) = 0; @@ -344,7 +349,21 @@ public: virtual std::vector KVDel(const std::vector &keys, const DelParam &delParam = {}) = 0; - virtual void IncreGlobalReference(const std::vector &objectIds) = 0; + virtual std::vector KVExist(const std::vector &keys) = 0; + + virtual std::shared_ptr CreateStreamProducer(const std::string &streamName, + ProducerConf producerConf) = 0; + + virtual std::shared_ptr CreateStreamConsumer(const std::string &streamName, + const SubscriptionConfig &config, bool autoAck = false) = 0; + + virtual void DeleteStream(const std::string &streamName) = 0; + + virtual void QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) = 0; + + virtual void QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) = 0; + + virtual void IncreGlobalReference(const std::vector &objectIds, bool toDatasystem = true) = 0; virtual void DecreGlobalReference(const std::vector &objectIds) = 0; @@ -384,9 +403,10 @@ public: virtual void LoadState(const int &timeout) = 0; - virtual void Delete(const std::vector &objectIds, std::vector &failedObjectIds) = 0; + virtual void DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) = 0; - virtual void LocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) = 0; + virtual void DevLocalDelete(const std::vector &objectIds, + std::vector &failedObjectIds) = 0; virtual void DevSubscribe(const std::vector &keys, const std::vector &blob2dList, std::vector> &futureVec) = 0; @@ -403,22 +423,33 @@ public: virtual std::string Put(std::shared_ptr data, const std::unordered_set &nestedObjectIds, const CreateParam &createParam) = 0; + virtual std::string Put(std::shared_ptr data, const std::unordered_set &nestedObjectIds, + const CreateParam &createParam) = 0; + virtual void KVWrite(const std::string &key, std::shared_ptr value, SetParamV2 setParam) = 0; virtual void KVMSetTx(const std::vector &keys, const std::vector> &vals, const MSetParam &mSetParam) = 0; virtual internal::FuncMeta GetInstance(const std::string &name, const std::string &nameSpace, int timeoutSec) = 0; - + virtual std::string GetGroupInstanceIds(const std::string &objectId) = 0; virtual void SaveGroupInstanceIds(const std::string &objectId, const std::string &groupInsIds, const InvokeOptions &opts) = 0; - + virtual std::string GetInstanceRoute(const std::string &objectId) = 0; virtual void SaveInstanceRoute(const std::string &objectId, const std::string &instanceRoute) = 0; virtual void TerminateInstanceSync(const std::string &instanceId) = 0; + + virtual std::vector Nodes() = 0; + + virtual std::shared_ptr CreateMutableBuffer(uint64_t size) = 0; + + virtual std::vector> GetMutableBuffer(const std::vector &ids, + int timeout) = 0; + virtual std::shared_future TerminateInstanceAsync(const std::string &instanceId, bool isSync) = 0; }; } // namespace YR diff --git a/api/cpp/include/yr/api/runtime_env.h b/api/cpp/include/yr/api/runtime_env.h new file mode 100644 index 0000000..5a3a3e5 --- /dev/null +++ b/api/cpp/include/yr/api/runtime_env.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "json.hpp" + +#include "yr/api/constant.h" +#include "yr/api/exception.h" + +namespace YR { + +/*! + * @class RuntimeEnv + * @brief This class provides interfaces of setting runtime environments for python actor. + */ +class RuntimeEnv { +public: + /*! + * @brief Set a runtime env field by name and Object. + * + * @tparam T The second param type of the value. + * @param name The runtime env plugin name. + * @param value An object with jsonable type of nlohmann/json. + * + * @note Constraints + * - RuntimeEnv only processes keys named conda, pip, working_dir, and env_vars; other keys neither take effect + * nor trigger errors. + * - [1] **pip** `String[]/Iterable` Specifies Python dependencies (mutually exclusive with conda), + * for example: + * @snippet{trimleft} runtime_env_example1.cpp set pip + * - [2] **working_dir** `str` Specify the code path, but this feature is not enabled yet. for example: + * @snippet{trimleft} runtime_env_example1.cpp set working_dir + * - [3] **env_vars** `JSON` Environment variables (values must be strings), for example: + * @snippet{trimleft} runtime_env_example1.cpp set env_vars + * - [4] **conda** `str/JSON` Conda configuration (requires YR_CONDA_HOME), for example:
+ * (1). Specify scheduling to an existing conda environment. + * @snippet{trimleft} runtime_env_example1.cpp set existed conda environ + * (2). Create a new environment and specify its dependencies and environment name (optional). + * @snippet{trimleft} runtime_env_example1.cpp set conda environ with dependency + * (3). Create a new environment, specify dependencies and environment names through files. + * @snippet{trimleft} runtime_env_example1.cpp set conda environ with yaml file + * + * @snippet{trimleft} runtime_env_example.cpp runtime env demo + */ + template + void Set(const std::string &name, const T &value); + + /** + * @brief Get the object of a runtime env field. + * @tparam T The return type of the function. + * @param name The runtime env plugin name. + * @return A runtime env field with T type. + */ + template + T Get(const std::string &name) const; + + /** + * @brief Set a runtime env field by name and json string. + * @param name The runtime env plugin name. + * @param jsonStr A json string represents the runtime env field. + */ + void SetJsonStr(const std::string &name, const std::string &jsonStr); + + /** + * @brief Get the json string of a runtime env field. + * @param name The runtime env plugin name. + * @return A string type object with runtime env field. + */ + std::string GetJsonStr(const std::string &name) const; + + /** + * @brief Whether a field is contained. + * @param name The runtime env plugin name. + * @return Whether the filed is contained. + */ + bool Contains(const std::string &name) const; + + /** + * @brief Remove a field by name. + * @param name The runtime env plugin name. + * @return true if remove an existing field, otherwise false. + */ + bool Remove(const std::string &name); + + /** + * @brief Whether the runtime env is empty. + * @return Whether the runtime env is empty. + */ + bool Empty() const + { + return fields_.empty(); + } + +private: + nlohmann::json fields_; +}; + +template +inline void RuntimeEnv::Set(const std::string &name, const T &value) +{ + try { + nlohmann::json valueJ = value; + fields_[name] = valueJ; + } catch (std::exception &e) { + throw Exception::InvalidParamException("Failed to set the field " + name + ": " + e.what()); + } +} + +template +inline T RuntimeEnv::Get(const std::string &name) const +{ + if (!Contains(name)) { + throw Exception::InvalidParamException("The field " + name + " not found."); + } + try { + return fields_[name].get(); + } catch (std::exception &e) { + throw Exception::InvalidParamException("Failed to get the field " + name + ": " + e.what()); + } +} +} // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/api/serdes.h b/api/cpp/include/yr/api/serdes.h index e5811bd..627c183 100644 --- a/api/cpp/include/yr/api/serdes.h +++ b/api/cpp/include/yr/api/serdes.h @@ -31,12 +31,17 @@ static msgpack::sbuffer Serialize(const T &value) return buffer; } +static bool referenceFunc(msgpack::type::object_type type, std::size_t length, void* user_data) +{ + return true; +} + template static T Deserialize(const msgpack::sbuffer &data) { try { - msgpack::unpacked unpacked = msgpack::unpack(data.data(), data.size(), 0); - return unpacked.get().as(); + msgpack::object_handle oh = msgpack::unpack(data.data(), data.size(), referenceFunc); + return oh.get().as(); } catch (std::exception &e) { std::string msg = "failed to deserialize input argument whose type=" + std::string(typeid(T).name()) + " and len=" + std::to_string(data.size()) + ", original exception message: " + std::string(e.what()); @@ -48,9 +53,9 @@ template static T Deserialize(const std::shared_ptr data) { try { - msgpack::unpacked unpacked = - msgpack::unpack(static_cast(data->ImmutableData()), data->GetSize(), 0); - return unpacked.get().as(); + msgpack::object_handle oh = + msgpack::unpack(static_cast(data->ImmutableData()), data->GetSize(), referenceFunc); + return oh.get().as(); } catch (std::exception &e) { std::string msg = "failed to deserialize input argument whose type=" + std::string(typeid(T).name()) + " and len=" + std::to_string(data->GetSize()) + diff --git a/api/cpp/include/yr/api/stream.h b/api/cpp/include/yr/api/stream.h new file mode 100644 index 0000000..5a78dd4 --- /dev/null +++ b/api/cpp/include/yr/api/stream.h @@ -0,0 +1,209 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace YR { +enum SubscriptionType { STREAM, ROUND_ROBIN, KEY_PARTITIONS }; + +/** + * @struct Element + */ +struct Element { + Element(uint8_t *ptr = nullptr, uint64_t size = 0, uint64_t id = ULONG_MAX) : ptr(ptr), size(size), id(id) {} + + ~Element() = default; + /** + * @brief Pointer to the data + */ + uint8_t *ptr; + + /** + * @brief Size of the data + */ + uint64_t size; + + /** + * @brief ID of the Element + */ + uint64_t id; +}; + +/** + * @struct ProducerConf + */ +struct ProducerConf { + /** + * @brief After sending, the producer will delay for the specified duration before triggering a flush. + * <0: Do not automatically flush; 0: Flush immediately; otherwise, the delay duration in milliseconds. Default: 5. + */ + int64_t delayFlushTime = 5; + + /** + * @brief Specifies the buffer page size for the producer, in bytes (B). When a page is full, it will trigger a + * flush. Default: 1 MB. Must be greater than 0 and a multiple of 4 KB. + */ + int64_t pageSize = 1024 * 1024ul; + + /** + * @brief Specifies the maximum shared memory size that the stream can use on a worker, in bytes (B). + * Default: 100 MB. Range: [64 KB, size of worker shared memory]. + */ + uint64_t maxStreamSize = 100 * 1024 * 1024ul; + + /** + * @brief Specifies whether to enable automatic cleanup for the stream. Default: false (disabled). + * When the last producer/consumer exits, the stream will be automatically cleaned up. + */ + bool autoCleanup = false; + + /** + * @brief Specifies whether to enable content encryption for the stream. Default: false (disabled). + */ + bool encryptStream = false; + + /** + * @brief Specifies how many consumers should retain the producer's data. Default: 0. + * If set to 0, data will not be retained if there are no consumers. When consumers are created later, they may not + * receive the data. This parameter is only effective for the first consumer created, and the current valid range is + * [0, 1]. Multiple consumers are not supported. + */ + uint64_t retainForNumConsumers = 0; + + /** + * @brief Specifies the reserved memory size, in bytes (B). When creating a producer, it will attempt to reserve + * reserveSize bytes of memory. If reservation fails, an exception will be thrown. reserveSize must be an integer + * multiple of pageSize and within the range [0, maxStreamSize]. If reserveSize is 0, it will be set to pageSize by + * default. Default: 0. + */ + uint64_t reserveSize = 0; + + /** + * @brief Extended configuration for the producer. Common configuration items include: + * "STREAM_MODE": The stream mode, which can be "MPMC", "MPSC", or "SPSC". Default: "MPMC". If an unsupported mode + * is specified, an exception will be thrown. MPMC represents multi-producer multi-consumer, MPSC represents + * multi-producer single-consumer, and SPSC represents single-producer single-consumer. If MPSC or SPSC is selected, + * the data system will enable multi-stream shared page functionality internally. + */ + std::unordered_map extendConfig; + + /** + * @brief Custom trace ID for troubleshooting and performance optimization. Only supported in the cloud; settings + * outside the cloud will not take effect. Maximum length: 36. Valid characters must match the regular expression: + * ``^[a-zA-Z0-9\~\.\-\/_!@#%\^\&\*\(\)\+\=\:;]*$``. + */ + std::string traceId; +}; + +/** + * @struct SubscriptionConfig + */ +struct SubscriptionConfig { + /** + * @brief Subscription name + */ + std::string subscriptionName; + + /** + * @brief Subscription type, including three types: STREAM, ROUND_ROBIN, and KEY_PARTITIONS. + * STREAM indicates that a single consumer in a subscription group consumes the stream. + * ROUND_ROBIN indicates that multiple consumers in a subscription group consume the stream in a round-robin + * load-balancing manner. KEY_PARTITIONS indicates that multiple consumers in a subscription group consume the + * stream in a key-partitioned load-balancing manner. Currently, only the STREAM type is supported; other types are + * temporarily unsupported. Default subscription type: STREAM. + */ + SubscriptionType subscriptionType = SubscriptionType::STREAM; + + /** + * @brief Extended configuration for SubscriptionConfig. + */ + std::unordered_map extendConfig; + + /** + * @brief Custom trace ID for troubleshooting and performance optimization. Only supported in the cloud; settings + * outside the cloud will not take effect. Maximum length: 36. Valid characters must match the regular expression: + * ``^[a-zA-Z0-9\~\.\-\/_!@#%\^\&\*\(\)\+\=\:;]*$``. + */ + std::string traceId; + + /** + * @brief Constructor for SubscriptionConfig. + * @param subName Subscription name + * @param subType Subscription type + */ + SubscriptionConfig(std::string subName, const SubscriptionType subType) + : subscriptionName(std::move(subName)), subscriptionType(subType) + { + } + + SubscriptionConfig() = default; + + SubscriptionConfig(const SubscriptionConfig &other) = default; + + SubscriptionConfig &operator=(const SubscriptionConfig &other) = default; + + SubscriptionConfig(SubscriptionConfig &&other) noexcept + { + subscriptionName = std::move(other.subscriptionName); + subscriptionType = other.subscriptionType; + } + + SubscriptionConfig &operator=(SubscriptionConfig &&other) noexcept + { + subscriptionName = std::move(other.subscriptionName); + subscriptionType = other.subscriptionType; + return *this; + } + + bool operator==(const SubscriptionConfig &config) const + { + return subscriptionName == config.subscriptionName && subscriptionType == config.subscriptionType; + } + + bool operator!=(const SubscriptionConfig &config) const + { + return !(*this == config); + } +}; + +class Producer { +public: + virtual void Send(const Element &element) = 0; + + virtual void Send(const Element &element, int64_t timeoutMs) = 0; + + virtual void Flush() = 0; + + virtual void Close() = 0; +}; + +class Consumer { +public: + virtual void Receive(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) = 0; + + virtual void Receive(uint32_t timeoutMs, std::vector &outElements) = 0; + + virtual void Ack(uint64_t elementId) = 0; + + virtual void Close() = 0; +}; + +} // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/api/wait_request_manager.h b/api/cpp/include/yr/api/wait_request_manager.h index 57b0d67..e5c638f 100644 --- a/api/cpp/include/yr/api/wait_request_manager.h +++ b/api/cpp/include/yr/api/wait_request_manager.h @@ -85,7 +85,7 @@ private: std::unordered_map>> requestStore; std::shared_ptr ioc; - std::unique_ptr work; + std::unique_ptr> work; std::unique_ptr asyncRunner; }; diff --git a/api/cpp/include/yr/api/yr_core.h b/api/cpp/include/yr/api/yr_core.h new file mode 100644 index 0000000..3308421 --- /dev/null +++ b/api/cpp/include/yr/api/yr_core.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "yr/api/client_info.h" +#include "yr/api/config.h" + +namespace YR { +/*! + @brief YuanRong Init API, Configures runtime modes and system parameters. + Refer to the data structure documentation for parameter specifications + struct-Config . + @param conf YuanRong initialization parameter configuration. For parameter specifications, refer to + struct-Config . + @return ClientInfo Refer to struct-Config . + + @note When multi-tenancy is enabled on the YuanRong Cluster, users must configure a tenant ID. For + details on tenant ID configuration, refer to the "About Tenant ID" section in + struct-Config . + + @throws Exception The system will throw an exception when invalid config parameters are detected, such as an invalid + mode type. + + @snippet{trimleft} init_and_finalize_example.cpp Init localMode + + @snippet{trimleft} init_and_finalize_example.cpp Init clusterMode + */ +ClientInfo Init(const Config &conf); + +ClientInfo Init(const Config &conf, int argc, char **argv); + +ClientInfo Init(int argc, char **argv); + +/*! + @brief Finalizes the Yuanrong system + + This function is responsible for releasing resources such as function instances + and data objects that have been created during the execution of the program. + It ensures that no resources are leaked, which could lead to issues in a + production environment. + + @note - In a cluster deployment scenario, if worker processes exit and restart, + it might lead to process residuals. In such cases, it is recommended to + redeploy the cluster. Deployment scenarios like Donau or SGE can rely on + the resource scheduling platform's capability to recycle processes. + + - This function should be called after the system has been initialized + with the appropriate Init function. Calling Finalize before Init will result + in an exception. + + @throws Exception If Finalize is called before the system is initialized, + the exception "Please init YR first" will be thrown. + + @snippet{trimleft} init_and_finalize_example.cpp Init and Finalize + */ +void Finalize(); +} // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/parallel/detail/parallel_for_local.h b/api/cpp/include/yr/parallel/detail/parallel_for_local.h index f30c680..85566ad 100644 --- a/api/cpp/include/yr/parallel/detail/parallel_for_local.h +++ b/api/cpp/include/yr/parallel/detail/parallel_for_local.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "barrier.h" #include "native_sem.h" #include "yr/api/function_handler.h" @@ -26,6 +28,8 @@ static inline int GetThreadid() return g_threadid; } +template +constexpr bool is_invocable_v = boost::callable_traits::is_invocable::value; namespace Parallel { /*! @@ -90,9 +94,9 @@ public: static constexpr bool HandlerTypeCheck() { - if constexpr (std::is_invocable_v) { + if constexpr (is_invocable_v) { return true; - } else if constexpr (std::is_invocable_v) { + } else if constexpr (is_invocable_v) { return true; } return false; @@ -101,9 +105,9 @@ public: static void CallBodyHandler(Index start, Index end, const Handler &handler, const Context &ctx) { // match the argument format of Handler - if constexpr (std::is_invocable_v) { + if constexpr (is_invocable_v) { handler(start, end); - } else if constexpr (std::is_invocable_v) { + } else if constexpr (is_invocable_v) { handler(start, end, ctx); } } diff --git a/api/cpp/include/yr/yr.h b/api/cpp/include/yr/yr.h index 5f9c54c..c557733 100644 --- a/api/cpp/include/yr/yr.h +++ b/api/cpp/include/yr/yr.h @@ -31,12 +31,16 @@ #include "yr/api/hetero_exception.h" #include "yr/api/hetero_manager.h" #include "yr/api/instance_creator.h" -#include "yr/api/kv_manager.h" +#include "yr/api/kv.h" +#include "yr/api/mutable_buffer.h" +#include "yr/api/node.h" #include "yr/api/object_ref.h" #include "yr/api/object_store.h" #include "yr/api/runtime_manager.h" #include "yr/api/serdes.h" +#include "yr/api/stream.h" #include "yr/api/wait_result.h" +#include "yr/api/yr_core.h" #include "yr/api/yr_invoke.h" extern thread_local std::unordered_set localNestedObjList; @@ -49,54 +53,8 @@ namespace YR { template using WaitResult = std::pair>, std::vector>>; -/*! - @brief YuanRong Init API, Configures runtime modes and system parameters. - Refer to the data structure documentation for parameter specifications - struct-Config . - @param conf YuanRong initialization parameter configuration. For parameter specifications, refer to - struct-Config . - @return ClientInfo Refer to struct-Config . - - @note When multi-tenancy is enabled on the YuanRong Cluster, users must configure a tenant ID. For - details on tenant ID configuration, refer to the "About Tenant ID" section in - struct-Config . - - @throws Exception The system will throw an exception when invalid config parameters are detected, such as an invalid - mode type. - - @snippet{trimleft} init_and_finalize_example.cpp Init localMode - - @snippet{trimleft} init_and_finalize_example.cpp Init clusterMode - */ -ClientInfo Init(const Config &conf); +void Run(int argc, char *argv[]); -ClientInfo Init(const Config &conf, int argc, char **argv); - -ClientInfo Init(int argc, char **argv); - -/*! - @brief Finalizes the Yuanrong system - - This function is responsible for releasing resources such as function instances - and data objects that have been created during the execution of the program. - It ensures that no resources are leaked, which could lead to issues in a - production environment. - - @note - In a cluster deployment scenario, if worker processes exit and restart, - it might lead to process residuals. In such cases, it is recommended to - redeploy the cluster. Deployment scenarios like Donau or SGE can rely on - the resource scheduling platform's capability to recycle processes. - - - This function should be called after the system has been initialized - with the appropriate Init function. Calling Finalize before Init will result - in an exception. - - @throws Exception If Finalize is called before the system is initialized, - the exception "Please init YR first" will be thrown. - - @snippet{trimleft} init_and_finalize_example.cpp Init and Finalize - */ -void Finalize(); /*! @brief Exit the current function instance @@ -273,6 +231,43 @@ void SaveState(const int &timeout = DEFAULT_SAVE_LOAD_STATE_TIMEOUT); */ void LoadState(const int &timeout = DEFAULT_SAVE_LOAD_STATE_TIMEOUT); +/** + * @brief Creates a producer. + * @param streamName The name of the stream. + * @param producerConf Configuration information for the producer. + * @return A pointer to the created producer. + * @throws Exception **4006**: not support local mode. + * + * @snippet{trimleft} stream_example.cpp create producer + */ +std::shared_ptr CreateProducer(const std::string &streamName, ProducerConf producerConf = {}); + +/** + * @brief Create a consumer. + * @param streamName The name of the stream. Must be less than 256 characters and contain only the following characters: + * (a-zA-Z0-9\~\.\-\/_!@#%\^\&\*\(\)\+\=\:;). + * @param config Configuration information for the consumer. + * @param autoAck If `autoAck` is true, the consumer will automatically send an Acknowledgment (Ack) for received + * messages to the data system. Default value: false. + * @return A pointer to the created consumer. + * @throws Exception **4006**: not support local mode. + * + * @snippet{trimleft} stream_example.cpp create consumer + */ +std::shared_ptr Subscribe(const std::string &streamName, const SubscriptionConfig &config, + bool autoAck = false); + +/** + * @brief Deletes a stream. When the global count of producers and consumers for the stream reaches zero, this stream is + * no longer in use, and all related metadata on workers and master nodes will be cleaned up. This function can be + * called on any Host node. + * @param streamName The name of the stream. Must be less than 256 characters and contain only the following characters: + * (a-zA-Z0-9\~\.\-\/_!@#%\^\&\*\(\)\+\=\:;). + * @throws Exception **4006**: not support local mode. + * + * @snippet{trimleft} stream_example.cpp delete stream + */ +void DeleteStream(const std::string &streamName); /*! @brief Create an InstanceCreator for constructing an instance of a class. @@ -572,16 +567,6 @@ FunctionHandler> JavaFunction(const std::string &classNam return FunctionHandler>(funcMeta, JavaFunctionHandler()); } -/** - * @brief Interface for key-value storage. - * @return the kv manager - */ -inline KVManager &KV() -{ - CheckInitialized(); - return KVManager::Singleton(); -} - /*! @brief Put an object to datasystem @@ -600,9 +585,14 @@ ObjectRef Put(const T &val) if (YR::internal::IsLocalMode()) { return YR::internal::GetLocalModeRuntime()->Put(val); } + std::string objId; localNestedObjList.clear(); - std::shared_ptr data = std::make_shared(YR::internal::Serialize(val)); - std::string objId = YR::internal::GetRuntime()->Put(data, localNestedObjList); + if constexpr (boost::is_same::value) { + objId = YR::internal::GetRuntime()->Put(std::make_shared(val), localNestedObjList); + } else { + std::shared_ptr data = std::make_shared(YR::internal::Serialize(val)); + objId = YR::internal::GetRuntime()->Put(data, localNestedObjList); + } localNestedObjList.clear(); return ObjectRef(objId, false); } @@ -626,38 +616,26 @@ ObjectRef Put(const T &val, const CreateParam &createParam) if (YR::internal::IsLocalMode()) { return YR::internal::GetLocalModeRuntime()->Put(val); } + std::string objId; localNestedObjList.clear(); - std::shared_ptr data = std::make_shared(YR::internal::Serialize(val)); - std::string objId = YR::internal::GetRuntime()->Put(data, localNestedObjList, createParam); + if constexpr (boost::is_same::value) { + objId = YR::internal::GetRuntime()->Put(std::make_shared(val), localNestedObjList, createParam); + } else { + std::shared_ptr data = std::make_shared(YR::internal::Serialize(val)); + objId = YR::internal::GetRuntime()->Put(data, localNestedObjList, createParam); + } localNestedObjList.clear(); return ObjectRef(objId, false); } template -std::shared_ptr Get(const ObjectRef &obj, int timeoutSec) +std::vector> Get(const std::vector &objIds, int timeoutSec, bool allowPartial) { - CheckInitialized(); - if (obj.IsLocal()) { - return internal::GetLocalModeRuntime()->Get(obj, timeoutSec); - } - auto result = Get(std::vector>{obj}, timeoutSec, false); - if (!result.empty()) { - return result[0]; - } - return nullptr; -} - -template -std::vector> Get(const std::vector> &objs, int timeoutSec, bool allowPartial) -{ - internal::CheckObjsAndTimeout(objs, timeoutSec); - if (objs[0].IsLocal()) { - return YR::internal::GetLocalModeRuntime()->Get(objs, timeoutSec, allowPartial); - } + internal::CheckTimeout(timeoutSec); std::vector remainIds; std::unordered_map> idToIndex; - for (size_t i = 0; i < objs.size(); i++) { - remainIds.push_back(objs[i].ID()); + for (size_t i = 0; i < objIds.size(); i++) { + remainIds.push_back(objIds[i]); idToIndex[remainIds[i]].push_back(i); } std::vector> returnObjects; @@ -677,7 +655,9 @@ std::vector> Get(const std::vector> &objs, int t int to = (remainTimeoutMs == NO_TIMEOUT) ? (DEFAULT_TIMEOUT_MS) : (remainTimeoutMs - static_cast(getElapsedTime())); to = to < 0 ? 0 : to; - auto [retryInfo, remainBuffers] = YR::internal::GetRuntime()->Get(remainIds, to, limitedRetryTime); + auto ret = YR::internal::GetRuntime()->Get(remainIds, to, limitedRetryTime); + auto retryInfo = ret.first; + auto remainBuffers = ret.second; auto needRetry = retryInfo.needRetry; err = retryInfo.errorInfo; internal::ExtractSuccessObjects(remainIds, remainBuffers, returnObjects, idToIndex); @@ -686,18 +666,18 @@ std::vector> Get(const std::vector> &objs, int t break; } if (!needRetry) { - status = remainIds.size() == objs.size() ? internal::GetStatus::ALL_FAILED + status = remainIds.size() == objIds.size() ? internal::GetStatus::ALL_FAILED : internal::GetStatus::PARTIAL_SUCCESS; break; } if ((remainTimeoutMs != NO_TIMEOUT && getElapsedTime() > remainTimeoutMs) || (remainTimeoutMs == 0)) { - status = remainIds.size() == objs.size() ? internal::GetStatus::ALL_FAILED_AND_TIMEOUT + status = remainIds.size() == objIds.size() ? internal::GetStatus::ALL_FAILED_AND_TIMEOUT : internal::GetStatus::PARTIAL_SUCCESS_AND_TIMEOUT; break; } std::this_thread::sleep_for(std::chrono::seconds(GET_RETRY_INTERVAL)); if ((remainTimeoutMs != NO_TIMEOUT && getElapsedTime() > remainTimeoutMs)) { - status = remainIds.size() == objs.size() ? internal::GetStatus::ALL_FAILED_AND_TIMEOUT + status = remainIds.size() == objIds.size() ? internal::GetStatus::ALL_FAILED_AND_TIMEOUT : internal::GetStatus::PARTIAL_SUCCESS_AND_TIMEOUT; break; } @@ -706,6 +686,34 @@ std::vector> Get(const std::vector> &objs, int t return returnObjects; } +template +std::shared_ptr Get(const ObjectRef &obj, int timeoutSec) +{ + CheckInitialized(); + if (obj.IsLocal()) { + return internal::GetLocalModeRuntime()->Get(obj, timeoutSec); + } + auto result = Get(std::vector{obj.ID()}, timeoutSec, false); + if (!result.empty()) { + return result[0]; + } + return nullptr; +} + +template +std::vector> Get(const std::vector> &objs, int timeoutSec, bool allowPartial) +{ + internal::CheckObjs(objs); + if (objs[0].IsLocal()) { + return YR::internal::GetLocalModeRuntime()->Get(objs, timeoutSec, allowPartial); + } + std::vector objIds; + for (auto obj: objs) { + objIds.push_back(obj.ID()); + } + return Get(objIds, timeoutSec, allowPartial); +} + template void Wait(const ObjectRef &obj, int timeoutSec) { @@ -837,8 +845,24 @@ NamedInstance GetInstance(const std::string &name, const std::string &nameSpa handler.SetAlwaysLocalMode(false); handler.SetClassName(funcMeta.className); handler.SetFunctionUrn(funcMeta.funcUrn); - handler.SetName(funcMeta.name.value_or("")); - handler.SetNs(funcMeta.ns.value_or("")); + handler.SetName(funcMeta.name); + handler.SetNs(funcMeta.ns); return handler; } + +/** + * @brief Get node information in the cluster. + * @return std::vector: node information, include id, alive, resources and labels. + * @throws Exception if Yuanrong is not initialized or failed to get node information. + */ +std::vector Nodes(); + +std::shared_ptr CreateBuffer(uint64_t size); + +std::vector> Get(const std::vector> &objs, + int timeoutSec = DEFAULT_GET_TIMEOUT_SEC); + +std::string Serialize(ObjectRef &obj); + +ObjectRef Deserialize(const void *value, int size); } // namespace YR \ No newline at end of file diff --git a/api/cpp/src/cluster_mode_runtime.cpp b/api/cpp/src/cluster_mode_runtime.cpp index ec66548..1817826 100644 --- a/api/cpp/src/cluster_mode_runtime.cpp +++ b/api/cpp/src/cluster_mode_runtime.cpp @@ -25,23 +25,41 @@ #include "api/cpp/include/yr/api/hetero_exception.h" #include "api/cpp/include/yr/api/object_store.h" #include "api/cpp/src/cluster_mode_runtime.h" + #include "api/cpp/src/code_manager.h" #include "api/cpp/src/executor/executor_holder.h" #include "api/cpp/src/hetero_future.h" #include "api/cpp/src/read_only_buffer.h" -#include "api/cpp/src/state_loader.h" +#include "api/cpp/src/runtime_env_parse.h" +#include "api/cpp/src/stream_pubsub.h" #include "api/cpp/src/utils/utils.h" +#include "datasystem_buffer.h" #include "src/dto/data_object.h" #include "src/dto/internal_wait_result.h" +#include "src/dto/stream_conf.h" #include "src/libruntime/err_type.h" #include "src/libruntime/libruntime_manager.h" +#include "src/libruntime/streamstore/stream_producer_consumer.h" #include "src/proto/libruntime.pb.h" #include "src/utility/logger/logger.h" #include "src/utility/string_utility.h" +#include "yr/api/mutable_buffer.h" namespace YR { using YR::Libruntime::DataObject; +constexpr uint8_t BYTES = 3; +constexpr uint32_t NORMAL_STATUS = 0; + +std::shared_ptr GetLibRuntime() +{ + auto librt = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(); + if (!librt) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_FINALIZED, "already finalized"); + } + return librt; +} + internal::FuncMeta convertToInternalFuncMeta(YR::Libruntime::FunctionMeta &libFuncMeta) { internal::FuncMeta funcMeta; @@ -139,6 +157,10 @@ std::list> BuildScheduleAffinities(con libAffinity->SetPreferredPriority(preferredPriority); libAffinity->SetRequiredPriority(requiredPriority); libAffinity->SetPreferredAntiOtherLabels(preferredAntiOtherLabels); + std::string affinityScope = affinity.GetAffinityScope(); + if (!affinityScope.empty()) { + libAffinity->SetAffinityScope(affinityScope); + } libAffinities.push_back(libAffinity); } return libAffinities; @@ -208,6 +230,44 @@ YR::Libruntime::CreateParam BuildCreateParam(const YR::CreateParam &createParam) return dsCreateParam; } +YR::Libruntime::ProducerConf BuildProducerConf(ProducerConf producerConf) +{ + YR::Libruntime::ProducerConf libProducerConf; + libProducerConf.delayFlushTime = producerConf.delayFlushTime; + libProducerConf.maxStreamSize = producerConf.maxStreamSize; + libProducerConf.pageSize = producerConf.pageSize; + libProducerConf.autoCleanup = producerConf.autoCleanup; + libProducerConf.encryptStream = producerConf.encryptStream; + libProducerConf.retainForNumConsumers = producerConf.retainForNumConsumers; + libProducerConf.reserveSize = producerConf.reserveSize; + libProducerConf.extendConfig = producerConf.extendConfig; + return libProducerConf; +} + +static libruntime::SubscriptionType ConvertSubscriptionType(const SubscriptionType type) +{ + if (type == SubscriptionType::STREAM) { + return libruntime::SubscriptionType::STREAM; + } else if (type == SubscriptionType::ROUND_ROBIN) { + return libruntime::SubscriptionType::ROUND_ROBIN; + } else if (type == SubscriptionType::KEY_PARTITIONS) { + return libruntime::SubscriptionType::KEY_PARTITIONS; + } else { + return libruntime::SubscriptionType::UNKNOWN; + } + YRLOG_DEBUG("SubscriptionType not supported, lang: {}", fmt::underlying(type)); + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "SubscriptionType not supported"); +} + +YR::Libruntime::SubscriptionConfig BuildSubscriptionConfig(const SubscriptionConfig &config) +{ + YR::Libruntime::SubscriptionConfig subscriptionConfig; + subscriptionConfig.subscriptionName = config.subscriptionName; + subscriptionConfig.subscriptionType = ConvertSubscriptionType(config.subscriptionType); + subscriptionConfig.extendConfig = config.extendConfig; + return subscriptionConfig; +} + YR::Libruntime::FunctionMeta BuildFunctionMeta(const internal::FuncMeta &funcMeta) { YR::Libruntime::FunctionMeta libFunctionMeta; @@ -219,10 +279,10 @@ YR::Libruntime::FunctionMeta BuildFunctionMeta(const internal::FuncMeta &funcMet if (!funcMeta.funcUrn.empty()) { libFunctionMeta.functionId = ConvertFunctionUrnToId(funcMeta.funcUrn); } - if (funcMeta.name) { + if (!funcMeta.name.empty()) { libFunctionMeta.name = funcMeta.name; } - if (funcMeta.ns) { + if (!funcMeta.ns.empty()) { libFunctionMeta.ns = funcMeta.ns; } libFunctionMeta.apiType = libruntime::ApiType::Function; @@ -234,23 +294,38 @@ std::vector BuildInvokeArgs(std::vector libArgs; for (auto &arg : args) { YR::Libruntime::InvokeArg libArg; - auto size = arg.buf.size(); - libArg.dataObj = std::make_shared(0, size); - WriteDataObject(static_cast(arg.buf.data()), libArg.dataObj, size, {}); + if (arg.yrBuf.ImmutableData() != nullptr) { + auto size = arg.yrBuf.GetSize(); + libArg.dataObj = std::make_shared(0, size); + WriteDataObject(arg.yrBuf.ImmutableData(), libArg.dataObj, size, {}); + libArg.dataObj->SetMetaDataType(BYTES); + } else { + auto size = arg.buf.size(); + libArg.dataObj = std::make_shared(0, size); + WriteDataObject(static_cast(arg.buf.data()), libArg.dataObj, size, {}); + } libArg.isRef = arg.isRef; libArg.objId = arg.objId; libArg.nestedObjects = std::move(arg.nestedObjects); - libArg.tenantId = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GetTenantId(); + libArg.tenantId = GetLibRuntime()->GetTenantId(); libArgs.emplace_back(std::move(libArg)); } return libArgs; } +YR::Libruntime::DebugConfig BuildDebugConfig(const YR::DebugConfig &debug) +{ + YR::Libruntime::DebugConfig ret; + ret.enable = debug.enable; + return ret; +} + YR::Libruntime::InvokeOptions BuildOptions(const YR::InvokeOptions &opts) { YR::Libruntime::InvokeOptions libOpts; libOpts.affinity = opts.affinity; libOpts.retryTimes = opts.retryTimes; + libOpts.maxRetryTime = opts.maxRetryTime; if (opts.retryChecker) { libOpts.retryChecker = [checker = opts.retryChecker](const Libruntime::ErrorInfo &err) -> bool { YR::Exception e(err.Code(), err.MCode(), err.Msg()); @@ -274,7 +349,12 @@ YR::Libruntime::InvokeOptions BuildOptions(const YR::InvokeOptions &opts) libOpts.instanceRange = BuildInstanceRange(opts.instanceRange); libOpts.recoverRetryTimes = opts.recoverRetryTimes; libOpts.envVars = opts.envVars; + libOpts.debug = BuildDebugConfig(opts.debug); libOpts.timeout = opts.timeout; + libOpts.preemptedAllowed = opts.preemptedAllowed; + libOpts.instancePriority = opts.instancePriority; + libOpts.scheduleTimeoutMs = opts.scheduleTimeoutMs; + ParseRuntimeEnv(libOpts, opts.runtimeEnv); return libOpts; } @@ -370,9 +450,15 @@ void ClusterModeRuntime::Init() libConfig.localThreadPoolSize = ConfigManager::Singleton().localThreadPoolSize; libConfig.loadPaths = ConfigManager::Singleton().loadPaths; libConfig.tenantId = ConfigManager::Singleton().tenantId; + libConfig.logToDriver = ConfigManager::Singleton().logToDriver; + libConfig.dedupLogs = ConfigManager::Singleton().dedupLogs; libConfig.libruntimeOptions.functionExecuteCallback = internal::ExecuteFunction; - libConfig.libruntimeOptions.loadFunctionCallback = internal::LoadFunctions; + if (ConfigManager::Singleton().launchUserBinary == true) { + libConfig.libruntimeOptions.loadFunctionCallback = internal::LoadNoneFunctions; + } else { + libConfig.libruntimeOptions.loadFunctionCallback = internal::LoadFunctions; + } libConfig.libruntimeOptions.shutdownCallback = internal::ExecuteShutdownFunction; libConfig.libruntimeOptions.checkpointCallback = internal::Checkpoint; libConfig.libruntimeOptions.recoverCallback = internal::Recover; @@ -382,15 +468,23 @@ void ClusterModeRuntime::Init() libConfig.privateKeyPath = ConfigManager::Singleton().privateKeyPath; libConfig.certificateFilePath = ConfigManager::Singleton().certificateFilePath; libConfig.verifyFilePath = ConfigManager::Singleton().verifyFilePath; + int len = sizeof(ConfigManager::Singleton().privateKeyPaaswd); + memcpy_s(libConfig.privateKeyPaaswd, len, ConfigManager::Singleton().privateKeyPaaswd, len); + libConfig.encryptPrivateKeyPasswd = ConfigManager::Singleton().encryptPrivateKeyPasswd; } libConfig.primaryKeyStoreFile = ConfigManager::Singleton().primaryKeyStoreFile; libConfig.standbyKeyStoreFile = ConfigManager::Singleton().standbyKeyStoreFile; libConfig.encryptEnable = ConfigManager::Singleton().enableDsEncrypt; if (ConfigManager::Singleton().enableDsEncrypt) { - libConfig.runtimePublicKeyPath = ConfigManager::Singleton().runtimePublicKeyContextPath; - libConfig.runtimePrivateKeyPath = ConfigManager::Singleton().runtimePrivateKeyContextPath; - libConfig.dsPublicKeyPath = ConfigManager::Singleton().dsPublicKeyContextPath; - } + libConfig.runtimePublicKey = ConfigManager::Singleton().runtimePublicKeyContext; + libConfig.runtimePrivateKey = ConfigManager::Singleton().runtimePrivateKeyContext; + libConfig.dsPublicKey = ConfigManager::Singleton().dsPublicKeyContext; + libConfig.encryptRuntimePublicKeyContext = ConfigManager::Singleton().encryptRuntimePublicKeyContext; + libConfig.encryptRuntimePrivateKeyContext = ConfigManager::Singleton().encryptRuntimePrivateKeyContext; + libConfig.encryptDsPublicKeyContext = ConfigManager::Singleton().encryptDsPublicKeyContext; + } + libConfig.tlsContext = ConfigManager::Singleton().tlsContext; + libConfig.httpIocThreadsNum = ConfigManager::Singleton().httpIocThreadsNum; libConfig.serverName = ConfigManager::Singleton().serverName; libConfig.ns = ConfigManager::Singleton().ns; libConfig.customEnvs = ConfigManager::Singleton().customEnvs; @@ -409,7 +503,7 @@ void ClusterModeRuntime::Init() std::string ClusterModeRuntime::GetServerVersion() { - return YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GetServerVersion(); + return GetLibRuntime()->GetServerVersion(); } std::string ClusterModeRuntime::CreateInstance(const internal::FuncMeta &funcMeta, @@ -424,9 +518,8 @@ std::string ClusterModeRuntime::CreateInstance(const internal::FuncMeta &funcMet auto invokeOptions = BuildOptions(opts); YRLOG_DEBUG("create instance, function meta, name={}, language={}.", funcMeta.funcName, static_cast(funcMeta.language)); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto [err, instanceId] = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->CreateInstance( - functionMeta, invokeArgs, invokeOptions); + GetLibRuntime()->SetTenantIdWithPriority(); + auto [err, instanceId] = GetLibRuntime()->CreateInstance(functionMeta, invokeArgs, invokeOptions); if (err.Code() != Libruntime::ErrorCode::ERR_OK) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } @@ -446,9 +539,8 @@ std::string ClusterModeRuntime::InvokeInstance(const internal::FuncMeta &funcMet auto libFunctionMeta = BuildFunctionMeta(funcMeta); auto libArgs = BuildInvokeArgs(args); auto libOpts = BuildOptions(opts); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->InvokeByInstanceId( - libFunctionMeta, instanceId, libArgs, libOpts, returnObjs); + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->InvokeByInstanceId(libFunctionMeta, instanceId, libArgs, libOpts, returnObjs); if (!err.OK()) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } @@ -468,9 +560,8 @@ std::string ClusterModeRuntime::InvokeByName(const internal::FuncMeta &funcMeta, auto libArgs = BuildInvokeArgs(args); auto libOpts = BuildOptions(opts); std::vector returnObjs{{""}}; - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->InvokeByFunctionName( - libFunctionMeta, libArgs, libOpts, returnObjs); + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->InvokeByFunctionName(libFunctionMeta, libArgs, libOpts, returnObjs); if (!err.OK()) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } @@ -479,35 +570,31 @@ std::string ClusterModeRuntime::InvokeByName(const internal::FuncMeta &funcMeta, void ClusterModeRuntime::TerminateInstance(const std::string &instanceId) { - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->Kill(instanceId); + auto errInfo = GetLibRuntime()->Kill(instanceId); if (!errInfo.OK()) { YR::Exception exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); throw exception; } } -std::string ClusterModeRuntime::Put(std::shared_ptr data, - const std::unordered_set &nestedId) -{ - return Put(data, nestedId, {}); -} - -std::string ClusterModeRuntime::Put(std::shared_ptr data, +std::string ClusterModeRuntime::Put(const void *data, uint64_t dataSize, const std::unordered_set &nestedId, const CreateParam &createParam) { + if (data == nullptr || dataSize == 0) { + throw Exception::InvalidParamException("Put val is nullptr"); + } auto param = BuildCreateParam(createParam); auto dataObj = std::make_shared(); std::vector nestedIds(nestedId.begin(), nestedId.end()); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto [err, objId] = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->CreateDataObject( - 0, data->size(), dataObj, nestedIds, param); + GetLibRuntime()->SetTenantIdWithPriority(); + auto [err, objId] = GetLibRuntime()->CreateDataObject(0, dataSize, dataObj, nestedIds, param); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { YRLOG_DEBUG("failed to Create DataObject {}", err.Msg()); YR::Exception e2(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); throw e2; } // copy data to DataObject data - err = WriteDataObject(data->data(), dataObj, data->size(), nestedId); + err = WriteDataObject(data, dataObj, dataSize, nestedId); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { YRLOG_DEBUG("failed to WriteDataObject {}", err.Msg()); YR::Exception e2(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); @@ -516,14 +603,36 @@ std::string ClusterModeRuntime::Put(std::shared_ptr data, return objId; } +std::string ClusterModeRuntime::Put(std::shared_ptr data, + const std::unordered_set &nestedId) +{ + return Put(data, nestedId, {}); +} + +std::string ClusterModeRuntime::Put(std::shared_ptr data, + const std::unordered_set &nestedId, const CreateParam &createParam) +{ + return Put(data->data(), data->size(), nestedId, createParam); +} + +std::string ClusterModeRuntime::Put(std::shared_ptr data, const std::unordered_set &nestedId) +{ + return Put(data, nestedId, {}); +} + +std::string ClusterModeRuntime::Put(std::shared_ptr data, const std::unordered_set &nestedId, + const CreateParam &createParam) +{ + return Put(data->ImmutableData(), data->GetSize(), nestedId, createParam); +} + void ClusterModeRuntime::Put(const std::string &objId, std::shared_ptr data, const std::unordered_set &nestedId) { auto dataObj = std::make_shared(); std::vector nestedIds(nestedId.begin(), nestedId.end()); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->CreateDataObject(objId, 0, data->size(), - dataObj, nestedIds); + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->CreateDataObject(objId, 0, data->size(), dataObj, nestedIds); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { YRLOG_DEBUG("failed to CreateDataObject {}", err.Msg()); YR::Exception e2(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); @@ -542,9 +651,8 @@ std::pair>> ClusterMode { internal::RetryInfo returnRetryInfo; returnRetryInfo.needRetry = true; - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto [retryInfo, dataObjects] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GetDataObjectsWithoutWait(ids, timeoutMS); + GetLibRuntime()->SetTenantIdWithPriority(); + auto [retryInfo, dataObjects] = GetLibRuntime()->GetDataObjectsWithoutWait(ids, timeoutMS); std::vector> buffers; buffers.resize(dataObjects.size()); std::vector remainIds; @@ -572,9 +680,8 @@ std::pair>> ClusterMode YR::internal::WaitResult ClusterModeRuntime::Wait(const std::vector &objs, std::size_t waitNum, int timeout) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - std::shared_ptr internalWaitResult = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->Wait(objs, waitNum, timeout); + GetLibRuntime()->SetTenantIdWithPriority(); + std::shared_ptr internalWaitResult = GetLibRuntime()->Wait(objs, waitNum, timeout); YR::internal::WaitResult waitResult; waitResult.readyIds = internalWaitResult->readyIds; waitResult.unreadyIds = internalWaitResult->unreadyIds; @@ -591,8 +698,7 @@ YR::internal::WaitResult ClusterModeRuntime::Wait(const std::vector int64_t ClusterModeRuntime::WaitBeforeGet(const std::vector &ids, int timeoutMs, bool allowPartial) { - auto [err, remainedTimeoutMs] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->WaitBeforeGet(ids, timeoutMs, allowPartial); + auto [err, remainedTimeoutMs] = GetLibRuntime()->WaitBeforeGet(ids, timeoutMs, allowPartial); if (!err.OK()) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } @@ -606,8 +712,7 @@ std::vector ClusterModeRuntime::GetInstances(const std::string &obj "invalid GetInstances timeout, timeout: " + std::to_string(timeoutSec) + ", please set the timeout >= -1."; throw YR::Exception(Libruntime::ErrorCode::ERR_PARAM_INVALID, Libruntime::ModuleCode::RUNTIME, msg); } - auto [instanceIds, err] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GetInstances(objId, timeoutSec); + auto [instanceIds, err] = GetLibRuntime()->GetInstances(objId, timeoutSec); if (!err.OK()) { throw YR::Exception(err.Code(), err.MCode(), err.Msg()); } @@ -616,15 +721,14 @@ std::vector ClusterModeRuntime::GetInstances(const std::string &obj std::string ClusterModeRuntime::GenerateGroupName() { - return YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GenerateGroupName(); + return GetLibRuntime()->GenerateGroupName(); } -void ClusterModeRuntime::IncreGlobalReference(const std::vector &objids) +void ClusterModeRuntime::IncreGlobalReference(const std::vector &objids, bool toDatasystem) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); + GetLibRuntime()->SetTenantIdWithPriority(); // Here, LibRuntime return YR::Libruntime::ErrorInfo, should catch and cast type to YR::Exception. - YR::Libruntime::ErrorInfo err = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->IncreaseReference(objids); + YR::Libruntime::ErrorInfo err = GetLibRuntime()->IncreaseReference(objids, toDatasystem); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { YR::Exception e2(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); throw e2; @@ -634,25 +738,26 @@ void ClusterModeRuntime::IncreGlobalReference(const std::vector &ob void ClusterModeRuntime::DecreGlobalReference(const std::vector &objids) { if (YR::Libruntime::LibruntimeManager::Instance().IsInitialized()) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->DecreaseReference(objids); + GetLibRuntime()->SetTenantIdWithPriority(); + GetLibRuntime()->DecreaseReference(objids); } } void ClusterModeRuntime::KVWrite(const std::string &key, const char *value, YR::SetParam setParam) { auto dsSetParam = BuildSetParam(setParam); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(setParam.traceId); + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->SetTraceId(setParam.traceId); if (!err.OK()) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } auto nativeBuffer = std::make_shared(static_cast(const_cast(value)), std::strlen(value)); - err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVWrite(key, nativeBuffer, dsSetParam); + err = GetLibRuntime()->KVWrite(key, nativeBuffer, dsSetParam); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVWrite err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVWrite err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } } @@ -660,15 +765,15 @@ void ClusterModeRuntime::KVWrite(const std::string &key, const char *value, YR:: void ClusterModeRuntime::KVWrite(const std::string &key, std::shared_ptr value, YR::SetParam setParam) { auto dsSetParam = BuildSetParam(setParam); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(setParam.traceId); + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->SetTraceId(setParam.traceId); if (!err.OK()) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } - err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVWrite( - key, std::make_shared(value), dsSetParam); + err = GetLibRuntime()->KVWrite(key, std::make_shared(value), dsSetParam); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVWrite err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVWrite err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } } @@ -677,15 +782,15 @@ void ClusterModeRuntime::KVWrite(const std::string &key, std::shared_ptrSetTenantIdWithPriority(); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(setParam.traceId); + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->SetTraceId(setParam.traceId); if (!err.OK()) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } - err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVWrite( - key, std::make_shared(value), dsSetParam); + err = GetLibRuntime()->KVWrite(key, std::make_shared(value), dsSetParam); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVWrite err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVWrite err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } } @@ -708,23 +813,23 @@ void ClusterModeRuntime::KVMSetTx(const std::vector &keys, for (size_t i = 0; i < vals.size(); i++) { buffers[i] = std::make_shared(vals[i]); } - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - YR::Libruntime::ErrorInfo err = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVMSetTx(keys, buffers, dsMSetParam); + GetLibRuntime()->SetTenantIdWithPriority(); + YR::Libruntime::ErrorInfo err = GetLibRuntime()->KVMSetTx(keys, buffers, dsMSetParam); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVMSetTx err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVMSetTx err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } } std::shared_ptr ClusterModeRuntime::KVRead(const std::string &key, int timeoutMs) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - YR::Libruntime::SingleReadResult result = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVRead(key, timeoutMs); + GetLibRuntime()->SetTenantIdWithPriority(); + YR::Libruntime::SingleReadResult result = GetLibRuntime()->KVRead(key, timeoutMs); YR::Libruntime::ErrorInfo &err = result.second; if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVRead err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVRead err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } return std::make_shared(result.first); @@ -733,12 +838,12 @@ std::shared_ptr ClusterModeRuntime::KVRead(const std::string &key, int t std::vector> ClusterModeRuntime::KVRead(const std::vector &keys, int timeoutMs, bool allowPartial) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - YR::Libruntime::MultipleReadResult result = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVRead(keys, timeoutMs, allowPartial); + GetLibRuntime()->SetTenantIdWithPriority(); + YR::Libruntime::MultipleReadResult result = GetLibRuntime()->KVRead(keys, timeoutMs, allowPartial); YR::Libruntime::ErrorInfo &err = result.second; if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVRead err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVRead err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } std::vector> buffers; @@ -756,63 +861,160 @@ std::vector> ClusterModeRuntime::KVGetWithParam(const st const YR::GetParams ¶ms, int timeoutMs) { auto dsParams = BuildGetParam(params); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto res = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(params.traceId); + GetLibRuntime()->SetTenantIdWithPriority(); + auto res = GetLibRuntime()->SetTraceId(params.traceId); if (res.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("Set trace id err: Code:{}, MCode:{}, Msg:{}", res.Code(), res.MCode(), res.Msg()); + YRLOG_ERROR("Set trace id err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(res.Code()), + fmt::underlying(res.MCode()), res.Msg()); throw YR::Exception(static_cast(res.Code()), static_cast(res.MCode()), res.Msg()); } - YR::Libruntime::MultipleReadResult result = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVGetWithParam(keys, dsParams, timeoutMs); + YR::Libruntime::MultipleReadResult result = GetLibRuntime()->KVGetWithParam(keys, dsParams, timeoutMs); YR::Libruntime::ErrorInfo &err = result.second; if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVGetWithParam err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVGetWithParam err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } std::vector> buffers; buffers.resize(result.first.size()); for (size_t i = 0; i < result.first.size(); i++) { - if (result.first[i] != nullptr) { - buffers[i] = std::make_shared(result.first[i]); + if (result.first[i] == nullptr) { + continue; } + buffers[i] = std::make_shared(result.first[i]); } return buffers; } void ClusterModeRuntime::KVDel(const std::string &key, const DelParam &delParam) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(delParam.traceId); + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->SetTraceId(delParam.traceId); if (!err.OK()) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } - err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVDel(key); + err = GetLibRuntime()->KVDel(key); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVDel err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVDel err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } } std::vector ClusterModeRuntime::KVDel(const std::vector &keys, const DelParam &delParam) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTenantIdWithPriority(); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(delParam.traceId); + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->SetTraceId(delParam.traceId); if (!err.OK()) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } - YR::Libruntime::MultipleDelResult result = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->KVDel(keys); + YR::Libruntime::MultipleDelResult result = GetLibRuntime()->KVDel(keys); err = result.second; if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("KVDel err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("KVDel err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + return result.first; +} + +std::vector ClusterModeRuntime::KVExist(const std::vector &keys) +{ + GetLibRuntime()->SetTenantIdWithPriority(); + YR::Libruntime::MultipleExistResult result = GetLibRuntime()->KVExist(keys); + auto err = result.second; + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("KVExist err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } return result.first; } +std::shared_ptr ClusterModeRuntime::CreateStreamProducer(const std::string &streamName, + ProducerConf producerConf) +{ + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->SetTraceId(producerConf.traceId); + if (!err.OK()) { + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + auto libProducerConf = BuildProducerConf(producerConf); + std::shared_ptr streamProducer; + err = GetLibRuntime()->CreateStreamProducer(streamName, libProducerConf, streamProducer); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("CreateStreamProducer err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + std::shared_ptr producer = std::make_shared(streamProducer, producerConf.traceId); + return producer; +} + +std::shared_ptr ClusterModeRuntime::CreateStreamConsumer(const std::string &streamName, + const SubscriptionConfig &config, bool autoAck) +{ + auto subscriptionConfig = BuildSubscriptionConfig(config); + std::shared_ptr streamConsumer; + GetLibRuntime()->SetTenantIdWithPriority(); + auto err = GetLibRuntime()->SetTraceId(config.traceId); + if (!err.OK()) { + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + err = GetLibRuntime()->CreateStreamConsumer(streamName, subscriptionConfig, streamConsumer, autoAck); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("CreateStreamConsumer err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + std::shared_ptr consumer = std::make_shared(streamConsumer, config.traceId); + return consumer; +} + +void ClusterModeRuntime::DeleteStream(const std::string &streamName) +{ + GetLibRuntime()->SetTenantIdWithPriority(); + YR::Libruntime::ErrorInfo err = GetLibRuntime()->DeleteStream(streamName); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("DeleteStream err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} + +void ClusterModeRuntime::QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) +{ + if (ConfigManager::Singleton().IsLocalMode()) { + throw YR::Exception(Libruntime::ErrorCode::ERR_INCORRECT_FUNCTION_USAGE, Libruntime::ModuleCode::RUNTIME, + "local mode does not support QueryGlobalProducersNum\n"); + } + GetLibRuntime()->SetTenantIdWithPriority(); + YR::Libruntime::ErrorInfo err = GetLibRuntime()->QueryGlobalProducersNum(streamName, gProducerNum); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("QueryGlobalProducersNum err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} + +void ClusterModeRuntime::QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) +{ + if (ConfigManager::Singleton().IsLocalMode()) { + throw YR::Exception(Libruntime::ErrorCode::ERR_INCORRECT_FUNCTION_USAGE, Libruntime::ModuleCode::RUNTIME, + "local mode does not support QueryGlobalConsumersNum\n"); + } + GetLibRuntime()->SetTenantIdWithPriority(); + YR::Libruntime::ErrorInfo err = GetLibRuntime()->QueryGlobalConsumersNum(streamName, gConsumerNum); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("QueryGlobalConsumersNum err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} + std::string ClusterModeRuntime::GetRealInstanceId(const std::string &objectId) { - return YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GetRealInstanceId(objectId); + return GetLibRuntime()->GetRealInstanceId(objectId); } void ClusterModeRuntime::SaveRealInstanceId(const std::string &objectId, const std::string &instanceId, @@ -820,12 +1022,12 @@ void ClusterModeRuntime::SaveRealInstanceId(const std::string &objectId, const s { YR::Libruntime::InstanceOptions instOpts; instOpts.needOrder = opts.needOrder; - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SaveRealInstanceId(objectId, instanceId, instOpts); + GetLibRuntime()->SaveRealInstanceId(objectId, instanceId, instOpts); } std::string ClusterModeRuntime::GetGroupInstanceIds(const std::string &objectId) { - return YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GetGroupInstanceIds(objectId, NO_TIMEOUT); + return GetLibRuntime()->GetGroupInstanceIds(objectId, NO_TIMEOUT); } void ClusterModeRuntime::SaveGroupInstanceIds(const std::string &objectId, const std::string &groupInsIds, @@ -833,23 +1035,22 @@ void ClusterModeRuntime::SaveGroupInstanceIds(const std::string &objectId, const { YR::Libruntime::InstanceOptions instOpts; instOpts.needOrder = opts.needOrder; - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SaveGroupInstanceIds(objectId, groupInsIds, - instOpts); + GetLibRuntime()->SaveGroupInstanceIds(objectId, groupInsIds, instOpts); } void ClusterModeRuntime::Cancel(const std::vector &objs, bool isForce, bool isRecursive) { - YR::Libruntime::ErrorInfo err = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->Cancel(objs, isForce, isRecursive); + YR::Libruntime::ErrorInfo err = GetLibRuntime()->Cancel(objs, isForce, isRecursive); if (!err.OK()) { - YRLOG_DEBUG("Cancel err: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_DEBUG("Cancel err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } } void ClusterModeRuntime::Exit(void) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->Exit(); + GetLibRuntime()->Exit(); } void ClusterModeRuntime::StopRuntime(void) @@ -872,7 +1073,7 @@ void ClusterModeRuntime::GroupCreate(const std::string &name, GroupOptions &opts libOpts.timeout = opts.timeout; libOpts.groupName = name; libOpts.sameLifecycle = opts.sameLifecycle; - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GroupCreate(name, libOpts); + auto errInfo = GetLibRuntime()->GroupCreate(name, libOpts); if (!errInfo.OK()) { throw YR::Exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); } @@ -880,12 +1081,12 @@ void ClusterModeRuntime::GroupCreate(const std::string &name, GroupOptions &opts void ClusterModeRuntime::GroupTerminate(const std::string &name) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GroupTerminate(name); + GetLibRuntime()->GroupTerminate(name); } void ClusterModeRuntime::GroupWait(const std::string &name) { - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GroupWait(name); + auto errInfo = GetLibRuntime()->GroupWait(name); if (!errInfo.OK()) { throw YR::Exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); } @@ -899,7 +1100,7 @@ void ClusterModeRuntime::SaveState(const int &timeout) throw YR::Exception(static_cast(dumpErr.Code()), static_cast(dumpErr.MCode()), dumpErr.Msg()); } int timeoutMS = timeout != NO_TIMEOUT ? timeout * S_TO_MS : NO_TIMEOUT; - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SaveState(data, timeoutMS); + auto errInfo = GetLibRuntime()->SaveState(data, timeoutMS); if (!errInfo.OK()) { throw YR::Exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); } @@ -909,7 +1110,7 @@ void ClusterModeRuntime::LoadState(const int &timeout) { std::shared_ptr data; int timeoutMS = timeout != NO_TIMEOUT ? timeout * S_TO_MS : NO_TIMEOUT; - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->LoadState(data, timeoutMS); + auto errInfo = GetLibRuntime()->LoadState(data, timeoutMS); if (!errInfo.OK()) { throw YR::Exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); } @@ -919,20 +1120,19 @@ void ClusterModeRuntime::LoadState(const int &timeout) } } -void ClusterModeRuntime::Delete(const std::vector &objectIds, std::vector &failedObjectIds) +void ClusterModeRuntime::DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) { - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->Delete(objectIds, failedObjectIds); + auto errInfo = GetLibRuntime()->DevDelete(objectIds, failedObjectIds); if (!errInfo.OK()) { throw YR::HeteroException(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg(), failedObjectIds); } } -void ClusterModeRuntime::LocalDelete(const std::vector &objectIds, - std::vector &failedObjectIds) +void ClusterModeRuntime::DevLocalDelete(const std::vector &objectIds, + std::vector &failedObjectIds) { - auto errInfo = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->LocalDelete(objectIds, failedObjectIds); + auto errInfo = GetLibRuntime()->DevLocalDelete(objectIds, failedObjectIds); if (!errInfo.OK()) { throw YR::HeteroException(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg(), failedObjectIds); @@ -945,8 +1145,7 @@ void ClusterModeRuntime::DevSubscribe(const std::vector &keys, { auto libDevBlobList = BuildLibDeviceBlobList(blob2dList); std::vector> libHeteroFutureVec; - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->DevSubscribe(keys, libDevBlobList, - libHeteroFutureVec); + auto errInfo = GetLibRuntime()->DevSubscribe(keys, libDevBlobList, libHeteroFutureVec); if (!errInfo.OK()) { throw YR::HeteroException(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); } @@ -960,8 +1159,7 @@ void ClusterModeRuntime::DevPublish(const std::vector &keys, const { auto libDevBlobList = BuildLibDeviceBlobList(blob2dList); std::vector> libHeteroFutureVec; - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->DevPublish(keys, libDevBlobList, - libHeteroFutureVec); + auto errInfo = GetLibRuntime()->DevPublish(keys, libDevBlobList, libHeteroFutureVec); if (!errInfo.OK()) { throw YR::HeteroException(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); } @@ -974,8 +1172,7 @@ void ClusterModeRuntime::DevMSet(const std::vector &keys, const std std::vector &failedKeys) { auto libDevBlobList = BuildLibDeviceBlobList(blob2dList); - auto errInfo = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->DevMSet(keys, libDevBlobList, failedKeys); + auto errInfo = GetLibRuntime()->DevMSet(keys, libDevBlobList, failedKeys); if (!errInfo.OK()) { throw YR::HeteroException(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg(), failedKeys); @@ -986,8 +1183,7 @@ void ClusterModeRuntime::DevMGet(const std::vector &keys, const std std::vector &failedKeys, int32_t timeoutSec) { auto libDevBlobList = BuildLibDeviceBlobList(blob2dList); - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->DevMGet(keys, libDevBlobList, - failedKeys, timeoutSec); + auto errInfo = GetLibRuntime()->DevMGet(keys, libDevBlobList, failedKeys, timeoutSec); if (!errInfo.OK()) { throw YR::HeteroException(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg(), failedKeys); @@ -997,8 +1193,7 @@ void ClusterModeRuntime::DevMGet(const std::vector &keys, const std internal::FuncMeta ClusterModeRuntime::GetInstance(const std::string &name, const std::string &nameSpace, int timeoutSec) { - auto [funcMeta, errInfo] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GetInstance(name, nameSpace, timeoutSec); + auto [funcMeta, errInfo] = GetLibRuntime()->GetInstance(name, nameSpace, timeoutSec); if (!errInfo.OK()) { throw YR::Exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); } @@ -1007,21 +1202,88 @@ internal::FuncMeta ClusterModeRuntime::GetInstance(const std::string &name, cons std::string ClusterModeRuntime::GetInstanceRoute(const std::string &objectId) { - return YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->GetInstanceRoute(objectId); + return GetLibRuntime()->GetInstanceRoute(objectId); } void ClusterModeRuntime::SaveInstanceRoute(const std::string &objectId, const std::string &instanceRoute) { - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SaveInstanceRoute(objectId, instanceRoute); + GetLibRuntime()->SaveInstanceRoute(objectId, instanceRoute); } void ClusterModeRuntime::TerminateInstanceSync(const std::string &instanceId) { - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->Kill( - instanceId, libruntime::Signal::killInstanceSync); + auto errInfo = GetLibRuntime()->Kill(instanceId, libruntime::Signal::killInstanceSync); if (!errInfo.OK()) { - YR::Exception exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); - throw exception; + throw YR::Exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); + } +} + +std::vector ClusterModeRuntime::Nodes() +{ + auto [errInfo, resourceUnitVector] = GetLibRuntime()->GetResources(); + if (!errInfo.OK()) { + throw YR::Exception(static_cast(errInfo.Code()), static_cast(errInfo.MCode()), errInfo.Msg()); + } + std::vector nodes; + for (const auto &resourceUnit : resourceUnitVector) { + Node node{ + .id = resourceUnit.id, + .alive = (resourceUnit.status == NORMAL_STATUS), + .resources = resourceUnit.capacity, + .labels = resourceUnit.nodeLabels, + }; + nodes.push_back(node); } + return nodes; } + +std::shared_ptr ClusterModeRuntime::CreateMutableBuffer(uint64_t size) +{ + std::shared_ptr buf; + auto res = GetLibRuntime()->CreateBuffer(size, buf); + if (!res.first.OK()) { + throw YR::Exception(static_cast(res.first.Code()), static_cast(res.first.MCode()), res.first.Msg()); + } + return std::make_shared(res.second, buf); +} + +std::vector> ClusterModeRuntime::GetMutableBuffer(const std::vector &ids, + int timeoutSec) +{ + auto timeoutMs = timeoutSec * S_TO_MS; + auto res = GetLibRuntime()->GetBuffers(ids, timeoutMs, false); + if (!res.first.OK() || ids.size() != res.second.size()) { + throw YR::Exception(static_cast(res.first.Code()), static_cast(res.first.MCode()), res.first.Msg()); + } + std::vector> buffers; + buffers.resize(res.second.size()); + for (size_t i = 0; i < res.second.size(); i++) { + if (!res.second[i]) { + continue; + } + buffers[i] = std::make_shared(ids[i], res.second[i]); + } + return buffers; +} + +std::shared_future ClusterModeRuntime::TerminateInstanceAsync(const std::string &instanceId, bool isSync) +{ + int sigNo = libruntime::Signal::KillInstance; + if (isSync) { + sigNo = libruntime::Signal::killInstanceSync; + } + auto promise = std::make_shared>(); + auto f = promise->get_future().share(); + auto cb = [promise](const YR::Libruntime::ErrorInfo &err) { + if (err.OK()) { + promise->set_value(); + } else { + promise->set_exception(std::make_exception_ptr( + YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()))); + } + }; + GetLibRuntime()->KillAsync(instanceId, sigNo, cb); + return f; +} + } // namespace YR diff --git a/api/cpp/src/cluster_mode_runtime.h b/api/cpp/src/cluster_mode_runtime.h index 3c9d6ae..d25d0bd 100644 --- a/api/cpp/src/cluster_mode_runtime.h +++ b/api/cpp/src/cluster_mode_runtime.h @@ -21,10 +21,10 @@ #include #include "config_manager.h" +#include "src/libruntime/libruntime_config.h" #include "src/libruntime/libruntime_options.h" #include "yr/api/runtime.h" #include "yr/api/wait_result.h" - namespace YR { YR::Libruntime::InvokeOptions BuildOptions(const YR::InvokeOptions &opts); class ClusterModeRuntime : public Runtime { @@ -38,9 +38,14 @@ public: // return objid std::string Put(std::shared_ptr data, const std::unordered_set &nestedId); + std::string Put(std::shared_ptr data, const std::unordered_set &nestedId); + std::string Put(std::shared_ptr data, const std::unordered_set &nestedId, const CreateParam &createParam); + std::string Put(std::shared_ptr data, const std::unordered_set &nestedId, + const CreateParam &createParam); + void Put(const std::string &objId, std::shared_ptr data, const std::unordered_set &nestedId); @@ -69,6 +74,19 @@ public: std::vector KVDel(const std::vector &keys, const DelParam &delParam = {}); + std::vector KVExist(const std::vector &keys); + + std::shared_ptr CreateStreamProducer(const std::string &streamName, ProducerConf producerConf); + + std::shared_ptr CreateStreamConsumer(const std::string &streamName, const SubscriptionConfig &config, + bool autoAck = false); + + void DeleteStream(const std::string &streamName); + + void QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum); + + void QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum); + std::string InvokeByName(const internal::FuncMeta &funcMeta, std::vector &args, const InvokeOptions &opt) override; @@ -97,7 +115,7 @@ public: std::string GenerateGroupName() override; // throw YR::Libruntime::Exception - void IncreGlobalReference(const std::vector &objids); + void IncreGlobalReference(const std::vector &objids, bool toDatasystem = true); void DecreGlobalReference(const std::vector &objids); @@ -119,9 +137,9 @@ public: void LoadState(const int &timeout); - void Delete(const std::vector &objectIds, std::vector &failedObjectIds) override; + void DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) override; - void LocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) override; + void DevLocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) override; void DevSubscribe(const std::vector &keys, const std::vector &blob2dList, std::vector> &futureVec) override; @@ -136,7 +154,7 @@ public: std::vector &failedKeys, int32_t timeout) override; internal::FuncMeta GetInstance(const std::string &name, const std::string &nameSpace, int timeoutSec) override; - + std::string GetGroupInstanceIds(const std::string &objectId); void SaveGroupInstanceIds(const std::string &objectId, const std::string &groupInsIds, const InvokeOptions &opts); @@ -146,5 +164,18 @@ public: void SaveInstanceRoute(const std::string &objectId, const std::string &instanceRoute); void TerminateInstanceSync(const std::string &instanceId); + + std::vector Nodes(); + + std::shared_ptr CreateMutableBuffer(uint64_t size) override; + + std::vector> GetMutableBuffer(const std::vector &ids, + int timeout) override; + + std::shared_future TerminateInstanceAsync(const std::string &instanceId, bool isSync) override; + +private: + std::string Put(const void *data, uint64_t dataSize, const std::unordered_set &nestedId, + const CreateParam &createParam); }; } // namespace YR diff --git a/api/cpp/src/config_manager.cpp b/api/cpp/src/config_manager.cpp index eeaf873..78beb31 100644 --- a/api/cpp/src/config_manager.cpp +++ b/api/cpp/src/config_manager.cpp @@ -49,6 +49,7 @@ ABSL_FLAG(std::vector, codePath, std::vector(), namespace { const int MAX_THREADPOOL_SIZE = 64; const int MIN_THREADPOOL_SIZE = 1; +const int MAX_LOCAL_THREAD_POOL_SIZE = 300; } // namespace namespace YR { @@ -98,6 +99,18 @@ bool GetValidLogCompress(bool logCompress) return (logCompress = YR::Libruntime::Config::Instance().YR_LOG_COMPRESS()); } +int GetValidLocalThreadPoolSize(int threadPoolSize) +{ + if (threadPoolSize > MAX_LOCAL_THREAD_POOL_SIZE || threadPoolSize < MIN_THREADPOOL_SIZE) { + // default is the number of CPUs + std::cerr << "Config localThreadPoolSize is invalid; the valid range is " << MIN_THREADPOOL_SIZE << " to " + << MAX_LOCAL_THREAD_POOL_SIZE << "; set to core number " + << static_cast(std::thread::hardware_concurrency()) << " by default" << std::endl; + return static_cast(std::thread::hardware_concurrency()); + } + return threadPoolSize; +} + int GetValidThreadPoolSize(int threadPoolSize) { if (threadPoolSize > MAX_THREADPOOL_SIZE || threadPoolSize < MIN_THREADPOOL_SIZE) { @@ -133,7 +146,10 @@ ClientInfo ConfigManager::GetClientInfo() } void ConfigManager::ClearPasswd() -{} +{ + memset_s(privateKeyPaaswd, MAX_PASSWD_LENGTH, 0, MAX_PASSWD_LENGTH); + encryptPrivateKeyPasswd.clear(); +} void ConfigManager::Init(const Config &conf, int argc, char **argv) { @@ -151,9 +167,15 @@ void ConfigManager::Init(const Config &conf, int argc, char **argv) this->privateKeyPath = conf.privateKeyPath; this->certificateFilePath = conf.certificateFilePath; this->verifyFilePath = conf.verifyFilePath; + int len = sizeof(conf.privateKeyPaaswd); + memcpy_s(this->privateKeyPaaswd, len, conf.privateKeyPaaswd, len); + this->encryptPrivateKeyPasswd = conf.encryptPrivateKeyPasswd; } this->primaryKeyStoreFile = conf.primaryKeyStoreFile; this->standbyKeyStoreFile = conf.standbyKeyStoreFile; + this->tlsContext = conf.tlsContext; + this->inCluster = conf.inCluster; + this->httpIocThreadsNum = conf.httpIocThreadsNum; this->serverName = conf.serverName; this->isDriver = conf.isDriver; this->isLowReliabilityTask = conf.isLowReliabilityTask; @@ -176,9 +198,12 @@ void ConfigManager::Init(const Config &conf, int argc, char **argv) } this->enableDsEncrypt = conf.enableDsEncrypt; if (conf.enableDsEncrypt) { - this->dsPublicKeyContextPath = conf.dsPublicKeyContextPath; - this->runtimePublicKeyContextPath = conf.runtimePublicKeyContextPath; - this->runtimePrivateKeyContextPath = conf.runtimePrivateKeyContextPath; + this->dsPublicKeyContext = conf.dsPublicKeyContext; + this->runtimePublicKeyContext = conf.runtimePublicKeyContext; + this->runtimePrivateKeyContext = conf.runtimePrivateKeyContext; + this->encryptDsPublicKeyContext = conf.encryptDsPublicKeyContext; + this->encryptRuntimePublicKeyContext = conf.encryptRuntimePublicKeyContext; + this->encryptRuntimePrivateKeyContext = conf.encryptRuntimePrivateKeyContext; } if (conf.threadPoolSize > 0) { @@ -187,7 +212,7 @@ void ConfigManager::Init(const Config &conf, int argc, char **argv) this->localThreadPoolSize = conf.localThreadPoolSize; if (conf.localThreadPoolSize > 0) { - this->localThreadPoolSize = static_cast(GetValidThreadPoolSize(conf.localThreadPoolSize)); + this->localThreadPoolSize = static_cast(GetValidLocalThreadPoolSize(conf.localThreadPoolSize)); } this->defaultGetTimeoutSec = conf.defaultGetTimeoutSec; @@ -242,6 +267,7 @@ void ConfigManager::Init(const Config &conf, int argc, char **argv) this->logCompress = GetValidLogCompress(conf.logCompress); this->maxLogFileNum = conf.maxLogFileNum; this->maxLogFileSize = conf.maxLogSizeMb; + this->tenantId = conf.tenantId; if (argc != 0 && argv != nullptr) { absl::ParseCommandLine(argc, argv); this->logFlushInterval = absl::GetFlag(FLAGS_logFlushInterval); @@ -270,20 +296,22 @@ void ConfigManager::Init(const Config &conf, int argc, char **argv) const auto &codePathsValue = absl::GetFlag(FLAGS_codePath); this->loadPaths.insert(this->loadPaths.end(), codePathsValue.begin(), codePathsValue.end()); } - - // parse the info with auto init - YR::Libruntime::ClusterAccessInfo info{ - .serverAddr = this->functionSystemAddr, - .dsAddr = this->dataSystemAddr, - .inCluster = this->inCluster, - }; - info = YR::Libruntime::AutoGetClusterAccessInfo(info); - this->functionSystemAddr = info.serverAddr; // leading protocol will be trimmed, the value would never change - this->dataSystemAddr = info.dsAddr; // changes when this is empty - this->inCluster = info.inCluster; // changes only when read from masterinfo, - // or a protocol is specified in functionSystemAddr + if (conf.mode == Config::CLUSTER_MODE) { + // parse the info with auto init + YR::Libruntime::ClusterAccessInfo info{ + .serverAddr = this->functionSystemAddr, + .dsAddr = this->dataSystemAddr, + .inCluster = this->inCluster, + }; + info = YR::Libruntime::AutoGetClusterAccessInfo(info); + this->functionSystemAddr = info.serverAddr; // leading protocol will be trimmed, the value would never change + this->dataSystemAddr = info.dsAddr; // changes when this is empty + this->inCluster = info.inCluster; // changes only when read from masterinfo, + // or a protocol is specified in functionSystemAddr + } this->customEnvs = conf.customEnvs; + this->launchUserBinary = conf.launchUserBinary; } bool ConfigManager::IsLocalMode() const diff --git a/api/cpp/src/config_manager.h b/api/cpp/src/config_manager.h index 8a56560..8b5cd48 100644 --- a/api/cpp/src/config_manager.h +++ b/api/cpp/src/config_manager.h @@ -78,6 +78,8 @@ public: int maxTaskInstanceNum; + std::string autoFunctionName; + std::string functionId; std::string functionIdPython; @@ -96,20 +98,34 @@ public: std::string verifyFilePath = ""; + char privateKeyPaaswd[MAX_PASSWD_LENGTH] = {0}; + + std::string encryptPrivateKeyPasswd; + bool enableDsAuth = false; bool enableDsEncrypt = false; - std::string dsPublicKeyContextPath = ""; + std::string dsPublicKeyContext = ""; + + std::string encryptDsPublicKeyContext; + + std::string runtimePublicKeyContext = ""; - std::string runtimePublicKeyContextPath = ""; + std::string encryptRuntimePublicKeyContext; - std::string runtimePrivateKeyContextPath = ""; + std::string runtimePrivateKeyContext = ""; + + std::string encryptRuntimePrivateKeyContext; std::string primaryKeyStoreFile; std::string standbyKeyStoreFile; + std::shared_ptr tlsContext; + + uint32_t httpIocThreadsNum; + std::string serverName = ""; std::string ns = ""; @@ -127,5 +143,7 @@ public: bool logToDriver = false; bool dedupLogs = true; + + bool launchUserBinary = false; }; } // namespace YR diff --git a/api/cpp/src/datasystem_buffer.cpp b/api/cpp/src/datasystem_buffer.cpp new file mode 100644 index 0000000..8242e8d --- /dev/null +++ b/api/cpp/src/datasystem_buffer.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem_buffer.h" + +#include "yr/api/err_type.h" +#include "yr/api/exception.h" +#include "yr/api/object_ref.h" + +namespace YR { +DataSystemBuffer::DataSystemBuffer(std::string objId, std::shared_ptr buffer) + : objId_(std::move(objId)), buffer_(std::move(buffer)) +{ +} + +void *DataSystemBuffer::MutableData() +{ + return buffer_->MutableData(); +} + +ObjectRef DataSystemBuffer::Publish() +{ + if (const YR::Libruntime::ErrorInfo err = buffer_->Publish(); !err.OK()) { + throw YR::Exception(err.Code(), YR::ModuleCode::DATASYSTEM_, "MutableBuffer->Publish() fail"); + } + return ObjectRef(objId_); +} + +int64_t DataSystemBuffer::GetSize() +{ + return buffer_->GetSize(); +} +} // namespace YR diff --git a/api/cpp/src/datasystem_buffer.h b/api/cpp/src/datasystem_buffer.h new file mode 100644 index 0000000..e957d4a --- /dev/null +++ b/api/cpp/src/datasystem_buffer.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/dto/buffer.h" +#include "yr/api/mutable_buffer.h" + +namespace YR { +class DataSystemBuffer : public MutableBuffer { +public: + DataSystemBuffer(std::string objId, std::shared_ptr buffer); + + void *MutableData() override; + + ObjectRef Publish() override; + + int64_t GetSize() override; + +private: + std::string objId_; + std::shared_ptr buffer_; +}; +} // namespace YR \ No newline at end of file diff --git a/api/cpp/src/executor/executor_holder.cpp b/api/cpp/src/executor/executor_holder.cpp index 17df6cb..21b95e8 100644 --- a/api/cpp/src/executor/executor_holder.cpp +++ b/api/cpp/src/executor/executor_holder.cpp @@ -45,6 +45,11 @@ Libruntime::ErrorInfo LoadFunctions(const std::vector &paths) return ExecutorHolder::Singleton().GetExecutor()->LoadFunctions(paths); } +Libruntime::ErrorInfo LoadNoneFunctions(const std::vector &paths) +{ + return Libruntime::ErrorInfo(); +} + Libruntime::ErrorInfo ExecuteFunction(const YR::Libruntime::FunctionMeta &function, const libruntime::InvokeType invokeType, const std::vector> &rawArgs, diff --git a/api/cpp/src/executor/executor_holder.h b/api/cpp/src/executor/executor_holder.h index b3202f0..64379f3 100644 --- a/api/cpp/src/executor/executor_holder.h +++ b/api/cpp/src/executor/executor_holder.h @@ -34,6 +34,8 @@ private: Libruntime::ErrorInfo LoadFunctions(const std::vector &paths); +Libruntime::ErrorInfo LoadNoneFunctions(const std::vector &paths); + Libruntime::ErrorInfo ExecuteFunction(const YR::Libruntime::FunctionMeta &function, const libruntime::InvokeType invokeType, const std::vector> &rawArgs, diff --git a/api/cpp/src/faas/context_env.cpp b/api/cpp/src/faas/context_env.cpp new file mode 100644 index 0000000..6125d11 --- /dev/null +++ b/api/cpp/src/faas/context_env.cpp @@ -0,0 +1,109 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "api/cpp/src/faas/context_env.h" + +namespace Function { + +int ContextEnv::GetRunningTimeInSeconds() const +{ + return runningTimeInSeconds_; +} + +int ContextEnv::GetCPUNumber() const +{ + return cpuNumber_; +} +int ContextEnv::GetMemorySize() const +{ + return memorySize_; +} + +const std::string ContextEnv::GetInstanceLabel() const +{ + return instanceLabel_; +} + +const std::string ContextEnv::GetUserData(std::string key) const +{ + auto userDataIt = userData_.find(key); + if (userDataIt != userData_.end()) { + return userDataIt->second; + } + return ""; +} + +const std::string ContextEnv::GetFuncPackage() const +{ + return funcPackage_; +} + +const std::string ContextEnv::GetFunctionName() const +{ + return functionName_; +} + +const std::string ContextEnv::GetVersion() const +{ + return version_; +} + +const std::string ContextEnv::GetProjectID() const +{ + return projectID_; +} + +void ContextEnv::SetUserData(std::unordered_map &funcKey) +{ + std::swap(funcKey, userData_); +} + +void ContextEnv::SetRunningTimeInSeconds(int runningTime) +{ + runningTimeInSeconds_ = runningTime; +} +void ContextEnv::SetCPUNumber(int cpuNum) +{ + cpuNumber_ = cpuNum; +} +void ContextEnv::SetMemorySize(int memorySz) +{ + memorySize_ = memorySz; +} + +void ContextEnv::SetInstanceLabel(const std::string &instanceLabel) +{ + instanceLabel_ = instanceLabel; +} + +void ContextEnv::SetFuncPackage(const std::string &package) +{ + funcPackage_ = package; +} +void ContextEnv::SetFunctionName(const std::string &funcName) +{ + functionName_ = funcName; +} +void ContextEnv::SetVersion(const std::string &funcVersion) +{ + version_ = funcVersion; +} +void ContextEnv::SetProjectID(const std::string &proId) +{ + projectID_ = proId; +} + +} // namespace Function \ No newline at end of file diff --git a/api/cpp/src/faas/context_env.h b/api/cpp/src/faas/context_env.h new file mode 100644 index 0000000..f3dcfa9 --- /dev/null +++ b/api/cpp/src/faas/context_env.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +namespace Function { +class ContextEnv { +public: + int GetRunningTimeInSeconds() const; + int GetCPUNumber() const; + int GetMemorySize() const; + const std::string GetInstanceLabel() const; + const std::string GetUserData(std::string key) const; + const std::string GetFuncPackage() const; + const std::string GetFunctionName() const; + const std::string GetVersion() const; + const std::string GetProjectID() const; + + void SetRunningTimeInSeconds(int runningTimeInSeconds); + void SetCPUNumber(int cpuNumber); + void SetMemorySize(int memorySize); + void SetInstanceLabel(const std::string &instanceLabel); + void SetUserData(std::unordered_map &userData); + void SetFuncPackage(const std::string &package); + void SetFunctionName(const std::string &funcName); + void SetVersion(const std::string &version); + void SetProjectID(const std::string &projectID); + +private: + int runningTimeInSeconds_ = 0; + int cpuNumber_ = 0; + int memorySize_ = 0; + std::string instanceLabel_; + std::unordered_map userData_; + std::string funcPackage_; + std::string functionName_; + std::string version_; + std::string projectID_; +}; +} // namespace Function diff --git a/api/cpp/src/faas/context_impl.cpp b/api/cpp/src/faas/context_impl.cpp new file mode 100644 index 0000000..27f1b84 --- /dev/null +++ b/api/cpp/src/faas/context_impl.cpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "api/cpp/src/faas/context_impl.h" + +namespace Function { + +ContextImpl::ContextImpl(std::shared_ptr &invokeParams, std::shared_ptr &contextEnv) +{ + this->contextInvokeParams_ = invokeParams; + this->contextEnv_ = contextEnv; +} + +ContextImpl::ContextImpl(const ContextImpl &contextImpl) +{ + this->stateId_ = contextImpl.stateId_; + this->state_ = contextImpl.state_; + this->property_ = contextImpl.property_; + this->funcStartTime_ = contextImpl.funcStartTime_; + this->contextInvokeParams_ = contextImpl.contextInvokeParams_; + this->contextEnv_ = contextImpl.contextEnv_; +} + +ContextImpl &ContextImpl::operator=(const ContextImpl &contextImpl) +{ + this->stateId_ = contextImpl.stateId_; + this->state_ = contextImpl.state_; + this->property_ = contextImpl.property_; + this->funcStartTime_ = contextImpl.funcStartTime_; + this->contextInvokeParams_ = contextImpl.contextInvokeParams_; + this->contextEnv_ = contextImpl.contextEnv_; + return *this; +} + +const std::string ContextImpl::GetTraceId() const +{ + return contextInvokeParams_->GetTraceId(); +} + +const std::string ContextImpl::GetInvokeId() const +{ + return contextInvokeParams_->GetInvokeId(); +} + +const FunctionLogger &ContextImpl::GetLogger() +{ + logger_.SetInvokeID(GetInvokeId()); + logger_.SetTraceID(GetTraceId()); + return logger_; +} + +const std::string ContextImpl::GetInstanceId() const +{ + return stateId_; +} + +const std::string ContextImpl::GetInstanceLabel() const +{ + return contextEnv_->GetInstanceLabel(); +} + +void ContextImpl::SetStateId(const std::string &stateId) +{ + std::lock_guard lock(stateMtx_); + this->stateId_ = stateId; +} + +const std::string ContextImpl::GetState() const +{ + return state_; +} + +void ContextImpl::SetState(const std::string &state) +{ + std::lock_guard lock(stateMtx_); + this->state_ = state; +} + +const std::string ContextImpl::GetInvokeProperty() const +{ + return property_; +} + +void ContextImpl::SetInvokeProperty(const std::string &prop) +{ + this->property_ = prop; +} + +const std::string ContextImpl::GetRequestID() const +{ + return contextInvokeParams_->GetRequestId(); +} + +const std::string ContextImpl::GetUserData(std::string key) const +{ + return contextEnv_->GetUserData(key); +} + +const std::string ContextImpl::GetFunctionName() const +{ + return contextEnv_->GetFunctionName(); +} + +int ContextImpl::GetRunningTimeInSeconds() const +{ + return contextEnv_->GetRunningTimeInSeconds(); +} + +int ContextImpl::GetRemainingTimeInMilliSeconds() const +{ + return 0; +} + +const std::string ContextImpl::GetVersion() const +{ + return contextEnv_->GetVersion(); +} + +int ContextImpl::GetMemorySize() const +{ + return contextEnv_->GetMemorySize(); +} + +int ContextImpl::GetCPUNumber() const +{ + return contextEnv_->GetCPUNumber(); +} + +const std::string ContextImpl::GetProjectID() const +{ + return contextEnv_->GetProjectID(); +} + +const std::string ContextImpl::GetPackage() const +{ + return contextEnv_->GetFuncPackage(); +} + +const std::string ContextImpl::GetAccessKey() const +{ + return contextInvokeParams_->GetAccessKey(); +} + +const std::string ContextImpl::GetAlias() const +{ + return contextInvokeParams_->GetAlias(); +} + +const std::string ContextImpl::GetSecretKey() const +{ + return contextInvokeParams_->GetSecretKey(); +} + +const std::string ContextImpl::GetSecurityAccessKey() const +{ + return contextInvokeParams_->GetSecurityAccessKey(); +} + +const std::string ContextImpl::GetSecuritySecretKey() const +{ + return contextInvokeParams_->GetSecuritySecretKey(); +} + +const std::string ContextImpl::GetToken() const +{ + return contextInvokeParams_->GetToken(); +} + +void ContextImpl::SetFuncStartTime(const long startTime) +{ + this->funcStartTime_ = startTime; +} +} // namespace Function diff --git a/api/cpp/src/faas/context_impl.h b/api/cpp/src/faas/context_impl.h new file mode 100644 index 0000000..8d9f4be --- /dev/null +++ b/api/cpp/src/faas/context_impl.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "Context.h" +#include "api/cpp/src/faas/context_env.h" +#include "api/cpp/src/faas/context_invoke_params.h" + +namespace Function { +class ContextImpl : public Context { +public: + ContextImpl() = default; + explicit ContextImpl(std::shared_ptr &invokeParams, std::shared_ptr &contextEnv); + ContextImpl(const ContextImpl &contextImpl); + ContextImpl &operator=(const ContextImpl &contextImpl); + ~ContextImpl() = default; + const std::string GetAccessKey() const override; + + const std::string GetSecretKey() const override; + + const std::string GetSecurityAccessKey() const override; + + const std::string GetSecuritySecretKey() const override; + + const std::string GetToken() const override; + + const std::string GetAlias() const override; + + const std::string GetTraceId() const override; + + const std::string GetInvokeId() const override; + + const FunctionLogger &GetLogger() override; + + const std::string GetInstanceId() const override; + + const std::string GetInstanceLabel() const override; + + void SetStateId(const std::string &stateId); + + const std::string GetState() const override; + + void SetState(const std::string &state) override; + + const std::string GetInvokeProperty() const override; + + const std::string GetRequestID() const override; + + const std::string GetUserData(std::string key) const override; + + const std::string GetFunctionName() const override; + + int GetRemainingTimeInMilliSeconds() const override; + + int GetRunningTimeInSeconds() const override; + + const std::string GetVersion() const override; + + int GetMemorySize() const override; + + int GetCPUNumber() const override; + + const std::string GetProjectID() const override; + + const std::string GetPackage() const override; + + void SetInvokeProperty(const std::string &property); + + void SetFuncStartTime(const long startTime); + +private: + FunctionLogger logger_; + std::string stateId_; + std::string state_; + std::string property_; + long funcStartTime_; + + std::shared_ptr contextInvokeParams_; + std::shared_ptr contextEnv_; + + mutable std::mutex stateMtx_; +}; +} // namespace Function diff --git a/api/cpp/src/faas/context_invoke_params.cpp b/api/cpp/src/faas/context_invoke_params.cpp new file mode 100644 index 0000000..51cd488 --- /dev/null +++ b/api/cpp/src/faas/context_invoke_params.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "api/cpp/src/faas/context_invoke_params.h" +#include "FunctionLogger.h" + +namespace Function { + +ContextInvokeParams::ContextInvokeParams(const std::unordered_map ¶ms) +{ + auto aliasIt = params.find("X-Invoke-Alias"); + if (aliasIt != params.end()) { + this->alias = aliasIt->second; + } + auto tokenIt = params.find("X-Auth-Token"); + if (tokenIt != params.end()) { + this->token = tokenIt->second; + } + auto securitySecretKeyIt = params.find("X-Security-Secret-Key"); + if (securitySecretKeyIt != params.end()) { + this->securitySecretKey = securitySecretKeyIt->second; + } + auto securityAccessKeyIt = params.find("X-Security-Access-Key"); + if (securityAccessKeyIt != params.end()) { + this->securityAccessKey = securityAccessKeyIt->second; + } + auto accessKeyIt = params.find("X-Access-Key"); + if (accessKeyIt != params.end()) { + this->accessKey = accessKeyIt->second; + } + auto secretKeyIt = params.find("X-Secret-Key"); + if (secretKeyIt != params.end()) { + this->secretKey = secretKeyIt->second; + } + auto securityTokenIt = params.find("X-Security-Token"); + if (securityTokenIt != params.end()) { + this->securityToken = securityTokenIt->second; + } +} + +const std::string ContextInvokeParams::GetAccessKey() const +{ + return accessKey; +} + +const std::string ContextInvokeParams::GetSecretKey() const +{ + return secretKey; +} + +const std::string ContextInvokeParams::GetSecurityAccessKey() const +{ + return securityAccessKey; +} + +const std::string ContextInvokeParams::GetSecuritySecretKey() const +{ + return securitySecretKey; +} + +const std::string ContextInvokeParams::GetRequestId() const +{ + return requestID; +} + +const std::string ContextInvokeParams::GetTraceId() const +{ + return GetRequestId(); +} + +const std::string ContextInvokeParams::GetInvokeId() const +{ + return invokeID; +} + +const std::string ContextInvokeParams::GetToken() const +{ + return token; +} + +const std::string ContextInvokeParams::GetAlias() const +{ + return alias; +} + +void ContextInvokeParams::SetInvokeID(const std::string &id) +{ + this->invokeID = id; +} +void ContextInvokeParams::SetRequestID(const std::string &id) +{ + this->requestID = id; +} + +} // namespace Function \ No newline at end of file diff --git a/api/cpp/src/faas/context_invoke_params.h b/api/cpp/src/faas/context_invoke_params.h new file mode 100644 index 0000000..8778edc --- /dev/null +++ b/api/cpp/src/faas/context_invoke_params.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace Function { + +class ContextInvokeParams { +public: + ContextInvokeParams() = default; + explicit ContextInvokeParams(const std::unordered_map ¶ms); + ~ContextInvokeParams() = default; + const std::string GetAccessKey() const; + + const std::string GetSecretKey() const; + + const std::string GetSecurityAccessKey() const; + + const std::string GetSecuritySecretKey() const; + + const std::string GetRequestId() const; + + const std::string GetTraceId() const; + + const std::string GetInvokeId() const; + + const std::string GetToken() const; + + const std::string GetAlias() const; + + void SetInvokeID(const std::string &id); + void SetRequestID(const std::string &id); + +private: + std::string accessKey = ""; + std::string secretKey = ""; + std::string securityAccessKey = ""; + std::string securitySecretKey = ""; + std::string requestID = ""; + std::string invokeID = ""; + std::string token = ""; + std::string securityToken = ""; + std::string alias = ""; +}; +} // namespace Function diff --git a/api/cpp/src/faas/faas_executor.cpp b/api/cpp/src/faas/faas_executor.cpp new file mode 100644 index 0000000..5892651 --- /dev/null +++ b/api/cpp/src/faas/faas_executor.cpp @@ -0,0 +1,327 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "api/cpp/src/faas/faas_executor.h" + +#include "json.hpp" + +#include "FunctionError.h" +#include "api/cpp/src/faas/context_impl.h" +#include "api/cpp/src/utils/utils.h" +#include "src/libruntime/libruntime_manager.h" +#include "src/utility/json_utility.h" +#include "src/utility/logger/logger.h" +namespace Function { + +using YR::utility::JsonGet; + +const int META_DATA_INDEX = 0; +const int USER_EVENT_INDEX = 1; +const int CREATE_PARAM_INDEX = 1; +const int RUNTIME_MAX_RESP_BODY_SIZE = 6 * 1024 * 1024; + +const char *const PRE_STOP_INVOKE_ID = "preStop"; +const char *const INITIALIZER_INVOKE_ID = "initializer"; +const char *const INSTANCE_LABEL_KEY = "instanceLabel"; + +const std::string LD_LIBRARY_PATH = "LD_LIBRARY_PATH"; + +std::string MakeErrorResult(int code, std::string message) +{ + nlohmann::json returnValJson; + returnValJson["innerCode"] = std::to_string(code); + returnValJson["body"] = message; + return returnValJson.dump(); +} + +static std::unique_ptr runtimeHandler; + +void SetRuntimeHandler(std::unique_ptr r) +{ + runtimeHandler = std::move(r); +} + +YR::Libruntime::ErrorInfo FaasExecutor::LoadFunctions(const std::vector &paths) +{ + return {}; +} + +YR::Libruntime::ErrorInfo FaasExecutor::ExecuteFunction( + const YR::Libruntime::FunctionMeta &function, const libruntime::InvokeType invokeType, + const std::vector> &rawArgs, + std::vector> &returnObjects) +{ + std::string result; + if (invokeType == libruntime::InvokeType::CreateInstance || + invokeType == libruntime::InvokeType::CreateInstanceStateless) { + auto initErr = FaasInitHandler(rawArgs); + if (!initErr.OK()) { + return initErr; + } + } else if (invokeType == libruntime::InvokeType::InvokeFunctionStateless || + invokeType == libruntime::InvokeType::InvokeFunction) { + result = FaasCallHandler(rawArgs); + } + + YRLOG_DEBUG("finish to execute faas function, result: {}", result); + if (result.empty()) { + return {}; + } + + uint64_t totalNativeBufferSize = 0; + auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->AllocReturnObject( + returnObjects[0], 0, result.size(), {}, totalNativeBufferSize); + if (!err.OK()) { + return err; + } + return YR::WriteDataObject(result.data(), returnObjects[0], result.size(), {}); +} + +YR::Libruntime::ErrorInfo FaasExecutor::Checkpoint(const std::string &instanceID, + std::shared_ptr &data) +{ + return {}; +} + +YR::Libruntime::ErrorInfo FaasExecutor::Recover(std::shared_ptr data) +{ + return {}; +} + +YR::Libruntime::ErrorInfo FaasExecutor::ExecuteShutdownFunction(uint64_t gracePeriodSecond) +{ + if (contextInvokeParams_ == nullptr || contextEnv_ == nullptr) { + return YR::Libruntime::ErrorInfo( + YR::Libruntime::ErrorCode::ERR_USER_FUNCTION_EXCEPTION, YR::Libruntime::ModuleCode::RUNTIME, + FunctionError(ErrorCode::FUNCTION_EXCEPTION, "can not call prestop before initialize").GetJsonString()); + } + contextInvokeParams_->SetInvokeID(PRE_STOP_INVOKE_ID); + auto context = ContextImpl(contextInvokeParams_, contextEnv_); + try { + runtimeHandler->PreStop(context); + } catch (FunctionError &functionError) { + YRLOG_ERROR("execute prestop failed, {}", functionError.GetJsonString()); + return YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_USER_FUNCTION_EXCEPTION, + YR::Libruntime::ModuleCode::RUNTIME, functionError.GetJsonString()); + } catch (std::exception &e) { + YRLOG_ERROR("execute prestop failed, {}", e.what()); + return YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_USER_FUNCTION_EXCEPTION, + YR::Libruntime::ModuleCode::RUNTIME, + FunctionError(ErrorCode::FUNCTION_EXCEPTION, e.what()).GetJsonString()); + } + return {}; +} + +YR::Libruntime::ErrorInfo FaasExecutor::Signal(int sigNo, std::shared_ptr payload) +{ + return {}; +} + +void FaasExecutor::ParseContextMeta(const std::string &contextMeta) +{ + auto contextMetaJson = nlohmann::json::parse(contextMeta); + std::string emptyStr = ""; + if (contextMetaJson.contains("funcMetaData")) { + auto funcMetaData = contextMetaJson.at("funcMetaData"); + contextEnv_->SetProjectID(JsonGet(funcMetaData, "tenantId", emptyStr)); + contextEnv_->SetFuncPackage(JsonGet(funcMetaData, "service", emptyStr)); + contextEnv_->SetFunctionName(JsonGet(funcMetaData, "func_name", emptyStr)); + contextEnv_->SetVersion(JsonGet(funcMetaData, "version", emptyStr)); + } + if (contextMetaJson.contains("resourceMetaData")) { + auto resourceMetaData = contextMetaJson.at("resourceMetaData"); + contextEnv_->SetCPUNumber(JsonGet(resourceMetaData, "cpu", 0)); + contextEnv_->SetMemorySize(JsonGet(resourceMetaData, "memory", 0)); + } +} + +void FaasExecutor::ParseCreateParams(const std::string &createParam) +{ + auto createParamJson = nlohmann::json::parse(createParam); + std::string emptyStr = ""; + if (createParamJson.contains(std::string(INSTANCE_LABEL_KEY))) { + contextEnv_->SetInstanceLabel(JsonGet(createParamJson, std::string(INSTANCE_LABEL_KEY), emptyStr)); + } +} + +void FaasExecutor::ParseDelegateDecrypt(const std::string &delegateDecrypt) +{ + if (delegateDecrypt.empty()) { + return; + } + auto delegateDecryptJson = nlohmann::json::parse(delegateDecrypt); + std::unordered_map userData; + if (delegateDecryptJson.contains("environment")) { + std::string envStr = delegateDecryptJson.at("environment"); + if (!envStr.empty()) { + auto env = nlohmann::json::parse(envStr); + for (auto &[key, value] : env.items()) { + YRLOG_DEBUG("setenv {}={}", key, value.dump()); + if (key == LD_LIBRARY_PATH) { + auto newPath = YR::GetEnv(key) + ":" + value.dump(); + userData[key] = newPath; + YR::SetEnv(key, newPath); + } else { + userData[key] = value; + YR::SetEnv(key, value); + } + } + } + } + if (delegateDecryptJson.contains("encrypted_user_data")) { + std::string userDataStr = delegateDecryptJson.at("encrypted_user_data"); + if (!userDataStr.empty()) { + auto userDataJson = nlohmann::json::parse(userDataStr); + for (auto &[key, value] : userDataJson.items()) { + YRLOG_DEBUG("set user data {}={}", key, value.dump()); + if (key == LD_LIBRARY_PATH) { + auto newPath = YR::GetEnv(key) + ":" + value.dump(); + userData[key] = newPath; + } else { + userData[key] = value; + } + } + } + } + + contextEnv_->SetUserData(userData); +} + +void SetTraceID(nlohmann::json callReqJson, std::shared_ptr &contextInvokeParams) +{ + std::string traceId = ""; + std::string headerKey = "header"; + if (callReqJson.contains(headerKey)) { + auto header = callReqJson.at(headerKey); + std::string traceIdKey = "X-Trace-Id"; + if (header.contains(traceIdKey)) { + traceId = header.at(traceIdKey); + } + } + contextInvokeParams->SetRequestID(traceId); +} + +YR::Libruntime::ErrorInfo FaasExecutor::FaasInitHandler( + const std::vector> &rawArgs) +{ + YRLOG_DEBUG("start to call faas init handler, arg size: {} ", rawArgs.size()); + contextEnv_ = std::make_shared(); + std::unordered_map params; + contextInvokeParams_ = std::make_shared(params); + auto contextMeta = std::string(static_cast(rawArgs[META_DATA_INDEX]->data->ImmutableData()), + rawArgs[META_DATA_INDEX]->data->GetSize()); + if (contextMeta.empty()) { + YRLOG_WARN("failed to parse context with empty string."); + return YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "failed to parse context with empty string"); + } + try { + ParseContextMeta(contextMeta); + auto delegateDecrypt = YR::GetEnv("ENV_DELEGATE_DECRYPT"); + ParseDelegateDecrypt(delegateDecrypt); + } catch (nlohmann::detail::exception &e) { + YRLOG_WARN("failed to parse context with exception: {} ", e.what()); + return YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + std::string("faild to parse context with exception: ") + e.what()); + } + if (rawArgs.size() > CREATE_PARAM_INDEX) { + auto createParam = std::string(static_cast(rawArgs[CREATE_PARAM_INDEX]->data->ImmutableData()), + rawArgs[CREATE_PARAM_INDEX]->data->GetSize()); + if (!createParam.empty()) { + try { + ParseCreateParams(createParam); + } catch (nlohmann::detail::exception &e) { + YRLOG_WARN("failed to parse create param with exception: {} ", e.what()); + } + } + } + contextInvokeParams_->SetInvokeID(INITIALIZER_INVOKE_ID); + auto context = ContextImpl(contextInvokeParams_, contextEnv_); + try { + runtimeHandler->Initializer(context); + } catch (FunctionError &functionError) { + YRLOG_WARN("failed to call initializer: {} ", functionError.GetJsonString()); + return YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_USER_FUNCTION_EXCEPTION, + functionError.GetJsonString()); + } catch (std::exception &e) { + YRLOG_WARN("failed to call initializer:: {} ", e.what()); + return YR::Libruntime::ErrorInfo( + YR::Libruntime::ErrorCode::ERR_USER_FUNCTION_EXCEPTION, + std::string("failed to call initializer function with exception: ") + e.what()); + } + YRLOG_DEBUG("success to call faas init handler"); + return {}; +} + +std::string FaasExecutor::FaasCallHandler(const std::vector> &rawArgs) +{ + YRLOG_DEBUG("start to call faas call handler, arg size: {} ", rawArgs.size()); + auto callReq = std::string(static_cast(rawArgs[USER_EVENT_INDEX]->data->ImmutableData()), + rawArgs[USER_EVENT_INDEX]->data->GetSize()); + if (callReq.empty()) { + return MakeErrorResult(ErrorCode::FUNCTION_EXCEPTION, "call req is empty"); + } + auto callReqJson = nlohmann::json::parse(callReq); + std::string bodyKey = "body"; + if (!callReqJson.contains(bodyKey)) { + return MakeErrorResult(ErrorCode::FUNCTION_EXCEPTION, "can not find body"); + } + std::string eventStr; + auto event = callReqJson.at(bodyKey); + if (event.type() != nlohmann::json::value_t::string) { + eventStr = event.dump(); + } else { + eventStr = event; + } + if (contextInvokeParams_ == nullptr || contextEnv_ == nullptr) { + return MakeErrorResult(ErrorCode::FUNCTION_EXCEPTION, "can not call handlerequest before initialize"); + } + SetTraceID(callReqJson, contextInvokeParams_); + auto context = ContextImpl(contextInvokeParams_, contextEnv_); + auto errCode = ErrorCode::OK; + std::string errMsg = ""; + nlohmann::json returnValJson; + std::string result; + try { + result = runtimeHandler->HandleRequest(eventStr, context); + } catch (FunctionError &functionError) { + errCode = functionError.GetErrorCode(); + errMsg = functionError.GetMessage(); + } catch (std::exception &e) { + errCode = ErrorCode::FUNCTION_EXCEPTION; + errMsg = std::string("failed to call function with exception: ") + e.what(); + } + returnValJson["innerCode"] = std::to_string(errCode); + if (errCode != ErrorCode::OK) { + YRLOG_WARN("failed to call user function, err {}", errMsg); + returnValJson["body"] = errMsg; + return returnValJson.dump(); + } + + if (result.size() > RUNTIME_MAX_RESP_BODY_SIZE) { + std::stringstream ss; + ss << "function result size: " << result.size() << ", exceed limit(" << RUNTIME_MAX_RESP_BODY_SIZE << ")"; + YRLOG_WARN(ss.str()); + return MakeErrorResult(ErrorCode::ILLEGAL_RETURN, ss.str()); + } + returnValJson["body"] = result; + returnValJson["billingDuration"] = "this is billing duration TODO"; + returnValJson["logResult"] = "this is user log TODO"; + returnValJson["invokerSummary"] = "this is summary TODO"; + return returnValJson.dump(); +} + +} // namespace Function \ No newline at end of file diff --git a/api/cpp/src/faas/faas_executor.h b/api/cpp/src/faas/faas_executor.h new file mode 100644 index 0000000..fbf747c --- /dev/null +++ b/api/cpp/src/faas/faas_executor.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "RuntimeHandler.h" +#include "api/cpp/src/executor/executor.h" +#include "api/cpp/src/faas/context_env.h" +#include "api/cpp/src/faas/context_invoke_params.h" + +namespace Function { +void SetRuntimeHandler(std::unique_ptr r); +class FaasExecutor : public YR::internal::Executor { +public: + FaasExecutor() = default; + ~FaasExecutor() = default; + + YR::Libruntime::ErrorInfo LoadFunctions(const std::vector &paths) override; + + YR::Libruntime::ErrorInfo ExecuteFunction( + const YR::Libruntime::FunctionMeta &function, const libruntime::InvokeType invokeType, + const std::vector> &rawArgs, + std::vector> &returnObjects) override; + + YR::Libruntime::ErrorInfo Checkpoint(const std::string &instanceID, + std::shared_ptr &data) override; + + YR::Libruntime::ErrorInfo Recover(std::shared_ptr data) override; + + YR::Libruntime::ErrorInfo ExecuteShutdownFunction(uint64_t gracePeriodSecond) override; + + YR::Libruntime::ErrorInfo Signal(int sigNo, std::shared_ptr payload) override; + +private: + // The reason for returning ErrorInfo here is to convey the failure of initialization to the functionsystem. + // The failed instance needs to be reclaimed by the functionsystem. + YR::Libruntime::ErrorInfo FaasInitHandler(const std::vector> &rawArgs); + std::string FaasCallHandler(const std::vector> &rawArgs); + void ParseContextMeta(const std::string &contextMeta); + void ParseCreateParams(const std::string &createParam); + void ParseDelegateDecrypt(const std::string &delegateDecrypt); + + std::shared_ptr contextEnv_; + std::shared_ptr contextInvokeParams_; +}; + +} // namespace Function \ No newline at end of file diff --git a/api/cpp/src/faas/function.cpp b/api/cpp/src/faas/function.cpp new file mode 100644 index 0000000..51b65b4 --- /dev/null +++ b/api/cpp/src/faas/function.cpp @@ -0,0 +1,215 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Function.h" + +#include "json.hpp" + +#include "FunctionError.h" +#include "context_impl.h" +#include "api/cpp/src/utils/utils.h" +#include "src/dto/data_object.h" +#include "src/dto/invoke_options.h" +#include "src/libruntime/libruntime_manager.h" +#include "src/utility/logger/logger.h" + +namespace Function { + +const int NAME_AND_VERSION_SIZE = 2; +const char *const FUNC_NAME_PATTERN_STRING = "^[a-zA-Z]([a-zA-Z0-9._-]*[a-zA-Z0-9])?$"; +const char *const VERSION_PATTERN_STRING = "^[a-zA-Z0-9]([a-zA-Z0-9_-]*\\\\.)*[a-zA-Z0-9_-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$"; +const char *const ALIAS_PATTERN_STRING = "^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$"; +const char *const INSTANCE_NAME_PATTERN_STRING = "^[a-zA-Z0-9]+$"; +const char *const ENV_KEY_RUNTIME_SERVICE_FUNC_VERSION = "RUNTIME_SERVICE_FUNC_VERSION"; +const char *const DEFAULT_VERSION = "latest"; +const char *const DEFAULT_TENANT_ID = "12345678901234561234567890123456"; +const char FUNCTION_URN_SEPERATOR = '@'; +const char FUNCTION_NAME_SEPERATOR = ':'; +const int FUNCTION_NAME_LENGTH = 3; +const int SERVICE_NAME_INDEX_FOR_FUNCTION_NAME = 1; +const char ALIAS_PREFIX = '!'; +const int FUNC_NAME_LENGTH_LIMIT = 128; +const int VERSION_LENGTH_LIMIT = 32; +const int ALIAS_LENGTH_LIMIT = 32; +const int INSTANCE_NAME_LENGTH_LIMIT = 128; + +void CheckFuncName(const std::string &funcName) +{ + const re2::RE2 pattern(FUNC_NAME_PATTERN_STRING); + bool match = RE2::FullMatch(funcName, pattern); + if (!match || funcName.length() > FUNC_NAME_LENGTH_LIMIT) { + YRLOG_WARN("Invalid function name, {}", funcName); + throw FunctionError(ErrorCode::INVALID_PARAMETER, + "Invalid funcName, not match regular expression or length exceeds " + "limit"); + } +} +void CheckFuncVersion(const std::string &version) +{ + const re2::RE2 pattern(VERSION_PATTERN_STRING); + bool match = RE2::FullMatch(version, pattern); + if (!match || version.length() > VERSION_LENGTH_LIMIT) { + YRLOG_WARN("Invalid function version, {}", version); + throw FunctionError(ErrorCode::INVALID_PARAMETER, + "Invalid func version, not match regular expression or length exceeds " + "limit"); + } +} +void CheckFuncAlias(const std::string &alias) +{ + const re2::RE2 pattern(ALIAS_PATTERN_STRING); + bool match = RE2::FullMatch(alias, pattern); + if (!match || alias.length() > ALIAS_LENGTH_LIMIT) { + YRLOG_WARN("Invalid function alias, {}", alias); + throw FunctionError(ErrorCode::INVALID_PARAMETER, + "Invalid func alias, not match regular expression or length exceeds " + "limit"); + } +} + +std::string GetFunctionName(const std::string &funcString) +{ + if (funcString.empty()) { + YRLOG_WARN("function name is invalid: {}", funcString); + throw FunctionError(ErrorCode::INVALID_PARAMETER, "invalid funcName, expect not null"); + } + std::string funcName = ""; + std::string version = DEFAULT_VERSION; + + if (funcString.find(FUNCTION_NAME_SEPERATOR) != std::string::npos) { + std::vector splitRet; + YR::utility::Split(funcString, splitRet, FUNCTION_NAME_SEPERATOR); + if (splitRet.size() != NAME_AND_VERSION_SIZE) { + YRLOG_WARN("function name is invalid: {}", funcString); + throw FunctionError(ErrorCode::INVALID_PARAMETER, "invalid funcName, not match regular expression"); + } + funcName = splitRet[0]; + version = splitRet[1]; + + CheckFuncName(funcName); + if (version[0] == ALIAS_PREFIX) { + CheckFuncAlias(version); + } else { + CheckFuncVersion(version); + } + } else { + funcName = funcString; + CheckFuncName(funcName); + } + return funcName + "/" + version; +} + +Function::Function(Context &context, const std::string &funcName, const std::string &instanceName) + : funcName_(funcName), instanceName_(instanceName) +{ + context_ = std::make_shared(*dynamic_cast(&context)); +} + +Function::Function(Context &context, const std::string &funcName) : funcName_(funcName) +{ + context_ = std::make_shared(*dynamic_cast(&context)); +} + +Function::Function(Context &context) +{ + context_ = std::make_shared(*dynamic_cast(&context)); + funcName_ = context_->GetFunctionName(); +} + +ObjectRef Function::Invoke(const std::string &payload) +{ + std::string tenantId = context_->GetProjectID(); + std::string traceId = context_->GetTraceId(); + if (tenantId.empty()) { + tenantId = DEFAULT_TENANT_ID; + } + std::stringstream ss; + // 12345678901234561234567890123456/0@faas@java:latest + ss << tenantId << "/" + << "0@" << context_->GetPackage() << "@" << GetFunctionName(funcName_); + YR::Libruntime::FunctionMeta libFunctionMeta; + libFunctionMeta.languageType = libruntime::LanguageType::Cpp; + libFunctionMeta.functionId = ss.str(); + libFunctionMeta.apiType = libruntime::ApiType::Faas; + std::vector libArgs; + std::vector argsStr; + argsStr.emplace_back("{}"); + nlohmann::json returnValJson; + returnValJson["body"] = payload; + argsStr.emplace_back(returnValJson.dump()); + for (auto &arg : argsStr) { + YR::Libruntime::InvokeArg libArg; + libArg.dataObj = std::make_shared(0, arg.size()); + YR::WriteDataObject(static_cast(arg.data()), libArg.dataObj, arg.size(), {}); + libArg.tenantId = tenantId; + libArgs.emplace_back(std::move(libArg)); + } + YR::Libruntime::InvokeOptions libOpts; + libOpts.cpu = options_.cpu; + libOpts.memory = options_.memory; + libOpts.aliasParams = options_.aliasParams; + libOpts.traceId = traceId; + std::vector returnObjs{{""}}; + auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->InvokeByFunctionName( + libFunctionMeta, libArgs, libOpts, returnObjs); + if (!err.OK()) { + YRLOG_WARN("failed to invoke function {}, err: {}", funcName_, err.Msg()); + throw FunctionError(ErrorCode::INVALID_PARAMETER, err.Msg()); + } + return ObjectRef(returnObjs[0].id, instanceID_); +} + +Function &Function::Options(const InvokeOptions &opt) +{ + options_ = opt; + return *this; +} + +const std::string Function::GetObjectRef(ObjectRef &objectRef) +{ + return objectRef.Get(); +} + +void Function::GetInstance(const std::string &functionName, const std::string &instanceName) +{ + throw FunctionError(ErrorCode::INVALID_PARAMETER, "not support this function"); +} + +void Function::GetLocalInstance(const std::string &functionName, const std::string &instanceName) +{ + throw FunctionError(ErrorCode::INVALID_PARAMETER, "not support this function"); +} + +ObjectRef Function::Terminate() +{ + throw FunctionError(ErrorCode::INVALID_PARAMETER, "not support this function"); +} + +void Function::SaveState() +{ + throw FunctionError(ErrorCode::INVALID_PARAMETER, "not support this function"); +} + +const std::shared_ptr Function::GetContext() const +{ + return context_; +} + +std::string Function::GetInstanceId() const +{ + throw FunctionError(ErrorCode::INVALID_PARAMETER, "not support this function"); +} +} // namespace Function \ No newline at end of file diff --git a/api/cpp/src/faas/function_error.cpp b/api/cpp/src/faas/function_error.cpp new file mode 100644 index 0000000..dcbe769 --- /dev/null +++ b/api/cpp/src/faas/function_error.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FunctionError.h" + +#include "json.hpp" + +namespace Function { +const char *FunctionError::what() const noexcept +{ + return errMsg.c_str(); +} + +ErrorCode FunctionError::GetErrorCode() const +{ + return errCode; +} + +const std::string FunctionError::GetMessage() const +{ + return errMsg; +} + +const std::string FunctionError::GetJsonString() const +{ + nlohmann::json returnValJson; + returnValJson["code"] = errCode; + returnValJson["message"] = errMsg; + return returnValJson.dump(); +} + +} // namespace Function \ No newline at end of file diff --git a/api/cpp/src/faas/function_logger.cpp b/api/cpp/src/faas/function_logger.cpp new file mode 100644 index 0000000..b5fb19f --- /dev/null +++ b/api/cpp/src/faas/function_logger.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FunctionLogger.h" + +#include +#include + +#include + +using namespace std; +namespace Function { +const std::string LOG_INFO = "INFO"; +const std::string LOG_WARN = "WARN"; +const std::string LOG_DEBUG = "DEBUG"; +const std::string LOG_ERROR = "ERROR"; +const size_t PRINTF_BUFFER_LENGTH = 2048; + +string AdaptPrintf(const string &message, va_list &args) +{ + char buf[PRINTF_BUFFER_LENGTH] = {0}; + int ret = vsprintf_s(buf, sizeof(buf), message.c_str(), args); + if (ret < 0) { + return ""; + } + return string(buf); +} + +void FunctionLogger::Info(string message, ...) +{ + if (logLevel == LOG_WARN || logLevel == LOG_ERROR) { + return; + } + if (sendEmptyLog(message, LOG_INFO)) { + return; + } + va_list args; + va_start(args, message); + string logMessage = AdaptPrintf(message, args); + va_end(args); +} + +void FunctionLogger::Warn(string message, ...) +{ + if (logLevel == LOG_ERROR) { + return; + } + if (sendEmptyLog(message, LOG_WARN)) { + return; + } + va_list args; + va_start(args, message); + string logMessage = AdaptPrintf(message, args); + va_end(args); +} + +void FunctionLogger::Debug(string message, ...) +{ + if (logLevel == LOG_INFO || logLevel == LOG_WARN || logLevel == LOG_ERROR) { + return; + } + if (sendEmptyLog(message, LOG_DEBUG)) { + return; + } + va_list args; + va_start(args, message); + string logMessage = AdaptPrintf(message, args); + va_end(args); +} + +void FunctionLogger::Error(string message, ...) +{ + if (sendEmptyLog(message, LOG_ERROR)) { + return; + } + va_list args; + va_start(args, message); + string logMessage = AdaptPrintf(message, args); + va_end(args); +} + +void FunctionLogger::Log(const string &level, const string &logMessage) +{ + return; +} + +bool FunctionLogger::sendEmptyLog(const string &message, const string &level) +{ + if (message.empty()) { + return true; + } + return false; +} + +void FunctionLogger::setLevel(const string &level) +{ + if (level == LOG_DEBUG || level == LOG_INFO || level == LOG_WARN || level == LOG_ERROR) { + this->logLevel = level; + } +} +void FunctionLogger::SetInvokeID(std::string invokeID) +{ + this->invokeId = invokeID; +} +void FunctionLogger::SetTraceID(std::string traceID) +{ + this->traceId = traceID; +} +} // namespace Function diff --git a/api/cpp/src/faas/object_ref.cpp b/api/cpp/src/faas/object_ref.cpp new file mode 100644 index 0000000..f2301d6 --- /dev/null +++ b/api/cpp/src/faas/object_ref.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ObjectRef.h" + +#include +#include + +#include "Function.h" +#include "FunctionError.h" +#include "src/libruntime/libruntime_manager.h" +#include "src/utility/logger/logger.h" + +namespace Function { +const std::string ObjectRef::GetObjectRefId() const +{ + return objectRefId_; +} + +const std::string ObjectRef::GetResult() const +{ + return result_; +} + +bool ObjectRef::GetResultFlag() const +{ + return isResultExist_; +} + +const std::string ObjectRef::Get() +{ + int defaultTimeout = 300 * 1000; + auto [err, rets] = + YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->Get({objectRefId_}, defaultTimeout, false); + if (!err.OK()) { + YRLOG_WARN("failed to get result {}, err: {}", objectRefId_, err.Msg()); + throw FunctionError(ErrorCode::FUNCTION_EXCEPTION, err.Msg()); + } + isResultExist_ = true; + auto result = std::string(static_cast(rets[0]->data->ImmutableData()), rets[0]->data->GetSize()); + if (result.empty()) { + result_ = ""; + return result_; + } + std::string innerCode; + try { + auto resultJson = nlohmann::json::parse(result); + result_ = resultJson["body"]; + innerCode = resultJson["innerCode"]; + } catch (nlohmann::detail::exception &e) { + std::stringstream ss; + ss << "failed to parse result, err: " << e.what(); + YRLOG_WARN(ss.str()); + throw FunctionError(ErrorCode::FUNCTION_EXCEPTION, ss.str()); + } + if (innerCode != "0") { + throw FunctionError(std::stoi(innerCode), result_); + } + return result_; +} + +} // namespace Function diff --git a/api/cpp/src/faas/register_runtime_handler.cpp b/api/cpp/src/faas/register_runtime_handler.cpp new file mode 100644 index 0000000..f3c814e --- /dev/null +++ b/api/cpp/src/faas/register_runtime_handler.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "api/cpp/src/faas/register_runtime_handler.h" + +#include "FunctionError.h" +#include "src/utility/logger/logger.h" + +namespace Function { +void RegisterRuntimeHandler::RegisterHandler( + std::function handleRequestFunc) +{ + this->handleRequest = std::move(handleRequestFunc); +} + +void RegisterRuntimeHandler::RegisterInitializerFunction(std::function initializerFunc) +{ + this->initializerFunction = std::move(initializerFunc); +} + +void RegisterRuntimeHandler::RegisterPreStopFunction(std::function preStopFunc) +{ + this->preStopFunction = std::move(preStopFunc); +} + +void RegisterRuntimeHandler::InitState(std::function initStateFunc) +{ + this->initStateFunction = std::move(initStateFunc); +} + +std::string RegisterRuntimeHandler::HandleRequest(const std::string &request, Context &context) +{ + if (this->handleRequest == nullptr) { + YRLOG_WARN("The HandleRequest Function is null"); + throw FunctionError(ErrorCode::FUNCTION_EXCEPTION, "undefined HandleRequest()"); + } + return this->handleRequest(request, context); +} + +void RegisterRuntimeHandler::InitState(const std::string &request, Context &context) +{ + if (this->initStateFunction == nullptr) { + YRLOG_WARN("The InitState Function is null"); + throw FunctionError(ErrorCode::FUNCTION_EXCEPTION, "undefined InitState()"); + } + this->initStateFunction(request, context); +} + +void RegisterRuntimeHandler::PreStop(Context &context) +{ + if (this->preStopFunction == nullptr) { + return; + } + this->preStopFunction(context); +} + +void RegisterRuntimeHandler::Initializer(Context &context) +{ + if (this->initializerFunction == nullptr) { + return; + } + this->initializerFunction(context); +} +} // namespace Function \ No newline at end of file diff --git a/api/cpp/src/faas/register_runtime_handler.h b/api/cpp/src/faas/register_runtime_handler.h new file mode 100644 index 0000000..c664e0b --- /dev/null +++ b/api/cpp/src/faas/register_runtime_handler.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "Context.h" +#include "Runtime.h" + +namespace Function { +class RegisterRuntimeHandler : public Function::RuntimeHandler { +public: + RegisterRuntimeHandler() + { + } + virtual ~RegisterRuntimeHandler() + { + } + + std::string HandleRequest(const std::string &request, Function::Context &context) override; + + void InitState(const std::string &request, Function::Context &context) override; + + void PreStop(Function::Context &context) override; + + void Initializer(Function::Context &context) override; + + void RegisterHandler( + std::function handleRequestFunc); + + void RegisterInitializerFunction(std::function initializerFunc); + + void RegisterPreStopFunction(std::function preStopFunc); + + void InitState(std::function initStateFunc); + +private: + std::function handleRequest; + std::function initializerFunction; + std::function preStopFunction; + std::function initStateFunction; +}; +} // namespace Function diff --git a/api/cpp/src/faas/runtime.cpp b/api/cpp/src/faas/runtime.cpp new file mode 100644 index 0000000..2ce7bd6 --- /dev/null +++ b/api/cpp/src/faas/runtime.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Runtime.h" + +#include + +#include "api/cpp/src/executor/executor_holder.h" +#include "api/cpp/src/faas/faas_executor.h" +#include "api/cpp/src/faas/register_runtime_handler.h" +#include "yr/yr.h" + +namespace Function { +const std::string DEFAULT_RUNTIME_LOG_PATH = "/home/sn/log/"; +const std::string DEFAULT_RUNTIME_LOG_NAME = "cppbin-runtime"; +Runtime::Runtime() +{ + InitRuntimeLogger(); +} + +Runtime::~Runtime() +{ + ReleaseRuntimeLogger(); +} + +void Runtime::InitRuntimeLogger() const +{ + return; +} + +void Runtime::ReleaseRuntimeLogger() const +{ + return; +} + +void Runtime::RegisterHandler( + std::function handleRequestFunc) +{ + this->handleRequest = std::move(handleRequestFunc); +} +void Runtime::RegisterInitializerFunction(std::function initializerFunc) +{ + this->initializerFunction = std::move(initializerFunc); +} +void Runtime::RegisterPreStopFunction(std::function preStopFunc) +{ + this->preStopFunction = std::move(preStopFunc); +} + +void Runtime::InitState(std::function initStateFunc) +{ + this->initStateFunction = std::move(initStateFunc); +} + +void Runtime::BuildRegisterRuntimeHandler() const +{ + auto handlerPtr = std::make_unique(); + handlerPtr->RegisterHandler(this->handleRequest); + handlerPtr->RegisterInitializerFunction(this->initializerFunction); + handlerPtr->InitState(this->initStateFunction); + handlerPtr->RegisterPreStopFunction(this->preStopFunction); + SetRuntimeHandler(std::move(handlerPtr)); +} + +void Runtime::Start(int argc, char *argv[]) +{ + this->BuildRegisterRuntimeHandler(); + YR::internal::ExecutorHolder::Singleton().SetExecutor(std::make_shared()); + YR::Config conf; + conf.isDriver = false; + YR::Init(conf, argc, argv); + YRLOG_INFO("success to init faas binrary runtime."); + YR::ReceiveRequestLoop(); +} +} // namespace Function diff --git a/api/cpp/src/local_mode_runtime.cpp b/api/cpp/src/local_mode_runtime.cpp index 0a8b5a3..16257d7 100644 --- a/api/cpp/src/local_mode_runtime.cpp +++ b/api/cpp/src/local_mode_runtime.cpp @@ -166,5 +166,10 @@ std::vector LocalModeRuntime::KVDel(const std::vector { return stateStore_->Del(keys); } + +std::vector LocalModeRuntime::KVExist(const std::vector &keys) +{ + return stateStore_->Exist(keys); +} } // namespace internal } // namespace YR \ No newline at end of file diff --git a/api/cpp/src/local_state_store.cpp b/api/cpp/src/local_state_store.cpp index 2f77bbe..7e45ba0 100644 --- a/api/cpp/src/local_state_store.cpp +++ b/api/cpp/src/local_state_store.cpp @@ -126,6 +126,21 @@ std::vector LocalStateStore::Del(const std::vector &ke return failedKeys; } +std::vector LocalStateStore::Exist(const std::vector &keys) +{ + std::vector exists; + exists.reserve(keys.size()); + std::lock_guard lock(mtx); + for (auto &key : keys) { + if (kv_map.count(key) > 0) { + exists.push_back(true); + } else { + exists.push_back(false); + } + } + return exists; +} + bool LocalStateStore::IsEmpty() { std::lock_guard lock(mtx); diff --git a/api/cpp/src/mutable_buffer.cpp b/api/cpp/src/mutable_buffer.cpp new file mode 100644 index 0000000..710a9e9 --- /dev/null +++ b/api/cpp/src/mutable_buffer.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "yr/api/mutable_buffer.h" +#include "yr/api/object_ref.h" + +namespace YR { +void *MutableBuffer::MutableData() +{ + return nullptr; +} + +ObjectRef MutableBuffer::Publish() +{ + return ObjectRef(); +} + +int64_t MutableBuffer::GetSize() +{ + return 0; +} +} // namespace YR \ No newline at end of file diff --git a/api/cpp/src/object_store.cpp b/api/cpp/src/object_store.cpp index db6676f..ee8c812 100644 --- a/api/cpp/src/object_store.cpp +++ b/api/cpp/src/object_store.cpp @@ -35,11 +35,11 @@ void ThrowExceptionBasedOnStatus(const GetStatus status, const ErrorInfo &err, } else { oss << " partial"; } - if (!remainIds.empty()) { - oss << " failed: " << "(" << remainIds.size() << "). "; - oss << "Failed objects: [ "; - oss << remainIds[0] << " ... " << "]"; - } + oss << " failed: " + << "(" << remainIds.size() << "). "; + oss << "Failed objects: [ "; + oss << remainIds[0] << " ... " + << "]"; if (status == GetStatus::ALL_FAILED || status == GetStatus::ALL_FAILED_AND_TIMEOUT) { throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), oss.str()); } diff --git a/api/cpp/src/read_only_buffer.h b/api/cpp/src/read_only_buffer.h index cf41512..ca0168b 100644 --- a/api/cpp/src/read_only_buffer.h +++ b/api/cpp/src/read_only_buffer.h @@ -24,7 +24,9 @@ namespace YR { class ReadOnlyBuffer : public Buffer { public: - ReadOnlyBuffer(std::shared_ptr buf) : buf_(buf) {} + ReadOnlyBuffer(std::shared_ptr buf) : Buffer(nullptr, 0), buf_(buf) { + } + virtual uint64_t GetSize() const override { return buf_->GetSize(); diff --git a/api/cpp/src/runtime_env.cpp b/api/cpp/src/runtime_env.cpp new file mode 100644 index 0000000..996a307 --- /dev/null +++ b/api/cpp/src/runtime_env.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "yr/api/runtime_env.h" + +namespace YR { + +void RuntimeEnv::SetJsonStr(const std::string &name, const std::string &jsonStr) +{ + try { + nlohmann::json valueJ = nlohmann::json::parse(jsonStr); + fields_[name] = valueJ; + } catch (std::exception &e) { + throw YR::Exception::InvalidParamException("Failed to set the field " + name + " by json string: " + e.what()); + } +} + +std::string RuntimeEnv::GetJsonStr(const std::string &name) const +{ + if (!Contains(name)) { + throw YR::Exception::InvalidParamException("The field " + name + " not found."); + } + auto j = fields_[name].get(); + return j.dump(); +} + +bool RuntimeEnv::Contains(const std::string &name) const +{ + return fields_.contains(name); +} + +bool RuntimeEnv::Remove(const std::string &name) +{ + if (Contains(name)) { + fields_.erase(name); + return true; + } + return false; +} + +} // namespace YR \ No newline at end of file diff --git a/api/cpp/src/runtime_env_parse.cpp b/api/cpp/src/runtime_env_parse.cpp new file mode 100644 index 0000000..22c22f6 --- /dev/null +++ b/api/cpp/src/runtime_env_parse.cpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if __has_include() +#include +#else +#include +#endif +#include + +#include +#include + +#include "yr/api/exception.h" +#include "src/libruntime/err_type.h" +#include "api/cpp/src/utils/utils.h" +#include "yr/api/runtime_env.h" +#include "src/dto/invoke_options.h" +#include "api/cpp/src/runtime_env_parse.h" + +#ifdef __cpp_lib_filesystem +namespace filesystem = std::filesystem; +#elif __cpp_lib_experimental_filesystem +namespace filesystem = std::experimental::filesystem; +#endif + +namespace YR { +const std::string CONDA = "conda"; +const std::string PIP = "pip"; +const std::string YR_CONDA_HOME = "YR_CONDA_HOME"; +const std::string WORKER_DIR = "working_dir"; +const std::string ENV_VARS = "env_vars"; +const std::string SHARED_DIR = "shared_dir"; + +const std::string POST_START_EXEC = "POST_START_EXEC"; +const std::string CONDA_PREFIX = "CONDA_PREFIX"; +const std::string CONDA_CONFIG = "CONDA_CONFIG"; +const std::string CONDA_COMMAND = "CONDA_COMMAND"; +const std::string CONDA_DEFAULT_ENV = "CONDA_DEFAULT_ENV"; + +std::string GetCondaBinExecutable() +{ + if (auto envStr = GetEnv(YR_CONDA_HOME); !envStr.empty()) { + return envStr; + } + if (auto envStr = GetEnv(CONDA_PREFIX); !envStr.empty()) { + return envStr; + } + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please configure YR_CONDA_HOME environment variable which contain a bin subdirectory"); +} + +std::string YamlToJson(YAML::Node &node) +{ + YAML::Emitter emitter; + emitter << YAML::DoubleQuoted << YAML::Flow << YAML::BeginSeq << node; + + return std::string(emitter.c_str() + 1); +} +void HandleCondaConfig(YR::Libruntime::InvokeOptions& invokeOptions, const nlohmann::json& condaConfig); + +void HandleSharedDirConfig(YR::Libruntime::InvokeOptions &invokeOptions, const nlohmann::json &sharedDirConfig) +{ + if (sharedDirConfig.is_object()) { + // 处理JSON对象类型的conda配置 + std::string name = sharedDirConfig.value("name", ""); + if (name.empty()) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "shared dir name must be string"); + } + int ttl = sharedDirConfig.value("TTL", 0); + invokeOptions.createOptions["DELEGATE_SHARED_DIRECTORY"] = name; + invokeOptions.createOptions["DELEGATE_SHARED_DIRECTORY_TTL"] = std::to_string(ttl); + } else { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "shared dir format must be json"); + } +} + +void ParseRuntimeEnv(YR::Libruntime::InvokeOptions& invokeOptions, const YR::RuntimeEnv& runtimeEnv) +{ + if (runtimeEnv.Empty()) { + return; + } + + if (runtimeEnv.Contains(CONDA) && runtimeEnv.Contains(PIP)) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "The 'pip' field and 'conda' field of runtime_env cannot both be specified.\n" + "To use pip with conda, please only set the 'conda' " + "field, and specify your pip dependencies " + "within the conda YAML config dict"); + } + if (runtimeEnv.Contains(PIP)) { + try { + const auto &pipPackages = runtimeEnv.Get>(PIP); + std::ostringstream pipCommand; + pipCommand << "pip3 install"; + for (size_t i = 0; i < pipPackages.size(); i++) { + pipCommand << " " << pipPackages[i]; + } + invokeOptions.createOptions[POST_START_EXEC] = pipCommand.str(); + } catch (std::exception &e) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + std::string("Failed to parse pip field of RuntimeEnv: ") + e.what()); + } + } + if (runtimeEnv.Contains(WORKER_DIR)) { + try { + const std::string &workingDir = runtimeEnv.Get(WORKER_DIR); + invokeOptions.workingDir = workingDir; + } catch (std::exception &e) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + std::string("`working_dir` must be a string: ") + e.what()); + } + } + if (runtimeEnv.Contains(ENV_VARS)) { + try { + const auto& envVars = runtimeEnv.Get>(ENV_VARS); + for (const auto& pair : envVars) { + if (invokeOptions.envVars.find(pair.first) == invokeOptions.envVars.end()) { + invokeOptions.envVars.insert(pair); + } + } + } catch (std::exception &e) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + std::string("`envs` must be a map[string: string]: ") + e.what()); + } + } + if (runtimeEnv.Contains(CONDA)) { + nlohmann::json condaJson; + try { + condaJson = runtimeEnv.Get(CONDA); + } catch(std::exception &e) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + std::string("`get conda to nlohmann:json format failed: ") + e.what()); + } + HandleCondaConfig(invokeOptions, condaJson); + } + if (runtimeEnv.Contains(SHARED_DIR)) { + nlohmann::json sharedDirJson; + try { + sharedDirJson = runtimeEnv.Get(SHARED_DIR); + HandleSharedDirConfig(invokeOptions, sharedDirJson); + } catch (std::exception &e) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + std::string("`get shared dir to nlohmann:json format failed: ") + e.what()); + } + } +} + +void HandleCondaConfig(YR::Libruntime::InvokeOptions& invokeOptions, const nlohmann::json& condaConfig) +{ + invokeOptions.createOptions[CONDA_PREFIX] = GetCondaBinExecutable(); + + if (condaConfig.is_string()) { + // 处理字符串类型的conda配置(YAML文件路径或环境名称) + const std::string& condaStr = condaConfig.get(); + const filesystem::path condaPath(condaStr); + + if (condaPath.extension() == ".yaml" || condaPath.extension() == ".yml") { + if (!filesystem::exists(condaPath)) { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + std::string("Can't find conda YAML file ") + condaStr); + } + + YAML::Node yamlNode = YAML::LoadFile(condaStr); + std::string envName; + try { + envName = yamlNode["name"].as(); + } catch (std::exception& e) { + // 支持传空,以及类型异常处理,yaml-cpp高低版本之间判断是否是string类型的函数不一致,使用try catch保持一致 + } + if (envName.empty()) { + envName = "virtual_env-" + YR::utility::IDGenerator::GenRequestId(); + } + + invokeOptions.createOptions[CONDA_CONFIG] = YamlToJson(yamlNode); + invokeOptions.createOptions[CONDA_COMMAND] = "conda env create -f env.yaml"; + invokeOptions.createOptions[CONDA_DEFAULT_ENV] = envName; + } else { + // 直接使用环境名称 + invokeOptions.createOptions[CONDA_COMMAND] = "conda activate " + condaStr; + invokeOptions.createOptions[CONDA_DEFAULT_ENV] = condaStr; + } + } else if (condaConfig.is_object()) { + // 处理JSON对象类型的conda配置 + std::string envName = condaConfig.value("name", ""); + if (envName.empty()) { + envName = "virtual_env-" + YR::utility::IDGenerator::GenRequestId(); + } + + invokeOptions.createOptions[CONDA_CONFIG] = condaConfig.dump(); + invokeOptions.createOptions[CONDA_COMMAND] = "conda env create -f env.yaml"; + invokeOptions.createOptions[CONDA_DEFAULT_ENV] = envName; + } else { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "conda format must be string or json"); + } +} +} diff --git a/api/cpp/src/runtime_env_parse.h b/api/cpp/src/runtime_env_parse.h new file mode 100644 index 0000000..671b36a --- /dev/null +++ b/api/cpp/src/runtime_env_parse.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "src/dto/invoke_options.h" +#include "yr/api/runtime_env.h" + +namespace YR { +void ParseRuntimeEnv(YR::Libruntime::InvokeOptions& invokeOptions, const YR::RuntimeEnv& runtimeEnv); +} diff --git a/api/cpp/src/stream_pubsub.cpp b/api/cpp/src/stream_pubsub.cpp new file mode 100644 index 0000000..dd3bca5 --- /dev/null +++ b/api/cpp/src/stream_pubsub.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "api/cpp/src/stream_pubsub.h" +#include "api/cpp/include/yr/api/exception.h" +#include "src/dto/stream_conf.h" +#include "src/libruntime/err_type.h" +#include "src/libruntime/libruntime_manager.h" + +namespace YR { +void StreamProducer::Send(const Element &element) +{ + YR::Libruntime::Element streamElement(element.ptr, element.size, element.id); + YR::Libruntime::ErrorInfo err = producer_->Send(streamElement); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("Send err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} + +void StreamProducer::Send(const Element &element, int64_t timeoutMs) +{ + YR::Libruntime::Element streamElement(element.ptr, element.size, element.id); + YR::Libruntime::ErrorInfo err = producer_->Send(streamElement, timeoutMs); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("Send err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} + +void StreamProducer::Flush() +{ + YR::Libruntime::ErrorInfo err = producer_->Flush(); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("Flush err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} + +void StreamProducer::Close() +{ + auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(traceId_); + if (!err.OK()) { + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + err = producer_->Close(); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("Close err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} + +void StreamConsumer::Receive(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) +{ + std::vector elements; + YR::Libruntime::ErrorInfo err = consumer_->Receive(expectNum, timeoutMs, elements); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("Receive err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + for (auto &element : elements) { + outElements.push_back(Element(element.ptr, element.size, element.id)); + } +} + +void StreamConsumer::Receive(uint32_t timeoutMs, std::vector &outElements) +{ + std::vector elements; + YR::Libruntime::ErrorInfo err = consumer_->Receive(timeoutMs, elements); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("Receive err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + for (auto &element : elements) { + outElements.push_back(Element(element.ptr, element.size, element.id)); + } +} + +void StreamConsumer::Ack(uint64_t elementId) +{ + YR::Libruntime::ErrorInfo err = consumer_->Ack(elementId); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("Ack err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} + +void StreamConsumer::Close() +{ + auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(traceId_); + if (!err.OK()) { + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } + err = consumer_->Close(); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR("Close err: Code:{}, MCode:{}, Msg:{}", + fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); + throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); + } +} +} // namespace YR diff --git a/api/cpp/src/stream_pubsub.h b/api/cpp/src/stream_pubsub.h new file mode 100644 index 0000000..5cd5917 --- /dev/null +++ b/api/cpp/src/stream_pubsub.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "api/cpp/include/yr/api/stream.h" +#include "src/libruntime/streamstore/stream_producer_consumer.h" +#include "src/utility/logger/logger.h" + +namespace YR { +/** + * @class StreamProducer + */ +class StreamProducer : public Producer { +public: + StreamProducer(std::shared_ptr producer, const std::string &traceId = "") + { + producer_ = std::move(producer); + traceId_ = traceId; + } + + /** + * @brief Sends data to the producer. The data is first placed in the buffer and then flushed based on the + * configured automatic flush strategy (either after a certain interval or when the buffer is full), or manually via + * Flush to make the data available to consumers. + * @param element The Element data to be sent. + * @throws Exception + * - **4299**: failed to send element. + * + * @snippet{trimleft} stream_example.cpp producer send + */ + void Send(const Element &element); + + /** + * @brief Sends data to the producer. The data is first placed in the buffer and then flushed based on the + * configured automatic flush strategy (either after a certain interval or when the buffer is full), or manually via + * Flush to make the data available to consumers. + * @param element The Element data to be sent. + * @param timeoutMs Optional timeout in milliseconds. + * @throws Exception + * - **4299**: failed to send element. + */ + void Send(const Element &element, int64_t timeoutMs); + + /** + * @brief Manually flushes the buffer to make the data visible to consumers. + * @throws Exception + * - **4299**: producer failed to flush. + * + * @snippet{trimleft} stream_example.cpp producer send + */ + void Flush(); + + /** + * @brief Closes the producer, triggering an automatic flush of the buffer and indicating that the buffer will no + * longer be used. Once closed, the producer cannot be used again. + * @throws Exception + * - **4299**: failed to close producer. + * + * @snippet{trimleft} stream_example.cpp close producer + */ + void Close(); + +private: + std::shared_ptr producer_; + std::string traceId_; +}; + +/** + * @class StreamConsumer + */ +class StreamConsumer : public Consumer { +public: + StreamConsumer(std::shared_ptr consumer, const std::string &traceId = "") + { + consumer_ = std::move(consumer); + traceId_ = traceId; + } + + /** + * @brief Receives data with subscription functionality. The consumer waits for the expected number of elements + * (`expectNum`) to be received. The call returns when the timeout (`timeoutMs`) is reached or the expected number + * of elements is received. + * @param expectNum The expected number of elements to receive. + * @param timeoutMs The timeout in milliseconds. + * @param outElements The actual elements received. + * @throws YR::Exception Thrown in the following cases: + * - **3003**: the total size exceed the uint64_t max value or the total size exceed the limit. + * - **4299**: failed to receive element with expectNum. + * + * @snippet{trimleft} stream_example.cpp consumer recv + */ + void Receive(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements); + + /** + * @brief Receives data with subscription functionality. The consumer waits for the expected number of elements + * (`expectNum`) to be received. The call returns when the timeout (`timeoutMs`) is reached or the expected number + * of elements is received. + * @param timeoutMs The timeout in milliseconds. + * @param outElements The actual elements received. + * @throws YR::Exception Thrown in the following cases: + * - **3003**: the total size exceed the uint64_t max value or the total size exceed the limit. + * - **4299**: failed to receive element. + */ + void Receive(uint32_t timeoutMs, std::vector &outElements); + + /** + * @brief Acknowledges that a specific element (identified by `elementId`) has been consumed. This allows workers to + * determine if all consumers have finished consuming the element, enabling internal memory回收 mechanisms if all + * consumers have acknowledged it. If not acknowledged, the element will be automatically acknowledged when the + * consumer exits. + * @param elementId The ID of the element to acknowledge. + * @throws YR::Exception + * - **4299**: failed to ack. + * + * @snippet{trimleft} stream_example.cpp consumer recv + */ + void Ack(uint64_t elementId); + + /** + * @brief Closes the consumer. Once closed, the consumer cannot be used again. + * @throws YR::Exception + * - **4299**: failed to Close consumer. + * + * @snippet{trimleft} stream_example.cpp close consumer + */ + void Close(); + +private: + std::shared_ptr consumer_; + std::string traceId_; +}; +} // namespace YR diff --git a/api/cpp/src/wait_request_manager.cpp b/api/cpp/src/wait_request_manager.cpp index 7ecf8b3..c6fd720 100644 --- a/api/cpp/src/wait_request_manager.cpp +++ b/api/cpp/src/wait_request_manager.cpp @@ -75,7 +75,8 @@ void WaitRequest::SetException(const std::exception_ptr &exception) WaitRequestManager::WaitRequestManager() { this->ioc = std::make_shared(); - this->work = std::make_unique(*ioc); + this->work = std::make_unique>( + boost::asio::make_work_guard(*ioc)); this->asyncRunner = std::make_unique([&] { this->ioc->run(); }); pthread_setname_np(this->asyncRunner->native_handle(), "wait_request_handler"); } @@ -154,7 +155,7 @@ void WaitRequestManager::WaitTimer(boost::asio::steady_timer &timer, int timeout const std::shared_ptr &waitRequest) { if (timeout != NO_TIMEOUT) { - timer.expires_from_now(std::chrono::seconds(timeout)); + timer.expires_after(std::chrono::seconds(timeout)); timer.async_wait([this, waitRequest](const boost::system::error_code &ec) { if (ec) { return; diff --git a/api/cpp/src/yr.cpp b/api/cpp/src/yr.cpp index 835bc15..32d4242 100644 --- a/api/cpp/src/yr.cpp +++ b/api/cpp/src/yr.cpp @@ -24,6 +24,7 @@ #include "src/libruntime/libruntime_manager.h" #include "yr/api/runtime.h" #include "yr/api/runtime_manager.h" +#include "yr/api/serdes.h" thread_local std::unordered_set localNestedObjList; namespace YR { @@ -89,7 +90,8 @@ ClientInfo Init(const Config &conf, int argc, char **argv) if (!ConfigManager::Singleton().isDriver) { auto err = internal::CodeManager::LoadFunctions(ConfigManager::Singleton().loadPaths); if (!err.OK()) { - YRLOG_INFO("load function error: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); + YRLOG_INFO("load function error: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); } } SetInitialized(); @@ -113,6 +115,15 @@ ClientInfo Init(int argc, char **argv) return Init(conf, argc, argv); } +void Run(int argc, char *argv[]) +{ + Config conf; + conf.isDriver = false; + conf.launchUserBinary = true; + Init(conf, argc, argv); + ReceiveRequestLoop(); +} + void Finalize(void) { CheckInitialized(); @@ -155,6 +166,38 @@ bool IsLocalMode() return ConfigManager::Singleton().IsLocalMode(); } +std::shared_ptr CreateProducer(const std::string &streamName, ProducerConf producerConf) +{ + CheckInitialized(); + if (ConfigManager::Singleton().IsLocalMode()) { // local mode + throw Exception(YR::Libruntime::ErrorCode::ERR_INCORRECT_FUNCTION_USAGE, YR::Libruntime::ModuleCode::RUNTIME, + "local mode does not support CreateProducer\n"); + } + std::shared_ptr producer = YR::internal::GetRuntime()->CreateStreamProducer(streamName, producerConf); + return producer; +} + +std::shared_ptr Subscribe(const std::string &streamName, const SubscriptionConfig &config, bool autoAck) +{ + CheckInitialized(); + if (ConfigManager::Singleton().IsLocalMode()) { // local mode + throw Exception(YR::Libruntime::ErrorCode::ERR_INCORRECT_FUNCTION_USAGE, YR::Libruntime::ModuleCode::RUNTIME, + "local mode does not support Subscribe\n"); + } + std::shared_ptr consumer = YR::internal::GetRuntime()->CreateStreamConsumer(streamName, config, autoAck); + return consumer; +} + +void DeleteStream(const std::string &streamName) +{ + CheckInitialized(); + if (ConfigManager::Singleton().IsLocalMode()) { // local mode + throw Exception(YR::Libruntime::ErrorCode::ERR_INCORRECT_FUNCTION_USAGE, YR::Libruntime::ModuleCode::RUNTIME, + "local mode does not support DeleteStream\n"); + } + YR::internal::GetRuntime()->DeleteStream(streamName); +} + void SaveState(const int &timeout) { CheckInitialized(); @@ -186,4 +229,52 @@ void LoadState(const int &timeout) YR::internal::GetRuntime()->LoadState(timeout); } +std::vector Nodes() +{ + CheckInitialized(); + std::vector nodes = YR::internal::GetRuntime()->Nodes(); + return nodes; +} + +std::shared_ptr CreateBuffer(uint64_t size) +{ + CheckInitialized(); + if (ConfigManager::Singleton().IsLocalMode()) { + throw Exception(YR::Libruntime::ErrorCode::ERR_INCORRECT_FUNCTION_USAGE, YR::ModuleCode::RUNTIME_, + "local mode does not support CreateBuffer\n"); + } else { + return YR::internal::GetRuntime()->CreateMutableBuffer(size); + } +} + +std::vector> Get(const std::vector> &objs, int timeoutSec) +{ + CheckInitialized(); + if (ConfigManager::Singleton().IsLocalMode()) { + throw Exception(YR::Libruntime::ErrorCode::ERR_INCORRECT_FUNCTION_USAGE, YR::ModuleCode::RUNTIME_, + "local mode does not support Get\n"); + } else { + std::vector ids; + for (size_t i = 0; i < objs.size(); i++) { + ids.push_back(objs[i].ID()); + } + return YR::internal::GetRuntime()->GetMutableBuffer(ids, timeoutSec); + } +} + +std::string Serialize(ObjectRef &obj) +{ + msgpack::sbuffer buffer = YR::internal::Serialize(std::move(obj)); + std::string str(buffer.data(), buffer.size()); + return str; +} + +ObjectRef Deserialize(const void *value, int size) +{ + msgpack::sbuffer buffer; + const char *valueChar = static_cast(value); + buffer.write(valueChar, size); + return YR::internal::Deserialize>(buffer); +} + } // namespace YR diff --git a/api/go/BUILD.bazel b/api/go/BUILD.bazel new file mode 100644 index 0000000..ac0cbf4 --- /dev/null +++ b/api/go/BUILD.bazel @@ -0,0 +1,95 @@ +load("//bazel:yr.bzl", "cc_strip") +load("//bazel:yr_go.bzl", "yr_go_test") + +filegroup( + name = "go_sources", + srcs = glob([ + "**/*.go", + "**/go.mod", + ]), +) + +cc_strip( + name = "go_strip", + srcs = ["//api/go/libruntime/cpplibruntime:libcpplibruntime.so"], +) + +genrule( + name = "yr_go_pkg", + srcs = [ + "go_strip", + "//:grpc_strip", + "@datasystem_sdk//:shared", + "@metrics_sdk//:shared", + ":go_sources", + ], + outs = ["yr_go_pkg.out"], + cmd = """ + BASE_DIR="$$(pwd)" && + GO_OUT_DIR=$$BASE_DIR/build/output/runtime/service/go/bin/ && + DATASYSTEM_FILE=$$(echo $(locations @datasystem_sdk//:shared) | awk '{print $$1}') && + DATASYSTEM_DIR=$$(dirname $$BASE_DIR/$$DATASYSTEM_FILE) && + METRICS_FILE=$$(echo $(locations @metrics_sdk//:shared) | awk '{print $$1}') && + METRICS_DIR=$$(dirname $$BASE_DIR/$$METRICS_FILE) && + mkdir -p $$GO_OUT_DIR && + cd $$BASE_DIR && + chmod +w $(locations :go_strip) $(locations //:grpc_strip) && + chrpath -d $(locations :go_strip) $(locations //:grpc_strip) && + cp -ar $(locations :go_strip) $(locations //:grpc_strip) $$GO_OUT_DIR && + cp -ar $$DATASYSTEM_DIR/* $$GO_OUT_DIR && + cp -ar $$METRICS_DIR/* $$GO_OUT_DIR && + cd $$(realpath $$BASE_DIR/api/go/runtime/) && + go mod tidy && + LD_LIBRARY_PATH=$${LD_LIBRARY_PATH:+$$LD_LIBRARY_PATH}:$$GO_OUT_DIR CC='gcc -fstack-protector-strong -D_FORTIFY_SOURCE=2 -O2' go build -buildmode=pie -ldflags '-extldflags "-fPIC -fstack-protector-strong -Wl,-z,now,-z,relro,-z,noexecstack,-s -Wall -Werror"' -o $$GO_OUT_DIR/goruntime yr_runtime_main.go && + cd $$BASE_DIR && + chmod -R 750 $$GO_OUT_DIR && + GO_SDK_DIR=$$BASE_DIR/build/output/runtime/sdk/go/runtime && + rm -rf $$GO_SDK_DIR && + mkdir -p $$GO_SDK_DIR && + cp -arf $$BASE_DIR/api/go/libruntime $$GO_SDK_DIR && + find $$GO_SDK_DIR/libruntime/ -name *.cpp | xargs rm -rf && + find $$GO_SDK_DIR/libruntime/ -name *.bazel | xargs rm -rf && + cp -arf $$BASE_DIR/api/go/faassdk $$GO_SDK_DIR && + cp -arf $$BASE_DIR/api/go/posixsdk $$GO_SDK_DIR && + cp -arf $$BASE_DIR/api/go/yr $$GO_SDK_DIR && + cp -arf $$BASE_DIR/api/go/go.mod $$GO_SDK_DIR && + cp -arf $$BASE_DIR/api/go/README.md $$GO_SDK_DIR && + echo "$$GO_OUT_DIR" > $@ + """, + local = True, + visibility = ["//visibility:public"], +) + +config_setting( + name = "asan", + values = {"define": "sanitize=address"}, +) + +config_setting( + name = "tsan", + values = {"define": "sanitize=thread"}, +) + +alias( + name = "yr_go_test", + actual = select({ + ":asan": "yr_go_test_asan", + ":tsan": "yr_go_test_tsan", + "//conditions:default": "yr_go_test_default", + }), +) + +yr_go_test( + name = "yr_go_test_asan", + sanitizer = "address", +) + +yr_go_test( + name = "yr_go_test_tsan", + sanitizer = "thread", +) + +yr_go_test( + name = "yr_go_test_default", + sanitizer = "off", +) diff --git a/api/go/README.md b/api/go/README.md new file mode 100644 index 0000000..83285af --- /dev/null +++ b/api/go/README.md @@ -0,0 +1,12 @@ +## runtime go sdk 包使用说明 + +### 准备工作 + +本 sdk 包需要用到的库文件在`yuanrong`包中,请先下载`yuanrong.tar.gz`并解压,然后设置环境变量: + +```bash +export CGO_LDFLAGS="-L${WORKING_DIR}/yuanrong/runtime/service/go/bin -lcpplibruntime" +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${WORKING_DIR}/yuanrong/runtime/service/go/bin +``` + +即可执行`go build`。 \ No newline at end of file diff --git a/api/go/example/actor_example.go b/api/go/example/actor_example.go new file mode 100644 index 0000000..ae9d8c2 --- /dev/null +++ b/api/go/example/actor_example.go @@ -0,0 +1,102 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "yuanrong.org/kernel/runtime/yr" +) + +type Counter struct { + Count int +} + +func NewCounter(init int) *Counter { + return &Counter{Count: init} +} + +func (c *Counter) Add(x int) int { + c.Count += x + return c.Count +} + +func (c *Counter) AddValue(x int) int { + instance := yr.Instance(NewValue).Invoke(30) + defer instance.Terminate() + objs := instance.Function((*Value).Show).Invoke() + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0 + } + return value + x +} + +func (c *Counter) CallAddException(x int) (int, error) { + instance := yr.Instance(NewCounter).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*Counter).AddException).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *Counter) AddException(x int) (int, error) { + return x + 1, errors.New("error from go counter") +} + +type Value struct { + value int +} + +func NewValue(init int) *Value { + return &Value{value: init} +} + +func (v *Value) Show() int { + fmt.Println("Value is ", v.value) + return v.value +} + +func ActorExample() { + flag.Parse() + config := &yr.Config{ + Mode: yr.ClusterMode, + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + ServerAddr: flag.Args()[0], + DataSystemAddr: flag.Args()[1], + InCluster: true, + LogLevel: "DEBUG", + } + info, err := yr.Init(config) + if err != nil { + fmt.Println("Init failed, error: ", err) + } + fmt.Println(info) + + instance := yr.Instance(NewCounter).Invoke(28) + objs := instance.Function((*Counter).Add).Invoke(100) + fmt.Println(yr.Get[int](objs[0], 30000)) + fmt.Println(instance.Terminate()) + + // task call task + instance1 := yr.Instance(NewCounter).Invoke(30) + objs1 := instance1.Function((*Counter).AddValue).Invoke(120) + fmt.Println(yr.Get[int](objs1[0], 30000)) + fmt.Println(instance1.Terminate()) +} + +// compile command: +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -buildmode=plugin -o yrlib.so actor_example.go +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -o main actor_example.go +func main() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT) + ActorExample() + <-c +} diff --git a/api/go/example/actor_example_exception.go b/api/go/example/actor_example_exception.go new file mode 100644 index 0000000..e02cc88 --- /dev/null +++ b/api/go/example/actor_example_exception.go @@ -0,0 +1,149 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "errors" + + "yuanrong.org/kernel/runtime/yr" +) + +type ExceptionCounter struct { + Count int +} + +func NewCounterException(init int) *ExceptionCounter { + return &ExceptionCounter{Count: init} +} + +func (c *ExceptionCounter) ReturnException(x int) (int, error) { + return x + 1, errors.New("error from go newCounter 1") +} + +func (c *ExceptionCounter) CallAddException10(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).ReturnException).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException9(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException10).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException8(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException9).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException7(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException8).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException6(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException7).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException5(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException6).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException4(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException5).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException3(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException4).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException2(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException3).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +func (c *ExceptionCounter) CallAddException1(x int) (int, error) { + instance := yr.Instance(NewCounterException).Invoke(x) + defer instance.Terminate() + objs := instance.Function((*ExceptionCounter).CallAddException2).Invoke(200) + value, err := yr.Get[int](objs[0], 30000) + if err != nil { + return 0, err + } + return value, nil +} + +// compile command: +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" \ +// go build -buildmode=plugin -o yrlibexception.so actor_example_exception.go diff --git a/api/go/example/entrance_example.go b/api/go/example/entrance_example.go new file mode 100644 index 0000000..8cf1b0a --- /dev/null +++ b/api/go/example/entrance_example.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package main for example +package main + +import ( + "yuanrong.org/kernel/runtime/faassdk" + "yuanrong.org/kernel/runtime/faassdk/faas-sdk/go-api/context" +) + +func initEntry(ctx context.RuntimeContext) { + ctx.GetLogger().Infof("info log in initHandler") +} + +func preStopEntry(ctx context.RuntimeContext) { + ctx.GetLogger().Infof("info log in preStopHandler") +} + +func callEntry(_ []byte, _ context.RuntimeContext) (interface{}, error) { + return "hello world", nil +} + +func main() { + faassdk.RegisterInitializerFunction(initEntry) + faassdk.RegisterPreStopFunction(preStopEntry) + faassdk.Register(callEntry) +} diff --git a/api/go/example/kv_example.go b/api/go/example/kv_example.go new file mode 100644 index 0000000..7bd7bf5 --- /dev/null +++ b/api/go/example/kv_example.go @@ -0,0 +1,68 @@ +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "yuanrong.org/kernel/runtime/yr" +) + +func InitYR() { + flag.Parse() + config := &yr.Config{ + Mode: yr.ClusterMode, + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + ServerAddr: flag.Args()[0], + DataSystemAddr: flag.Args()[1], + InCluster: true, + LogLevel: "DEBUG", + } + info, err := yr.Init(config) + if err != nil { + fmt.Println("Init failed, error: ", err) + } + fmt.Println(info) +} + +func KVSetAndGetAndDelExample() { + err := yr.SetKV("key", "value") + if err != nil { + fmt.Println("set kv failed, error: ", err) + } + + fmt.Println(yr.GetKV("key", 30000)) + + fmt.Println(yr.DelKV("key")) +} + +func BatchKVSetAndGetAndDelExample() { + err := yr.SetKV("key", "value") + if err != nil { + fmt.Println("set kv failed, error: ", err) + } + + err = yr.SetKV("key1", "value1") + if err != nil { + fmt.Println("set kv failed, error: ", err) + } + + fmt.Println(yr.GetKVs([]string{"key", "key1"}, 30000, false)) + + fmt.Println(yr.DelKVs([]string{"key", "key1"})) + + fmt.Println(yr.GetKVs([]string{"key", "key1"}, 30000, false)) +} + +// compile command: +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -o main kv_example.go +func main() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT) + InitYR() + KVSetAndGetAndDelExample() + BatchKVSetAndGetAndDelExample() + <-c +} diff --git a/api/go/example/put_get_wait_example.go b/api/go/example/put_get_wait_example.go new file mode 100644 index 0000000..d5fc6d6 --- /dev/null +++ b/api/go/example/put_get_wait_example.go @@ -0,0 +1,90 @@ +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "yuanrong.org/kernel/runtime/yr" +) + +func InitYR() { + flag.Parse() + config := &yr.Config{ + Mode: yr.ClusterMode, + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + ServerAddr: flag.Args()[0], + DataSystemAddr: flag.Args()[1], + InCluster: true, + LogLevel: "DEBUG", + } + info, err := yr.Init(config) + if err != nil { + fmt.Println("Init failed, error: ", err) + } + fmt.Println(info) +} + +func PutAndGetExample() { + obj, err := yr.Put(250) + if err != nil { + fmt.Println("Put failed, error: ", err) + } + + fmt.Println(yr.Get[int](obj, 30000)) +} + +func PutAndWaitExample() { + obj, err := yr.Put(250) + if err != nil { + fmt.Println("Put failed, error: ", err) + } + + fmt.Println(yr.Wait(obj, 30000)) +} + +func PutAndBatchGetExample() { + obj, err := yr.Put(250) + if err != nil { + fmt.Println("Put failed, error: ", err) + } + + obj1, err := yr.Put(2560) + if err != nil { + fmt.Println("Put failed, error: ", err) + } + + objs := []*yr.ObjectRef{obj, obj1} + fmt.Println(yr.BatchGet[int](objs, 30000, false)) +} + +func PutAndBatchWaitExample() { + obj, err := yr.Put(250) + if err != nil { + fmt.Println("Put failed, error: ", err) + } + + obj1, err := yr.Put(2560) + if err != nil { + fmt.Println("Put failed, error: ", err) + } + + objs := []*yr.ObjectRef{obj, obj1} + fmt.Println(objs) + fmt.Println(yr.WaitNum(objs, 3, 2)) +} + +// compile command: +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -o main put_get_wait_example.go +func main() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT) + InitYR() + PutAndGetExample() + PutAndWaitExample() + PutAndBatchGetExample() + PutAndBatchWaitExample() + <-c +} diff --git a/api/go/example/stream_example.go b/api/go/example/stream_example.go new file mode 100644 index 0000000..6a83b0c --- /dev/null +++ b/api/go/example/stream_example.go @@ -0,0 +1,87 @@ +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "yuanrong.org/kernel/runtime/yr" +) + +func InitYR() { + flag.Parse() + config := &yr.Config{ + Mode: yr.ClusterMode, + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + ServerAddr: flag.Args()[0], + DataSystemAddr: flag.Args()[1], + InCluster: true, + LogLevel: "DEBUG", + } + info, err := yr.Init(config) + if err != nil { + fmt.Println("Init failed, error: ", err) + } + fmt.Println(info) +} + +func StreamExample() { + producerConf := yr.ProducerConf{ + DelayFlushTime: 5, + PageSize: 1024 * 1024, + MaxStreamSize: 1024 * 1024 * 1024, + } + producer, err := yr.CreateProducer("teststream", producerConf) + if err != nil { + fmt.Println("create producer failed, err: ", err) + return + } + + consumer, err := yr.Subscribe("teststream", "substreamName", 0) + if err != nil { + fmt.Println("create consumer failed, err: ", err) + return + } + + data, err := msgpack.Marshal("test-message") + if err != nil { + fmt.Println("marshal failed, err: ", err) + return + } + + fmt.Println(producer.Send(data)) + fmt.Println(producer.Flush()) + + subDatas, err := consumer.Receive(1, 30000) + if err != nil { + fmt.Println("receive failed, err: ", err) + } + + fmt.Println(subDatas[0]) + var result string + fmt.Println(msgpack.Unmarshal(subDatas[0].Data, &result)) + fmt.Println("result: ", result) + fmt.Println(consumer.Ack(uint64(subDatas[0].Id))) + + fmt.Println(producer.Close()) + fmt.Println(consumer.Close()) + fmt.Println(yr.DeleteStream("teststream")) +} + +func QueryStreamNumExample() { + fmt.Println(yr.QueryGlobalProducersNum("testStream")) + fmt.Println(yr.QueryGlobalConsumersNum("testStream")) +} + +// compile command: +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -o main stream_example.go +func main() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT) + InitYR() + StreamExample() + QueryStreamNumExample() + <-c +} diff --git a/api/go/example/task_example.go b/api/go/example/task_example.go new file mode 100644 index 0000000..f065383 --- /dev/null +++ b/api/go/example/task_example.go @@ -0,0 +1,60 @@ +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "yuanrong.org/kernel/runtime/yr" +) + +func PlusOne(x int) int { + fmt.Println("Hello PlusOne") + return x + 1 +} + +func PlusOne2(x int) int { + fmt.Println("Hello PlusOne2") + function := yr.Function(PlusOne).Options(yr.NewInvokeOptions()) + ref := function.Invoke(300) + fmt.Println(yr.Get[int](ref[0], 3000)) + return x + 1 +} + +func TaskExample() { + flag.Parse() + config := &yr.Config{ + Mode: yr.ClusterMode, + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + ServerAddr: flag.Args()[0], + DataSystemAddr: flag.Args()[1], + InCluster: true, + LogLevel: "DEBUG", + } + info, err := yr.Init(config) + if err != nil { + fmt.Println("Init failed, error: ", err) + } + fmt.Println(info) + + function := yr.Function(PlusOne).Options(yr.NewInvokeOptions()) + ref := function.Invoke(298) + fmt.Println(yr.Get[int](ref[0], 3000)) + + // task call task + function1 := yr.Function(PlusOne2).Options(yr.NewInvokeOptions()) + ref1 := function1.Invoke(298) + fmt.Println(yr.Get[int](ref1[0], 3000)) +} + +// compile command: +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -buildmode=plugin -o yrlib.so task_example.go +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -o main task_example.go +func main() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT) + TaskExample() + <-c +} diff --git a/api/go/example/test_exception.go b/api/go/example/test_exception.go new file mode 100644 index 0000000..b1117ed --- /dev/null +++ b/api/go/example/test_exception.go @@ -0,0 +1,78 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "yuanrong.org/kernel/runtime/yr" +) + +func Functiond(x int) (int, error) { + if x == 0 { + return 0, errors.New("invalid param") + } + return x - 1, nil +} + +func Functionc(x int) (int, error) { + function := yr.Function(Functiond).Options(yr.NewInvokeOptions()) + ref := function.Invoke(x) + res, err := yr.Get[int](ref[0], 30000) + if err != nil { + fmt.Println("IN FUNCTIONC") + return 0, err + } + return res, nil +} + +func Functionb(x int) (int, error) { + function := yr.Function(Functionc).Options(yr.NewInvokeOptions()) + ref := function.Invoke(x) + res, err := yr.Get[int](ref[0], 30000) + if err != nil { + fmt.Println("IN FUNCTIONB") + return 0, err + } + return res, nil +} + +func Functiona() { + fmt.Println("TaskExceptionExample") + flag.Parse() + config := &yr.Config{ + Mode: yr.ClusterMode, + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + ServerAddr: flag.Args()[0], + DataSystemAddr: flag.Args()[1], + InCluster: true, + LogLevel: "DEBUG", + } + info, err := yr.Init(config) + if err != nil { + fmt.Println("Init failed, error: ", err) + } + fmt.Println(info) + + function := yr.Function(Functionb).Options(yr.NewInvokeOptions()) + ref := function.Invoke(0) + res, err := yr.Get[int](ref[0], 30000) + if err != nil { + fmt.Printf(err.Error()) + } else { + fmt.Printf("FunctionA result is %d\n", res) + } +} + +// compile command: +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -buildmode=plugin -o yrlib.so task_example.go +// CGO_LDFLAGS="-L$(path which contains libcpplibruntime.so)" go build -o main task_example.go +func main() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT) + Functiona() + <-c +} diff --git a/api/go/faassdk/common/alarm/logalarm.go b/api/go/faassdk/common/alarm/logalarm.go new file mode 100644 index 0000000..76c11d8 --- /dev/null +++ b/api/go/faassdk/common/alarm/logalarm.go @@ -0,0 +1,242 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package alarm alarm log by filebeat +package alarm + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strconv" + "sync" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils" + "yuanrong.org/kernel/runtime/libruntime/common/logger" + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" + "yuanrong.org/kernel/runtime/libruntime/common/logger/log" +) + +const ( + // ConfigKey environment variable key of alarm config + ConfigKey = "ALARM_CONFIG" + // ClusterName define cluster env key + ClusterName = "CLUSTER_NAME" + + cacheLimit = 10 * 1 << 20 // 10 mb + + // Level3 - + Level3 = "critical" + // Level2 - + Level2 = "major" + // Level1 - + Level1 = "minor" + // Level0 - + Level0 = "notice" + + // GenerateAlarmLog - + GenerateAlarmLog = "firing" + // ClearAlarmLog - + ClearAlarmLog = "resolved" + + // InsufficientMinInstance00001 alarm id + InsufficientMinInstance00001 = "InsufficientMinInstance00001" + // MetadataEtcdConnection00001 alarm id + MetadataEtcdConnection00001 = "MetadataEtcdConnection00001" + // RouterEtcdConnection00001 alarm id + RouterEtcdConnection00001 = "RouterEtcdConnection00001" + // InitStsSdkErr00001 alarm id + InitStsSdkErr00001 = "InitStsSdkErr00001" + // PullStsConfiguration00001 alarm id + PullStsConfiguration00001 = "PullStsConfiguration00001" +) + +var ( + alarmLogger *zap.Logger + createLoggerErr error + createLoggerOnce sync.Once +) + +// LogAlarmInfo Custom alarm info +type LogAlarmInfo struct { + AlarmID string + AlarmName string + AlarmLevel string +} + +// Detail alarm detail +type Detail struct { + SourceTag string // 告警来源 + OpType string // 告警操作类型 + Details string // 告警详情 + StartTimestamp int // 产生时间 + EndTimestamp int // 清除时间 +} + +// GetAlarmLogger - +func GetAlarmLogger() (*zap.Logger, error) { + createLoggerOnce.Do(func() { + alarmLogger, createLoggerErr = newAlarmLogger() + if createLoggerErr != nil { + return + } + if alarmLogger == nil { + createLoggerErr = errors.New("failed to new alarmLogger") + return + } + // 祥云四元组 - 站点/租户ID/产品ID/服务ID + alarmLogger = alarmLogger.With(zapcore.Field{ + Key: "site", Type: zapcore.StringType, + String: os.Getenv(constants.WiseCloudSite), + }, zapcore.Field{ + Key: "tenant_id", Type: zapcore.StringType, + String: os.Getenv(constants.TenantID), + }, zapcore.Field{ + Key: "application_id", Type: zapcore.StringType, + String: os.Getenv(constants.ApplicationID), + }, zapcore.Field{ + Key: "service_id", Type: zapcore.StringType, + String: os.Getenv(constants.ServiceID), + }) + }) + return alarmLogger, createLoggerErr +} + +func newAlarmLogger() (*zap.Logger, error) { + coreInfo, err := config.ExtractCoreInfoFromEnv(ConfigKey) + log.GetLogger().Infof("ALARM_CONFIG is: %v", coreInfo) + if err != nil { + log.GetLogger().Errorf("failed to valid log path, err: %s", err.Error()) + return nil, err + } + + coreInfo.FilePath = filepath.Join(coreInfo.FilePath, "alarm.dat") + + sink, err := logger.CreateSink(coreInfo) + if err != nil { + log.GetLogger().Errorf("failed to create sink: %s", err.Error()) + return nil, err + } + + ws := zapcore.AddSync(sink) + priority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= zapcore.DebugLevel + }) + encoderConfig := zapcore.EncoderConfig{} + rollingFileEncoder := zapcore.NewJSONEncoder(encoderConfig) + + return zap.New(zapcore.NewCore(rollingFileEncoder, ws, priority)), nil +} + +func addAlarmLogger(rollingLogger *zap.Logger, alarmInfo *LogAlarmInfo, detail *Detail) *zap.Logger { + return rollingLogger.With(zapcore.Field{ + Key: "id", Type: zapcore.StringType, + String: alarmInfo.AlarmID, + }, zapcore.Field{ + Key: "name", Type: zapcore.StringType, + String: alarmInfo.AlarmName, + }, zapcore.Field{ + Key: "level", Type: zapcore.StringType, + String: alarmInfo.AlarmLevel, + }, zapcore.Field{ + Key: "source_tag", Type: zapcore.StringType, + String: detail.SourceTag, + }, zapcore.Field{ + Key: "op_type", Type: zapcore.StringType, + String: detail.OpType, + }, zapcore.Field{ + Key: "details", Type: zapcore.StringType, + String: detail.Details, + }, zapcore.Field{ + Key: "clear_type", Type: zapcore.StringType, + String: "ADAC", + }, zapcore.Field{ + Key: "start_timestamp", Type: zapcore.StringType, + String: strconv.Itoa(detail.StartTimestamp), + }, zapcore.Field{ + Key: "end_timestamp", Type: zapcore.StringType, + String: strconv.Itoa(detail.EndTimestamp), + }) +} + +// ReportOrClearAlarm - +func ReportOrClearAlarm(alarmInfo *LogAlarmInfo, detail *Detail) { + alarmLog, err := GetAlarmLogger() + if err != nil { + log.GetLogger().Errorf("GetAlarmLogger err %v", err) + return + } + logger := addAlarmLogger(alarmLog, alarmInfo, detail) + logger.Info("") +} + +// SetAlarmEnv - +func SetAlarmEnv(alarmConfigInfo config.CoreInfo) { + alarmConfigBytes, err := json.Marshal(alarmConfigInfo) + if err != nil { + log.GetLogger().Errorf("json marshal alarmConfigInfo err %v", err) + } + if err := os.Setenv(ConfigKey, string(alarmConfigBytes)); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", ConfigKey, err.Error()) + } + log.GetLogger().Infof("succeeded to set env of %s, value: %s", ConfigKey, string(alarmConfigBytes)) +} + +// SetXiangYunFourConfigEnv - +func SetXiangYunFourConfigEnv(xiangYunFourConfig types.XiangYunFourConfig) { + if err := os.Setenv(constants.WiseCloudSite, xiangYunFourConfig.Site); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constants.WiseCloudSite, err.Error()) + } + if err := os.Setenv(constants.TenantID, xiangYunFourConfig.TenantID); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constants.TenantID, err.Error()) + } + if err := os.Setenv(constants.ApplicationID, xiangYunFourConfig.ApplicationID); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constants.ApplicationID, err.Error()) + } + if err := os.Setenv(constants.ServiceID, xiangYunFourConfig.ServiceID); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constants.ServiceID, err.Error()) + } + log.GetLogger().Infof("succeeded to set env, value: %v", xiangYunFourConfig) +} + +// SetPodIP - +func SetPodIP() error { + ip, err := utils.GetServerIP() + if err != nil { + log.GetLogger().Errorf("failed to get pod ip, err: %s", err.Error()) + return err + } + err = os.Setenv(constants.PodIPEnvKey, ip) + if err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constants.PodIPEnvKey, err.Error()) + return err + } + return nil +} + +// SetClusterNameEnv - +func SetClusterNameEnv(clusterName string) { + if err := os.Setenv(ClusterName, clusterName); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", ClusterName, err.Error()) + } + log.GetLogger().Infof("succeeded to set env of %s, value: %s", ClusterName, clusterName) +} diff --git a/api/go/faassdk/common/alarm/logalarm_test.go b/api/go/faassdk/common/alarm/logalarm_test.go new file mode 100644 index 0000000..b51349c --- /dev/null +++ b/api/go/faassdk/common/alarm/logalarm_test.go @@ -0,0 +1,123 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package alarm alarm log by filebeat +package alarm + +import ( + "encoding/json" + "io" + "os" + "reflect" + "sync" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/libruntime/common/logger" + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +func TestGetAlarmLogger(t *testing.T) { + convey.Convey("TestGetAlarmLogger", t, func() { + convey.Convey("success", func() { + defer gomonkey.ApplyFunc(config.ExtractCoreInfoFromEnv, func(env string) (config.CoreInfo, error) { + return config.CoreInfo{FilePath: "./"}, nil + }).Reset() + defer gomonkey.ApplyFunc(logger.CreateSink, func(coreInfo config.CoreInfo) (io.Writer, error) { + return nil, nil + }).Reset() + defer gomonkey.ApplyFunc(zapcore.AddSync, func(w io.Writer) zapcore.WriteSyncer { + return nil + }).Reset() + // 设置环境变量 + os.Setenv("WiseCloudSite", "testSite") + os.Setenv("TenantID", "testTenantID") + os.Setenv("ApplicationID", "testApplicationID") + os.Setenv("ServiceID", "testServiceID") + // 测试正常情况 + logger, err := GetAlarmLogger() + convey.So(logger, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestReportOrClearAlarm(t *testing.T) { + convey.Convey("TestReportOrClearAlarm", t, func() { + convey.Convey("success", func() { + defer func() { + alarmLogger = nil + createLoggerErr = nil + createLoggerOnce = sync.Once{} + os.Remove("WiseCloudSite") + os.Remove("TenantID") + os.Remove("ApplicationID") + os.Remove("ServiceID") + }() + defer gomonkey.ApplyMethod(reflect.TypeOf(&zap.Logger{}), "Info", func(_ *zap.Logger, msg string, fields ...zap.Field) { + }).Reset() + ReportOrClearAlarm(&LogAlarmInfo{}, &Detail{}) + }) + }) +} + +func TestSetAlarmEnv(t *testing.T) { + convey.Convey("TestSetAlarmEnv", t, func() { + SetAlarmEnv(config.CoreInfo{FilePath: "./test"}) + bytes := []byte(os.Getenv(ConfigKey)) + var coreInfo config.CoreInfo + _ = json.Unmarshal(bytes, &coreInfo) + convey.So(coreInfo.FilePath, convey.ShouldEqual, "./test") + os.Remove(ConfigKey) + }) +} + +func TestSetClusterNameEnv(t *testing.T) { + convey.Convey("TestSetClusterNameEnv", t, func() { + SetXiangYunFourConfigEnv(types.XiangYunFourConfig{Site: "www", TenantID: "testTenantID", ApplicationID: "app", ServiceID: "service"}) + convey.So(os.Getenv(constants.WiseCloudSite), convey.ShouldEqual, "www") + convey.So(os.Getenv(constants.TenantID), convey.ShouldEqual, "testTenantID") + convey.So(os.Getenv(constants.ApplicationID), convey.ShouldEqual, "app") + convey.So(os.Getenv(constants.ServiceID), convey.ShouldEqual, "service") + os.Remove(constants.WiseCloudSite) + os.Remove(constants.TenantID) + os.Remove(constants.ApplicationID) + os.Remove(constants.ServiceID) + }) +} + +func TestSetPodIP(t *testing.T) { + convey.Convey("TestSetPodIP", t, func() { + SetPodIP() + getenv := os.Getenv(constants.PodIPEnvKey) + convey.So(getenv, convey.ShouldNotEqual, "") + }) +} + +func TestSetXiangYunFourConfigEnv(t *testing.T) { + convey.Convey("TestSetPodIP", t, func() { + SetClusterNameEnv("cluster") + getenv := os.Getenv(constants.ClusterName) + convey.So(getenv, convey.ShouldEqual, "cluster") + os.Remove(constants.ClusterName) + }) +} diff --git a/api/go/faassdk/common/aliasroute/alias.go b/api/go/faassdk/common/aliasroute/alias.go new file mode 100644 index 0000000..9cfb4e1 --- /dev/null +++ b/api/go/faassdk/common/aliasroute/alias.go @@ -0,0 +1,358 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package aliasroute alias routing in busclient +package aliasroute + +import ( + "fmt" + "strings" + "sync" + + "yuanrong.org/kernel/runtime/faassdk/common/loadbalance" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + weightRatio = 100 // max weight of a node + routingTypeRule = "rule" + // AliasKeySeparator is the separator in an alias key + AliasKeySeparator = "/" + // DefaultAliasSign is the default alias sign + DefaultAliasSign = "aliases" + // DefaultTenantSign is the default tenant sign + DefaultTenantSign = "tenant" + // DefaultFuncSign is the default function sign + DefaultFuncSign = "function" +) + +// example of an aliasKey: +// ///////// +const ( + ProductIDIndex = iota + 1 + AliasSignIndex + BusinessSignIndex + BusinessIDIndex + TenantSignIndex + TenantIDIndex + FunctionSignIndex + FunctionIDIndex + AliasNameIndex + aliasKeyLength +) + +// Aliases map for stateless function alias +type Aliases struct { + AliasMap *sync.Map // Key: aliasURN -- Value: *AliasElement +} + +// aliases for alias routing +var ( + aliases = &Aliases{ + AliasMap: &sync.Map{}, + } +) + +// GetAliases - +func GetAliases() *Aliases { + return aliases +} + +// UpdateAliasesMap - +func UpdateAliasesMap(aliasList []*AliasElement) { + latestAlias := &Aliases{ + AliasMap: &sync.Map{}, + } + for _, alias := range aliasList { + latestAlias.AddAlias(alias) + } + aliases = latestAlias +} + +// AddAlias add alias to Aliases map from etcd +func (a *Aliases) AddAlias(alias *AliasElement) { + existAliasIf, exist := a.AliasMap.Load(alias.AliasURN) + var existAlias *AliasElement + var ok bool + if !exist { + // new alias, initialize RR and Mutex + existAlias = &AliasElement{ + AliasURN: alias.AliasURN, + FunctionURN: alias.FunctionURN, + FunctionVersionURN: alias.FunctionVersionURN, + Name: alias.Name, + Description: alias.Description, + FunctionVersion: alias.FunctionVersion, + RevisionID: alias.RevisionID, + RoutingConfigs: alias.RoutingConfigs, + RoutingRules: alias.RoutingRules, + RoutingType: alias.RoutingType, + + lb: loadbalance.LBFactory(loadbalance.RoundRobinNginx), + aliasLock: &sync.RWMutex{}, + } + existAlias.resetRR() + logger.GetLogger().Debugf("add alias %s", alias.AliasURN) + a.AliasMap.Store(alias.AliasURN, existAlias) + return + } + existAlias, ok = existAliasIf.(*AliasElement) + if ok { + aliasUpdate(existAlias, alias) + existAlias.resetRR() + } +} + +// RemoveAlias remove alias to aliases map +func (a *Aliases) RemoveAlias(aliasURN string) { + a.AliasMap.Delete(aliasURN) +} + +// GetFuncURNFromAlias If the alias exists, the weighted route version is returned. +// If the alias does not exist, the original URN is returned. +func (a *Aliases) GetFuncURNFromAlias(urn string) string { + existAliasIf, exist := a.AliasMap.Load(urn) + if !exist { + return urn + } + existAlias, ok := existAliasIf.(*AliasElement) + if !ok { + logger.GetLogger().Warnf("Failed to convert the alias urn %s", urn) + return "" + } + return existAlias.getFuncVersionURN() +} + +// GetFuncVersionURNWithParams gets the routing version URN of stateless functionName with parmas for rules +func (a *Aliases) GetFuncVersionURNWithParams(aliasURN string, params map[string]string) string { + existAliasIf, exist := a.AliasMap.Load(aliasURN) + if !exist { + return aliasURN + } + existAlias, ok := existAliasIf.(*AliasElement) + if !ok { + logger.GetLogger().Warnf("Failed to convert the alias urn %s", aliasURN) + return "" + } + return existAlias.getFuncVersionURNWithParams(params) +} + +type routingRules struct { + RuleLogic string `json:"ruleLogic"` + Rules []string `json:"rules"` + GrayVersion string `json:"grayVersion"` +} + +// AliasElement struct stores an alias configs of stateless function +type AliasElement struct { + aliasLock *sync.RWMutex + lb loadbalance.LBInterface + AliasURN string `json:"aliasUrn"` + FunctionURN string `json:"functionUrn"` + FunctionVersionURN string `json:"functionVersionUrn"` + Name string `json:"name"` + FunctionVersion string `json:"functionVersion"` + RevisionID string `json:"revisionId"` + Description string `json:"description"` + RoutingType string `json:"routingType"` + RoutingRules routingRules `json:"routingRules"` + RoutingConfigs []*routingConfig `json:"routingconfig"` +} + +type routingConfig struct { + FunctionVersionURN string `json:"functionVersionUrn"` + Weight float64 `json:"weight"` +} + +func (a *AliasElement) getFuncVersionURN() string { + a.aliasLock.RLock() + defer a.aliasLock.RUnlock() + funcVersion := a.lb.Next("", true) + if funcVersion == nil { + return "" + } + res, ok := funcVersion.(string) + if !ok { + return "" + } + return res +} + +func (a *AliasElement) resetRR() { + a.aliasLock.Lock() + defer a.aliasLock.Unlock() + a.lb.RemoveAll() + for _, v := range a.RoutingConfigs { + a.lb.Add(v.FunctionVersionURN, int(v.Weight*weightRatio)) + } +} + +func (a *AliasElement) getFuncVersionURNByRule(params map[string]string) string { + a.aliasLock.RLock() + defer a.aliasLock.RUnlock() + if len(params) == 0 { + logger.GetLogger().Warnf("params is empty, use default func version") + return a.FunctionVersionURN + } + if len(a.RoutingRules.Rules) == 0 { + logger.GetLogger().Warnf("rule len is 0, use default func version") + return a.FunctionVersionURN + } + + matchRules, err := parseRules(a.RoutingRules) + if err != nil { + logger.GetLogger().Warnf("parse rule error, use default func version: %s", err.Error()) + return a.FunctionVersionURN + } + + // To obtain the final matching result by matching each rule and considering the "AND" or "OR"relationship of the rules + matched := matchRule(params, matchRules, a.RoutingRules.RuleLogic) + // got to default version if not matched + if matched { + return a.RoutingRules.GrayVersion + } + return a.FunctionVersionURN +} + +func (a *AliasElement) getFuncVersionURNWithParams(params map[string]string) string { + if a.RoutingType == routingTypeRule { + return a.getFuncVersionURNByRule(params) + } + // default to go weight + return a.getFuncVersionURN() +} + +func aliasUpdate(destAlias, srcAlias *AliasElement) { + destAlias.AliasURN = srcAlias.AliasURN + destAlias.FunctionURN = srcAlias.FunctionURN + destAlias.FunctionVersionURN = srcAlias.FunctionVersionURN + destAlias.Name = srcAlias.Name + destAlias.FunctionVersion = srcAlias.FunctionVersion + destAlias.RevisionID = srcAlias.RevisionID + destAlias.Description = srcAlias.Description + destAlias.RoutingConfigs = srcAlias.RoutingConfigs + destAlias.RoutingRules = srcAlias.RoutingRules + destAlias.RoutingType = srcAlias.RoutingType +} + +// AliasKey contains the elements of an alias key +type AliasKey struct { + ProductID string + AliasSign string + BusinessSign string + BusinessID string + TenantSign string + TenantID string + FunctionSign string + FunctionID string + AliasName string +} + +// ParseFrom parses elements from an alias key +func (a *AliasKey) ParseFrom(aliasKeyStr string) error { + elements := strings.Split(aliasKeyStr, AliasKeySeparator) + urnLen := len(elements) + if urnLen != aliasKeyLength { + return fmt.Errorf("failed to parse an alias key %s, incorrect length", aliasKeyStr) + } + a.ProductID = elements[ProductIDIndex] + a.AliasSign = elements[AliasSignIndex] + a.BusinessSign = elements[BusinessSignIndex] + a.BusinessID = elements[BusinessIDIndex] + a.TenantSign = elements[TenantSignIndex] + a.TenantID = elements[TenantIDIndex] + a.FunctionSign = elements[FunctionSignIndex] + a.FunctionID = elements[FunctionIDIndex] + a.AliasName = elements[AliasNameIndex] + return nil +} + +// FetchInfoFromAliasKey collects alias information from an alias key +func FetchInfoFromAliasKey(aliasKeyStr string) *AliasKey { + var aliasKey AliasKey + if err := aliasKey.ParseFrom(aliasKeyStr); err != nil { + logger.GetLogger().Errorf("error while parsing an URN: %s", err.Error()) + return &AliasKey{} + } + return &aliasKey +} + +// BuildURNFromAliasKey builds a URN from a alias key +func BuildURNFromAliasKey(aliasKeyStr string) string { + aliasKey := FetchInfoFromAliasKey(aliasKeyStr) + productURN := &urnutils.BaseURN{ + ProductID: urnutils.DefaultURNProductID, + RegionID: urnutils.DefaultURNRegion, + BusinessID: aliasKey.BusinessID, + TenantID: aliasKey.TenantID, + TypeSign: urnutils.DefaultURNFuncSign, + Name: aliasKey.FunctionID, + Version: aliasKey.AliasName, + } + return productURN.String() +} + +func parseRules(routingRules routingRules) ([]Expression, error) { + rules := routingRules.Rules + var expressions []Expression + const expressionSize = 3 + for _, value := range rules { + partition := strings.Split(value, ":") + if len(partition) != expressionSize { + return nil, fmt.Errorf("rules (%s) fields size not equal %v", value, expressionSize) + } + expression := Expression{ + leftVal: partition[0], + operator: partition[1], + rightVal: partition[2], + } + expressions = append(expressions, expression) + } + return expressions, nil +} + +func matchRule(params map[string]string, expressions []Expression, ruleLogic string) bool { + var matchResultList []bool + + for _, exp := range expressions { + matchResultList = append(matchResultList, exp.Execute(params)) + } + if len(matchResultList) > 0 { + return isMatch(matchResultList, ruleLogic) + } + return false +} + +func isMatch(matchResultList []bool, ruleLogic string) bool { + matchResult := matchResultList[0] + if len(matchResultList) > 1 { + switch ruleLogic { + case "or": + for _, value := range matchResultList { + matchResult = matchResult || value + } + case "and": + for _, value := range matchResultList { + matchResult = matchResult && value + } + default: + logger.GetLogger().Warnf("unknow rulelogic: %s, return false", ruleLogic) + return false + } + } + return matchResult +} diff --git a/api/go/faassdk/common/aliasroute/alias_test.go b/api/go/faassdk/common/aliasroute/alias_test.go new file mode 100644 index 0000000..bbe98b5 --- /dev/null +++ b/api/go/faassdk/common/aliasroute/alias_test.go @@ -0,0 +1,274 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package aliasroute alias routing +package aliasroute + +import ( + "sync" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" +) + +const ( + aliasURN = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasv1" +) + +// TestCase init +func GetFakeAliasEle() *AliasElement { + fakeAliasEle := &AliasElement{ + AliasURN: aliasURN, + FunctionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld", + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + Name: "myaliasv1", + FunctionVersion: "$latest", + RevisionID: "20210617023315921", + Description: "", + RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + Weight: 60, + }, + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:v1", + Weight: 40, + }, + }, + } + return fakeAliasEle +} + +func GetFakeRuleAliasEle() *AliasElement { + fakeAliasEle := &AliasElement{ + AliasURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasrulev1", + FunctionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld", + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + Name: "myaliasrulev1", + FunctionVersion: "$latest", + RevisionID: "20210617023315921", + Description: "", + RoutingType: "rule", + RoutingRules: routingRules{ + RuleLogic: "and", + Rules: []string{"userType:=:VIP", "age:<=:20", "devType:in:P40,P50,MATE40"}, + GrayVersion: "sn:cn:yrk:172120022620195843:function:0@default@test_func:3", + }, + } + return fakeAliasEle +} +func ClearAliasRoute() { + aliases = &Aliases{ + AliasMap: &sync.Map{}, + } +} + +func TestOptAlias(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + convey.Convey("AddAlias success", t, func() { + fakeAliasEle := GetFakeAliasEle() + aliases.AddAlias(fakeAliasEle) + ele, ok := aliases.AliasMap.Load(fakeAliasEle.AliasURN) + convey.So(ok, convey.ShouldBeTrue) + convey.So(ele, convey.ShouldNotBeNil) + }) + convey.Convey("update Alias success", t, func() { + fakeAliasEle := GetFakeAliasEle() + fakeAliasEle.RoutingConfigs = []*routingConfig{ + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + Weight: 50, + }, + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:v1", + Weight: 50, + }, + } + aliases.AddAlias(fakeAliasEle) + ele, ok := aliases.AliasMap.Load(fakeAliasEle.AliasURN) + convey.So(ok, convey.ShouldBeTrue) + convey.So(ele.(*AliasElement).RoutingConfigs[0].Weight, convey.ShouldEqual, 50) + convey.So(ele.(*AliasElement).RoutingConfigs[1].Weight, convey.ShouldEqual, 50) + }) + convey.Convey("remove Alias success", t, func() { + fakeAliasEle := GetFakeAliasEle() + aliases.AddAlias(fakeAliasEle) + aliases.RemoveAlias(fakeAliasEle.AliasURN) + ele, ok := aliases.AliasMap.Load(fakeAliasEle.AliasURN) + convey.So(ok, convey.ShouldBeFalse) + convey.So(ele, convey.ShouldBeNil) + }) +} + +func TestGetFuncURNFromAlias(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + convey.Convey("alias does not exist", t, func() { + urn := aliases.GetFuncURNFromAlias(aliasURN) + convey.So(urn, convey.ShouldEqual, aliasURN) + }) + + convey.Convey("alias get error", t, func() { + aliases.AliasMap.Store(aliasURN, "456") + urn := aliases.GetFuncURNFromAlias(aliasURN) + aliases.AliasMap.Delete(aliasURN) + convey.So(urn, convey.ShouldEqual, "") + }) + convey.Convey("alias get error", t, func() { + aliases.AddAlias(GetFakeAliasEle()) + urn := aliases.GetFuncURNFromAlias(aliasURN) + convey.So(urn, convey.ShouldNotEqual, aliasURN) + convey.So(urn, convey.ShouldNotEqual, "") + convey.So(urn, convey.ShouldNotContainSubstring, "myaliasv1") + }) + +} + +func TestFetchInfoFromAliasKey(t *testing.T) { + path := "/sn/aliases/business/yrk/tenant/12345678901234561234567890123456/function/helloworld/myalias" + aliasKey := FetchInfoFromAliasKey(path) + + assert.Equal(t, aliasKey.FunctionID, "helloworld") + assert.Equal(t, aliasKey.AliasName, "myalias") + + path = "/sn/aliases/business/yrk/tenant/12345678901234561234567890123456/function/helloworld" + aliasKey = FetchInfoFromAliasKey(path) + assert.Empty(t, aliasKey) +} + +func TestBuildURNFromAliasKey(t *testing.T) { + path := "/sn/aliases/business/yrk/tenant/12345678901234561234567890123456/function/helloworld/myalias" + urn := BuildURNFromAliasKey(path) + assert.Contains(t, urn, "myalias") +} + +func TestGetFuncVersionURNWithParamsMatch(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + fakeAliasEle := GetFakeRuleAliasEle() + aliases.AddAlias(fakeAliasEle) + params := map[string]string{} + params["userType"] = "VIP" + params["age"] = "10" + params["devType"] = "P40" + + aliasUrn := "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasrulev1" + wantFuncVer := "sn:cn:yrk:172120022620195843:function:0@default@test_func:3" + got := GetAliases().GetFuncVersionURNWithParams(aliasUrn, params) + assert.Equal(t, wantFuncVer, got) +} + +func TestGetFuncVersionURNWithParamsNotMatch(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + fakeAliasEle := GetFakeRuleAliasEle() + aliases.AddAlias(fakeAliasEle) + params := map[string]string{} + params["userType"] = "VIP" + params["age"] = "50" + params["devType"] = "P40" + + aliasUrn := "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasrulev1" + wantFuncVer := "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest" + got := GetAliases().GetFuncVersionURNWithParams(aliasUrn, params) + assert.Equal(t, wantFuncVer, got) +} + +func TestUpdateAliasesMap(t *testing.T) { + convey.Convey( + "Test UpdateAliasesMap", t, func() { + convey.Convey( + "UpdateAliasesMap success", func() { + fakeAliasEle := GetFakeAliasEle() + convey.So(func() { + UpdateAliasesMap([]*AliasElement{fakeAliasEle}) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestGetFuncVersionURNWithParams(t *testing.T) { + convey.Convey( + "Test GetFuncVersionURNWithParams", t, func() { + convey.Convey( + "GetFuncVersionURNWithParams success", func() { + params := map[string]string{} + params["userType"] = "VIP" + params["age"] = "50" + params["devType"] = "P40" + str := GetAliases().GetFuncVersionURNWithParams("aliasURN", params) + convey.So(str, convey.ShouldEqual, "aliasURN") + }, + ) + }, + ) +} + +func TestGetFuncVersionURNByRule(t *testing.T) { + convey.Convey( + "Test getFuncVersionURNByRule", t, func() { + convey.Convey( + "getFuncVersionURNByRule success", func() { + params := map[string]string{} + ele := GetFakeAliasEle() + var s sync.RWMutex + ele.aliasLock = &s + str := ele.getFuncVersionURNByRule(params) + convey.So(str, convey.ShouldEqual, ele.FunctionVersionURN) + params["userType"] = "VIP" + str = ele.getFuncVersionURNByRule(params) + convey.So(str, convey.ShouldEqual, ele.FunctionVersionURN) + }, + ) + }, + ) +} + +func TestMatchRule(t *testing.T) { + convey.Convey( + "Test matchRule", t, func() { + convey.Convey( + "matchRule success", func() { + flag := matchRule(map[string]string{}, []Expression{}, "ruleLogic") + convey.So(flag, convey.ShouldBeFalse) + }, + ) + }, + ) +} + +func TestIsMatch(t *testing.T) { + convey.Convey( + "Test isMatch", t, func() { + convey.Convey( + "isMatch success when ruleLogic==or", func() { + flag := isMatch([]bool{false, false}, "or") + convey.So(flag, convey.ShouldBeFalse) + }, + ) + convey.Convey( + "isMatch success when default", func() { + flag := isMatch([]bool{false, false}, "non") + convey.So(flag, convey.ShouldBeFalse) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/common/aliasroute/expression.go b/api/go/faassdk/common/aliasroute/expression.go new file mode 100644 index 0000000..c169db3 --- /dev/null +++ b/api/go/faassdk/common/aliasroute/expression.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package aliasroute alias routing in busclient +package aliasroute + +import ( + "strconv" + "strings" + + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + expressionSize = 3 +) + +// Expression rule expression struct +type Expression struct { + leftVal string + operator string + rightVal string +} + +func compareIntegerStrings(a, b string) (int, error) { + numA, err := strconv.Atoi(a) + if err != nil { + return 0, err + } + + numB, err := strconv.Atoi(b) + if err != nil { + return 0, err + } + + if numA < numB { + return -1, nil + } else if numA > numB { + return 1, nil + } else { + return 0, nil + } +} + +// Execute the rule expression +func (exp *Expression) Execute(params map[string]string) bool { + logger.GetLogger().Debugf("params %v, exp.leftVal %v,exp.rightVal %v", params, exp.leftVal, exp.rightVal) + val, exist := params[exp.leftVal] + if !exist { + logger.GetLogger().Warnf("cannot find val for %s in params", exp.leftVal) + return false + } + + switch exp.operator { + case "=": + return strings.TrimSpace(val) == strings.TrimSpace(exp.rightVal) + case "!=": + return strings.TrimSpace(val) != strings.TrimSpace(exp.rightVal) + case ">": + ret, err := compareIntegerStrings(val, exp.rightVal) + return err == nil && ret == 1 + case "<": + ret, err := compareIntegerStrings(val, exp.rightVal) + return err == nil && ret == -1 + case ">=": + ret, err := compareIntegerStrings(val, exp.rightVal) + return err == nil && (ret == 1 || ret == 0) + case "<=": + ret, err := compareIntegerStrings(val, exp.rightVal) + return err == nil && (ret == -1 || ret == 0) + case "in": + return matchStr(val, exp.rightVal) + default: + logger.GetLogger().Warnf("unknown operator(%s), return false", val, exp.operator) + return false + } +} + +func convertStrsToNum(strs []string) ([]int, error) { + var ret []int + for _, str := range strs { + num, err := strconv.Atoi(strings.TrimSpace(str)) + if err != nil { + return nil, err + } + ret = append(ret, num) + } + return ret, nil +} + +func matchStr(str string, targetStr string) bool { + tars := strings.Split(targetStr, ",") + for _, tar := range tars { + // The rvalue of the 'in' operator ignores "" + if tar != "" && strings.TrimSpace(str) == strings.TrimSpace(tar) { + return true + } + } + return false +} diff --git a/api/go/faassdk/common/aliasroute/expression_test.go b/api/go/faassdk/common/aliasroute/expression_test.go new file mode 100644 index 0000000..85e3960 --- /dev/null +++ b/api/go/faassdk/common/aliasroute/expression_test.go @@ -0,0 +1,194 @@ +package aliasroute + +import ( + "fmt" + "strings" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" +) + +type ExpressionTestSuite struct { + alias AliasElement +} + +func (suite *ExpressionTestSuite) SetupTest() { + +} + +func (suite *ExpressionTestSuite) TearDownTest() { + +} + +func (suite *ExpressionTestSuite) TestEquel() { + +} + +func genExpression(str string) (Expression, error) { + partition := strings.Split(str, ":") + if len(partition) != expressionSize { + return Expression{}, fmt.Errorf("express(#{str}) string format is error") + } + return Expression{ + leftVal: partition[0], + operator: partition[1], + rightVal: partition[2], + }, nil +} + +func ExecuteExp(t *testing.T, expStr string, params map[string]string) bool { + exp, err := genExpression(expStr) + if err != nil { + t.Error("gen expression fail: ", expStr) + return false + } + return exp.Execute(params) +} + +func TestExpEq(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + + got := ExecuteExp(t, "id:=:123", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:=:444", params) + assert.False(t, got) +} + +func TestExpNotEq(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + + got := ExecuteExp(t, "id:!=:200", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:!=:123", params) + assert.False(t, got) +} + +func TestExpLt(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "id:<:200", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:<:100", params) + assert.False(t, got) + + got = ExecuteExp(t, "type:<:100", params) + assert.False(t, got) +} + +func TestExpLtEq(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "id:<=:200", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:<=:100", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:<=:123", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:<=:100", params) + assert.False(t, got) +} + +func TestExpGt(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "id:>:200", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:>:100", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:>:100", params) + assert.False(t, got) +} + +func TestExpGtEq(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "id:>=:200", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:>=:100", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:>=:123", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:>=:1", params) + assert.False(t, got) +} + +func TestExpIn(t *testing.T) { + params := map[string]string{} + params["type"] = "p40" + + got := ExecuteExp(t, "type:in:p40,mate40", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:in:mate40, p40", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:in:mate40, p40 , p30", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:in:mate40,p30", params) + assert.False(t, got) + + got = ExecuteExp(t, "type:in:", params) + assert.False(t, got) +} + +func TestExpExcept(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "age:<:30", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:<:", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:<:abc", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:||:123", params) + assert.False(t, got) +} + +func TestConvertStrsToNum(t *testing.T) { + convey.Convey( + "Test convertStrsToNum", t, func() { + convey.Convey( + "convertStrsToNum success when err", func() { + n, err := convertStrsToNum([]string{"a"}) + convey.So(n, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "convertStrsToNum success", func() { + n, err := convertStrsToNum([]string{"1", "2"}) + convey.So(n, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/common/constants/constants.go b/api/go/faassdk/common/constants/constants.go new file mode 100644 index 0000000..41c9862 --- /dev/null +++ b/api/go/faassdk/common/constants/constants.go @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constants gets function services from URNs +package constants + +const ( + // FuncLogLevelDebug - + FuncLogLevelDebug = "DEBUG" + // FuncLogLevelInfo - + FuncLogLevelInfo = "INFO" + // FuncLogLevelWarn - + FuncLogLevelWarn = "WARN" + // FuncLogLevelWarning - + FuncLogLevelWarning = "WARNING" + // FuncLogLevelError - + FuncLogLevelError = "ERROR" +) + +// stage +const ( + InitializeStage = "initialize" + RestoreStage = "restore" + InvokeStage = "invoke" + LoadStage = "load" + ExtensionStage = "extension" + ColdStartStage = "coldstart" +) + +const ( + // DefaultFuncLogIndex default function log's index + DefaultFuncLogIndex = -2 + // RuntimeLogOptTail - + RuntimeLogOptTail = "Tail" + // RuntimeContainerIDEnvKey - + RuntimeContainerIDEnvKey = "DELEGATE_CONTAINER_ID" +) + +const ( + // ValidBasicCreateParamSize - + ValidBasicCreateParamSize = 4 + // ValidCustomImageCreateParamSize - + ValidCustomImageCreateParamSize = 5 + // CustomImageUserArgIndex - + CustomImageUserArgIndex = 4 + // ValidInvokeArgumentSize - + ValidInvokeArgumentSize = 2 +) + +// CaaS alarm +const ( + // WiseCloudSite site + WiseCloudSite = "WISECLOUD_SITE" + // TenantID WiseCloud tenantID + TenantID = "WISECLOUD_TENANTID" + // ApplicationID WiseCloud applicationId + ApplicationID = "WISECLOUD_APPLICATIONID" + // ServiceID WiseCloud serviceId + ServiceID = "WISECLOUD_SERVICEID" + // ClusterName define cluster env key + ClusterName = "CLUSTER_NAME" + // PodNameEnvKey define pod name env key + PodNameEnvKey = "POD_NAME" + // PodIPEnvKey define pod ip env key + PodIPEnvKey = "POD_IP" + // HostIPEnvKey define pod ip env key + HostIPEnvKey = "HOST_IP" +) + +const ( + // DefaultMapSize is the default map size + DefaultMapSize = 16 + // KernelRequestIDKey is the requestID in kernel + KernelRequestIDKey = "requestId" + // RuntimeTypeHttp represents runtime type of HTTP + RuntimeTypeHttp = "http" + // RuntimeTypeCustomContainer represents runtime type of custom container + RuntimeTypeCustomContainer = "custom image" +) + +const ( + // CaaSTraceIDHeaderKey is the key of trace ID + CaaSTraceIDHeaderKey = "X-CaaS-Trace-Id" + // CffRequestIDHeaderKey is the key of trace ID + CffRequestIDHeaderKey = "X-Cff-Request-Id" +) + +const ( + // RuntimeMaxRespBodySize - + RuntimeMaxRespBodySize = 6 * 1024 * 1024 + // RuntimeRoot - + RuntimeRoot = "/home/snuser/runtime" + // RuntimeCodeRoot - + RuntimeCodeRoot = "/opt/function/code" + // RuntimeLogDir - + RuntimeLogDir = "/home/snuser/log" + // RuntimePkgNameSplit - + RuntimePkgNameSplit = 2 + // INT64ToINT - + INT64ToINT = 10 + // LDLibraryPath - + LDLibraryPath = "LD_LIBRARY_PATH" +) diff --git a/api/go/faassdk/common/constants/error.go b/api/go/faassdk/common/constants/error.go new file mode 100644 index 0000000..4075e51 --- /dev/null +++ b/api/go/faassdk/common/constants/error.go @@ -0,0 +1,79 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constants +package constants + +const ( + // NoneError - + NoneError = 0 + // FaaSError - + FaaSError = 500 +) + +// executor error code +const ( + // ExecutorErrCodeInitFail - + ExecutorErrCodeInitFail = 6001 + // AcquireLeaseTrafficLimitErrorCode is reach max limit of acquiring lease concurrently + AcquireLeaseTrafficLimitErrorCode = 6037 +) + +// user error code +const ( + // EntryNotFound user code entry not found + EntryNotFound = 4001 + // FunctionRunError user function failed to run + FunctionRunError = 4002 + // StateContentTooLarge state content is too large + StateContentTooLarge = 4003 + // ResponseExceedLimit response of user function exceeds the platform limit + ResponseExceedLimit = 4004 + // UndefinedState state is undefined + UndefinedState = 4005 + // HeartBeatFunctionInvalid heart beat function of user invalid + HeartBeatFunctionInvalid = 4006 + // FunctionResultInvalid user function result is invalid + FunctionResultInvalid = 4007 + // InitializeFunctionError user initialize function error + InitializeFunctionError = 4009 + // HeartBeatInvokeError failed to invoke heart beat function + HeartBeatInvokeError = 4010 + // InvokeFunctionTimeout user function invoke timeout + InvokeFunctionTimeout = 4010 + + // InstanceCircuitBreakError function is circuit break + InstanceCircuitBreakError = 4011 + + // InitFunctionTimeout user function init timeout + InitFunctionTimeout = 4211 + // RequestBodyExceedLimit request body exceeds limit + RequestBodyExceedLimit = 4140 + // InitFunctionFail function initialization failed + InitFunctionFail = 4201 + // MemoryLimitExceeded runtime memory limit exceeded + MemoryLimitExceeded = 4205 + // DiskUsageExceed disk usage exceed code + DiskUsageExceed = 4207 +) + +// frontend error code +const ( + // FrontendStatusOk ok code + FrontendStatusOk = 200200 + // ClusterIsUpgrading - + ClusterIsUpgrading = 150439 +) diff --git a/api/go/faassdk/common/faasscheduler/proxy.go b/api/go/faassdk/common/faasscheduler/proxy.go new file mode 100644 index 0000000..3e64fcb --- /dev/null +++ b/api/go/faassdk/common/faasscheduler/proxy.go @@ -0,0 +1,115 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faasscheduler - +package faasscheduler + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "yuanrong.org/kernel/runtime/faassdk/common/loadbalance" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + // hashRingSize the concurrent hash ring length + hashRingSize = 100 + limiterTime = 1 * time.Millisecond +) + +// SchedulerInfo is scheduler info +type SchedulerInfo struct { + SchedulerFuncKey string `json:"schedulerFuncKey"` + SchedulerIDList []string `json:"schedulerIDList"` +} + +// Proxy is the singleton proxy +var Proxy *schedulerProxy + +func init() { + Proxy = newSchedulerProxy( + loadbalance.NewLimiterCHGeneric(limiterTime), + ) +} + +// schedulerProxy is used to get instances from FaaSScheduler via a grpc stream +type schedulerProxy struct { + // used to select a FaaSScheduler by the func info Concurrent Consistent Hash + loadBalance loadbalance.LBInterface + schedulerFuncKey string +} + +// Add an FaaSScheduler +func (im *schedulerProxy) Add(faaSSchedulerID string) { + im.loadBalance.Add(faaSSchedulerID, 0) +} + +// Remove a FaaSScheduler +func (im *schedulerProxy) Remove(faaSSchedulerID string) { + im.loadBalance.Remove(faaSSchedulerID) +} + +// Get an instance for this request +func (im *schedulerProxy) Get(funcKey string) (string, error) { + logger.GetLogger().Infof("Getting instance from scheduler for funcKey: %s", funcKey) + // select one FaaSScheduler by the func key + next := im.loadBalance.Next(funcKey, false) + faaSSchedulerID, ok := next.(string) + if !ok { + return "", fmt.Errorf("failed to parse the result of loadbanlance: %+v", next) + } + if strings.TrimSpace(faaSSchedulerID) == "" { + return "", fmt.Errorf("no avaiable faas scheduler was found") + } + return faaSSchedulerID, nil +} + +// GetSchedulerFuncKey - +func (im *schedulerProxy) GetSchedulerFuncKey() string { + return im.schedulerFuncKey +} + +// newSchedulerProxy return an instance pool which get the instance from the remote FaaSScheduler +func newSchedulerProxy(lb loadbalance.LBInterface) *schedulerProxy { + return &schedulerProxy{ + loadBalance: lb, + } +} + +// ParseSchedulerData - +func ParseSchedulerData(args api.Arg) error { + schedulerInfo := &SchedulerInfo{} + err := json.Unmarshal(args.Data, schedulerInfo) + if err != nil { + return err + } + for _, schedulerID := range schedulerInfo.SchedulerIDList { + Proxy.Add(schedulerID) + } + Proxy.schedulerFuncKey = schedulerInfo.SchedulerFuncKey + return nil +} + +// SetStain - +func (im *schedulerProxy) SetStain(funcKey, instanceID string) { + if v, ok := im.loadBalance.(*loadbalance.LimiterCHGeneric); ok { + v.SetStain(funcKey, instanceID) + } +} diff --git a/api/go/faassdk/common/faasscheduler/proxy_test.go b/api/go/faassdk/common/faasscheduler/proxy_test.go new file mode 100644 index 0000000..d6116eb --- /dev/null +++ b/api/go/faassdk/common/faasscheduler/proxy_test.go @@ -0,0 +1,131 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faasscheduler - +package faasscheduler + +import ( + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/faassdk/common/loadbalance" + "yuanrong.org/kernel/runtime/libruntime/api" +) + +func Test_schedulerProxy_Get(t *testing.T) { + convey.Convey("Get", t, func() { + convey.Convey("assert failed", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&loadbalance.ConcurrentCHGeneric{}), "Next", + func(_ *loadbalance.ConcurrentCHGeneric, name string, move bool) interface{} { + return 123 + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + _, err := Proxy.Get("functionKey") + convey.So(err, convey.ShouldBeError) + }) + + convey.Convey("no avaiable faas scheduler was found", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&loadbalance.ConcurrentCHGeneric{}), "Next", + func(_ *loadbalance.ConcurrentCHGeneric, name string, move bool) interface{} { + return "" + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + _, err := Proxy.Get("functionKey") + convey.So(err, convey.ShouldBeError) + }) + + convey.Convey("failed to get the faas scheduler named", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&loadbalance.ConcurrentCHGeneric{}), "Next", + func(_ *loadbalance.ConcurrentCHGeneric, name string, move bool) interface{} { + return "faaSScheduler" + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + _, err := Proxy.Get("functionKey") + convey.So(err, convey.ShouldBeError) + }) + + convey.Convey("success", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&loadbalance.ConcurrentCHGeneric{}), "Next", + func(_ *loadbalance.ConcurrentCHGeneric, name string, move bool) interface{} { + return "instance1" + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + Proxy.Add("instance1") + instanceID, err := Proxy.Get("functionKey") + convey.So(err, convey.ShouldBeNil) + convey.So(instanceID, convey.ShouldEqual, "instance1") + + Proxy.Remove("instance1") + _, err2 := Proxy.Get("functionKey") + convey.So(err2.Error(), convey.ShouldEqual, "no avaiable faas scheduler was found") + }) + }) +} + +func TestParseSchedulerData(t *testing.T) { + convey.Convey("TestParseSchedulerData", t, func() { + convey.Convey("success", func() { + err := ParseSchedulerData(api.Arg{ + Type: api.Value, + Data: []byte(`{"schedulerFuncKey": "faasscheudler", "schedulerIDList": ["123"]}`), + }) + convey.So(err, convey.ShouldBeNil) + key := Proxy.GetSchedulerFuncKey() + convey.So(key, convey.ShouldEqual, "faasscheudler") + Proxy.Remove("faasscheudler") + }) + }) +} + +func TestSetStain(t *testing.T) { + convey.Convey("", t, func() { + Proxy.Add("faasscheduler") + Proxy.SetStain("function", "faasscheduler") + Proxy.Remove("faasscheduler") + }) +} diff --git a/api/go/faassdk/common/functionlog/function_log.go b/api/go/faassdk/common/functionlog/function_log.go new file mode 100644 index 0000000..20e854d --- /dev/null +++ b/api/go/faassdk/common/functionlog/function_log.go @@ -0,0 +1,437 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package functionlog gets function services from URNs +package functionlog + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" + logger2 "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" + "yuanrong.org/kernel/runtime/libruntime/common/logger" + "yuanrong.org/kernel/runtime/libruntime/common/logger/async" + runtimeLoggerConfig "yuanrong.org/kernel/runtime/libruntime/common/logger/config" + "yuanrong.org/kernel/runtime/libruntime/common/utils" +) + +const ( + innerLogSeparator = "," + funcLogSeparator = ";" + + // Max length of function log in "tail" mode + maxTailLogSizeKB = 2 * 1024 + // Max length of normal instance initializer log + maxInitLogSizeMB = 10 * 1024 * 1024 + // InitializerInvokeID - + InitializerInvokeID = "initializer" + // LoadInvokeID - + LoadInvokeID = "load" + // RestoreHookInvokeID - + RestoreHookInvokeID = "restore" + + // LogTimeFormat - + LogTimeFormat = "2006-01-02T15:04:05Z" + // NanoLogLayout - + NanoLogLayout = "2006-01-02T15:04:05.999999999Z07:00" + + // RuntimeLogFormat - + RuntimeLogFormat = "2006-01-02 15:04:05.999999999Z07:00" + + userLogTimeFormat = "2006-01-02 15:04:05.999" + + logTankServiceLength = 2 + originalLocation = 0 + alternativeLocation = 1 + userFuncLogChanCap = 5000 + funcInfoMapCap = 10 + userLogKey = "log" + fileMode = 0640 + cacheLimit = 10 * 1 << 20 // 10 mb + uint64Width = 8 + + dateIndex = 0 + timeIndex = 1 + traceIDIndex = 2 + minLength = 3 + traceIDLen = 36 +) + +var ( + functionLogger *FunctionLogger + createLoggerErr error + createLoggerOnce sync.Once +) + +// FunctionLogger process function's log +type FunctionLogger struct { + logger *zap.Logger + fieldsPool *sync.Pool + logPool *sync.Pool + // key is invokeID, value is logRecorder + logRecorders sync.Map + logLevel string + logAbsPath string +} + +// NewFunctionLogger - +func NewFunctionLogger(logger *zap.Logger, logLevel string, logPath string) *FunctionLogger { + return &FunctionLogger{ + logger: logger, + fieldsPool: &sync.Pool{New: func() interface{} { + return newLogFields() + }}, + logPool: &sync.Pool{New: func() interface{} { + return NewFunctionLog() + }}, + logLevel: logLevel, + logAbsPath: logPath, + } +} + +type logTankService struct { + originalGroupID string + originalStreamID string + alternativeGroupID string + alternativeStreamID string +} + +// GetFunctionLogger - +func GetFunctionLogger(cfg *config.Configuration) (*FunctionLogger, error) { + createLoggerOnce.Do(func() { + functionLogger, createLoggerErr = newLTSFunctionLogger(cfg) + if createLoggerErr != nil { + return + } + if functionLogger == nil { + createLoggerErr = errors.New("failed to new FunctionLogger") + return + } + }) + return functionLogger, createLoggerErr +} + +// SetLogLevel - +func (f *FunctionLogger) SetLogLevel(level string) { + logger2.GetLogger().Infof("set function log level: %s", level) + f.logLevel = level +} + +// NewLogRecorder new "invokeID:LogRecorder" pair in FunctionLogger +func (f *FunctionLogger) NewLogRecorder(invokeID, traceID, stage string, opts ...LogRecorderOption) *LogRecorder { + r := &LogRecorder{ + f: f, + invokeID: invokeID, + traceID: traceID, + logLevel: f.logLevel, + stage: stage, + separatedWithLineBreak: true, + } + r.logs = NewFixSizeRecorder(maxTailLogSizeKB, r.guessSize, r.handleDropped) + for _, opt := range opts { + opt(r) + } + + f.logRecorders.Store(invokeID, r) + return r +} + +// GetLogRecorder get LogRecorder of invokeID in FunctionLogger +func (f *FunctionLogger) GetLogRecorder(invokeID string) *LogRecorder { + value, ok := f.logRecorders.Load(invokeID) + if !ok { + return nil + } + if recorder, ok := value.(*LogRecorder); ok { + return recorder + } + return nil +} + +// WriteStdLog write std logs shown on all requests +func (f *FunctionLogger) WriteStdLog(logString, timeStamp string, isReserved bool, nanoTimestamp string) { + l := f.AcquireLog() + l.Time = timeStamp + l.NanoTime = nanoTimestamp + l.IsStdLog = true + l.Level = constants.FuncLogLevelInfo + l.Message.WriteString(logString) + var logged bool + f.logRecorders.Range(func(_, value interface{}) bool { + logged = true + if recorder, ok := value.(*LogRecorder); ok { + recorder.Write(l) + } + return true + }) + if !logged { + f.WriteDefaultLog(l) + } +} + +// AcquireLog retrieves a FunctionLog struct from the cached pool +func (f *FunctionLogger) AcquireLog() *FunctionLog { + functionLogIf := f.logPool.Get() + functionLog, ok := functionLogIf.(*FunctionLog) + if !ok { + logger2.GetLogger().Errorf("failed to assert FunctionLog") + return nil + } + functionLog.Reset() + return functionLog +} + +// ReleaseLog puts a FunctionLog struct to the cached pool +func (f *FunctionLogger) ReleaseLog(functionLog *FunctionLog) { + f.logPool.Put(functionLog) +} + +// WriteDefaultLog - +func (f *FunctionLogger) WriteDefaultLog(functionLog *FunctionLog) { + f.write(getDefaultInternalFunctionLog(functionLog)) +} + +func getDefaultInternalFunctionLog(functionLog *FunctionLog) *internalFunctionLog { + return &internalFunctionLog{ + FunctionLog: functionLog, + FunctionName: urnutils.LocalFuncURN.Name, + TraceID: "default", + Stage: constants.InitializeStage, + Index: constants.DefaultFuncLogIndex, + } +} + +func (f *FunctionLogger) deleteLogRecorder(invokeID string) { + f.logRecorders.Delete(invokeID) +} + +// RefreshFileModTime timer of refreshing file modified time in case of the user log file being deleted +func (f *FunctionLogger) RefreshFileModTime(stopCh chan struct{}) { + if stopCh == nil { + logger2.GetLogger().Warnf("empty stop channel") + return + } + ticker := time.NewTicker(timeModInterval) + logger2.GetLogger().Infof("start to regularly modify the file modification time") + for { + select { + case <-ticker.C: + f.refreshWithRetry() + case <-stopCh: + logger2.GetLogger().Warnf("received the runtime exit signal and stopped refreshing the file " + + "modification time") + ticker.Stop() + return + } + } +} + +func (f *FunctionLogger) refreshWithRetry() { + newTime := time.Now() + for i := 0; i < timeModRetry; i++ { + path := logger.GetLogName(f.logAbsPath) + if path == "" { + path = f.logAbsPath + } + if err := os.Chtimes(path, newTime, newTime); err != nil { + logger2.GetLogger().Warnf("failed to change the modification time of the user log file: %s", + err.Error()) + continue + } + logger2.GetLogger().Infof("succeeded to change the modification time of the user log file") + break + } +} + +func newZapLogger(fullPath, messageKey string) (*zap.Logger, error) { + coreInfo, err := runtimeLoggerConfig.GetCoreInfoFromEnv() + if err != nil { + logger2.GetLogger().Errorf("failed to get core info: %s", err.Error()) + coreInfo = runtimeLoggerConfig.GetDefaultCoreInfo() + } + coreInfo.FilePath = fullPath + if messageKey == userLogKey { + coreInfo.IsUserLog = true + } + sink, err := logger.CreateSink(coreInfo) + if err != nil { + logger2.GetLogger().Errorf("failed to create sink: %s", err.Error()) + return nil, err + } + + ws := async.NewAsyncWriteSyncer(sink, async.WithCachedLimit(cacheLimit)) + + encoderConfig := zapcore.EncoderConfig{ + LevelKey: "Level", + NameKey: "Logger", + MessageKey: messageKey, + CallerKey: "CallerKey", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + rollingFileEncoder := zapcore.NewJSONEncoder(encoderConfig) + + priority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= zapcore.DebugLevel + }) + + return zap.New(zapcore.NewCore(rollingFileEncoder, ws, priority)), nil +} + +func newLTSFunctionLogger(cfg *config.Configuration) (*FunctionLogger, error) { + path := os.Getenv("function-log") + if path == "" { + path = cfg.StartArgs.LogDir + } + if err := utils.ValidateFilePath(path); err != nil { + logger2.GetLogger().Errorf("failed to valid log path, err: %s", err.Error()) + return nil, err + } + + path = filepath.Join(path, urnutils.LocalFuncURN.TenantID) + c := &urnutils.ComplexFuncName{} + logger2.GetLogger().Infof("POD_NAME is: %s", os.Getenv("POD_NAME")) + + c.ParseFrom(urnutils.LocalFuncURN.Name) + l := &logTankService{} + l.extractLogTankService(cfg.RuntimeConfig.LogTankService.GroupID, cfg.RuntimeConfig.LogTankService.StreamID) + + if err := createFunctionInfoFile(c); err != nil { + logger2.GetLogger().Errorf("failed to create function info file, %s", err.Error()) + return nil, err + } + + // functionName@serviceID@version@podName@time#logGroupId#logStreamId + functionLogName := fmt.Sprintf("%s@%s@%s@%s@%s#%s#%s#%s", c.FuncName, c.ServiceID, + urnutils.LocalFuncURN.Version, os.Getenv("POD_NAME"), time.Now().Format("20060102150405"), + l.originalGroupID, l.originalStreamID, cfg.UserLogTag) + + fullPath, err := getAbsFilePath(path, functionLogName) + if err != nil { + return nil, err + } + + rollingLogger, err := newZapLogger(fullPath, userLogKey) + if err != nil { + logger2.GetLogger().Errorf("failed to new zapLogger, %s", err.Error()) + return nil, err + } + + rollingLogger = createLTSLogger(rollingLogger, c, l) + + return NewFunctionLogger(rollingLogger, cfg.RuntimeConfig.FuncLogLevel, fullPath), nil +} + +func createLTSLogger(rollingLogger *zap.Logger, c *urnutils.ComplexFuncName, l *logTankService) *zap.Logger { + return rollingLogger.With(zapcore.Field{ + Key: "projectId", + Type: zapcore.StringType, + String: urnutils.LocalFuncURN.TenantID, + }, zapcore.Field{ + Key: "podName", + Type: zapcore.StringType, + String: os.Getenv("POD_NAME"), + }, zapcore.Field{ + Key: "package", + Type: zapcore.StringType, + String: c.ServiceID, + }, zapcore.Field{ + Key: "function", + Type: zapcore.StringType, + String: c.FuncName, + }, zapcore.Field{ + Key: "version", + Type: zapcore.StringType, + String: urnutils.LocalFuncURN.Version, + }, zapcore.Field{ + Key: "stream", + Type: zapcore.StringType, + String: "stdout", + }, zapcore.Field{ + Key: "instanceId", + Type: zapcore.StringType, + String: os.Getenv("POD_ID"), + }, zapcore.Field{ + Key: "newLogGroupId", + Type: zapcore.StringType, + String: l.alternativeGroupID, + }, zapcore.Field{ + Key: "newLogStreamId", + Type: zapcore.StringType, + String: l.alternativeStreamID, + }) +} + +func (l *logTankService) extractLogTankService(groupID, streamID string) { + if strings.Contains(groupID, funcLogSeparator) && strings.Contains(streamID, funcLogSeparator) { + splitGroupID := strings.Split(groupID, funcLogSeparator) + splitStreamID := strings.Split(streamID, funcLogSeparator) + if len(splitGroupID) != logTankServiceLength || len(splitStreamID) != logTankServiceLength { + l.originalGroupID = groupID + l.originalStreamID = streamID + return + } + l.originalGroupID = splitGroupID[originalLocation] + l.alternativeGroupID = splitGroupID[alternativeLocation] + l.originalStreamID = splitStreamID[originalLocation] + l.alternativeStreamID = splitStreamID[alternativeLocation] + return + } + l.originalGroupID = groupID + l.originalStreamID = streamID +} + +func createFunctionInfoFile(complexFuncName *urnutils.ComplexFuncName) error { + coreInfo, err := runtimeLoggerConfig.GetCoreInfoFromEnv() + if err != nil { + logger2.GetLogger().Errorf("failed to get core info: %s", err.Error()) + coreInfo = runtimeLoggerConfig.GetDefaultCoreInfo() + } + // urn:fss:projectID:function:package:function-name:version.info + infoFileName := strings.Join([]string{"urn:fss", urnutils.LocalFuncURN.TenantID, + "function", complexFuncName.ServiceID, complexFuncName.FuncName, + urnutils.LocalFuncURN.Version + ".info"}, ":") + f, err := os.OpenFile(filepath.Join(coreInfo.FilePath, infoFileName), os.O_RDONLY|os.O_CREATE, fileMode) + if err != nil { + return err + } + defer f.Close() + + return nil +} + +func getAbsFilePath(path, fileName string) (string, error) { + logPath, err := filepath.Abs(path) + if err != nil { + return "", err + } + fullPath := filepath.Join(logPath, fileName+".log") + return fullPath, nil +} diff --git a/api/go/faassdk/common/functionlog/function_log_test.go b/api/go/faassdk/common/functionlog/function_log_test.go new file mode 100644 index 0000000..29422c6 --- /dev/null +++ b/api/go/faassdk/common/functionlog/function_log_test.go @@ -0,0 +1,209 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package functionlog gets function services from URNs +package functionlog + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" +) + +func TestGetFunctionLogger(t *testing.T) { + _, err := GetFunctionLogger(&config.Configuration{StartArgs: config.StartArgs{ + LogDir: "/home/sn/logs", + }}) + assert.Nil(t, err) +} + +func TestRefreshFileModTime(t *testing.T) { + convey.Convey("Test RefreshFileModTime", t, func() { + convey.Convey("Test 1", func() { + f := &FunctionLogger{ + logAbsPath: "patch", + } + ch := make(chan struct{}) + close(ch) + f.RefreshFileModTime(ch) + + p := gomonkey.NewPatches() + p.ApplyFunc(time.NewTicker, func(d time.Duration) *time.Ticker { + ch := make(chan time.Time, 1) + ch <- time.Time{} + return &time.Ticker{C: ch} + }) + defer p.Reset() + f.RefreshFileModTime(ch) + f.refreshWithRetry() + }) + }) + convey.Convey("Test RefreshFileModTime err", t, func() { + convey.Convey("Test 2", func() { + var stopCh chan struct{} + f := &FunctionLogger{} + f.RefreshFileModTime(stopCh) + }) + }) +} + +func TestExtractLogTankService(t *testing.T) { + l := &logTankService{} + l.extractLogTankService("groupID;", "streamID;") + l.extractLogTankService("ID", "ID") + convey.Convey("Test extractLogTankService", t, func() { + convey.Convey("extractLogTankService success", func() { + convey.So(func() { + l.extractLogTankService("groupID1;groupID2;groupID3", "streamID;") + }, convey.ShouldNotPanic) + convey.So(func() { + l.extractLogTankService("groupID;", "streamID;") + }, convey.ShouldNotPanic) + convey.So(func() { + l.extractLogTankService("ID", "ID") + }, convey.ShouldNotPanic) + }) + }) +} + +func Test_refreshWithRetry(t *testing.T) { + convey.Convey("Test refreshWithRetry", t, func() { + convey.Convey("Test 1", func() { + var testMsg string + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(os.Chtimes, func(name string, atime time.Time, mtime time.Time) error { + testMsg = "os Chtimes return nil" + return nil + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + f := &FunctionLogger{} + f.refreshWithRetry() + convey.So(testMsg, convey.ShouldEqual, "os Chtimes return nil") + }) + }) +} + +func TestSync(t *testing.T) { + Sync(nil) +} + +func TestGetDefaultInternalFunctionLog(t *testing.T) { + convey.Convey("Test getDefaultInternalFunctionLog", t, func() { + convey.Convey("getDefaultInternalFunctionLog success", func() { + fl := &FunctionLog{} + ifl := getDefaultInternalFunctionLog(fl) + convey.So(ifl, convey.ShouldNotBeNil) + }) + }) +} + +func TestNewLTSFunctionLogger(t *testing.T) { + convey.Convey( + "Test newLTSFunctionLogger", t, func() { + convey.Convey( + "newLTSFunctionLogger success", func() { + cfg := &config.Configuration{StartArgs: config.StartArgs{ + LogDir: "", + }} + fl, err := newLTSFunctionLogger(cfg) + convey.So(fl, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestWriteStdLog(t *testing.T) { + convey.Convey( + "Test WriteStdLog", t, func() { + convey.Convey( + "WriteStdLog success", func() { + convey.So(func() { + functionLogger.WriteStdLog("logString", "timeStamp", false, "nanoTimestamp") + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestWriteDefaultLog(t *testing.T) { + convey.Convey( + "Test WriteDefaultLog", t, func() { + convey.Convey( + "WriteDefaultLog success", func() { + convey.So(func() { + fl := functionLogger.AcquireLog() + functionLogger.WriteDefaultLog(fl) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestNewZapLogger(t *testing.T) { + convey.Convey( + "Test newZapLogger", t, func() { + convey.Convey( + "newZapLogger success", func() { + defer cleanFile("fullPath") + zl, err := newZapLogger("fullPath", "messageKey") + convey.So(zl, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestCreateFunctionInfoFile(t *testing.T) { + convey.Convey( + "Test createFunctionInfoFile", t, func() { + convey.Convey( + "createFunctionInfoFile success", func() { + defer cleanFile("fullPath") + err := createFunctionInfoFile(&urnutils.ComplexFuncName{}) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func cleanFile(fileName string) { + files, _ := os.ReadDir("./") + for _, file := range files { + flag := strings.HasPrefix(file.Name(), fileName) + if flag { + os.Remove(file.Name()) + } + } +} diff --git a/api/go/faassdk/common/functionlog/log_fields.go b/api/go/faassdk/common/functionlog/log_fields.go new file mode 100644 index 0000000..96afd6d --- /dev/null +++ b/api/go/faassdk/common/functionlog/log_fields.go @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functionlog + +import ( + "strconv" + + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + traceIDPos = iota + timePos + errorTypePos + stagePos + statusPos + finishPos + livedataTraceIDPos + endFields +) + +type logFileds []zapcore.Field + +func newLogFields() logFileds { + return []zapcore.Field{ + { + Key: "requestId", + Type: zapcore.StringType, + }, + { + Key: "time", + Type: zapcore.StringType, + }, + { + Key: "errorType", + Type: zapcore.StringType, + }, + { + Key: "stage", + Type: zapcore.StringType, + }, + { + Key: "status", + Type: zapcore.StringType, + }, + { + Key: "finishLog", + Type: zapcore.StringType, + }, + { + Key: "livedataTraceId", + Type: zapcore.StringType, + }, + } +} + +func (f logFileds) set(funcLog *internalFunctionLog) { + if len(f) < endFields { + logger.GetLogger().Warnf("invalid logFields") + return + } + + f[traceIDPos].String = funcLog.TraceID + f[timePos].String = funcLog.NanoTime + f[errorTypePos].String = funcLog.ErrorType + f[stagePos].String = funcLog.Stage + + statusStr := "success" + if funcLog.ErrorType != "" { + statusStr = "fail" + } + f[statusPos].String = statusStr + f[finishPos].String = strconv.FormatBool(funcLog.IsFinishedLog) + f[livedataTraceIDPos].String = funcLog.LivedataID +} diff --git a/api/go/faassdk/common/functionlog/log_fields_test.go b/api/go/faassdk/common/functionlog/log_fields_test.go new file mode 100644 index 0000000..c57f823 --- /dev/null +++ b/api/go/faassdk/common/functionlog/log_fields_test.go @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package functionlog gets function services from URNs +package functionlog + +import ( + "testing" + + "go.uber.org/zap/zapcore" + + "github.com/smartystreets/goconvey/convey" +) + +func TestSet(t *testing.T) { + convey.Convey("Test set", t, func() { + convey.Convey("set success", func() { + var f logFileds = []zapcore.Field{} + fl := &internalFunctionLog{} + convey.So(func() { + f.set(fl) + }, convey.ShouldNotPanic) + }) + }) +} diff --git a/api/go/faassdk/common/functionlog/log_recoder.go b/api/go/faassdk/common/functionlog/log_recoder.go new file mode 100644 index 0000000..c5dc4a0 --- /dev/null +++ b/api/go/faassdk/common/functionlog/log_recoder.go @@ -0,0 +1,497 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functionlog + +import ( + "bytes" + "container/list" + "encoding/base64" + "sync" + "time" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" +) + +const ( + // user log will be deleted by cffagent for follow reasons: + // 1.irregular log file name + // 2.unchanged for 12 hours when log is not empty + // 3.unchanged for 6 hours when log is empty + timeModInterval = 5 * time.Hour + timeModRetry = 3 +) + +// LogHeaderFooterGenerator - +type LogHeaderFooterGenerator interface { + GenerateLogHeader(*FunctionLog) + GenerateLogFooter(*FunctionLog) +} + +// FixSizeRecorder tries to maintain a number of logs with the total size just over limitSize. When we write some new +// logs to the recorder, old logs will be dropped. +type FixSizeRecorder struct { + logs *list.List + guessSize func(funcLog *FunctionLog) int + dropped func(funcLog *FunctionLog, logTooLarge bool) + size int + limitSize int +} + +// NewFixSizeRecorder - +func NewFixSizeRecorder( + limitSize int, + guessSize func(funcLog *FunctionLog) int, + dropped func(funcLog *FunctionLog, logTooLarge bool)) *FixSizeRecorder { + return &FixSizeRecorder{ + guessSize: guessSize, + dropped: dropped, + limitSize: limitSize, + } +} + +// Reset resets the recorder to its initial state +func (r *FixSizeRecorder) Reset() { + if r.logs != nil { + r.logs.Init() + } + r.size = 0 +} + +// Write - +func (r *FixSizeRecorder) Write(funcLog *FunctionLog) { + if r.logs == nil { + r.logs = list.New() + } + + if r.size > r.limitSize { + for { + elem := r.logs.Front() + if elem == nil { + break + } + funcLog, ok := elem.Value.(*FunctionLog) + if !ok { + break + } + + size := r.guessSize(funcLog) + if r.size-size < r.limitSize { + break + } + + r.size -= size + r.dropped(funcLog, false) + r.logs.Remove(elem) + } + } + + size := r.guessSize(funcLog) + if size > r.limitSize { // special case: the funcLog is too large + r.Range(func(fl *FunctionLog) { + r.dropped(fl, false) + }) + r.logs = nil + r.size = 0 + r.dropped(funcLog, true) + } else { + r.logs.PushBack(funcLog) + r.size += size + } +} + +// RealTimeWrite - +func (r *FixSizeRecorder) RealTimeWrite(funcLog *FunctionLog) { + if r.logs == nil { + r.logs = list.New() + } + + if r.size > r.limitSize { + for { + elem := r.logs.Front() + if elem == nil { + break + } + funcLog, ok := elem.Value.(*FunctionLog) + if !ok { + break + } + + size := r.guessSize(funcLog) + if r.size-size < r.limitSize { + break + } + r.size -= size + r.logs.Remove(elem) + } + } + size := r.guessSize(funcLog) + if size > r.limitSize { + r.dropped(funcLog, true) + return + } + r.logs.PushBack(funcLog.DeepCopy()) + r.size += size + r.dropped(funcLog, false) +} + +// Range - +func (r *FixSizeRecorder) Range(fn func(*FunctionLog)) { + if r.logs == nil { + return + } + + for e := r.logs.Front(); e != nil; e = e.Next() { + funcLog, ok := e.Value.(*FunctionLog) + if !ok { + continue + } + + fn(funcLog) + } +} + +// WriteLimit is to save limitSize log +func (r *FixSizeRecorder) WriteLimit(funcLog *FunctionLog, limitSize int) { + if r.logs == nil { + r.logs = list.New() + } + + if r.size > limitSize { + for { + elem := r.logs.Front() + if elem == nil { + break + } + funcLog, ok := elem.Value.(*FunctionLog) + if !ok { + break + } + + size := r.guessSize(funcLog) + if r.size-size < limitSize { + break + } + r.size -= size + r.logs.Remove(elem) + } + } + + size := r.guessSize(funcLog) + r.logs.PushBack(funcLog) + r.size += size +} + +// LogRecorder record user log of function +type LogRecorder struct { + f *FunctionLogger + header *FunctionLog + footer *FunctionLog + logs *FixSizeRecorder + logsMux sync.Mutex + generator LogHeaderFooterGenerator + idx int + base []byte + syncCh chan struct{} + + invokeID string + traceID string + stage string + logLevel string + livedataID string + logOption string + separatedWithLineBreak bool +} + +type internalFunctionLog struct { + *FunctionLog + FunctionName string + TraceID string + LivedataID string + Stage string + Index int +} + +// FunctionLog is format of function logs which are flushed +type FunctionLog struct { + Message *bytes.Buffer + Time string + NanoTime string + Level string + ErrorType string + IsFinishedLog bool + IsStdLog bool +} + +// NewFunctionLog - +func NewFunctionLog() *FunctionLog { + return &FunctionLog{Message: new(bytes.Buffer)} +} + +// Reset - +func (l *FunctionLog) Reset() { + l.Time = "" + l.Level = "" + l.Message.Reset() + l.ErrorType = "" + l.IsFinishedLog = false + l.IsStdLog = false +} + +// DeepCopy - +func (l *FunctionLog) DeepCopy() *FunctionLog { + x := *l // copy + oldBuf := l.Message.Bytes() + newBuf := make([]byte, len(oldBuf)) + copy(newBuf, oldBuf) + x.Message = bytes.NewBuffer(newBuf) + return &x +} + +// NewLogRecorder - +func NewLogRecorder() *LogRecorder { + r := &LogRecorder{} + r.logs = NewFixSizeRecorder(maxTailLogSizeKB, r.guessSize, r.handleDropped) + return r +} + +func (r *LogRecorder) handleDropped(funcLog *FunctionLog, logTooLarge bool) { + if r.idx == 0 { + if header := r.getHeader(); header != nil { + // Header can still be used in 'MarshalAll'. Pass in a copy to prevent releasing the old header + // to the sync.Pool. + r.f.write(r.generateInternalFunctionLog(header.DeepCopy(), 0)) + } + r.idx++ + } + + // when new drop happens, we no longer need the previous base + if r.base != nil { + r.base = nil + } + + if logTooLarge { + res := r.formatLog(funcLog) + if len(res) > maxTailLogSizeKB { + res = res[len(res)-maxTailLogSizeKB:] + } + r.base = make([]byte, len(res)) + copy(r.base, res) + } + + r.f.write(r.generateInternalFunctionLog(funcLog, r.idx)) + r.idx++ +} + +// guessSize returns the guessed size of bytes when calling 'formatLog' for the same 'functionLog'. The guessed size +// may not be accurate. +func (r *LogRecorder) guessSize(functionLog *FunctionLog) int { + if functionLog.IsStdLog { + return functionLog.Message.Len() + } + // 1 is the ' ' in between + return len(functionLog.Time) + 1 + len(r.traceID) + 1 + len(r.logLevel) + 1 + functionLog.Message.Len() +} + +func (r *LogRecorder) formatLog(functionLog *FunctionLog) []byte { + // user std log without additional information + if functionLog.IsStdLog { + return functionLog.Message.Bytes() + } + level := functionLog.Level + if level == "WARNING" { + level = constants.FuncLogLevelWarn + } + + var b bytes.Buffer + b.Grow(len(functionLog.Time) + 1 + len(r.traceID) + 1 + len(level) + 1 + functionLog.Message.Len()) + b.WriteString(functionLog.Time) + b.WriteString(" ") + b.WriteString(r.traceID) + b.WriteString(" ") + b.WriteString(level) + b.WriteString(" ") + b.Write(functionLog.Message.Bytes()) + + return b.Bytes() +} + +// StartSync - +func (r *LogRecorder) StartSync() { + r.syncCh = make(chan struct{}) +} + +// FinishSync - +func (r *LogRecorder) FinishSync() { + select { + case <-r.syncCh: // If chan is closed, the case is executed immediately. + return + default: // If chan is not closed, the case is executed. + close(r.syncCh) + } +} + +// MarshalAll - +func (r *LogRecorder) MarshalAll() string { + if r.syncCh != nil { + <-r.syncCh + } + + r.logsMux.Lock() + + var buffer bytes.Buffer + if r.base != nil { + buffer.Write(r.base) + if r.separatedWithLineBreak { + buffer.WriteString("\n") + } + } + r.logs.Range(func(piece *FunctionLog) { + buffer.Write(r.formatLog(piece)) + if r.separatedWithLineBreak { + buffer.WriteString("\n") + } + }) + + res := buffer.String() + if len(res) > maxTailLogSizeKB { + res = res[len(res)-maxTailLogSizeKB:] + } + + buffer.Reset() + if header := r.getHeader(); header != nil { + buffer.Write(header.Message.Bytes()) + } + buffer.WriteString("\n") + buffer.WriteString(res) + + if footer := r.getFooter(); footer != nil { + buffer.Write(footer.Message.Bytes()) + } + + encoded := base64.StdEncoding.EncodeToString(buffer.Bytes()) + r.logsMux.Unlock() + return encoded +} + +// Finish finalizes the logRecorder. This MUST be called. +func (r *LogRecorder) Finish() { + if r.syncCh != nil { + <-r.syncCh + } + + r.logsMux.Lock() + + if r.idx == 0 { + if header := r.getHeader(); header != nil { + r.f.write(r.generateInternalFunctionLog(header, 0)) + } + r.idx++ + } + r.logs.Range(func(piece *FunctionLog) { + r.f.write(r.generateInternalFunctionLog(piece, r.idx)) + r.idx++ + }) + if footer := r.getFooter(); footer != nil { + r.f.write(r.generateInternalFunctionLog(footer, -1)) + } + + r.f.deleteLogRecorder(r.invokeID) + + r.logsMux.Unlock() +} + +func (r *LogRecorder) generateInternalFunctionLog(functionLog *FunctionLog, idx int) *internalFunctionLog { + return &internalFunctionLog{ + FunctionLog: functionLog, + FunctionName: urnutils.LocalFuncURN.Name, + TraceID: r.traceID, + LivedataID: r.livedataID, + Stage: r.stage, + Index: idx, + } +} + +func (r *LogRecorder) getHeader() *FunctionLog { + if r.header != nil { + return r.header + } + + if r.generator != nil { + r.header = r.f.AcquireLog() + r.generator.GenerateLogHeader(r.header) + } + return r.header +} + +func (r *LogRecorder) getFooter() *FunctionLog { + if r.footer != nil { + return r.footer + } + + if r.generator != nil { + r.footer = r.f.AcquireLog() + r.generator.GenerateLogFooter(r.footer) + } + return r.footer +} + +// Write log to logRecorder. This can be called from a different goroutine +func (r *LogRecorder) Write(funcLog *FunctionLog) { + r.logsMux.Lock() + r.logs.Write(funcLog) + r.logsMux.Unlock() +} + +// WriteLimit write limited log to logRecorder. not reserved to disk +func (r *LogRecorder) WriteLimit(funcLog *FunctionLog) { + r.logsMux.Lock() + r.logs.WriteLimit(funcLog, maxInitLogSizeMB) + r.logsMux.Unlock() +} + +// RealTimeWrite log to logRecorder. This can be called from a different goroutine +func (r *LogRecorder) RealTimeWrite(funcLog *FunctionLog) { + r.logsMux.Lock() + r.logs.RealTimeWrite(funcLog) + r.logsMux.Unlock() +} + +// LogOption - +func (r *LogRecorder) LogOption() string { + return r.logOption +} + +// InvokeID - +func (r *LogRecorder) InvokeID() string { + return r.invokeID +} + +// TraceID - +func (r *LogRecorder) TraceID() string { + return r.traceID +} + +// Stage - +func (r *LogRecorder) Stage() string { + return r.stage +} + +// Generator - +func (r *LogRecorder) Generator() LogHeaderFooterGenerator { + return r.generator +} diff --git a/api/go/faassdk/common/functionlog/log_recoder_test.go b/api/go/faassdk/common/functionlog/log_recoder_test.go new file mode 100644 index 0000000..e486d61 --- /dev/null +++ b/api/go/faassdk/common/functionlog/log_recoder_test.go @@ -0,0 +1,585 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functionlog + +import ( + "bytes" + "container/list" + "io" + "os" + "strings" + "sync" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/config" + runtimeLogger "yuanrong.org/kernel/runtime/libruntime/common/logger" + runtimeLoggerConfig "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +const logDir = "/home/sn/log" + +func TestMain(m *testing.M) { + p := gomonkey.NewPatches() + p.ApplyFunc(runtimeLoggerConfig.GetCoreInfoFromEnv, func() (runtimeLoggerConfig.CoreInfo, error) { return runtimeLoggerConfig.GetDefaultCoreInfo(), nil }) + p.ApplyFunc(os.OpenFile, func(name string, flag int, perm os.FileMode) (*os.File, error) { + return &os.File{}, nil + }) + p.ApplyFunc(runtimeLogger.CreateSink, func(coreInfo runtimeLoggerConfig.CoreInfo) (io.Writer, error) { + return io.Discard, nil + }) + + GetFunctionLogger(&config.Configuration{StartArgs: config.StartArgs{ + LogDir: logDir, + }}) + + p.Reset() + m.Run() + createLoggerOnce = sync.Once{} + os.RemoveAll(logDir) +} + +func TestLogRecorder_Fusion(t *testing.T) { + functionLogger, err := GetFunctionLogger(&config.Configuration{ + StartArgs: config.StartArgs{LogDir: logDir}, + }) + if err != nil { + t.Fatalf("failed to get function functionLogger, capability:fusion, %s", err.Error()) + } + functionLogger.SetLogLevel(constants.FuncLogLevelWarn) + invokeID := "1" + traceID := "traceID" + logRecorder := functionLogger.NewLogRecorder(invokeID, traceID, constants.InvokeStage, WithoutLineBreak(), WithLogOptionTail()) + if functionLogger.GetLogRecorder(invokeID) == nil { + t.Error("failed to get log recorder by invokeID") + } + logRecorder.Write(&FunctionLog{ + Level: "INFO", + Message: bytes.NewBufferString("message"), + }) + + functionLogger.WriteStdLog("std logs", "my time", false, "time") + logRecorder.logs.Reset() + logRecorder.generator = &fakeGenerator{} + logRecorder.getFooter() + logRecorder.header = &FunctionLog{} + logRecorder.footer = &FunctionLog{} + logRecorder.Finish() +} + +func TestGetLTSLoggerMessage(m *testing.T) { + functionLog := &internalFunctionLog{FunctionLog: &FunctionLog{}} + functionLog.Index = -1 + functionLog.Time = "" + functionLog.TraceID = "" + functionLog.Message = new(bytes.Buffer) + msg := getLTSLoggerMessage(functionLog) + assert.Equal(m, len(msg), 0) + + functionLog = &internalFunctionLog{FunctionLog: &FunctionLog{}} + functionLog.Index = 1 + functionLog.Time = "test time" + functionLog.TraceID = "test 1" + functionLog.Message = bytes.NewBufferString("time test message") + msg = getLTSLoggerMessage(functionLog) + assert.NotEqual(m, len(msg), 0) + + functionLog = &internalFunctionLog{FunctionLog: &FunctionLog{}} + functionLog.Index = constants.DefaultFuncLogIndex + functionLog.Time = "test time" + functionLog.TraceID = "test 2" + functionLog.Message = bytes.NewBufferString("time test message") + str := functionLog.Time + " " + "time test message" + msg = getLTSLoggerMessage(functionLog) + assert.Equal(m, str, msg) +} + +func TestFormatLog(m *testing.T) { + r := &LogRecorder{} + fl := &FunctionLog{ + Time: "time", + Level: "WARNING", + Message: new(bytes.Buffer), + } + t := r.formatLog(fl) + assert.Equal(m, "time WARN ", string(t)) + fl.IsStdLog = true + t = r.formatLog(fl) + assert.Equal(m, "", string(t)) +} + +func TestMarshalAll(t *testing.T) { + convey.Convey("Test MarshalAll", t, func() { + convey.Convey("Test no log", func() { + r := &LogRecorder{ + f: &FunctionLogger{}, + separatedWithLineBreak: true, + } + + limitSize := maxTailLogSizeKB + droppedNum := 0 + r.logs = NewFixSizeRecorder( + limitSize, + r.guessSize, + func(funcLog *FunctionLog, logTooLarge bool) { + droppedNum++ + r.handleDropped(funcLog, logTooLarge) + }, + ) + + r.MarshalAll() + + var buffer bytes.Buffer + r.logs.Range(func(piece *FunctionLog) { + buffer.Write(r.formatLog(piece)) + if r.separatedWithLineBreak { + buffer.WriteString("\n") + } + }) + + assert.Equal(t, 0, droppedNum) + assert.Equal(t, 0, len(buffer.Bytes())) + assert.Equal(t, 0, r.idx) + + r.Finish() + + assert.Equal(t, 1, r.idx) + }) + convey.Convey("Test too many logs", func() { + r := &LogRecorder{ + f: &FunctionLogger{}, + separatedWithLineBreak: true, + } + + limitSize := maxTailLogSizeKB + droppedNum := 0 + r.logs = NewFixSizeRecorder( + limitSize, + r.guessSize, + func(funcLog *FunctionLog, logTooLarge bool) { + droppedNum++ + r.handleDropped(funcLog, logTooLarge) + }, + ) + + writeNum := 150 + for idx := 0; idx < writeNum; idx++ { + r.Write(&FunctionLog{ + Message: bytes.NewBufferString("Function log test"), + }) + } + r.MarshalAll() + + var buffer bytes.Buffer + r.logs.Range(func(piece *FunctionLog) { + buffer.Write(r.formatLog(piece)) + if r.separatedWithLineBreak { + buffer.WriteString("\n") + } + }) + + assert.True(t, droppedNum > 0) + assert.True(t, len(buffer.Bytes()) > limitSize) + assert.Equal(t, droppedNum+1, r.idx) + + r.Finish() + + assert.Equal(t, writeNum+1, r.idx) + }) + convey.Convey("Test logs", func() { + r := &LogRecorder{ + f: &FunctionLogger{}, + separatedWithLineBreak: true, + } + + limitSize := maxTailLogSizeKB + droppedNum := 0 + r.logs = NewFixSizeRecorder( + limitSize, + r.guessSize, + func(funcLog *FunctionLog, logTooLarge bool) { + droppedNum++ + r.handleDropped(funcLog, logTooLarge) + }, + ) + + writeNum := 20 + for idx := 0; idx < writeNum; idx++ { + r.Write(&FunctionLog{ + Message: bytes.NewBufferString("Function log test"), + }) + } + r.MarshalAll() + + var buffer bytes.Buffer + r.logs.Range(func(piece *FunctionLog) { + buffer.Write(r.formatLog(piece)) + if r.separatedWithLineBreak { + buffer.WriteString("\n") + } + }) + + assert.Equal(t, 0, droppedNum) + assert.True(t, len(buffer.Bytes()) < limitSize) + assert.Equal(t, 0, r.idx) + + r.Finish() + + assert.Equal(t, writeNum+1, r.idx) + }) + }) +} + +func TestMarshalAll2(t *testing.T) { + convey.Convey("Test logTooLarge", t, func() { + r := &LogRecorder{ + f: &FunctionLogger{}, + separatedWithLineBreak: true, + } + + limitSize := maxTailLogSizeKB + droppedNum := 0 + r.logs = NewFixSizeRecorder( + limitSize, + r.guessSize, + func(funcLog *FunctionLog, logTooLarge bool) { + droppedNum++ + r.handleDropped(funcLog, logTooLarge) + }, + ) + + msg := strings.Repeat("a", maxTailLogSizeKB+1) + r.Write(&FunctionLog{ + Message: bytes.NewBufferString(msg), + }) + r.Write(&FunctionLog{ + Message: bytes.NewBufferString("Function log test"), + }) + r.MarshalAll() + + assert.NotNil(t, r.base) + + var buffer bytes.Buffer + r.logs.Range(func(piece *FunctionLog) { + buffer.Write(r.formatLog(piece)) + if r.separatedWithLineBreak { + buffer.WriteString("\n") + } + }) + + assert.Equal(t, 1, droppedNum) + assert.True(t, len(buffer.Bytes()) != 0) + assert.True(t, len(buffer.Bytes()) < limitSize) + assert.Equal(t, 2, r.idx) + + r.Finish() + + assert.Equal(t, 3, r.idx) + }) + convey.Convey("Test logTooLarge then drop", t, func() { + r := &LogRecorder{ + f: &FunctionLogger{}, + separatedWithLineBreak: true, + } + + limitSize := maxTailLogSizeKB + droppedNum := 0 + r.logs = NewFixSizeRecorder( + limitSize, + r.guessSize, + func(funcLog *FunctionLog, logTooLarge bool) { + droppedNum++ + r.handleDropped(funcLog, logTooLarge) + }, + ) + + msg := strings.Repeat("a", maxTailLogSizeKB+1) + r.Write(&FunctionLog{ + Message: bytes.NewBufferString(msg), + }) + + writeNum := 150 + for idx := 0; idx < writeNum; idx++ { + r.Write(&FunctionLog{ + Message: bytes.NewBufferString("Function log test"), + }) + } + r.MarshalAll() + + assert.Nil(t, r.base) + + var buffer bytes.Buffer + r.logs.Range(func(piece *FunctionLog) { + buffer.Write(r.formatLog(piece)) + if r.separatedWithLineBreak { + buffer.WriteString("\n") + } + }) + + assert.True(t, droppedNum > 0) + assert.True(t, len(buffer.Bytes()) > limitSize) + assert.Equal(t, droppedNum+1, r.idx) + + r.Finish() + + assert.Equal(t, writeNum+2, r.idx) + + r.Finish() + }) + convey.Convey("Test drop then logTooLarge", t, func() { + r := &LogRecorder{ + f: &FunctionLogger{}, + separatedWithLineBreak: true, + } + + limitSize := maxTailLogSizeKB + droppedNum := 0 + r.logs = NewFixSizeRecorder( + limitSize, + r.guessSize, + func(funcLog *FunctionLog, logTooLarge bool) { + droppedNum++ + r.handleDropped(funcLog, logTooLarge) + }, + ) + + writeNum := 150 + for idx := 0; idx < writeNum; idx++ { + r.Write(&FunctionLog{ + Message: bytes.NewBufferString("Function log test"), + }) + } + msg := strings.Repeat("a", maxTailLogSizeKB+1) + r.Write(&FunctionLog{ + Message: bytes.NewBufferString(msg), + }) + r.MarshalAll() + + assert.NotNil(t, r.base) + + var buffer bytes.Buffer + r.logs.Range(func(piece *FunctionLog) { + buffer.Write(r.formatLog(piece)) + if r.separatedWithLineBreak { + buffer.WriteString("\n") + } + }) + + assert.True(t, droppedNum > 0) + assert.Equal(t, 0, len(buffer.Bytes())) + assert.Equal(t, droppedNum+1, r.idx) + + r.Finish() + + assert.Equal(t, droppedNum+1, r.idx) + }) +} + +type fakeGenerator struct { + header *FunctionLog +} + +func (f *fakeGenerator) GenerateLogHeader(header *FunctionLog) { f.header = header } +func (f *fakeGenerator) GenerateLogFooter(*FunctionLog) {} + +func TestLogPool(t *testing.T) { + generator := &fakeGenerator{} + r := &LogRecorder{ + f: NewFunctionLogger(zap.NewNop(), "", ""), + separatedWithLineBreak: true, + generator: generator, + } + r.handleDropped(r.f.AcquireLog(), false) + assert.True(t, r.header == generator.header) + assert.True(t, r.f.AcquireLog() != r.header) // header should not be released +} + +func TestFunctionLogDeepCopy(t *testing.T) { + l := NewFunctionLog() + l.Message.WriteString("abc") + l.ErrorType = "xxx" + l.IsFinishedLog = true + l.IsStdLog = true + l.Level = "info" + + x := l.DeepCopy() + assert.True(t, l != x) + assert.Equal(t, l, x) +} + +func TestFixSizeRecorder_WriteLimit(t *testing.T) { + r := &LogRecorder{} + case2Logs := list.New() + case2Logs.PushBack(&FunctionLog{ + Level: "INFO", + Message: bytes.NewBufferString("abc"), + }) + type fields struct { + logs *list.List + guessSize func(funcLog *FunctionLog) int + dropped func(funcLog *FunctionLog, logTooLarge bool) + size int + limitSize int + } + type args struct { + funcLog *FunctionLog + limitSize int + } + tests := []struct { + name string + fields fields + args args + }{ + {"case1 write less than limitSize", fields{logs: nil, guessSize: r.guessSize, limitSize: 20}, + args{funcLog: &FunctionLog{ + Level: "INFO", + Message: bytes.NewBufferString("abc"), + }, limitSize: 10}}, + {"case2 write more than limitSize", fields{logs: case2Logs, guessSize: r.guessSize, size: 20, + limitSize: 20}, args{funcLog: &FunctionLog{ + Level: "INFO", + Message: bytes.NewBufferString("abc"), + }, limitSize: 10}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &FixSizeRecorder{ + logs: tt.fields.logs, + guessSize: tt.fields.guessSize, + dropped: tt.fields.dropped, + size: tt.fields.size, + limitSize: tt.fields.limitSize, + } + r.WriteLimit(tt.args.funcLog, tt.args.limitSize) + if r.size > 20 { + t.Errorf("write limit error, size is: %d", r.size) + } + }) + + } +} + +func TestLogRecorder_WriteLimit(t *testing.T) { + functionLog := &FunctionLog{ + Level: "INFO", + Message: bytes.NewBufferString(strings.Repeat("a", 90*1024)), + } + r := NewLogRecorder() + for i := 0; i < 256; i++ { + r.WriteLimit(functionLog) + } + if r.logs.size > maxInitLogSizeMB+90*1024*2 { + t.Errorf("write limit error, size is: %d", r.logs.size) + } +} + +func TestReset(t *testing.T) { + convey.Convey("Test Reset", t, func() { + convey.Convey("Reset success", func() { + r := &FixSizeRecorder{} + convey.So(r.Reset, convey.ShouldNotPanic) + }) + }) +} + +func TestRealTimeWrite(t *testing.T) { + convey.Convey("Test RealTimeWrite", t, func() { + convey.Convey("RealTimeWrite success", func() { + functionLog := &FunctionLog{ + Level: "INFO", + Message: bytes.NewBufferString("a"), + } + r := NewLogRecorder() + r.logs.size = 2049 + convey.So(func() { + r.RealTimeWrite(functionLog) + }, convey.ShouldPanic) + }) + }) +} + +func TestStartSync(t *testing.T) { + convey.Convey( + "Test StartSync", t, func() { + convey.Convey( + "StartSync success", func() { + r := &LogRecorder{} + convey.So(r.StartSync, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestFinishSync(t *testing.T) { + convey.Convey( + "Test FinishSync", t, func() { + convey.Convey( + "FinishSync success", func() { + r := &LogRecorder{} + r.StartSync() + convey.So(r.FinishSync, convey.ShouldNotPanic) + convey.So(r.FinishSync, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestGetLogRecorderField(t *testing.T) { + convey.Convey( + "Test GetLogRecorderField", t, func() { + r := &LogRecorder{} + var str string + convey.Convey( + "LogOption success", func() { + str = r.LogOption() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "InvokeID success", func() { + str = r.InvokeID() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "TraceID success", func() { + str = r.TraceID() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "Stage success", func() { + str = r.Stage() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "Generator success", func() { + generator := r.Generator() + convey.So(generator, convey.ShouldBeEmpty) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/common/functionlog/log_recorder_option.go b/api/go/faassdk/common/functionlog/log_recorder_option.go new file mode 100644 index 0000000..0739ec8 --- /dev/null +++ b/api/go/faassdk/common/functionlog/log_recorder_option.go @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functionlog + +import "yuanrong.org/kernel/runtime/faassdk/common/constants" + +// LogRecorderOption - +type LogRecorderOption func(r *LogRecorder) + +// WithLogOption - +func WithLogOption(option string) LogRecorderOption { + return func(r *LogRecorder) { + r.logOption = option + } +} + +// WithLogOptionTail - +func WithLogOptionTail() LogRecorderOption { + return WithLogOption(constants.RuntimeLogOptTail) +} + +// WithoutLineBreak - +func WithoutLineBreak() LogRecorderOption { + return func(r *LogRecorder) { + r.separatedWithLineBreak = false + } +} + +// WithLivedataID - +func WithLivedataID(livedataID string) LogRecorderOption { + return func(r *LogRecorder) { + r.livedataID = livedataID + } +} + +// WithLogHeaderFooterGenerator - +func WithLogHeaderFooterGenerator(g LogHeaderFooterGenerator) LogRecorderOption { + return func(r *LogRecorder) { + r.generator = g + } +} diff --git a/api/go/faassdk/common/functionlog/log_recorder_option_test.go b/api/go/faassdk/common/functionlog/log_recorder_option_test.go new file mode 100644 index 0000000..f62df8e --- /dev/null +++ b/api/go/faassdk/common/functionlog/log_recorder_option_test.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functionlog + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestWithLivedataID(t *testing.T) { + convey.Convey( + "Test WithLivedataID", t, func() { + convey.Convey( + "WithLivedataID success", func() { + opt := WithLivedataID("livedataID") + convey.So(opt, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestWithLogHeaderFooterGenerator(t *testing.T) { + convey.Convey( + "Test WithLogHeaderFooterGenerator", t, func() { + convey.Convey( + "WithLogHeaderFooterGenerator success", func() { + opt := WithLogHeaderFooterGenerator(&fakeGenerator{}) + convey.So(opt, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/common/functionlog/std_log.go b/api/go/faassdk/common/functionlog/std_log.go new file mode 100644 index 0000000..15242fc --- /dev/null +++ b/api/go/faassdk/common/functionlog/std_log.go @@ -0,0 +1,85 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package functionlog create a log module +package functionlog + +import ( + "errors" + + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + defaultSplitLimit = 90*1024 - 1 +) + +// STDLogger log reader of stdout and stderr +type STDLogger struct { + // logCallback if it exists, collect log and send back + logCallback func([]byte) + splitLimit int +} + +// CreateSTDLogger new std log reader +func CreateSTDLogger() *STDLogger { + return CreateSTDLoggerWithSplitLimit(defaultSplitLimit) +} + +// CreateSTDLoggerWithSplitLimit - +func CreateSTDLoggerWithSplitLimit(splitLimit int) *STDLogger { + stdLogger := &STDLogger{ + logCallback: nil, + splitLimit: splitLimit, + } + + return stdLogger +} + +// RegisterLogCallback callback function for std log collection +func (l *STDLogger) RegisterLogCallback(cb func([]byte)) { + l.logCallback = cb +} + +// Write covered std log write method +func (l *STDLogger) Write(p []byte) (int, error) { + n := len(p) + if n == 0 { + return 0, nil + } + + if l.logCallback == nil { + logger.GetLogger().Errorf("logCallback is nil") + return 0, errors.New("logCallback is nil") + } + + if n < l.splitLimit { + l.logCallback(p) + } else { + start := 0 + end := l.splitLimit + for end < n { + l.logCallback(p[start:end]) + start = end + end += l.splitLimit + } + l.logCallback(p[start:]) + } + + logger.GetLogger().Debugf("successfully log %d bytes", n) + + return n, nil +} diff --git a/api/go/faassdk/common/functionlog/std_log_test.go b/api/go/faassdk/common/functionlog/std_log_test.go new file mode 100644 index 0000000..d8d377d --- /dev/null +++ b/api/go/faassdk/common/functionlog/std_log_test.go @@ -0,0 +1,97 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functionlog + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" +) + +type logCallback struct { + res [][]byte +} + +func (l *logCallback) callback(b []byte) { + l.res = append(l.res, b) +} + +func (l *logCallback) reset() { + l.res = nil +} + +func Test_stdLogReader(t *testing.T) { + convey.Convey("Test stdLogReader", t, func() { + convey.Convey("Test basic", func() { + logger := CreateSTDLogger() + cb := logCallback{} + logger.RegisterLogCallback(cb.callback) + log := []byte("123456789") + count, err := logger.Write(log) + assert.NoError(t, err) + assert.Equal(t, len(log), count) + assert.Equal(t, 1, len(cb.res)) + assert.Equal(t, log, cb.res[0]) + }) + convey.Convey("Test no content", func() { + logger := CreateSTDLogger() + cb := logCallback{} + logger.RegisterLogCallback(cb.callback) + count, err := logger.Write([]byte("")) + assert.NoError(t, err) + assert.Equal(t, 0, count) + assert.Equal(t, 0, len(cb.res)) + }) + convey.Convey("Test no callback", func() { + logger := CreateSTDLogger() + _, err := logger.Write([]byte("1234567\n89")) + assert.Error(t, err) + }) + convey.Convey("Test large lines", func() { + logger := CreateSTDLoggerWithSplitLimit(5) + cb := logCallback{} + logger.RegisterLogCallback(cb.callback) + log := []byte("12345") + count, err := logger.Write(log) + assert.NoError(t, err) + assert.Equal(t, len(log), count) + assert.Equal(t, 1, len(cb.res)) + assert.Equal(t, log, cb.res[0]) + cb.reset() + + log = []byte("123456") + count, err = logger.Write(log) + assert.NoError(t, err) + assert.Equal(t, len(log), count) + assert.Equal(t, 2, len(cb.res)) + assert.Equal(t, []byte("12345"), cb.res[0]) + assert.Equal(t, []byte("6"), cb.res[1]) + cb.reset() + + log = []byte("1234567890\n") + count, err = logger.Write(log) + assert.NoError(t, err) + assert.Equal(t, len(log), count) + assert.Equal(t, 3, len(cb.res)) + assert.Equal(t, []byte("12345"), cb.res[0]) + assert.Equal(t, []byte("67890"), cb.res[1]) + assert.Equal(t, []byte("\n"), cb.res[2]) + cb.reset() + }) + }) +} diff --git a/api/go/faassdk/common/functionlog/write.go b/api/go/faassdk/common/functionlog/write.go new file mode 100644 index 0000000..361a614 --- /dev/null +++ b/api/go/faassdk/common/functionlog/write.go @@ -0,0 +1,99 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functionlog + +import ( + "strings" + "time" + + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +// Sync ensure all log is written in file before faas-executor exit +func Sync(conf *config.Configuration) { + f, err := GetFunctionLogger(conf) + if err != nil { + logger.GetLogger().Warnf("failed to get functionLogger: %s", err.Error()) + return + } + + f.logger.Sync() +} + +func (f *FunctionLogger) write(funcLog *internalFunctionLog) { + if f.logger == nil { + logger.GetLogger().Errorf("invalid logger") + return + } + defer f.ReleaseLog(funcLog.FunctionLog) + + var level zapcore.Level + if err := level.UnmarshalText([]byte(funcLog.Level)); err != nil { + level = zapcore.InfoLevel + } + message := getLTSLoggerMessage(funcLog) + ent := zapcore.Entry{ + LoggerName: "", + Time: time.Now(), + Level: level, + Message: message, + } + + logFieldsIf := f.fieldsPool.Get() + defer f.fieldsPool.Put(logFieldsIf) + + logFileds, ok := logFieldsIf.(logFileds) + if !ok { + logger.GetLogger().Errorf("failed to assert logFields") + return + } + logFileds.set(funcLog) + + err := f.logger.Core().Write(ent, logFileds) + if err != nil { + logger.GetLogger().Errorf("failed to write, %s", err.Error()) + return + } +} + +func getLTSLoggerMessage(funcLog *internalFunctionLog) string { + if !funcLog.IsStdLog { + if funcLog.Index > 0 { + var b strings.Builder + b.WriteString(funcLog.Time) + b.WriteString(" ") + b.WriteString(funcLog.TraceID) + b.WriteString(" ") + b.WriteString(funcLog.Level) + b.WriteString(" ") + b.Write(funcLog.Message.Bytes()) + return b.String() + } + if funcLog.Index == constants.DefaultFuncLogIndex { + var b strings.Builder + b.WriteString(funcLog.Time) + b.WriteString(" ") + b.Write(funcLog.Message.Bytes()) + return b.String() + } + } + return funcLog.Message.String() +} diff --git a/api/go/faassdk/common/loadbalance/hash.go b/api/go/faassdk/common/loadbalance/hash.go new file mode 100644 index 0000000..7e7fe2f --- /dev/null +++ b/api/go/faassdk/common/loadbalance/hash.go @@ -0,0 +1,379 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides consistent hash alogrithm +package loadbalance + +import ( + "hash/crc32" + "sort" + "sync" + "time" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + // MaxInstanceSize is the max instance size be stored in hash ring + MaxInstanceSize = 100 + defaultMapSize = 100 +) + +type uint32Slice []uint32 + +// Len returns the size +func (u uint32Slice) Len() int { + return len(u) +} + +// Swap will swap two elements +func (u uint32Slice) Swap(i, j int) { + if i < 0 || i >= len(u) || j < 0 || j >= len(u) { + return + } + u[i], u[j] = u[j], u[i] +} + +// Less returns true if i less than j +func (u uint32Slice) Less(i, j int) bool { + if i < 0 || i >= len(u) || j < 0 || j >= len(u) { + return false + } + return u[i] < u[j] +} + +type anchorInfo struct { + instanceHash uint32 + instanceKey string +} + +// CHGeneric is the generic consistent hash +type CHGeneric struct { + anchorPoint map[string]*anchorInfo + instanceMap map[uint32]string + hashPool uint32Slice + insMutex sync.RWMutex + anchorMutex sync.Mutex +} + +// NewCHGeneric creates generic consistent hash +func NewCHGeneric() *CHGeneric { + return &CHGeneric{ + hashPool: make([]uint32, 0, MaxInstanceSize), + instanceMap: make(map[uint32]string, defaultMapSize), + anchorPoint: make(map[string]*anchorInfo, defaultMapSize), + } +} + +// Next returns the next scheduled node of a function +func (c *CHGeneric) Next(name string, move bool) interface{} { + c.anchorMutex.Lock() + anchor, exist := c.anchorPoint[name] + if !exist { + anchor = c.addAnchorPoint(name) + c.anchorMutex.Unlock() + return anchor.instanceKey + } + c.insMutex.RLock() + if move { + c.moveAnchorPoint(name, anchor.instanceHash) + } + _, exist = c.instanceMap[anchor.instanceHash] + c.insMutex.RUnlock() + // check if node still exists, no maxReqCount limitation + if !exist { + c.moveAnchorPoint(name, anchor.instanceHash) + } + c.anchorMutex.Unlock() + return anchor.instanceKey +} + +// Add will add a node into hash ring +func (c *CHGeneric) Add(node interface{}, weight int) { + c.insMutex.Lock() + defer c.insMutex.Unlock() + name, ok := node.(string) + if !ok { + logger.GetLogger().Errorf("unable to convert %T to string", node) + return + } + hashKey := getHashKeyCRC32([]byte(name)) + _, exist := c.instanceMap[hashKey] + if exist { + return + } + c.instanceMap[hashKey] = name + c.hashPool = append(c.hashPool, hashKey) + sort.Sort(c.hashPool) + logger.GetLogger().Infof("add node %s to hash ring", name) +} + +// Remove will remove a node from hash ring +func (c *CHGeneric) Remove(node interface{}) { + name, ok := node.(string) + if !ok { + logger.GetLogger().Errorf("unable to convert %T to string", node) + } + hashKey := getHashKeyCRC32([]byte(name)) + c.insMutex.Lock() + delete(c.instanceMap, hashKey) + for i, hash := range c.hashPool { + if hash == hashKey { + copy(c.hashPool[i:], c.hashPool[i+1:]) + c.hashPool[len(c.hashPool)-1] = 0 + c.hashPool = c.hashPool[:len(c.hashPool)-1] + break + } + } + logger.GetLogger().Infof("delete node %s from hash ring", name) + c.insMutex.Unlock() + return +} + +// RemoveAll will remove all nodes from hash ring +func (c *CHGeneric) RemoveAll() { + c.insMutex.Lock() + c.hashPool = make([]uint32, 0, MaxInstanceSize) + c.instanceMap = make(map[uint32]string, defaultMapSize) + c.insMutex.Unlock() + return +} + +// Reset will clean all anchor infos +func (c *CHGeneric) Reset() { + c.anchorMutex.Lock() + c.anchorPoint = make(map[string]*anchorInfo, defaultMapSize) + c.anchorMutex.Unlock() + return +} + +func (c *CHGeneric) addAnchorPoint(name string) *anchorInfo { + // need to be called in a thread safe context + hashKey := getHashKeyCRC32([]byte(name)) + c.insMutex.RLock() + instanceHash := c.getNextHashKey(hashKey) + c.insMutex.RUnlock() + newAnchor := &anchorInfo{ + instanceHash: instanceHash, + instanceKey: c.instanceMap[instanceHash], + } + c.anchorPoint[name] = newAnchor + return newAnchor +} + +func (c *CHGeneric) moveAnchorPoint(name string, curHash uint32) { + c.insMutex.RLock() + instanceHash := c.getNextHashKey(curHash) + c.anchorPoint[name].instanceHash = instanceHash + c.anchorPoint[name].instanceKey = c.instanceMap[instanceHash] + c.insMutex.RUnlock() +} + +func (c *CHGeneric) getNextHashKey(hashKey uint32) uint32 { + // need to be called with insMutex locked + if len(c.hashPool) == 0 { + return 0 + } + nextHashKey := c.hashPool[0] + for _, v := range c.hashPool { + if v > hashKey { + nextHashKey = v + break + } + } + return nextHashKey +} + +func getHashKeyCRC32(key []byte) uint32 { + return crc32.ChecksumIEEE(key) +} + +// NewConcurrentCHGeneric return ConcurrentCHGeneric with given concurrency +func NewConcurrentCHGeneric(concurrency int) *ConcurrentCHGeneric { + return &ConcurrentCHGeneric{ + CHGeneric: NewCHGeneric(), + concurrency: concurrency, + counter: make(map[string]*concurrentCounter, constants.DefaultMapSize), + } +} + +type concurrentCounter struct { + count int + last time.Time +} + +// ConcurrentCHGeneric is concurrency balanced +type ConcurrentCHGeneric struct { + *CHGeneric + counter map[string]*concurrentCounter + countMutex sync.Mutex + concurrency int +} + +// Next returns the next scheduled node +func (c *ConcurrentCHGeneric) Next(name string, move bool) interface{} { + c.countMutex.Lock() + defer c.countMutex.Unlock() + l, ok := c.counter[name] + if !ok { + c.counter[name] = &concurrentCounter{ + last: time.Now(), + } + return c.CHGeneric.Next(name, move) + } + l.count++ + if l.count >= c.concurrency { + now := time.Now() + l.count = 0 + if now.Sub(l.last) < 1*time.Second { + move = true + } + l.last = now + } + return c.CHGeneric.Next(name, move) +} + +// Add a node to hash ring +func (c *ConcurrentCHGeneric) Add(node interface{}, weight int) { + c.CHGeneric.Add(node, weight) +} + +// Remove a node from hash ring +func (c *ConcurrentCHGeneric) Remove(node interface{}) { + c.countMutex.Lock() + defer c.countMutex.Unlock() + delete(c.counter, node.(string)) + c.CHGeneric.Remove(node) +} + +// RemoveAll remove all nodes from hash ring +func (c *ConcurrentCHGeneric) RemoveAll() { + c.countMutex.Lock() + defer c.countMutex.Unlock() + c.counter = make(map[string]*concurrentCounter, constants.DefaultMapSize) + c.CHGeneric.RemoveAll() +} + +// Reset clean all anchor infos and counters +func (c *ConcurrentCHGeneric) Reset() { + c.countMutex.Lock() + defer c.countMutex.Unlock() + c.counter = make(map[string]*concurrentCounter, constants.DefaultMapSize) + c.CHGeneric.Reset() +} + +// NewLimiterCHGeneric return limiterCHGeneric with given concurrency +func NewLimiterCHGeneric(limiterTime time.Duration) *LimiterCHGeneric { + return &LimiterCHGeneric{ + CHGeneric: NewCHGeneric(), + limiterTime: limiterTime, + limiter: make(map[string]*concurrentLimiter, constants.DefaultMapSize), + } +} + +type concurrentLimiter struct { + head *limiterNode +} + +type limiterNode struct { + instanceKey interface{} + lastTime time.Time + next *limiterNode +} + +// LimiterCHGeneric is limiter balanced +type LimiterCHGeneric struct { + *CHGeneric + limiter map[string]*concurrentLimiter + nodeCount int + limiterMutex sync.Mutex + limiterTime time.Duration +} + +// Next returns the next scheduled node +func (c *LimiterCHGeneric) Next(name string, move bool) interface{} { + c.limiterMutex.Lock() + defer c.limiterMutex.Unlock() + if _, ok := c.limiter[name]; !ok { + c.limiter[name] = &concurrentLimiter{ + head: &limiterNode{}, + } + } + + moveFlag := move +label: + for exitFlag := 0; exitFlag <= c.nodeCount; exitFlag++ { + instanceKey := c.CHGeneric.Next(name, moveFlag) + h := c.limiter[name].head + n := h.next + for ; n != nil; n = n.next { + if n.instanceKey == instanceKey && !n.lastTime.IsZero() && time.Now().Sub(n.lastTime) < c.limiterTime { + moveFlag = true + continue label + } + if n.instanceKey == instanceKey && (n.lastTime.IsZero() || time.Now().Sub(n.lastTime) >= c.limiterTime) { + break + } + } + if n == nil { + h.next = &limiterNode{ + instanceKey: instanceKey, + next: h.next, + } + } + return instanceKey + } + return nil +} + +// Add a node to hash ring +func (c *LimiterCHGeneric) Add(node interface{}, weight int) { + c.nodeCount++ + c.CHGeneric.Add(node, weight) +} + +// Remove a node from hash ring +func (c *LimiterCHGeneric) Remove(node interface{}) { + c.nodeCount-- + c.CHGeneric.Remove(node) +} + +// RemoveAll remove all nodes from hash ring +func (c *LimiterCHGeneric) RemoveAll() { + c.nodeCount = 0 + c.CHGeneric.RemoveAll() +} + +// Reset clean all anchor infos and counters +func (c *LimiterCHGeneric) Reset() { + c.nodeCount = 0 + c.CHGeneric.Reset() +} + +// SetStain give the specified function, specify the node to set the stain +func (c *LimiterCHGeneric) SetStain(function string, node interface{}) { + if _, ok := c.limiter[function]; !ok { + return + } + n := c.limiter[function].head + for ; n != nil; n = n.next { + if n.instanceKey == node { + n.lastTime = time.Now() + return + } + } +} diff --git a/api/go/faassdk/common/loadbalance/hash_test.go b/api/go/faassdk/common/loadbalance/hash_test.go new file mode 100644 index 0000000..41cdd68 --- /dev/null +++ b/api/go/faassdk/common/loadbalance/hash_test.go @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides consistent hash alogrithm +package loadbalance + +import ( + "testing" + "time" + + "github.com/smartystreets/goconvey/convey" +) + +func TestUint32Slice(t *testing.T) { + convey.Convey( + "Test uint32Slice", t, func() { + var u uint32Slice = []uint32{1, 2} + convey.Convey( + "Swap success", func() { + convey.So(func() { + u.Swap(-1, 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Less success", func() { + flag := u.Less(-1, 0) + convey.So(flag, convey.ShouldBeFalse) + }, + ) + }, + ) +} + +func TestRemove(t *testing.T) { + convey.Convey( + "Test Remove", t, func() { + convey.Convey( + "Remove success", func() { + c := &CHGeneric{} + convey.So(func() { + c.Remove(1.0) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestConcurrentCHGenericNext(t *testing.T) { + convey.Convey( + "Test ConcurrentCHGenericNext", t, func() { + convey.Convey( + "ConcurrentCHGenericNext success", func() { + c := NewConcurrentCHGeneric(0) + c.counter["name"] = &concurrentCounter{ + count: 1, + } + res := c.Next("name", false) + convey.So(res, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestLimiterCHGenericRemoveAll(t *testing.T) { + convey.Convey( + "Test LimiterCHGenericRemoveAll", t, func() { + convey.Convey( + "LimiterCHGenericRemoveAll success", func() { + c := NewLimiterCHGeneric(time.Second) + convey.So(c.RemoveAll, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/common/loadbalance/hashcache.go b/api/go/faassdk/common/loadbalance/hashcache.go new file mode 100644 index 0000000..66a0f9d --- /dev/null +++ b/api/go/faassdk/common/loadbalance/hashcache.go @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import "sync" + +type hashCache struct { + hashes sync.Map +} + +func createHashCache() *hashCache { + return &hashCache{ + hashes: sync.Map{}, + } +} + +func (cache *hashCache) getHash(key string) uint32 { + hashIf, ok := cache.hashes.Load(key) + if ok { + hash, ok := hashIf.(uint32) + if ok { + return hash + } + return 0 + } + hash := getHashKeyCRC32([]byte(key)) + cache.hashes.Store(key, hash) + return hash +} diff --git a/api/go/faassdk/common/loadbalance/loadbalance.go b/api/go/faassdk/common/loadbalance/loadbalance.go new file mode 100644 index 0000000..964e090 --- /dev/null +++ b/api/go/faassdk/common/loadbalance/loadbalance.go @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides load balancing algorithm +package loadbalance + +import "time" + +const ( + // RoundRobinNginx represents type of Round Robin Nginx + RoundRobinNginx LBType = iota + // RoundRobinLVS represents type of Round Robin LVS + RoundRobinLVS + // ConsistentHashGeneric represents type of Generic Consistent Hash + ConsistentHashGeneric + // ConcurrentConsistentHashGeneric represents type of concurrent Consistent + ConcurrentConsistentHashGeneric +) + +// Request - +type Request struct { + Name string + TraceID string + Timestamp time.Time +} + +// LBType is the type of load loadbalance algorithm +type LBType int + +const defaultCHGenericConcurrency = 100 + +// LBInterface is the interface of loadbalance algorithm +type LBInterface interface { + Next(name string, move bool) interface{} // move parameter controls whether the hash loop moves + Add(node interface{}, weight int) + Remove(node interface{}) + RemoveAll() + Reset() +} + +// LBFactory is the factory of loadbalance algorithm +func LBFactory(t LBType) LBInterface { + switch t { + case RoundRobinNginx: + return &WNGINX{} + case ConsistentHashGeneric: + return NewCHGeneric() + case ConcurrentConsistentHashGeneric: + return NewConcurrentCHGeneric(defaultCHGenericConcurrency) + default: + return NewCHGeneric() + } +} diff --git a/api/go/faassdk/common/loadbalance/loadbalance_test.go b/api/go/faassdk/common/loadbalance/loadbalance_test.go new file mode 100644 index 0000000..5e083e6 --- /dev/null +++ b/api/go/faassdk/common/loadbalance/loadbalance_test.go @@ -0,0 +1,229 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides consistent hash algorithm +package loadbalance + +import ( + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "strconv" + "sync" + "testing" + "time" +) + +type LBTestSuite struct { + suite.Suite + LBInterface + lbType LBType + m sync.RWMutex + emptyNode interface{} +} + +func (lbs *LBTestSuite) SetupSuite() { + switch lbs.lbType { + case RoundRobinNginx, RoundRobinLVS: + lbs.emptyNode = nil + case ConsistentHashGeneric: + lbs.emptyNode = "" + default: + lbs.emptyNode = "" + } +} + +func (lbs *LBTestSuite) SetupTest() { + lbs.m = sync.RWMutex{} + lbs.LBInterface = LBFactory(lbs.lbType) +} + +func (lbs *LBTestSuite) TearDownTest() { + lbs.LBInterface = nil +} + +func (lbs *LBTestSuite) AddToLB(workerInstance interface{}, weight int) { + switch lbs.lbType { + case RoundRobinNginx, RoundRobinLVS: + lbs.m.Lock() + lbs.Add(workerInstance, weight) + lbs.Reset() + lbs.m.Unlock() + case ConsistentHashGeneric: + lbs.Add(workerInstance, 0) + default: + } +} + +func (lbs *LBTestSuite) DelFromLB(workerInstance interface{}) { + switch lbs.lbType { + case RoundRobinNginx, RoundRobinLVS: + lbs.m.Lock() + lbs.Remove(workerInstance) + lbs.Reset() + defer lbs.m.Unlock() + case ConsistentHashGeneric: + lbs.Remove(workerInstance) + default: + } +} + +func (lbs *LBTestSuite) TestAdd() { + lbs.AddToLB("new-node-01", 0) + lbs.AddToLB("new-node-01", 1) // test duplicate + lbs.AddToLB("new-node-02", 2) + lbs.AddToLB("new-node-03", 5) + lbs.AddToLB("", 6) + lbs.AddToLB(nil, 4) + next := lbs.Next("fn-urn-01", true) + assert.NotEqual(lbs.T(), lbs.emptyNode, next) + lbs.Reset() + next = lbs.Next("fn-urn-01", true) + assert.NotEqual(lbs.T(), lbs.emptyNode, next) +} + +func (lbs *LBTestSuite) TestNext() { + var wg sync.WaitGroup + next := lbs.Next("fn-urn-01", false) + assert.Equal(lbs.T(), lbs.emptyNode, next) + + lbs.AddToLB("new-node-01", 5) + next = lbs.Next("fn-urn-01", true) + assert.Equal(lbs.T(), "new-node-01", next) + + for i := 2; i < 5; i++ { + wg.Add(1) + go func(i int, wg *sync.WaitGroup) { + lbs.AddToLB("new-node-0"+strconv.Itoa(i), 5) + wg.Done() + }(i, &wg) + } + wg.Wait() + next = lbs.Next("fn-urn-01", true) + assert.NotEqual(lbs.T(), lbs.emptyNode, next) +} + +func (lbs *LBTestSuite) TestRemove() { + var wg sync.WaitGroup + for i := 1; i < 5; i++ { + wg.Add(1) + go func(i int, wg *sync.WaitGroup) { + lbs.AddToLB("new-node-0"+strconv.Itoa(i), 5) + wg.Done() + }(i, &wg) + } + wg.Wait() + for i := 1; i < 4; i++ { + wg.Add(1) + go func(i int, wg *sync.WaitGroup) { + lbs.DelFromLB("new-node-0" + strconv.Itoa(i)) + wg.Done() + }(i, &wg) + } + wg.Wait() + next := lbs.Next("fn-urn-01", true) + assert.Equal(lbs.T(), "new-node-04", next) +} + +func (lbs *LBTestSuite) TestRemoveAll() { + var wg sync.WaitGroup + for i := 1; i < 5; i++ { + wg.Add(1) + go func(i int, wg *sync.WaitGroup) { + lbs.Add("new-node-0"+strconv.Itoa(i), 5) + wg.Done() + }(i, &wg) + } + wg.Wait() + lbs.RemoveAll() + next := lbs.Next("fn-urn-01", true) + assert.Equal(lbs.T(), lbs.emptyNode, next) +} + +func TestLBTestSuite(t *testing.T) { + suite.Run(t, &LBTestSuite{lbType: ConsistentHashGeneric}) +} + +func TestConcurrentCHGeneric_Add(t *testing.T) { + con := NewConcurrentCHGeneric(2) + con.Add("n1", 0) + con.Add("n2", 0) + + next := con.Next("n1", false) + assert.Equal(t, "n2", next) + + con.Remove("n2") + con.RemoveAll() + con.Reset() + + next = con.Next("n1", false) + assert.Equal(t, "", next) +} + +func TestLBFactory(t *testing.T) { + convey.Convey("LBFactory", t, func() { + convey.Convey("RoundRobinNginx", func() { + factory := LBFactory(LBType(0)) + convey.So(factory, convey.ShouldNotBeNil) + }) + convey.Convey("ConsistentHashGeneric", func() { + factory := LBFactory(LBType(2)) + convey.So(factory, convey.ShouldNotBeNil) + }) + convey.Convey("ConcurrentConsistentHashGeneric", func() { + factory := LBFactory(LBType(3)) + convey.So(factory, convey.ShouldNotBeNil) + }) + convey.Convey("default", func() { + factory := LBFactory(LBType(1)) + convey.So(factory, convey.ShouldNotBeNil) + }) + }) +} + +func TestLimiterCHGeneric_Next(t *testing.T) { + convey.Convey("TestLimiterCHGeneric_Next", t, func() { + limiter := NewLimiterCHGeneric(1) + next := limiter.Next("n1", false) + convey.So(next, convey.ShouldEqual, "") + limiter.Add("n1", 0) + next = limiter.Next("n1", false) + convey.So(next, convey.ShouldEqual, "n1") + limiter.Add("n2", 0) + next = limiter.Next("n1", true) + convey.So(next, convey.ShouldEqual, "n2") + limiter.Remove("n1") + limiter.Remove("n2") + }) +} + +func TestLimiterCHGeneric_SetStain(t *testing.T) { + convey.Convey("TestLimiterCHGeneric_SetStain", t, func() { + limiter := NewLimiterCHGeneric(1) + limiter.Add("faasscheduelrID-1", 0) + next := limiter.Next("test-function", false) + convey.So(next, convey.ShouldEqual, "faasscheduelrID-1") + + limiter.SetStain("test-function", "faasscheduelrID-1") + next = limiter.Next("test-function", false) + convey.So(next, convey.ShouldEqual, "faasscheduelrID-1") + + time.Sleep(1 * time.Second) + limiter.Reset() + next = limiter.Next("test-function", false) + convey.So(next, convey.ShouldEqual, "faasscheduelrID-1") + }) +} diff --git a/api/go/faassdk/common/loadbalance/nolockconsistenthash.go b/api/go/faassdk/common/loadbalance/nolockconsistenthash.go new file mode 100644 index 0000000..852df80 --- /dev/null +++ b/api/go/faassdk/common/loadbalance/nolockconsistenthash.go @@ -0,0 +1,126 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "errors" + "sort" +) + +// Node - +type Node struct { + Obj interface{} + Key string + hash uint32 +} + +// NoLockLoadBalance - +type NoLockLoadBalance interface { + Add(node *Node) error + Next(key string) *Node + Delete(nodeKey string) *Node +} + +// CreateNoLockLB - +func CreateNoLockLB() NoLockLoadBalance { + return &ConsistentHash{ + nodes: make([]*Node, 0), + cache: createHashCache(), + } +} + +type nodeSlice []*Node + +// Len returns the size +func (s nodeSlice) Len() int { + return len(s) +} + +// Swap will swap two elements +func (s nodeSlice) Swap(i, j int) { + if i < 0 || i >= len(s) || j < 0 || j >= len(s) { + return + } + s[i], s[j] = s[j], s[i] +} + +// Less returns true if i less than j +func (s nodeSlice) Less(i, j int) bool { + if i < 0 || i >= len(s) || j < 0 || j >= len(s) { + return false + } + return s[i].hash < s[j].hash +} + +// ConsistentHash - +type ConsistentHash struct { + cache *hashCache + nodes nodeSlice +} + +// Add - +func (c *ConsistentHash) Add(newNode *Node) error { + newNode.hash = getHashKeyCRC32([]byte(newNode.Key)) + for _, node := range c.nodes { + if node.Key == newNode.Key { + return errors.New("node already exist") + } + if node.hash == newNode.hash { + return errors.New("node hash already exist") + } + } + + c.nodes = append(c.nodes, newNode) + sort.Sort(c.nodes) + return nil +} + +// Next - +func (c *ConsistentHash) Next(key string) *Node { + if len(c.nodes) == 0 { + return nil + } + + keyHash := c.cache.getHash(key) + index := c.search(keyHash) + return c.nodes[index] +} + +func (c *ConsistentHash) search(keyHash uint32) int { + f := func(x int) bool { + if x >= len(c.nodes) { + return false + } + return c.nodes[x].hash > keyHash + } + index := sort.Search(len(c.nodes), f) + if index >= len(c.nodes) { + return 0 + } + return index +} + +// Delete - +func (c *ConsistentHash) Delete(nodeKey string) *Node { + for i, node := range c.nodes { + if node.Key == nodeKey { + c.nodes = append(c.nodes[:i], c.nodes[i+1:]...) + return node + } + } + return nil +} diff --git a/api/go/faassdk/common/loadbalance/nolockconsistenthash_test.go b/api/go/faassdk/common/loadbalance/nolockconsistenthash_test.go new file mode 100644 index 0000000..dbe1c85 --- /dev/null +++ b/api/go/faassdk/common/loadbalance/nolockconsistenthash_test.go @@ -0,0 +1,118 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides consistent hash alogrithm +package loadbalance + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +const ( + nodeKey = "faas-scheduler-6b758c8b74-5zdwv" + funcKeyWithRes = "7e186a/0@base@testresourcepython36768/latest/300-128" +) + +var ( + node1 = &Node{ + Key: nodeKey, + } + + node2 = &Node{ + Key: nodeKey + "1", + } + + node3 = &Node{ + Key: nodeKey + "2", + } +) + +type mockRealNode struct { + state bool +} + +func (node *mockRealNode) IsEnable() bool { + return node.state +} + +func TestStatefulConsistent(t *testing.T) { + convey.Convey("TestStatefulConsistentHashWithOneNode", t, func() { + lb := CreateNoLockLB() + outNode := lb.Next(funcKeyWithRes) + convey.So(outNode, convey.ShouldBeNil) + + lb.Add(node1) + lb.Add(node1) + + outNode = lb.Next(funcKeyWithRes) + convey.So(outNode, convey.ShouldNotBeNil) + convey.So(outNode.Key, convey.ShouldEqual, node1.Key) + + outNode = lb.Delete(nodeKey) + convey.So(outNode, convey.ShouldNotBeNil) + convey.So(outNode.Key, convey.ShouldEqual, node1.Key) + + outNode = lb.Delete(nodeKey) + convey.So(outNode, convey.ShouldBeNil) + + outNode = lb.Next(funcKeyWithRes) + convey.So(outNode, convey.ShouldBeNil) + + lb.Add(node2) + lb.Add(node3) + lb.Add(node1) + + outNode = lb.Next(funcKeyWithRes) + convey.So(outNode, convey.ShouldNotBeNil) + convey.So(outNode.Key, convey.ShouldEqual, node2.Key) + + }) +} + +func BenchmarkStatefulConsistentHashWithThreeNode(b *testing.B) { + lb := CreateNoLockLB() + + lb.Add(node1) + lb.Add(node2) + lb.Add(node3) + + for i := 0; i < b.N; i++ { + lb.Next(funcKeyWithRes + "3") + } +} + +func TestNodeSlice(t *testing.T) { + convey.Convey( + "Test nodeSlice", t, func() { + var s nodeSlice = []*Node{} + convey.Convey( + "Swap success", func() { + convey.So(func() { + s.Swap(-1, 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Less success", func() { + flag := s.Less(-1, 0) + convey.So(flag, convey.ShouldBeFalse) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/common/loadbalance/roundrobin.go b/api/go/faassdk/common/loadbalance/roundrobin.go new file mode 100644 index 0000000..3d28451 --- /dev/null +++ b/api/go/faassdk/common/loadbalance/roundrobin.go @@ -0,0 +1,258 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides roundrobin algorithm +package loadbalance + +import ( + "math/rand" + "time" +) + +// WeightNginx weight nginx +type WeightNginx struct { + Node interface{} + Weight int + CurrentWeight int + EffectiveWeight int +} + +// WNGINX w nginx +type WNGINX struct { + nodes []*WeightNginx +} + +// Add add node +func (w *WNGINX) Add(node interface{}, weight int) { + weightNginx := &WeightNginx{ + Node: node, + Weight: weight, + EffectiveWeight: weight} + w.nodes = append(w.nodes, weightNginx) +} + +// Remove removes a node +func (w *WNGINX) Remove(node interface{}) { + for i, weighted := range w.nodes { + if weighted.Node == node { + w.nodes = append(w.nodes[:i], w.nodes[i+1:]...) + break + } + } +} + +// RemoveAll remove all nodes +func (w *WNGINX) RemoveAll() { + w.nodes = w.nodes[:0] +} + +// Next get next node +func (w *WNGINX) Next(_ string, _ bool) interface{} { + if len(w.nodes) == 0 { + return nil + } + if len(w.nodes) == 1 { + return w.nodes[0].Node + } + return nextWeightedNode(w.nodes).Node +} + +// nextWeightedNode get best next node info +func nextWeightedNode(nodes []*WeightNginx) *WeightNginx { + total := 0 + if len(nodes) == 0 { + return nil + } + best := nodes[0] + for _, w := range nodes { + w.CurrentWeight += w.EffectiveWeight + total += w.EffectiveWeight + if w.CurrentWeight > best.CurrentWeight { + best = w + } + } + best.CurrentWeight -= total + return best +} + +// Reset reset all nodes +func (w *WNGINX) Reset() { + for _, s := range w.nodes { + s.EffectiveWeight = s.Weight + s.CurrentWeight = 0 + } +} + +// Done - +func (w *WNGINX) Done(node interface{}) {} + +// NextWithRequest - +func (w *WNGINX) NextWithRequest(req *Request, move bool) interface{} { + return w.Next(req.Name, move) +} + +// SetConcurrency - +func (w *WNGINX) SetConcurrency(concurrency int) {} + +// Start - +func (w *WNGINX) Start() {} + +// Stop - +func (w *WNGINX) Stop() {} + +// NoLock - +func (w *WNGINX) NoLock() bool { + return false +} + +// WeightLvs weight lv5 +type WeightLvs struct { + Node interface{} + Weight int +} + +// WLVS w lv5 +type WLVS struct { + nodes []*WeightLvs + gcd int // general weight divisor + maxW int // maximum weight + i int // number of times selected + cw int // current weight +} + +// Next get next node +func (w *WLVS) Next(_ string, _ bool) interface{} { + if len(w.nodes) == 0 { + return nil + } + if len(w.nodes) == 1 { + return w.nodes[0].Node + } + for { + if w.updateCwAnsIsReturn() { + return nil + } + + if w.i < len(w.nodes) && w.nodes[w.i].Weight >= w.cw { + return w.nodes[w.i].Node + } + } +} + +// updateCwAnsIsReturn update current weight and return the value whether to return +func (w *WLVS) updateCwAnsIsReturn() bool { + w.i = (w.i + 1) % len(w.nodes) + if w.i == 0 { + if w.cw = w.cw - w.gcd; w.cw <= 0 { + if w.cw = w.maxW; w.cw == 0 { + return true + } + } + } + + return false +} + +// Add add a node +func (w *WLVS) Add(node interface{}, weight int) { + weighted := &WeightLvs{Node: node, Weight: weight} + if weight > 0 { + w.gcd = gcd(w.gcd, weight) + if w.maxW < weight { + w.maxW = weight + } + } + w.nodes = append(w.nodes, weighted) +} + +// Returns the maximum divisor +func gcd(x, y int) int { + for { + if y == 0 { + return x + } + x, y = y, x%y + } +} + +// Remove removes a node +func (w *WLVS) Remove(node interface{}) { + for i, weighted := range w.nodes { + if weighted.Node == node { + w.nodes = append(w.nodes[:i], w.nodes[i+1:]...) + break + } + } +} + +// RemoveAll remove all nodes +func (w *WLVS) RemoveAll() { + w.nodes = w.nodes[:0] + w.gcd = 0 + w.maxW = 0 + w.i = -1 + w.cw = 0 +} + +// Reset reset all nodes +func (w *WLVS) Reset() { + w.i = -1 + w.cw = 0 +} + +// Done - +func (w *WLVS) Done(node interface{}) {} + +// NextWithRequest - +func (w *WLVS) NextWithRequest(req *Request, move bool) interface{} { + return w.Next(req.Name, move) +} + +// SetConcurrency - +func (w *WLVS) SetConcurrency(concurrency int) {} + +// Start - +func (w *WLVS) Start() {} + +// Stop - +func (w *WLVS) Stop() {} + +// NoLock - +func (w *WLVS) NoLock() bool { + return false +} + +// NewRandomRR init RandomRR +func NewRandomRR() *RandomRR { + return &RandomRR{ + rnd: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +// RandomRR is random version of WLVS +// it will shuffle all nodes randomly when reset +type RandomRR struct { + rnd *rand.Rand + WLVS +} + +// Reset reset RR and shuffle all nodes +func (r *RandomRR) Reset() { + r.WLVS.Reset() + r.rnd.Shuffle(len(r.WLVS.nodes), func(i, j int) { + r.WLVS.nodes[i], r.WLVS.nodes[j] = r.WLVS.nodes[j], r.WLVS.nodes[i] + }) +} diff --git a/api/go/faassdk/common/loadbalance/roundrobin_test.go b/api/go/faassdk/common/loadbalance/roundrobin_test.go new file mode 100644 index 0000000..c3c0b48 --- /dev/null +++ b/api/go/faassdk/common/loadbalance/roundrobin_test.go @@ -0,0 +1,290 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides consistent hash alogrithm +package loadbalance + +import ( + "errors" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" +) + +func TestNext(t *testing.T) { + convey.Convey("node length is 0", t, func() { + node := []*WeightNginx{} + wnginx := WNGINX{node} + + res := wnginx.Next("", true) + convey.So(res, convey.ShouldBeNil) + }) + convey.Convey("node length is 1", t, func() { + node := []*WeightNginx{ + {"Node1", 30, 10, 20}, + } + wnginx := WNGINX{node} + + res := wnginx.Next("", true) + convey.So(res, convey.ShouldNotBeNil) + }) + convey.Convey("node length > 1", t, func() { + node := []*WeightNginx{ + {"Node1", 30, 10, 20}, + {"Node2", 30, 60, 20}, + } + wnginx := WNGINX{node} + + res := wnginx.Next("", true) + resStr, ok := res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node2") + }) + + convey.Convey("remove", t, func() { + node := []*WeightNginx{ + {"Node1", 30, 10, 20}, + } + wnginx := WNGINX{node} + wnginx.Add("Node2", 60) + res := wnginx.Next("", true) + resStr, ok := res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node2") + + wnginx.Remove("Node2") + res = wnginx.Next("", true) + resStr, ok = res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node1") + }) + + convey.Convey("remove", t, func() { + node := []*WeightNginx{ + {"Node1", 30, 10, 20}, + {"Node2", 30, 60, 20}, + } + wnginx := WNGINX{node} + wnginx.RemoveAll() + convey.So(len(wnginx.nodes), convey.ShouldEqual, 0) + }) +} + +func TestReset(t *testing.T) { + convey.Convey("Reset success", t, func() { + weightNginx := &WeightNginx{"Node1", 30, 10, 20} + var node []*WeightNginx + node = append(node, weightNginx) + wnginx := WNGINX{node} + + wnginx.Reset() + convey.So(weightNginx.EffectiveWeight, convey.ShouldEqual, weightNginx.Weight) + }) + +} + +func TestAdd(t *testing.T) { + convey.Convey("add", t, func() { + weightLvs := &WeightLvs{"Node1", 10} + var node []*WeightLvs + node = append(node, weightLvs) + wlvs := &WLVS{node, 10, 0, 10, 10} + wlvs.Add("Node2", 20) + convey.So(wlvs.gcd, convey.ShouldNotEqual, 20) + convey.So(wlvs.maxW, convey.ShouldEqual, 20) + + wlvs.gcd = 0 + wlvs.Add("Node2", 20) + convey.So(wlvs.gcd, convey.ShouldEqual, 20) + convey.So(wlvs.maxW, convey.ShouldEqual, 20) + }) +} + +func TestRemoveAll(t *testing.T) { + weightLvs := &WeightLvs{"Node1", 10} + var node []*WeightLvs + node = append(node, weightLvs) + wlvs := &WLVS{node, 10, 0, 10, 10} + wlvs.RemoveAll() + if wlvs.maxW != 0 || wlvs.gcd != 0 || wlvs.cw != 0 || wlvs.i != -1 { + err := errors.New("Test_RemoveAll error") + t.Fatal(err) + } +} + +func TestWLVSReset(t *testing.T) { + weightLvs := &WeightLvs{"Node1", 10} + var node []*WeightLvs + node = append(node, weightLvs) + wlvs := &WLVS{node, 10, 0, 10, 10} + wlvs.Reset() + if wlvs.cw != 0 || wlvs.i != -1 { + err := errors.New("Test_WLVSReset error") + t.Fatal(err) + } + +} + +func TestWLVSNext(t *testing.T) { + convey.Convey("node length is 0", t, func() { + var node []*WeightLvs + wlvs := &WLVS{node, 0, 0, 10, 10} + + res := wlvs.Next("", true) + convey.So(res, convey.ShouldBeNil) + }) + + convey.Convey("add", t, func() { + node := []*WeightLvs{ + {"Node1", 10}, + {"Node2", 20}, + } + wlvs := &WLVS{node, 0, 0, 10, 10} + + res := wlvs.Next("", true) + convey.So(res, convey.ShouldNotBeNil) + + wlvss := &WLVS{node, 1, 0, 10, 10} + res = wlvss.Next("", true) + resStr, ok := res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node2") + }) + + convey.Convey("remove all", t, func() { + node := []*WeightLvs{ + {"Node1", 10}, + } + wlvss := &WLVS{node, 1, 0, 10, 10} + res := wlvss.Next("", true) + resStr, ok := res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node1") + + wlvss.RemoveAll() + res = wlvss.Next("", true) + resStr, ok = res.(string) + convey.So(ok, convey.ShouldBeFalse) + convey.So(resStr, convey.ShouldEqual, "") + }) + convey.Convey("remove", t, func() { + node := []*WeightLvs{ + {"Node1", 10}, + } + wlvss := &WLVS{node, 1, 0, 10, 10} + res := wlvss.Next("", true) + resStr, ok := res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node1") + + wlvss.Remove(node[0].Node) + res = wlvss.Next("", true) + resStr, ok = res.(string) + convey.So(ok, convey.ShouldBeFalse) + convey.So(resStr, convey.ShouldEqual, "") + }) +} + +func Test_gcd(t *testing.T) { + cases := []struct { + x int + y int + want int + }{ + {9, 0, 9}, + {0, 8, 8}, + {1, 1, 1}, + {2, 9, 1}, + {12, 18, 6}, + {60, 140, 20}, + {3, 5, 1}, + {111, 111, 111}, + } + for _, tt := range cases { + assert.Equal(t, tt.want, gcd(tt.x, tt.y)) + } +} + +func Test_WNGINX(t *testing.T) { + w := &WNGINX{} + w.Done("node") + request := w.NextWithRequest(&Request{}, false) + assert.Equal(t, request, nil) + w.SetConcurrency(0) + w.Start() + w.Stop() + lock := w.NoLock() + assert.Equal(t, lock, false) +} + +func Test_WLVS(t *testing.T) { + w := &WLVS{} + w.Done("node") + request := w.NextWithRequest(&Request{}, false) + assert.Equal(t, request, nil) + w.SetConcurrency(0) + w.Start() + w.Stop() + lock := w.NoLock() + assert.Equal(t, lock, false) +} + +func TestNextWeightedNode(t *testing.T) { + convey.Convey( + "Test nextWeightedNode", t, func() { + convey.Convey( + "nextWeightedNode success", func() { + wn := nextWeightedNode([]*WeightNginx{}) + convey.So(wn, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestUpdateCwAnsIsReturn(t *testing.T) { + convey.Convey( + "Test updateCwAnsIsReturn", t, func() { + convey.Convey( + "updateCwAnsIsReturn success", func() { + weightLvs := &WeightLvs{"Node1", 10} + var node []*WeightLvs + node = append(node, weightLvs) + wlvs := &WLVS{node, 10, 0, 10, 10} + + flag := wlvs.updateCwAnsIsReturn() + convey.So(flag, convey.ShouldBeTrue) + }, + ) + }, + ) +} + +func TestRandomRR(t *testing.T) { + convey.Convey( + "Test RandomRR", t, func() { + convey.Convey( + "RandomRR success", func() { + convey.So(func() { + r := NewRandomRR() + r.Reset() + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/common/monitor/monitor_manager.go b/api/go/faassdk/common/monitor/monitor_manager.go new file mode 100644 index 0000000..0481485 --- /dev/null +++ b/api/go/faassdk/common/monitor/monitor_manager.go @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package monitor +package monitor + +import ( + "errors" + + "yuanrong.org/kernel/runtime/faassdk/common/monitor/oom" + "yuanrong.org/kernel/runtime/faassdk/types" +) + +var ( + // ErrExecuteDiskLimit is the error of execute function disk limit + ErrExecuteDiskLimit = errors.New("execute function disk limit") + // ErrExecuteOOM is the error of exec ute function oom + ErrExecuteOOM = errors.New("execute function oom") +) + +// EnvDelegateContainer is the environment key of delegate-container +const EnvDelegateContainer = "DELEGATE_CONTAINER_ID" + +// FunctionMonitorManager - +type FunctionMonitorManager struct { + MemoryManager *oom.MemoryManager + ErrChan chan error +} + +// CreateFunctionMonitor create function monitor include disk monitor, memory monitor +func CreateFunctionMonitor(funcSpec *types.FuncSpec, stopCh chan struct{}) (*FunctionMonitorManager, error) { + memoryManager := oom.NewMemoryManager(funcSpec.ResourceMetaData.Memory, stopCh) + + m := &FunctionMonitorManager{ + MemoryManager: memoryManager, + ErrChan: make(chan error, 1), + } + go m.receiveMonitorsError() + return m, nil +} + +func (fmm *FunctionMonitorManager) receiveMonitorsError() { + select { + case <-fmm.MemoryManager.OOMChan: + fmm.ErrChan <- ErrExecuteOOM + } +} diff --git a/api/go/faassdk/common/monitor/monitor_manager_test.go b/api/go/faassdk/common/monitor/monitor_manager_test.go new file mode 100644 index 0000000..20589f2 --- /dev/null +++ b/api/go/faassdk/common/monitor/monitor_manager_test.go @@ -0,0 +1,28 @@ +package monitor + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/faassdk/types" +) + +func TestCreateFunctionMonitor(t *testing.T) { + convey.Convey("CreateFunctionMonitor", t, func() { + spec := &types.FuncSpec{ResourceMetaData: types.ResourceMetaData{Memory: 500}} + convey.Convey("create monitor custom image disk", func() { + monitor, err := CreateFunctionMonitor(spec, make(chan struct{})) + convey.So(monitor, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("success", func() { + monitor, err := CreateFunctionMonitor(spec, make(chan struct{})) + convey.So(monitor, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + close(monitor.MemoryManager.OOMChan) + err = <-monitor.ErrChan + convey.So(err, convey.ShouldEqual, ErrExecuteOOM) + }) + }) +} diff --git a/api/go/faassdk/common/monitor/oom/memory_manager.go b/api/go/faassdk/common/monitor/oom/memory_manager.go new file mode 100644 index 0000000..021c82d --- /dev/null +++ b/api/go/faassdk/common/monitor/oom/memory_manager.go @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package oom +package oom + +import ( + "os" + "path/filepath" + "time" + + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + mebibyte = 1024 * 1024 + memoryGap = 5 + + memoryPath = "/runtime/memory" + memoryFile = "memory.stat" + + readMemoryInternal = 2 * time.Millisecond +) + +// MemoryManager - +type MemoryManager struct { + OOMChan chan struct{} + stopCh chan struct{} + PodMemory int +} + +// NewMemoryManager - +func NewMemoryManager(podMemory int, stopCh chan struct{}) *MemoryManager { + mm := &MemoryManager{ + OOMChan: make(chan struct{}, 1), + stopCh: stopCh, + PodMemory: podMemory, + } + go mm.WatchingRuntimeMemoryUsedLoop() + return mm +} + +// WatchingRuntimeMemoryUsedLoop - watch memory.stat until oom or process exit +func (mm *MemoryManager) WatchingRuntimeMemoryUsedLoop() { + lastMemoryUsed := 0.0 + podMemory := mm.PodMemory + limit := float64(podMemory - memoryGap) + path := filepath.Join(memoryPath, os.Getenv("DELEGATE_CONTAINER_ID"), memoryFile) + parser, err := NewCGroupMemoryParserWithPath(path) + if err != nil { + logger.GetLogger().Warnf("failed to create cgroup memory parser: %s", err.Error()) + return + } + defer parser.Close() + logger.GetLogger().Infof("start to watch memory, path:%s, limit %f", path, limit) + for { + select { + case _, ok := <-mm.stopCh: + if !ok { + logger.GetLogger().Warnf("context canceled") + return + } + default: + flag := mm.checkRuntimeMemory(&lastMemoryUsed, limit, parser) + if flag { + logger.GetLogger().Warnf("the runtime oom check process exited") + return + } + time.Sleep(readMemoryInternal) + } + } +} + +func (mm *MemoryManager) checkRuntimeMemory(lastMemoryUsed *float64, limit float64, + parser *Parser) bool { + memoryUsedBytes, err := parser.Read() + if err != nil { + logger.GetLogger().Warnf("failed to parse memory used, err: %s", err.Error()) + return false + } + tmpMemoryUsed := float64(memoryUsedBytes) / mebibyte + // update used memory + if tmpMemoryUsed != *lastMemoryUsed { + } + *lastMemoryUsed = tmpMemoryUsed + if tmpMemoryUsed > limit { + close(mm.OOMChan) + return true + } + return false +} diff --git a/api/go/faassdk/common/monitor/oom/memory_manager_test.go b/api/go/faassdk/common/monitor/oom/memory_manager_test.go new file mode 100644 index 0000000..9c4e5f4 --- /dev/null +++ b/api/go/faassdk/common/monitor/oom/memory_manager_test.go @@ -0,0 +1,84 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package oom +package oom + +import ( + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" +) + +func TestNewMemoryManager(t *testing.T) { + convey.Convey("NewMemoryManager", t, func() { + stopCh := make(chan struct{}) + manager := NewMemoryManager(500, stopCh) + convey.So(manager, convey.ShouldNotBeNil) + close(stopCh) + }) +} + +func createTempMemoryStatFile(path string, memory int) { + file, _ := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) + file.WriteString(fmt.Sprintf("rss %d", memory)) + file.Close() +} + +func TestMemoryManager_WatchingRuntimeMemoryUsedLoop(t *testing.T) { + convey.Convey("WatchingRuntimeMemoryUsedLoop", t, func() { + tmpFileMemoryStat := "./memory.stat" + createTempMemoryStatFile(tmpFileMemoryStat, 400*1024*1024) + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(filepath.Join, func(elem ...string) string { + return tmpFileMemoryStat + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + err := os.Remove(tmpFileMemoryStat) + if err != nil { + fmt.Println(err) + } + }() + convey.Convey("success", func() { + stopCh := make(chan struct{}) + manager := NewMemoryManager(500, stopCh) + convey.So(manager, convey.ShouldNotBeNil) + defer close(stopCh) + + time.Sleep(4 * time.Millisecond) + createTempMemoryStatFile(tmpFileMemoryStat, 500*1024*1024) + _, ok := <-manager.OOMChan + convey.So(ok, convey.ShouldEqual, false) + }) + + convey.Convey("stop channel close", func() { + stopCh := make(chan struct{}) + manager := NewMemoryManager(500, stopCh) + convey.So(manager, convey.ShouldNotBeNil) + time.Sleep(4 * time.Millisecond) + close(stopCh) + }) + }) +} diff --git a/api/go/faassdk/common/monitor/oom/parser.go b/api/go/faassdk/common/monitor/oom/parser.go new file mode 100644 index 0000000..d90d1bf --- /dev/null +++ b/api/go/faassdk/common/monitor/oom/parser.go @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package oom +package oom + +import ( + "bufio" + "bytes" + "io" + "os" + "strconv" + + "yuanrong.org/kernel/runtime/faassdk/utils" +) + +var ( + rssPrefix = []byte("rss ") +) + +const ( + decimalBase = 10 + bitSize = 64 +) + +// NewCGroupMemoryParserWithPath creates parser with cgroupMemoryParserFunc +func NewCGroupMemoryParserWithPath(path string) (*Parser, error) { + return NewParser(path, cgroupMemoryParserFunc) +} + +var cgroupMemoryParserFunc = func(reader *bufio.Reader) (uint64, error) { + for { + lineBytes, _, err := reader.ReadLine() + if err != nil { + return 0, err + } + + if bytes.HasPrefix(lineBytes, rssPrefix) { + lineBytes = bytes.TrimSpace(lineBytes[len(rssPrefix):]) + return strconv.ParseUint(utils.BytesToString(lineBytes), decimalBase, bitSize) + } + } +} + +// ParserFunc func that parser content of reader to uint64 +type ParserFunc func(reader *bufio.Reader) (uint64, error) + +// NewParser creates new Parser with file path and ParserFunc +func NewParser(path string, parserFunc ParserFunc) (*Parser, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + return &Parser{ + f: f, + reader: bufio.NewReader(nil), + parser: parserFunc, + }, nil +} + +// Parser aims to parse file content that updated frequently (such as cgroup file) with high performance. +// It opens file only once and seek to start every time before read. +// NOTICE: Parser is not thread safe +type Parser struct { + reader *bufio.Reader + f io.ReadSeekCloser + parser ParserFunc +} + +// Close closes file to parse +func (p *Parser) Close() error { + p.reader.Reset(nil) + return p.f.Close() +} + +// Read resets reader to the start of the file and parses it. +// This method is not thread safe +func (p *Parser) Read() (uint64, error) { + _, err := p.f.Seek(0, io.SeekStart) + if err != nil { + return 0, err + } + p.reader.Reset(p.f) + return p.parser(p.reader) +} diff --git a/api/go/faassdk/common/tokentosecret/secret_mgr.go b/api/go/faassdk/common/tokentosecret/secret_mgr.go new file mode 100644 index 0000000..cf8572b --- /dev/null +++ b/api/go/faassdk/common/tokentosecret/secret_mgr.go @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tokentosecret +package tokentosecret + +import ( + "crypto/sha512" + "sync" + + "golang.org/x/crypto/pbkdf2" + + "yuanrong.org/kernel/runtime/libruntime/common" +) + +// SecretMgr - +type SecretMgr struct { + token string + salt string + // expireTime int + sk []byte + lock sync.RWMutex +} + +var mgr = &SecretMgr{ + lock: sync.RWMutex{}, +} + +// GetSecretMgr - +func GetSecretMgr() *SecretMgr { + return mgr +} + +// SetAuthContext - +func (s *SecretMgr) SetAuthContext(auth *common.AuthContext) { + s.lock.Lock() + if auth == nil || auth.Token == s.token || auth.Token == "" { + s.lock.Unlock() + return + } + + s.token = auth.Token + s.salt = auth.Salt + s.lock.Unlock() + s.generateSk(s.token, s.salt) +} + +// GetSk - +func (s *SecretMgr) GetSk() []byte { + s.lock.RLock() + defer s.lock.RUnlock() + return s.sk +} + +func (s *SecretMgr) generateSk(token string, salt string) { + s.lock.Lock() + defer s.lock.Unlock() + s.sk = pbkdf2.Key([]byte(token), []byte(salt), 25000, sha512.Size, sha512.New) // 25000 iter +} diff --git a/api/go/faassdk/common/tokentosecret/secret_mgr_test.go b/api/go/faassdk/common/tokentosecret/secret_mgr_test.go new file mode 100644 index 0000000..f93d66f --- /dev/null +++ b/api/go/faassdk/common/tokentosecret/secret_mgr_test.go @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common for token +package tokentosecret + +import ( + "fmt" + "testing" + "time" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/common" +) + +func TestSecretMgr(t *testing.T) { + convey.Convey("Test token Mgr: GetToken", t, func() { + ch := make(chan struct{}) + GetSecretMgr().SetAuthContext(&common.AuthContext{ + ServerAuthEnable: false, + RootCertData: nil, + ModuleCertData: nil, + ModuleKeyData: nil, + Token: "fakeToken", + Salt: "134134134314134", + EnableServerMode: false, + ServerNameOverride: "", + }) + tk := GetSecretMgr().token + convey.So(tk, convey.ShouldEqual, "fakeToken") + close(ch) + }) +} + +func TestSecretMgr_UpdateToken_GetSk(t *testing.T) { + convey.Convey("Test SecretMgr TestSecretMgr_UpdateToken_GetSk", t, func() { + timeStamp := time.Now().Second() + 5 // 5秒后过期 + authContext := &common.AuthContext{ + Token: fmt.Sprintf("%d_%s", timeStamp, "token0"), + Salt: "134134134", + } + + GetSecretMgr().SetAuthContext(authContext) + //convey.So(GetSecretMgr().expireTime, convey.ShouldEqual, timeStamp) + convey.So(GetSecretMgr().token, convey.ShouldEqual, fmt.Sprintf("%d_%s", timeStamp, "token0")) + convey.So(GetSecretMgr().salt, convey.ShouldEqual, "134134134") + convey.So(GetSecretMgr().sk, convey.ShouldNotBeNil) + convey.So(GetSecretMgr().GetSk(), convey.ShouldNotBeNil) + }) +} diff --git a/api/go/faassdk/config/config.go b/api/go/faassdk/config/config.go new file mode 100644 index 0000000..30c14a3 --- /dev/null +++ b/api/go/faassdk/config/config.go @@ -0,0 +1,112 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config of plugin faas executor +package config + +import ( + "errors" + "os" + "sync" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +// LogTankService - +type LogTankService struct { + GroupID string `json:"logGroupId" valid:",optional"` + StreamID string `json:"logStreamId" valid:",optional"` +} + +var ( + configSingleton = struct { + sync.Once + config *Configuration + }{} +) + +// DelegateDownloadPath contains downloaded user code packages +const DelegateDownloadPath string = "ENV_DELEGATE_DOWNLOAD" + +// DelegateLayerDownloadPath contains downloaded layer packages +const DelegateLayerDownloadPath string = "DELEGATE_LAYER_DOWNLOAD " + +// MaxPayloadSize 6 MB +const MaxPayloadSize int32 = 6 * 1024 * 1024 + +// MaxReturnSize 6MB +const MaxReturnSize int32 = 6 * 1024 * 1024 + +// DefaultRuntimeJsonFilePath - +const DefaultRuntimeJsonFilePath = "/home/snuser/config/runtime.json" + +// Configuration represents total config for app +type Configuration struct { + StartArgs `yaml:"commandLine" valid:"optional"` + RuntimeConfig `yaml:"runtime" valid:"required"` + UserLogTag string `yaml:"userLogTag" valid:"optional"` + FunctionNameSeparator string `yaml:"functionNameSeparator" valid:"optional"` +} + +// StartArgs represents arguments of start +type StartArgs struct { + LogDir string `yaml:"logDir" valid:"optional"` +} + +// RuntimeConfig represents configuration of runtime configuration +type RuntimeConfig struct { + URN string `yaml:"urn" valid:"optional"` + RuntimeContainerID string `json:"runtimeContainerID"` + FuncLogLevel string `yaml:"loglevel" valid:"optional"` + LogTankService LogTankService `json:"logTankService"` +} + +// GetConfig singleton pattern thread safe +func GetConfig(funcSpec *types.FuncSpec) (*Configuration, error) { + configSingleton.Do(func() { + runtimeContainerID := os.Getenv(constants.RuntimeContainerIDEnvKey) + logger.GetLogger().Infof("runtime container id is: %s", runtimeContainerID) + configSingleton.config = &Configuration{ + StartArgs: StartArgs{ + LogDir: "/home/snuser/log", + }, + RuntimeConfig: RuntimeConfig{ + URN: funcSpec.FuncMetaData.FunctionVersionURN, + FuncLogLevel: constants.FuncLogLevelInfo, + RuntimeContainerID: runtimeContainerID, + LogTankService: LogTankService{ + GroupID: funcSpec.ExtendedMetaData.LogTankService.GroupID, + StreamID: funcSpec.ExtendedMetaData.LogTankService.StreamID, + }}, + UserLogTag: "cff-log", + FunctionNameSeparator: "@", + } + logger.GetLogger().Infof("config is: %v", configSingleton.config) + urnutils.SetSeparator(configSingleton.config.FunctionNameSeparator) + err := urnutils.LocalFuncURN.ParseFrom(configSingleton.config.RuntimeConfig.URN) + if err != nil { + logger.GetLogger().Errorf("failed to ParseFrom urn err is : %v", err.Error()) + } + }) + + if configSingleton.config == nil { + return nil, errors.New("failed to get worker config") + } + return configSingleton.config, nil +} diff --git a/api/go/faassdk/config/config_test.go b/api/go/faassdk/config/config_test.go new file mode 100644 index 0000000..e52aa95 --- /dev/null +++ b/api/go/faassdk/config/config_test.go @@ -0,0 +1,33 @@ +package config + +import ( + "testing" + + "yuanrong.org/kernel/runtime/faassdk/types" +) + +func TestGetConfig(t *testing.T) { + type args struct { + funcSpec *types.FuncSpec + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"case1 succedd to get config", args{funcSpec: &types.FuncSpec{ + FuncMetaData: types.FuncMetaData{ + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:0@yrservice@test-image-env", + }, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := GetConfig(tt.args.funcSpec) + if (err != nil) != tt.wantErr { + t.Errorf("GetConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/api/go/faassdk/entrance.go b/api/go/faassdk/entrance.go new file mode 100644 index 0000000..b0003c5 --- /dev/null +++ b/api/go/faassdk/entrance.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faassdk for init and start +package faassdk + +import ( + "fmt" + + "yuanrong.org/kernel/runtime/faassdk/handler/event" +) + +// Register register callEntry +func Register(callEntry event.CallEntry) { + event.SetUserCallEntry(callEntry) + err := InitRuntime() + if err != nil { + fmt.Print("init runtime failed: " + err.Error()) + return + } + Run() +} + +// RegisterInitializerFunction register initEntry +func RegisterInitializerFunction(initEntry event.InitEntry) { + event.SetUserInitEntry(initEntry) +} + +// RegisterPreStopFunction register preStopEntry +func RegisterPreStopFunction(preStopEntry event.PreStopEntry) { +} diff --git a/api/go/faassdk/entrance_test.go b/api/go/faassdk/entrance_test.go new file mode 100644 index 0000000..2273730 --- /dev/null +++ b/api/go/faassdk/entrance_test.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faassdk for init and start +package faassdk + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/faassdk/faas-sdk/go-api/context" +) + +func initEntry(ctx context.RuntimeContext) { + ctx.GetLogger().Infof("info log in initHandler") +} + +func preStopEntry(ctx context.RuntimeContext) { + ctx.GetLogger().Infof("info log in preStopHandler") +} + +func callEntry(_ []byte, _ context.RuntimeContext) (interface{}, error) { + return "hello world", nil +} + +func example() { + RegisterInitializerFunction(initEntry) + RegisterPreStopFunction(preStopEntry) + Register(callEntry) +} + +func TestRegister(t *testing.T) { + convey.Convey( + "Test Register", t, func() { + convey.Convey( + "Register success", func() { + convey.So(example, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/faas-sdk/go-api/context/client_context.go b/api/go/faassdk/faas-sdk/go-api/context/client_context.go new file mode 100644 index 0000000..245665e --- /dev/null +++ b/api/go/faassdk/faas-sdk/go-api/context/client_context.go @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package context for userCode +package context + +import ( + "yuanrong.org/kernel/runtime/faassdk/faas-sdk/pkg/runtime/userlog" +) + +// RuntimeContext for userCode +type RuntimeContext interface { + GetLogger() userlog.RuntimeLogger + + GetRequestID() string + + GetRemainingTimeInMilliSeconds() int + + GetAccessKey() string + + GetSecretKey() string + + GetUserData(string) string + + GetFunctionName() string + + GetRunningTimeInSeconds() int + + GetVersion() string + + GetMemorySize() int + + GetCPUNumber() int + + GetPackage() string + + GetToken() string + + GetProjectID() string + + GetState() string + + SetState(string) error + + GetInvokeProperty() string + + GetTraceID() string + + GetInvokeID() string + + GetAlias() string + + GetSecurityToken() string +} diff --git a/api/go/faassdk/faas-sdk/go-api/function/function.go b/api/go/faassdk/faas-sdk/go-api/function/function.go new file mode 100644 index 0000000..7194701 --- /dev/null +++ b/api/go/faassdk/faas-sdk/go-api/function/function.go @@ -0,0 +1,22 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package function is function sdk api +package function + +// Invoke Invoke - +func Invoke() { +} diff --git a/api/go/faassdk/faas-sdk/pkg/runtime/context/context.go b/api/go/faassdk/faas-sdk/pkg/runtime/context/context.go new file mode 100644 index 0000000..b052854 --- /dev/null +++ b/api/go/faassdk/faas-sdk/pkg/runtime/context/context.go @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package context +package context + +import ( + "yuanrong.org/kernel/runtime/faassdk/faas-sdk/pkg/runtime/userlog" +) + +// Env ContextEnv for InvokeHandler +type Env struct { + rtProjectID string + rtFcName string + rtFcVersion string + rtPackage string + rtMemory int + rtCPU int + rtTimeout int + rtStartTime int + rtUserData map[string]string + rtInitializerTimeout int +} + +// InvokeContext InvokeContext +type InvokeContext struct { + RequestID string `json:"request_id"` + Alias string `json:"alias"` + InvokeProperty string `json:"invoke_property"` + InvokeID string `json:"invoke_id"` + TraceID string `json:"trace_id"` +} + +// Provider ContextProvider for RuntimeContext +type Provider struct { + CtxEnv *Env + CtxHTTPHead *InvokeContext + Logger userlog.RuntimeLogger +} diff --git a/api/go/faassdk/faas-sdk/pkg/runtime/context/contextenv.go b/api/go/faassdk/faas-sdk/pkg/runtime/context/contextenv.go new file mode 100644 index 0000000..0da0878 --- /dev/null +++ b/api/go/faassdk/faas-sdk/pkg/runtime/context/contextenv.go @@ -0,0 +1,142 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package context +package context + +import ( + "os" + "time" + + "yuanrong.org/kernel/runtime/faassdk/faas-sdk/pkg/runtime/userlog" +) + +// GetRemainingTimeInMilliSeconds get remaining time +func (p Provider) GetRemainingTimeInMilliSeconds() int { + timeoutInMilliSeconds := p.CtxEnv.rtTimeout * (int)(time.Millisecond) + remainingTime := p.CtxEnv.rtStartTime + (timeoutInMilliSeconds) - getCurrentTime() + if remainingTime < 0 { + return 0 + } + return remainingTime +} + +// GetFunctionName get this functionName +func (p Provider) GetFunctionName() string { + return p.CtxEnv.rtFcName +} + +// GetRunningTimeInSeconds get timeout interval +func (p Provider) GetRunningTimeInSeconds() int { + return p.CtxEnv.rtTimeout +} + +// GetVersion get runtime version +func (p Provider) GetVersion() string { + return p.CtxEnv.rtFcVersion +} + +// GetMemorySize get memory size of runtime instances +func (p Provider) GetMemorySize() int { + return p.CtxEnv.rtMemory +} + +// GetCPUNumber get CPU usage of runtime instance +func (p Provider) GetCPUNumber() int { + return p.CtxEnv.rtCPU +} + +// GetUserData get userData from env +func (p Provider) GetUserData(key string) string { + if p.CtxEnv.rtUserData != nil { + return p.CtxEnv.rtUserData[key] + } + return "" +} + +// GetLogger get logger +func (p Provider) GetLogger() userlog.RuntimeLogger { + return p.Logger +} + +// GetProjectID get projectId +func (p Provider) GetProjectID() string { + return p.CtxEnv.rtProjectID +} + +// GetPackage get package +func (p Provider) GetPackage() string { + return p.CtxEnv.rtPackage +} + +// GetAccessKey get AccessKey +func (p Provider) GetAccessKey() string { + return "" +} + +// GetSecretKey get SecretKey +func (p Provider) GetSecretKey() string { + return "" +} + +// GetToken get token +func (p Provider) GetToken() string { + return "" +} + +// GetRequestID get requestId +func (p Provider) GetRequestID() string { + return p.CtxHTTPHead.RequestID +} + +// GetState get instance status +func (p Provider) GetState() string { + return os.Getenv("USER_DEFINED_STATE") +} + +// SetState set instance status +func (p Provider) SetState(state string) error { + err := os.Setenv("USER_DEFINED_STATE", state) + if err != nil { + return err + } + return nil +} + +// GetInvokeProperty get invoke property +func (p Provider) GetInvokeProperty() string { + return p.CtxHTTPHead.InvokeProperty +} + +// GetTraceID get traceId +func (p Provider) GetTraceID() string { + return p.CtxHTTPHead.TraceID +} + +// GetInvokeID get invokeId +func (p Provider) GetInvokeID() string { + return p.CtxHTTPHead.InvokeID +} + +// GetAlias get function alias +func (p Provider) GetAlias() string { + return p.CtxHTTPHead.Alias +} + +// GetSecurityToken get token +func (p Provider) GetSecurityToken() string { + return "" +} diff --git a/api/go/faassdk/faas-sdk/pkg/runtime/context/contextenv_test.go b/api/go/faassdk/faas-sdk/pkg/runtime/context/contextenv_test.go new file mode 100644 index 0000000..cf920d1 --- /dev/null +++ b/api/go/faassdk/faas-sdk/pkg/runtime/context/contextenv_test.go @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package context +package context + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestProvider(t *testing.T) { + convey.Convey( + "Test Provider", t, func() { + p := &Provider{ + CtxEnv: &Env{}, + CtxHTTPHead: &InvokeContext{}, + } + convey.Convey( + "GetRemainingTimeInMilliSeconds success", func() { + i := p.GetRemainingTimeInMilliSeconds() + convey.So(i, convey.ShouldBeZeroValue) + }, + ) + convey.Convey( + "GetFunctionName success", func() { + str := p.GetFunctionName() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetRunningTimeInSeconds success", func() { + i := p.GetRunningTimeInSeconds() + convey.So(i, convey.ShouldBeZeroValue) + }, + ) + convey.Convey( + "GetVersion success", func() { + str := p.GetVersion() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetMemorySize success", func() { + i := p.GetMemorySize() + convey.So(i, convey.ShouldBeZeroValue) + }, + ) + convey.Convey( + "GetCPUNumber success", func() { + i := p.GetCPUNumber() + convey.So(i, convey.ShouldBeZeroValue) + }, + ) + convey.Convey( + "GetUserData success", func() { + str := p.GetUserData("key") + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetLogger success", func() { + l := p.GetLogger() + convey.So(l, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GetProjectID success", func() { + str := p.GetProjectID() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetPackage success", func() { + str := p.GetPackage() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetAccessKey success", func() { + str := p.GetAccessKey() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetSecretKey success", func() { + str := p.GetSecretKey() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetToken success", func() { + str := p.GetToken() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetRequestID success", func() { + str := p.GetRequestID() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetState success", func() { + str := p.GetState() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "SetState success", func() { + err := p.SetState("state") + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GetInvokeProperty success", func() { + str := p.GetInvokeProperty() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetTraceID success", func() { + str := p.GetTraceID() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetInvokeID success", func() { + str := p.GetInvokeID() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetAlias success", func() { + str := p.GetAlias() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + convey.Convey( + "GetSecurityToken success", func() { + str := p.GetSecurityToken() + convey.So(str, convey.ShouldBeEmpty) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/faas-sdk/pkg/runtime/context/runtime_context.go b/api/go/faassdk/faas-sdk/pkg/runtime/context/runtime_context.go new file mode 100644 index 0000000..fe03ae5 --- /dev/null +++ b/api/go/faassdk/faas-sdk/pkg/runtime/context/runtime_context.go @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package context +package context + +import ( + "encoding/json" + "fmt" + "log" + "os" + "strconv" + "time" +) + +const ( + defaultTimeout = "3" +) + +func atoi(input string) int { + result, err := strconv.Atoi(input) + if err != nil { + log.Println("execute strconv.atoi failed") + } + return result +} + +func getCurrentTime() int { + return (int)(time.Now().UnixNano() / int64(time.Millisecond)) +} + +// InitializeContext initialize context from env +func (p *Provider) InitializeContext() error { + if err := p.CtxEnv.initializeContextEnv(); err != nil { + return err + } + + return p.CtxHTTPHead.initializeInvokeContext() +} + +func (ic *InvokeContext) initializeInvokeContext() error { + delegateDecrypt := os.Getenv("ENV_DELEGATE_DECRYPT") + if delegateDecrypt == "" { + return nil + } + if err := json.Unmarshal([]byte(delegateDecrypt), ic); err != nil { + return fmt.Errorf("initializeContext failed to Unmarshal ENV_DELEGATE_DECRYPT, error: %s", err) + } + return nil +} + +func (e *Env) initializeContextEnv() error { + e.rtStartTime = getCurrentTime() + timeout := os.Getenv("RUNTIME_TIMEOUT") + if timeout == "" { + timeout = defaultTimeout + } + e.rtTimeout = atoi(timeout) + rtProjectID := os.Getenv("RUNTIME_PROJECT_ID") + if rtProjectID != "" { + e.rtProjectID = rtProjectID + } + rtPackage := os.Getenv("RUNTIME_PACKAGE") + if rtPackage != "" { + e.rtPackage = rtPackage + } + rtFcName := os.Getenv("RUNTIME_FUNC_NAME") + if rtFcName != "" { + e.rtFcName = rtFcName + } + rtFcVersion := os.Getenv("RUNTIME_FUNC_VERSION") + if rtFcVersion != "" { + e.rtFcVersion = rtFcVersion + } + rtMemory := os.Getenv("RUNTIME_MEMORY") + if rtMemory != "" { + e.rtMemory = atoi(rtMemory) + } + rtCPU := os.Getenv("RUNTIME_CPU") + if rtCPU != "" { + e.rtCPU = atoi(rtCPU) + } + rtUserData := os.Getenv("RUNTIME_USERDATA") + if rtUserData != "" { + err := json.Unmarshal([]byte(rtUserData), &e.rtUserData) + if err != nil { + return fmt.Errorf("initializeContext failed to Unmarshal Userdata, error: %s", err) + } + } + return nil +} diff --git a/api/go/faassdk/faas-sdk/pkg/runtime/context/runtime_context_test.go b/api/go/faassdk/faas-sdk/pkg/runtime/context/runtime_context_test.go new file mode 100644 index 0000000..ecb03b3 --- /dev/null +++ b/api/go/faassdk/faas-sdk/pkg/runtime/context/runtime_context_test.go @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package context +package context + +import ( + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestInitializeContext(t *testing.T) { + convey.Convey( + "Test InitializeContext", t, func() { + convey.Convey( + "InitializeContext success", func() { + p := &Provider{ + CtxEnv: &Env{}, + CtxHTTPHead: &InvokeContext{}, + } + os.Setenv("RUNTIME_PROJECT_ID", "rtProjectID") + os.Setenv("RUNTIME_PACKAGE", "rtPackage") + os.Setenv("RUNTIME_FUNC_NAME", "rtFcName") + os.Setenv("RUNTIME_FUNC_VERSION", "rtFcVersion") + os.Setenv("RUNTIME_MEMORY", "1") + os.Setenv("RUNTIME_CPU", "1") + err := p.InitializeContext() + convey.So(err, convey.ShouldBeNil) + os.Setenv("RUNTIME_USERDATA", "1") + err = p.InitializeContext() + convey.So(err, convey.ShouldNotBeNil) + os.Setenv("RUNTIME_USERDATA", "") + os.Setenv("ENV_DELEGATE_DECRYPT", "1") + err = p.InitializeContext() + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestAtoi(t *testing.T) { + convey.Convey( + "Test atoi", t, func() { + convey.Convey( + "atoi success", func() { + i := atoi("a") + convey.So(i, convey.ShouldBeZeroValue) + }, + ) + }, + ) +} diff --git a/api/go/faassdk/faas-sdk/pkg/runtime/userlog/user_logger.go b/api/go/faassdk/faas-sdk/pkg/runtime/userlog/user_logger.go new file mode 100644 index 0000000..5d73c34 --- /dev/null +++ b/api/go/faassdk/faas-sdk/pkg/runtime/userlog/user_logger.go @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package userlog for printing go runtime logger +package userlog + +// RuntimeLogger logger sdk for user +type RuntimeLogger interface { + Infof(format string, params ...interface{}) + Debugf(format string, params ...interface{}) + Warnf(format string, params ...interface{}) + Errorf(format string, params ...interface{}) +} diff --git a/api/go/faassdk/faashandler.go b/api/go/faassdk/faashandler.go new file mode 100644 index 0000000..d7de871 --- /dev/null +++ b/api/go/faassdk/faashandler.go @@ -0,0 +1,207 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package faassdk +This package provides methods to obtain the execution interface. +*/ +package faassdk + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/handler" + "yuanrong.org/kernel/runtime/faassdk/handler/event" + "yuanrong.org/kernel/runtime/faassdk/handler/http" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/libruntime" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" + "yuanrong.org/kernel/runtime/libruntime/config" + "yuanrong.org/kernel/runtime/libruntime/execution" + "yuanrong.org/kernel/runtime/libruntime/libruntimesdkimpl" +) + +type faasHandlers struct { + execType execution.ExecutionType + faasexecutor handler.ExecutorHandler + libruntimeAPI api.LibruntimeAPI +} + +var ( + faasHdlrs faasHandlers = faasHandlers{ + execType: execution.ExecutionTypeFaaS, + libruntimeAPI: libruntimesdkimpl.NewLibruntimeSDKImpl(), + } +) + +// GetExecType get execution type +func (h *faasHandlers) GetExecType() execution.ExecutionType { + return h.execType +} + +// LoadFunction load user function +func (h *faasHandlers) LoadFunction(codePaths []string) error { + // 初始化日志 + logger.SetupLogger(h.libruntimeAPI.GetFormatLogger()) + return nil +} + +// FunctionExecute function execute hook +func (h *faasHandlers) FunctionExecute(funcMeta api.FunctionMeta, invokeType config.InvokeType, + args []api.Arg, returnobjs []config.DataObject) error { + var ret []byte + var err error + switch invokeType { + case config.CreateInstance, config.CreateInstanceStateless: + funcSpec := getFuncSpecFromArgs(args) + if funcSpec == nil { + return errors.New("invalid funcSpec") + } + h.InitFaaSExecutor(funcSpec, h.libruntimeAPI) + if h.faasexecutor != nil { + ret, err = h.faasexecutor.InitHandler(args, h.libruntimeAPI) + } + case config.InvokeInstance, config.InvokeInstanceStateless: + traceID := string(args[0].Data) + context := map[string]string{"traceID": traceID} + ret, err = h.faasexecutor.CallHandler(args, context) + default: + err = fmt.Errorf("no such invokeType %d", invokeType) + } + if err != nil { + return err + } + + var totalNativeBufferSize uint = 0 + var do *config.DataObject + if len(returnobjs) > 0 { + do = &returnobjs[0] + } + if do != nil && do.ID == "returnByMsg" { + libruntime.SetReturnObject(do, uint(len(ret))) + } else { + if err = libruntime.AllocReturnObject(do, uint(len(ret)), []string{}, &totalNativeBufferSize); err != nil { + return err + } + } + + if err = libruntime.WriterLatch(do); err != nil { + return err + } + defer func() { + if err = libruntime.WriterUnlatch(do); err != nil { + } + }() + + if err = libruntime.MemoryCopy(do, ret); err != nil { + return err + } + + if err = libruntime.Seal(do); err != nil { + return err + } + return nil +} + +// Checkpoint check point +func (h *faasHandlers) Checkpoint(checkpointID string) ([]byte, error) { + if h.faasexecutor == nil { + return nil, errors.New("faasexcutor not initialized") + } + return h.faasexecutor.CheckPointHandler(checkpointID) +} + +// Recover recover hook +func (h *faasHandlers) Recover(state []byte) error { + if h.faasexecutor == nil { + return errors.New("faasexcutor not initialized") + } + return h.faasexecutor.RecoverHandler(state) +} + +// Shutdown hook +func (h *faasHandlers) Shutdown(gracePeriod uint64) error { + if h.faasexecutor == nil { + return errors.New("faasexcutor not initialized") + } + return h.faasexecutor.ShutDownHandler(gracePeriod) +} + +// Signal hook +func (h *faasHandlers) Signal(sig int, data []byte) error { + if h.faasexecutor == nil { + return errors.New("faasexcutor not initialized") + } + waitFaaSExecutorDone() + return h.faasexecutor.SignalHandler(int32(sig), data) +} + +// HealthCheck 函数是一个健康检查函数,用于检查 faaexecutor 的状态 +func (h *faasHandlers) HealthCheck() (api.HealthType, error) { + if h.faasexecutor == nil { + return api.Healthy, nil + } + waitFaaSExecutorDone() + return h.faasexecutor.HealthCheckHandler() +} + +func newFaaSFuncExecutionIntfs() execution.FunctionExecutionIntfs { + return &faasHdlrs +} + +func getFuncSpecFromArgs(args []api.Arg) *types.FuncSpec { + if len(args) < 1 { + logger.GetLogger().Errorf("invalid args number %d", len(args)) + return nil + } + funcSpec := &types.FuncSpec{} + err := json.Unmarshal(args[0].Data, funcSpec) + if err != nil { + logger.GetLogger().Errorf("failed to unmarshal funcSpec %s", err.Error()) + return nil + } + return funcSpec +} + +var ( + faasExecutorOnce sync.Once + faaExeecutorDoneChan = make(chan struct{}) +) + +func (h *faasHandlers) InitFaaSExecutor(funcSpec *types.FuncSpec, client api.LibruntimeAPI) { + faasExecutorOnce.Do(func() { + switch funcSpec.FuncMetaData.Runtime { + case constants.RuntimeTypeHttp: + h.faasexecutor = http.NewHttpHandler(funcSpec, client) + case constants.RuntimeTypeCustomContainer: + h.faasexecutor = http.NewCustomContainerHandler(funcSpec, client) + default: + h.faasexecutor = event.NewEventHandler(funcSpec, client) + } + close(faaExeecutorDoneChan) + }) +} + +func waitFaaSExecutorDone() { + select { + case <-faaExeecutorDoneChan: + } +} diff --git a/api/go/faassdk/faashandler_test.go b/api/go/faassdk/faashandler_test.go new file mode 100644 index 0000000..46e11fb --- /dev/null +++ b/api/go/faassdk/faashandler_test.go @@ -0,0 +1,302 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package faassdk +This package provides methods to obtain the execution interface. +*/ +package faassdk + +import ( + "encoding/json" + "reflect" + "testing" + "time" + "yuanrong.org/kernel/runtime/libruntime" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/handler" + "yuanrong.org/kernel/runtime/faassdk/handler/http" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/config" + "yuanrong.org/kernel/runtime/libruntime/execution" +) + +func TestFaasHandler_LoadFunction(t *testing.T) { + convey.Convey( + "Test FaasHandler LoadFunction", t, func() { + convey.Convey( + "Load function success", func() { + intfs := newFaaSFuncExecutionIntfs() + convey.So(intfs.GetExecType(), convey.ShouldEqual, execution.ExecutionTypeFaaS) + codePaths := []string{"/tmp"} + err := intfs.LoadFunction(codePaths) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func Test_faasHandlers_FunctionExecute(t *testing.T) { + convey.Convey("Test faasHandlers_FunctionExecute", t, func() { + convey.Convey("CreateInstance and InvokeInstance", func() { + // 测试创建实例的情况 + h := &faasHandlers{} + funcMeta := api.FunctionMeta{} + invokeType := config.CreateInstance + var args []api.Arg + returnobjs := []config.DataObject{} + + // 设置预期的错误 + err := h.FunctionExecute(funcMeta, invokeType, args, returnobjs) + convey.So(err.Error(), convey.ShouldEqual, "invalid funcSpec") + }) + convey.Convey("InvokeInstance", func() { + // 测试调用实例的情况 + h := &faasHandlers{} + funcMeta := api.FunctionMeta{FuncName: "test function"} + funcSpec := &types.FuncSpec{FuncMetaData: types.FuncMetaData{FunctionName: "test function", Runtime: constants.RuntimeTypeCustomContainer}} + bytes, _ := json.Marshal(funcSpec) + invokeType := config.CreateInstance + args := []api.Arg{ + { + Type: api.Value, + Data: bytes, + }, + {}, + } + returnobjs := []config.DataObject{} + cch := &http.CustomContainerHandler{} + var patches []*gomonkey.Patches + patches = append(patches, + gomonkey.ApplyFunc(http.NewCustomContainerHandler, func(funcSpec *types.FuncSpec, client api.LibruntimeAPI) handler.ExecutorHandler { + return cch + }), + gomonkey.ApplyMethod(reflect.TypeOf(cch), "InitHandler", func(ch *http.CustomContainerHandler, args []api.Arg, rt api.LibruntimeAPI) ([]byte, error) { + return []byte("success"), nil + }), + gomonkey.ApplyFunc(libruntime.AllocReturnObject, func(do *config.DataObject, size uint, nestedIds []string, totalNativeBufferSize *uint) error { + return nil + }), + gomonkey.ApplyFunc(libruntime.WriterLatch, func(do *config.DataObject) error { + return nil + }), + gomonkey.ApplyFunc(libruntime.WriterUnlatch, func(do *config.DataObject) error { + return nil + }), + gomonkey.ApplyFunc(libruntime.MemoryCopy, func(do *config.DataObject, src []byte) error { + return nil + }), + gomonkey.ApplyFunc(libruntime.Seal, func(do *config.DataObject) error { + return nil + }), + ) + defer func() { + for _, patch := range patches { + time.Sleep(50 * time.Millisecond) + patch.Reset() + } + }() + // 设置预期的错误 + err := h.FunctionExecute(funcMeta, invokeType, args, returnobjs) + convey.So(err, convey.ShouldBeNil) + + invokeType = config.InvokeInstance + args = []api.Arg{ + { + Type: api.Value, + Data: []byte("trace ID"), + }, + { + Type: api.Value, + }, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(cch), "CallHandler", func(ch *http.CustomContainerHandler, args []api.Arg, context map[string]string) ([]byte, error) { + return []byte("success"), nil + }).Reset() + err = h.FunctionExecute(funcMeta, invokeType, args, returnobjs) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("InvalidInvokeType", func() { + // 测试无效的调用类型的情况 + h := &faasHandlers{} + funcMeta := api.FunctionMeta{} + invokeType := 99999 + args := []api.Arg{} + returnobjs := []config.DataObject{} + + // 设置预期的错误 + err := h.FunctionExecute(funcMeta, config.InvokeType(invokeType), args, returnobjs) + convey.So(err.Error(), convey.ShouldEqual, "no such invokeType 99999") + }) + }) +} + +func Test_faasHandlers_Checkpoint(t *testing.T) { + convey.Convey("Test_faasHandlers_Checkpoint", t, func() { + convey.Convey("faasexecutor is nil", func() { + // 测试faasexecutor为nil的情况 + h := &faasHandlers{ + faasexecutor: nil, + } + _, err := h.Checkpoint("123") + convey.So(err.Error(), convey.ShouldEqual, "faasexcutor not initialized") + }) + convey.Convey("faasexecutor is not nil", func() { + // 测试faasexecutor不为nil的情况 + executor := &http.CustomContainerHandler{} + h := &faasHandlers{ + faasexecutor: executor, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(executor), "CheckPointHandler", func(_ *http.CustomContainerHandler, checkPointId string) ([]byte, error) { + return nil, nil + }).Reset() + _, err := h.Checkpoint("123") + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func Test_faasHandlers_Recover(t *testing.T) { + convey.Convey("Test_faasHandlers_Recover", t, func() { + convey.Convey("faasexecutor is nil", func() { + // 测试faasexecutor为nil的情况 + h := &faasHandlers{ + faasexecutor: nil, + } + err := h.Recover([]byte("123")) + convey.So(err.Error(), convey.ShouldEqual, "faasexcutor not initialized") + }) + convey.Convey("faasexecutor is not nil", func() { + // 测试faasexecutor不为nil的情况 + executor := &http.CustomContainerHandler{} + h := &faasHandlers{ + faasexecutor: executor, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(executor), "RecoverHandler", func(_ *http.CustomContainerHandler, state []byte) error { + return nil + }).Reset() + err := h.Recover([]byte("123")) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func Test_faasHandlers_Shutdown(t *testing.T) { + convey.Convey("Test_faasHandlers_Shutdown", t, func() { + convey.Convey("faasexecutor is nil", func() { + // 测试faasexecutor为nil的情况 + h := &faasHandlers{ + faasexecutor: nil, + } + err := h.Shutdown(60) + convey.So(err.Error(), convey.ShouldEqual, "faasexcutor not initialized") + }) + convey.Convey("faasexecutor is not nil", func() { + // 测试faasexecutor不为nil的情况 + executor := &http.CustomContainerHandler{} + h := &faasHandlers{ + faasexecutor: executor, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(executor), "ShutDownHandler", func(_ *http.CustomContainerHandler, gracePeriodSecond uint64) error { + return nil + }).Reset() + err := h.Shutdown(60) + convey.So(err, convey.ShouldBeNil) + }) + }) + +} + +func Test_faasHandlers_Signal(t *testing.T) { + convey.Convey("Test_faasHandlers_Signal", t, func() { + convey.Convey("faasexecutor is nil", func() { + // 测试faasexecutor为nil的情况 + h := &faasHandlers{ + faasexecutor: nil, + } + err := h.Signal(65, []byte("123")) + convey.So(err.Error(), convey.ShouldEqual, "faasexcutor not initialized") + }) + convey.Convey("faasexecutor is not nil", func() { + // 测试faasexecutor不为nil的情况 + executor := &http.CustomContainerHandler{} + h := &faasHandlers{ + faasexecutor: executor, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(executor), "SignalHandler", func(_ *http.CustomContainerHandler, signal int32, payload []byte) error { + return nil + }).Reset() + err := h.Signal(65, []byte("123")) + convey.So(err, convey.ShouldBeNil) + }) + }) + +} + +func Test_faasHandlers_HealthCheck(t *testing.T) { + convey.Convey("Test_faasHandlers_HealthCheck", t, func() { + convey.Convey("faasexecutor is nil", func() { + // 测试faasexecutor为nil的情况 + h := &faasHandlers{ + faasexecutor: nil, + } + health, _ := h.HealthCheck() + convey.So(health, convey.ShouldEqual, api.Healthy) + }) + convey.Convey("faasexecutor is not nil", func() { + // 测试faasexecutor不为nil的情况 + faaExeecutorDoneChan = make(chan struct{}) + defer func() { + faaExeecutorDoneChan = make(chan struct{}) + }() + executor := &http.CustomContainerHandler{} + h := &faasHandlers{ + faasexecutor: executor, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(executor), "HealthCheckHandler", func(_ *http.CustomContainerHandler) (api.HealthType, error) { + return api.SubHealth, nil + }).Reset() + go func() { + close(faaExeecutorDoneChan) + }() + health, _ := h.HealthCheck() + convey.So(health, convey.ShouldEqual, api.SubHealth) + }) + }) +} + +func TestInitFaaSExecutor(t *testing.T) { + convey.Convey("Test InitFaaSExecutor", t, func() { + convey.Convey("InitFaaSExecutor when case constants.RuntimeTypeHttp", func() { + funcSpec := &types.FuncSpec{FuncMetaData: types.FuncMetaData{Runtime: "http"}} + convey.So(func() { + faasHdlrs.InitFaaSExecutor(funcSpec, faasHdlrs.libruntimeAPI) + }, convey.ShouldNotPanic) + }) + convey.Convey("InitFaaSExecutor when default", func() { + funcSpec := &types.FuncSpec{FuncMetaData: types.FuncMetaData{Runtime: "default"}} + convey.So(func() { + faasHdlrs.InitFaaSExecutor(funcSpec, faasHdlrs.libruntimeAPI) + }, convey.ShouldNotPanic) + }) + }) +} diff --git a/api/go/faassdk/handler/event/future.go b/api/go/faassdk/handler/event/future.go new file mode 100644 index 0000000..d7ec22e --- /dev/null +++ b/api/go/faassdk/handler/event/future.go @@ -0,0 +1,124 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package event - +package event + +import ( + "context" + "fmt" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/utils" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const defaultResNums = 2 + +type timeoutCallBack func() ([]byte, error) +type respCallBack func(res []interface{}) ([]byte, error) +type handleErrorRespCallBack func(body interface{}, statusCode int) ([]byte, error) + +func (eh *EventHandler) initWithTimeout() ([]byte, error) { + timeout := eh.initTimeout + resCh := make(chan []interface{}, 1) + go func() { + eh.userInitEntry(updateInvokeContext(map[string]string{})) + resCh <- nil + }() + + handleErrorResp := func(body interface{}, statusCode int) ([]byte, error) { + return utils.HandleInitResponse(body, statusCode) + } + timeoutCB := func() ([]byte, error) { + logger.GetLogger().Errorf("runtime initialization timed out after %ds", timeout) + return handleErrorResp(fmt.Sprintf("runtime initialization timed out after %ds", + timeout), constants.InitFunctionTimeout) + } + respCB := func(res []interface{}) ([]byte, error) { + return []byte("success"), nil + } + + return waitRespProcess(time.Duration(timeout), resCh, timeoutCB, respCB, handleErrorResp) +} + +func (eh *EventHandler) invokeWithTimeout(event []byte, context map[string]string, + traceID string, totalTime utils.ExecutionDuration) ([]byte, error) { + resCh := make(chan []interface{}, 1) + logWithID := logger.GetLogger().With(zap.Any("traceID", traceID)) + go func() { + totalTime.UserFuncBeginTime = time.Now() + userRes, userErr := eh.userCallEntry(event, updateInvokeContext(context)) + resAndErr := []interface{}{userRes, userErr} + resCh <- resAndErr + totalTime.UserFuncTotalTime = time.Since(totalTime.UserFuncBeginTime) + }() + + handleErrorResp := func(body interface{}, statusCode int) ([]byte, error) { + return utils.HandleCallResponse(body, statusCode, "", totalTime, nil) + } + timeout := eh.invokeTimeout + timeoutCB := func() ([]byte, error) { + totalTime.UserFuncTotalTime = time.Since(totalTime.UserFuncBeginTime) + logWithID.Errorf("call invoke timeout %d", timeout) + return handleErrorResp(fmt.Sprintf("call invoke timeout %ds", + timeout), constants.InvokeFunctionTimeout) + } + respCB := func(res []interface{}) ([]byte, error) { + if len(res) < defaultResNums { + logWithID.Errorf("invalid res nums") + return handleErrorResp(fmt.Sprintf("invalid res num, traceID: %s", traceID), + constants.FunctionRunError) + } + userRes := res[0] + userErr := res[1] + if userErr != nil { + logWithID.Errorf("faas invoke user function,error: %s", userErr.(error)) + return handleErrorResp(fmt.Sprintf("faas invoke user function error: %s", userErr.(error)), + constants.FunctionRunError) + } + return utils.HandleCallResponse(userRes, constants.NoneError, "", totalTime, nil) + } + + return waitRespProcess(time.Duration(timeout), resCh, timeoutCB, respCB, handleErrorResp) +} + +func waitRespProcess(timeout time.Duration, resCh <-chan []interface{}, timeoutCB timeoutCallBack, + respCallBack respCallBack, handleErrorResp handleErrorRespCallBack) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.TODO(), timeout*time.Second) + defer cancel() + + select { + case <-ctx.Done(): + logger.GetLogger().Errorf("invoke timed out after %d ms", timeout) + return timeoutCB() + case _, ok := <-stopCh: + logger.GetLogger().Warnf("failed to call invoke method, the runtime process exit") + if !ok { + return handleErrorResp("failed to call invoke method, the runtime process exit", + constants.FunctionRunError) + } + case res, ok := <-resCh: + if !ok { + return handleErrorResp("invoke response channel closed error", constants.ExecutorErrCodeInitFail) + } + return respCallBack(res) + } + return []byte{}, nil +} diff --git a/api/go/faassdk/handler/event/future_test.go b/api/go/faassdk/handler/event/future_test.go new file mode 100644 index 0000000..97db333 --- /dev/null +++ b/api/go/faassdk/handler/event/future_test.go @@ -0,0 +1,214 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package event - +package event + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/faas-sdk/go-api/context" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils" +) + +func newFuncSpec() *types.FuncSpec { + return &types.FuncSpec{ + FuncMetaData: types.FuncMetaData{ + FunctionName: "test-future-fuction", + Runtime: "go1.x", + TenantId: "123456789", + Version: "$latest", + Timeout: 10, + }, + ResourceMetaData: types.ResourceMetaData{}, + ExtendedMetaData: types.ExtendedMetaData{}, + } +} + +func TestInitWithTimeout_OK(t *testing.T) { + eventHandler := &EventHandler{ + funcSpec: newFuncSpec(), + libMap: nil, + functionName: "test-future-fuction", + args: nil, + client: nil, + userInitEntry: nil, + userCallEntry: nil, + } + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&EventHandler{}), "SetHandler", + func(eh *EventHandler, functionLibPath, handler string) error { + if handler == initHandlerName { + eh.userInitEntry = func(ctx context.RuntimeContext) { + fmt.Println("userCode init start") + fmt.Println("userCode init success") + } + } + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + eventHandler.SetHandler("", initHandlerName) + _, err := eventHandler.initWithTimeout() + convey.ShouldBeNil(err) +} + +func TestInitWithTimeout_Fail(t *testing.T) { + eventHandler := &EventHandler{ + funcSpec: newFuncSpec(), + libMap: nil, + functionName: "test-future-fuction", + args: nil, + client: nil, + userInitEntry: nil, + userCallEntry: nil, + } + + eventHandler.funcSpec.FuncMetaData.Timeout = 3 + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&EventHandler{}), "SetHandler", + func(eh *EventHandler, functionLibPath, handler string) error { + if handler == initHandlerName { + eh.userInitEntry = func(ctx context.RuntimeContext) { + fmt.Println("userCode init start") + fmt.Println("userCode init success") + time.Sleep(5 * time.Second) + } + } + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + eventHandler.SetHandler("", initHandlerName) + _, err := eventHandler.initWithTimeout() + convey.ShouldBeNil(err) + var res *types.CallResponse + _ = json.Unmarshal([]byte(err.Error()), &res) + convey.ShouldEqual(res.InnerCode, constants.InitFunctionTimeout) +} + +func TestInvokeWithTimeout_OK(t *testing.T) { + eventHandler := &EventHandler{ + funcSpec: newFuncSpec(), + libMap: nil, + functionName: "test-future-fuction", + args: nil, + client: nil, + userInitEntry: nil, + userCallEntry: nil, + } + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&EventHandler{}), "SetHandler", + func(eh *EventHandler, functionLibPath, handler string) error { + if handler == callHandlerName { + eh.userCallEntry = func(payload []byte, ctx context.RuntimeContext) (interface{}, error) { + fmt.Println("Handler function") + fmt.Println("payload:", string(payload)) + return "success", nil + } + } + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + eventHandler.SetHandler("", callHandlerName) + res, err := eventHandler.invokeWithTimeout(getCallArgs()[1].Data, map[string]string{}, "", utils.ExecutionDuration{}) + convey.ShouldBeNil(err) + convey.ShouldContain(res, "success") +} + +func TestInvokeWithTimeout_Fail(t *testing.T) { + eventHandler := &EventHandler{ + funcSpec: newFuncSpec(), + libMap: nil, + functionName: "test-future-fuction", + args: nil, + client: nil, + userInitEntry: nil, + userCallEntry: nil, + } + + eventHandler.funcSpec.FuncMetaData.Timeout = 3 + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&EventHandler{}), "SetHandler", + func(eh *EventHandler, functionLibPath, handler string) error { + if handler == callHandlerName { + eh.userCallEntry = func(payload []byte, ctx context.RuntimeContext) (interface{}, error) { + fmt.Println("Handler function") + time.Sleep(5 * time.Second) + fmt.Println("payload:", string(payload)) + return "success", nil + } + } + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + eventHandler.SetHandler("", callHandlerName) + result, err := eventHandler.invokeWithTimeout(getCallArgs()[1].Data, map[string]string{}, "", utils.ExecutionDuration{}) + convey.ShouldBeError(err) + var res *types.CallResponse + _ = json.Unmarshal(result, &res) + convey.ShouldEqual(res.InnerCode, constants.InvokeFunctionTimeout) +} + +func TestWaitRespProcess(t *testing.T) { + convey.Convey( + "Test waitRespProcess", t, func() { + convey.Convey("waitRespProcess success", func() { + go func() { + stopCh <- struct{}{} + }() + resCh := make(chan []interface{}) + timeoutCB := func() ([]byte, error) { return nil, nil } + respCB := func(res []interface{}) ([]byte, error) { return nil, nil } + handleErrorRespCB := func(body interface{}, statusCode int) ([]byte, error) { return nil, nil } + bytes, err := waitRespProcess(time.Second, resCh, timeoutCB, respCB, handleErrorRespCB) + convey.So(bytes, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + close(resCh) + bytes, err = waitRespProcess(time.Second, resCh, timeoutCB, respCB, handleErrorRespCB) + convey.So(bytes, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }) + }, + ) +} diff --git a/api/go/faassdk/handler/event/handler.go b/api/go/faassdk/handler/event/handler.go new file mode 100644 index 0000000..c10868a --- /dev/null +++ b/api/go/faassdk/handler/event/handler.go @@ -0,0 +1,340 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package event - +package event + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "plugin" + "strconv" + "strings" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/config" + contextApi "yuanrong.org/kernel/runtime/faassdk/faas-sdk/go-api/context" + contextImpl "yuanrong.org/kernel/runtime/faassdk/faas-sdk/pkg/runtime/context" + "yuanrong.org/kernel/runtime/faassdk/handler" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils" + "yuanrong.org/kernel/runtime/libruntime/api" + log "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" + runtimeLog "yuanrong.org/kernel/runtime/libruntime/common/logger/log" +) + +const ( + userEntryIndex = 1 + userLogLevel = "info" + defaultTimeout = 3 +) + +var ( + defaultContext *contextImpl.Provider + stopCh = make(chan struct{}) + preSetUserInitEntry InitEntry + preSetUserCallEntry CallEntry +) + +// InitEntry of user code +type InitEntry = func(ctx contextApi.RuntimeContext) + +// PreStopEntry of user code +type PreStopEntry = func(ctx contextApi.RuntimeContext) + +// CallEntry of user code +type CallEntry = func([]byte, contextApi.RuntimeContext) (interface{}, error) + +// EventHandler Data required by the executor +type EventHandler struct { + funcSpec *types.FuncSpec + libMap map[string]*plugin.Plugin + functionName string + args []*api.Arg + client api.LibruntimeAPI + userInitEntry InitEntry + userCallEntry CallEntry + initTimeout int + invokeTimeout int +} + +// SetUserInitEntry set EventHandler's userInitEntry +func SetUserInitEntry(initEntry InitEntry) { + preSetUserInitEntry = initEntry +} + +// SetUserCallEntry set EventHandler's userCallEntry +func SetUserCallEntry(callEntry CallEntry) { + preSetUserCallEntry = callEntry +} + +// NewEventHandler creates EventHandler +func NewEventHandler(funcSpec *types.FuncSpec, client api.LibruntimeAPI) handler.ExecutorHandler { + var initTimeout, invokeTimeout int + if initTimeout = funcSpec.ExtendedMetaData.Initializer.Timeout; initTimeout == 0 { + initTimeout = defaultTimeout + } + if invokeTimeout = funcSpec.FuncMetaData.Timeout; invokeTimeout == 0 { + invokeTimeout = defaultTimeout + } + return &EventHandler{ + funcSpec: funcSpec, + libMap: make(map[string]*plugin.Plugin), + args: make([]*api.Arg, 1), + client: client, + initTimeout: initTimeout, + invokeTimeout: invokeTimeout, + userInitEntry: preSetUserInitEntry, + userCallEntry: preSetUserCallEntry, + } +} + +// getHandlerName from api.Arg +func getHandlerName(args []api.Arg) (*types.CreateParams, error) { + if len(args) != constants.ValidBasicCreateParamSize { + return nil, errors.New("invalid args number") + } + createParams := &types.CreateParams{} + err := json.Unmarshal(args[userEntryIndex].Data, createParams) + if err != nil { + return nil, err + } + return createParams, nil +} + +// InitHandler - +func (eh *EventHandler) InitHandler(args []api.Arg, rt api.LibruntimeAPI) ([]byte, error) { + log.GetLogger().Infof("start to init user code") + log.SetupUserLogger(userLogLevel) + path, err := handler.GetUserCodePath() + if err != nil { + return utils.HandleInitResponse(err.Error(), constants.ExecutorErrCodeInitFail) + } + if eh.userCallEntry == nil { + userHook, err := getHandlerName(args) + if err != nil { + return utils.HandleInitResponse(err.Error(), constants.ExecutorErrCodeInitFail) + } + if err = eh.SetHandler(path, userHook.InitEntry); err != nil { + return utils.HandleInitResponse(err.Error(), constants.InitFunctionFail) + } + if err = eh.SetHandler(path, userHook.CallEntry); err != nil { + return utils.HandleInitResponse(err.Error(), constants.InitFunctionFail) + } + } + eh.setEnvContext() + if err = initContext(); err != nil { + log.GetLogger().Errorf("init context failed, error: %s", err) + return utils.HandleInitResponse(fmt.Sprintf("init context failed, error: %s", err.Error()), + constants.ExecutorErrCodeInitFail) + } + if eh.userInitEntry != nil { + if result, err := eh.initWithTimeout(); err != nil { + log.GetLogger().Errorf("faas-executor failed to init user code, err: %s", err.Error()) + return result, err + } + } + log.GetLogger().Infof("faas-executor succeed to init user code") + return []byte{}, nil +} + +// CallHandler - +func (eh *EventHandler) CallHandler(args []api.Arg, context map[string]string) ([]byte, error) { + traceID := context["traceID"] + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + totalTime := utils.ExecutionDuration{ + ExecutorBeginTime: time.Now(), + } + logger.Infof("faas-executor call handler, function: %s", eh.funcSpec.FuncMetaData.FunctionName) + if len(args) != constants.ValidInvokeArgumentSize { + return utils.HandleCallResponse("invalid invoke argument", + constants.FaaSError, "", totalTime, nil) + } + if int32(len(args[1].Data)) >= config.MaxPayloadSize { + logger.Errorf("the size of the invoke payload exceeds the limit %d", config.MaxPayloadSize) + return utils.HandleCallResponse( + fmt.Sprintf("the size of the invoke payload exceeds the limit %d", config.MaxPayloadSize), + constants.FaaSError, "", totalTime, nil) + } + if eh.userCallEntry == nil { + logger.Errorf("invoke handler is nil") + return utils.HandleCallResponse("invoke handler is nil", + constants.EntryNotFound, "", totalTime, nil) + } + + userCallRequest := &types.CallResponse{} + if err := json.Unmarshal(args[1].Data, userCallRequest); err != nil { + logger.Errorf("unmarshal invoke call request data: %s, err: %s", string(args[1].Data), err) + return utils.HandleCallResponse(fmt.Sprintf("unmarshal invoke call request data: %s, err: %s", + string(args[1].Data), err), + constants.FaaSError, "", totalTime, nil) + } + return eh.invokeWithTimeout(userCallRequest.Body, context, traceID, totalTime) +} + +// CheckPointHandler - +func (eh *EventHandler) CheckPointHandler(checkPointId string) ([]byte, error) { + return nil, nil +} + +// RecoverHandler - +func (eh *EventHandler) RecoverHandler(state []byte) error { + return nil +} + +// ShutDownHandler - +func (eh *EventHandler) ShutDownHandler(gracePeriodSecond uint64) error { + log.GetLogger().Sync() + runtimeLog.GetLogger().Sync() + return nil +} + +// SignalHandler - +func (eh *EventHandler) SignalHandler(signal int32, payload []byte) error { + return nil +} + +// HealthCheckHandler - +func (eh *EventHandler) HealthCheckHandler() (api.HealthType, error) { + return api.Healthy, nil +} + +// initContext init envContext from env +func initContext() error { + userFunctionLog := log.GetUserLogger() + if userFunctionLog == nil { + log.GetLogger().Errorf("init user function log error: cannot create file user-function.log") + return fmt.Errorf("init user function log error: cannot create file user-function.log") + } + + defaultContext = &contextImpl.Provider{ + CtxEnv: &contextImpl.Env{}, + CtxHTTPHead: &contextImpl.InvokeContext{}, + Logger: userFunctionLog, + } + return defaultContext.InitializeContext() +} + +// SetHandler will set user entries +func (eh *EventHandler) SetHandler(functionLibPath, handler string) error { + symbol, err := eh.getLib(functionLibPath, handler) + if err != nil { + return fmt.Errorf("getLib error: %s", err) + } + return eh.setHandlerByEntry(symbol, handler) +} + +func (eh *EventHandler) setHandlerByEntry(symbol plugin.Symbol, handler string) error { + switch symbol.(type) { + case InitEntry: + if eh.userInitEntry == nil { + eh.userInitEntry = symbol.(InitEntry) + } + case CallEntry: + if eh.userCallEntry == nil { + eh.userCallEntry = symbol.(CallEntry) + } + default: + return fmt.Errorf("%s type error", handler) + } + return nil +} + +// Open plugin and set Symbol from plugin +func (eh *EventHandler) getLib(functionLibPath, handler string) (plugin.Symbol, error) { + path, name := utils.GetLibInfo(functionLibPath, handler) + if path == "" { + return nil, fmt.Errorf("invalid handler name :%s", handler) + } + lib, ok := eh.libMap[path] + if !ok { + log.GetLogger().Infof("start to open lib %s", path) + userCodePlugin, err := plugin.Open(path) + if err != nil { + log.GetLogger().Errorf("failed to open lib %v", err) + return nil, fmt.Errorf("failed to open %s", handler) + } + lib = userCodePlugin + eh.libMap[path] = userCodePlugin + } + symbol, err := lib.Lookup(name) + if err != nil { + log.GetLogger().Errorf("failed to look up %v", err) + return nil, fmt.Errorf("failed to look up %s", handler) + } + return symbol, nil +} + +// updateInvokeContext update invoke context from CallRequest.createOpt +func updateInvokeContext(invokeContext map[string]string) contextApi.RuntimeContext { + if defaultContext == nil || defaultContext.CtxHTTPHead == nil { + log.GetLogger().Warnf("default context is not initialized") + return &contextImpl.Provider{ + CtxEnv: &contextImpl.Env{}, + CtxHTTPHead: &contextImpl.InvokeContext{}, + } + } + defaultContext.CtxHTTPHead.RequestID = invokeContext["requestId"] + defaultContext.CtxHTTPHead.Alias = invokeContext["Alias"] + defaultContext.CtxHTTPHead.InvokeProperty = invokeContext["InvokeProperty"] + defaultContext.CtxHTTPHead.InvokeID = invokeContext["InvokeID"] + defaultContext.CtxHTTPHead.TraceID = invokeContext["TraceID"] + log.GetLogger().Infof("succeed to update context") + return defaultContext +} + +func (eh *EventHandler) setEnvContext() { + var err error + // deal with env + err, envMap := utils.DealEnv() + if err != nil { + log.GetLogger().Errorf("deal env from createOpt failed, err: %s", err) + } + userDataStr, err := json.Marshal(envMap) + if err != nil { + log.GetLogger().Errorf("setEnvContext failed to marshal Userdata, error: %s", err) + } + err = os.Setenv(constants.LDLibraryPath, + os.Getenv(constants.LDLibraryPath)+fmt.Sprintf(":%s", envMap[constants.LDLibraryPath])) + err = os.Setenv("RUNTIME_USERDATA", string(userDataStr)) + err = os.Setenv("RUNTIME_PROJECT_ID", eh.funcSpec.FuncMetaData.TenantId) + err = os.Setenv("RUNTIME_FUNC_NAME", eh.funcSpec.FuncMetaData.FunctionName) + err = os.Setenv("RUNTIME_FUNC_VERSION", eh.funcSpec.FuncMetaData.Version) + err = os.Setenv("RUNTIME_HANDLER", eh.funcSpec.FuncMetaData.Handler) + + err = os.Setenv("RUNTIME_TIMEOUT", strconv.Itoa(eh.funcSpec.FuncMetaData.Timeout)) + nameSplit := strings.Split(eh.funcSpec.FuncMetaData.FunctionName, "@") + if len(nameSplit) >= constants.RuntimePkgNameSplit { + err = os.Setenv("RUNTIME_PACKAGE", nameSplit[1]) + } + err = os.Setenv("RUNTIME_CPU", strconv.Itoa(eh.funcSpec.ResourceMetaData.Cpu)) + err = os.Setenv("RUNTIME_MEMORY", strconv.Itoa(eh.funcSpec.ResourceMetaData.Memory)) + err = os.Setenv("RUNTIME_MAX_RESP_BODY_SIZE", strconv.Itoa(constants.RuntimeMaxRespBodySize)) + err = os.Setenv("RUNTIME_INITIALIZER_HANDLER", eh.funcSpec.ExtendedMetaData.Initializer.Handler) + err = os.Setenv("RUNTIME_INITIALIZER_TIMEOUT", + strconv.FormatInt(int64(eh.funcSpec.ExtendedMetaData.Initializer.Timeout), constants.INT64ToINT)) + err = os.Setenv("RUNTIME_ROOT", constants.RuntimeRoot) + err = os.Setenv("RUNTIME_CODE_ROOT", constants.RuntimeCodeRoot) + err = os.Setenv("RUNTIME_LOG_DIR", constants.RuntimeLogDir) + if err != nil { + log.GetLogger().Errorf("set env from createOpt failed, err: %s", err) + } +} diff --git a/api/go/faassdk/handler/event/handler_test.go b/api/go/faassdk/handler/event/handler_test.go new file mode 100644 index 0000000..e1f6a30 --- /dev/null +++ b/api/go/faassdk/handler/event/handler_test.go @@ -0,0 +1,804 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package event for faas executor api +package event + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "os" + "plugin" + "reflect" + "strconv" + "strings" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/common/faasscheduler" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/faassdk/faas-sdk/go-api/context" + "yuanrong.org/kernel/runtime/faassdk/handler" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/libruntime/api" + log "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +var userCodePath = "../test/main/user_code_test.so" +var initHandlerName = "user_code_test.InitFunction" +var callHandlerName = "user_code_test.HandlerFunction" +var createParams = `{ + "userInitEntry" : "user_code_test.InitFunction", + "userCallEntry" : "user_code_test.HandlerFunction" + }` +var invokeEnvFunctionParam = `{ + "CreateParams": { + "initEntry" : "user_code_test.InitFunction", + "callEntry" : "user_code_test.GetEnvFunction" + } + }` + +type mockLibruntimeClient struct { +} + +func (m mockLibruntimeClient) CreateInstance(funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) InvokeByInstanceId(funcMeta api.FunctionMeta, instanceID string, args []api.Arg, invokeOpt api.InvokeOptions) (returnObjectID string, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) InvokeByFunctionName(funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) AcquireInstance(state string, funcMeta api.FunctionMeta, acquireOpt api.InvokeOptions) (api.InstanceAllocation, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) ReleaseInstance(allocation api.InstanceAllocation, stateID string, abnormal bool, option api.InvokeOptions) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Kill(instanceID string, signal int, payload []byte) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) CreateInstanceRaw(createReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) InvokeByInstanceIdRaw(invokeReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KillRaw(killReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) SaveState(state []byte) (string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) LoadState(checkpointID string) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Exit(code int, message string) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Finalize() { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVSet(key string, value []byte, param api.SetParam) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVSetWithoutKey(value []byte, param api.SetParam) (string, error) { + //TODO implement me + panic("implement me") +} + +func (f *mockLibruntimeClient) KVMSetTx(keys []string, values [][]byte, param api.MSetParam) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVGet(key string, timeoutms uint) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVGetMulti(keys []string, timeoutms uint) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVDel(key string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVDelMulti(keys []string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) CreateProducer(streamName string, producerConf api.ProducerConf) (api.StreamProducer, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Subscribe(streamName string, config api.SubscriptionConfig) (api.StreamConsumer, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) DeleteStream(streamName string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) QueryGlobalProducersNum(streamName string) (uint64, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) QueryGlobalConsumersNum(streamName string) (uint64, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) SetTraceID(traceID string) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) SetTenantID(tenantID string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Put(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) PutRaw(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Get(objectIDs []string, timeoutMs int) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GetRaw(objectIDs []string, timeoutMs int) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Wait(objectIDs []string, waitNum uint64, timeoutMs int) ([]string, []string, map[string]error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GIncreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GIncreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GDecreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GDecreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GetAsync(objectID string, cb api.GetAsyncCallback) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GetFormatLogger() api.FormatLogger { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) CreateClient(config api.ConnectArguments) (api.KvClient, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) ReleaseGRefs(remoteClientID string) error { + //TODO implement me + panic("implement me") +} +func (m mockLibruntimeClient) GetCredential() api.Credential { + return api.Credential{} +} + +func (m mockLibruntimeClient) UpdateSchdulerInfo(schedulerName string, schedulerId string, option string) { + return +} + +func (m *mockLibruntimeClient) IsHealth() bool { + return true +} + +func (m *mockLibruntimeClient) IsDsHealth() bool { + return true +} + +// FakeLogger - +type FakeLogger struct{} + +// With - +func (f *FakeLogger) With(fields ...zapcore.Field) api.FormatLogger { + return f +} + +// Infof - +func (f *FakeLogger) Infof(format string, paras ...interface{}) {} + +// Errorf - +func (f *FakeLogger) Errorf(format string, paras ...interface{}) {} + +// Warnf - +func (f *FakeLogger) Warnf(format string, paras ...interface{}) {} + +// Debugf - +func (f *FakeLogger) Debugf(format string, paras ...interface{}) {} + +// Fatalf - +func (f *FakeLogger) Fatalf(format string, paras ...interface{}) {} + +// Info - +func (f *FakeLogger) Info(msg string, fields ...zap.Field) {} + +// Error - +func (f *FakeLogger) Error(msg string, fields ...zap.Field) {} + +// Warn - +func (f *FakeLogger) Warn(msg string, fields ...zap.Field) {} + +// Debug - +func (f *FakeLogger) Debug(msg string, fields ...zap.Field) {} + +// Fatal - +func (f *FakeLogger) Fatal(msg string, fields ...zap.Field) {} + +// Sync - +func (f *FakeLogger) Sync() {} + +func TestMain(m *testing.M) { + log.SetupLogger(&FakeLogger{}) + os.Exit(m.Run()) +} + +func TestNewEventHandler(t *testing.T) { + eventHandler := NewEventHandler(newFuncSpec(), nil) + assert.NotNil(t, eventHandler) +} + +func TestGetUserCodePath(t *testing.T) { + err := os.Setenv(config.DelegateDownloadPath, userCodePath) + assert.Nil(t, err) + functionLibPath, err := handler.GetUserCodePath() + assert.Nil(t, err) + log.GetLogger().Infof("functionLibPath:%s", functionLibPath) + assert.NotNil(t, functionLibPath) +} + +func TestGetHandlerName(t *testing.T) { + args := []api.Arg{ + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(createParams), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + userHook, err := getHandlerName(args) + log.GetLogger().Infof("getInitEntryName:%s", userHook.InitEntry) + log.GetLogger().Infof("getCallEntryName:%s", userHook.CallEntry) + assert.NotNil(t, userHook.CallEntry) + assert.NotNil(t, userHook.InitEntry) + assert.Nil(t, err) +} + +func TestInitHandler(t *testing.T) { + handler := NewEventHandler(&types.FuncSpec{}, nil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&EventHandler{}), "SetHandler", + func(eh *EventHandler, functionLibPath, handler string) error { + if handler == initHandlerName { + eh.userInitEntry = func(ctx context.RuntimeContext) { + fmt.Println("userCode init start") + fmt.Println("userCode init success") + } + } + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + _, err := handler.InitHandler([]api.Arg{}, nil) + var res *types.InitResponse + _ = json.Unmarshal([]byte(err.Error()), &res) + assert.Equal(t, res.ErrorCode, strconv.Itoa(constants.ExecutorErrCodeInitFail)) + assert.Contains(t, string(res.Message), "invalid args number") + args, _ := getInitArgs(createParams) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + result1, err := handler.InitHandler([]api.Arg{ + { + Type: api.Value, + Data: []byte(""), + }, + args[0], + { + Type: api.Value, + Data: schedulerParamsBytes, + }, + { + Type: api.Value, + Data: []byte(""), + }, + }, nil) + assert.Nil(t, err) + res = &types.InitResponse{} + _ = json.Unmarshal(result1, &res) + assert.Equal(t, "", string(res.Message)) +} + +func newHttpSchedulerParams() *faasscheduler.SchedulerInfo { + return &faasscheduler.SchedulerInfo{ + SchedulerFuncKey: "scheduler func key", + SchedulerIDList: []string{"1111"}, + } +} + +// CallHandler, impl of CallHandler +func TestCallHandlerError(t *testing.T) { + handler := NewEventHandler(&types.FuncSpec{}, nil) + + convey.Convey("CallHandlerError", t, func() { + convey.Convey("args length invalid", func() { + result, err := handler.CallHandler([]api.Arg{}, nil) + convey.So(err, convey.ShouldBeNil) + var res *types.CallResponse + _ = json.Unmarshal(result, &res) + convey.So(string(res.Body), convey.ShouldEqual, `"invalid invoke argument"`) + convey.So(res.InnerCode, convey.ShouldEqual, "500") + }) + + convey.Convey("the size of the invoke payload exceeds the limit", func() { + args := []api.Arg{ + {}, + {Type: 0, Data: createLargeBytes(config.MaxPayloadSize)}, + } + result, err := handler.CallHandler(args, nil) + convey.So(err, convey.ShouldBeNil) + var res *types.CallResponse + _ = json.Unmarshal(result, &res) + convey.So(string(res.Body), convey.ShouldEqual, fmt.Sprintf(`"the size of the invoke payload exceeds the limit %d"`, config.MaxPayloadSize)) + convey.So(res.InnerCode, convey.ShouldEqual, "500") + }) + + convey.Convey("invoke handler is nil", func() { + args := []api.Arg{ + {}, + {Type: 0, Data: []byte("hello")}, + } + result, err := handler.CallHandler(args, nil) + convey.So(err, convey.ShouldBeNil) + var res *types.CallResponse + _ = json.Unmarshal(result, &res) + convey.So(string(res.Body), convey.ShouldEqual, `"invoke handler is nil"`) + convey.So(res.InnerCode, convey.ShouldEqual, "4001") + }) + + convey.Convey("unmarshal invoke call request data", func() { + defer ApplyMethod(reflect.TypeOf(&EventHandler{}), "SetHandler", + func(eh *EventHandler, functionLibPath, handler string) error { + if handler == initHandlerName { + eh.userInitEntry = func(ctx context.RuntimeContext) { + fmt.Println("userCode init start") + fmt.Println("userCode init success") + } + } else if handler == callHandlerName { + eh.userCallEntry = func(payload []byte, ctx context.RuntimeContext) (interface{}, error) { + logger := ctx.GetLogger() + if logger == nil { + return nil, fmt.Errorf("user logger not initialized") + } + logger.Infof("payload:%s", string(payload)) + fmt.Println("Handler function") + fmt.Println("payload:", string(payload)) + return "success", nil + } + } + return nil + }).Reset() + os.Setenv("RUNTIME_LOG_DIR", "./") + args, _ := getInitArgs(createParams) + _, err := handler.InitHandler([]api.Arg{ + { + Type: api.Value, + Data: []byte(""), + }, + args[0], + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + }, nil) + convey.So(err, convey.ShouldBeNil) + + args = []api.Arg{ + {}, + {Type: 0, Data: []byte("hello")}, + } + result, err := handler.CallHandler(args, nil) + convey.So(err, convey.ShouldBeNil) + var res *types.CallResponse + _ = json.Unmarshal(result, &res) + convey.So(strings.Contains(string(res.Body), "unmarshal invoke call request data"), convey.ShouldEqual, true) + convey.So(res.InnerCode, convey.ShouldEqual, "500") + }) + }) +} + +func createLargeBytes(size int32) []byte { + arg := bytes.Buffer{} + arg.Grow(int(size)) + for arg.Len() < int(size) { + arg.WriteString("max limit") + } + return arg.Bytes() +} + +func TestCallHandler_OK(t *testing.T) { + handler := NewEventHandler(newFuncSpec(), nil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&EventHandler{}), "SetHandler", + func(eh *EventHandler, functionLibPath, handler string) error { + if handler == initHandlerName { + eh.userInitEntry = func(ctx context.RuntimeContext) { + fmt.Println("userCode init start") + fmt.Println("userCode init success") + } + } else if handler == callHandlerName { + eh.userCallEntry = func(payload []byte, ctx context.RuntimeContext) (interface{}, error) { + logger := ctx.GetLogger() + if logger == nil { + return nil, fmt.Errorf("user logger not initialized") + } + log.GetLogger().Infof("payload:%s", string(payload)) + fmt.Println("Handler function") + fmt.Println("payload:", string(payload)) + return "success", nil + } + } + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + os.Setenv("RUNTIME_LOG_DIR", "./") + args, _ := getInitArgs(createParams) + _, err := handler.InitHandler([]api.Arg{ + { + Type: api.Value, + Data: []byte(""), + }, + args[0], + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + }, nil) + assert.Nil(t, err) + result, err := handler.CallHandler(getCallArgs(), nil) + assert.Nil(t, err) + var res *types.CallResponse + _ = json.Unmarshal(result, &res) + assert.Contains(t, string(res.Body), "success") +} + +func TestSetHandler(t *testing.T) { + convey.Convey("SetHandler", t, func() { + convey.Convey("getLib error: invalid handler name", func() { + handler := &EventHandler{} + err := handler.SetHandler("/xxx", "callHandlerName") + convey.So(err.Error(), convey.ShouldEqual, "getLib error: invalid handler name :callHandlerName") + }) + + convey.Convey("getLib error: failed to open lib", func() { + handler := &EventHandler{} + err := handler.SetHandler("/xxx", "test.callHandlerName") + convey.So(err.Error(), convey.ShouldEqual, "getLib error: failed to open test.callHandlerName") + }) + + convey.Convey("getLib error: failed to look up", func() { + defer ApplyFunc(plugin.Open, func(path string) (*plugin.Plugin, error) { + return &plugin.Plugin{}, nil + }).Reset() + handler := &EventHandler{libMap: make(map[string]*plugin.Plugin)} + err := handler.SetHandler("/xxx", "test.callHandlerName") + convey.So(err.Error(), convey.ShouldNotBeNil) + }) + + convey.Convey("type error", func() { + defer ApplyFunc(plugin.Open, func(path string) (*plugin.Plugin, error) { + return &plugin.Plugin{}, nil + }).Reset() + defer ApplyMethod(reflect.TypeOf(&plugin.Plugin{}), "Lookup", + func(_ *plugin.Plugin, symName string) (plugin.Symbol, error) { + return "type error", nil + }).Reset() + handler := &EventHandler{libMap: make(map[string]*plugin.Plugin)} + err := handler.SetHandler("/xxx", "test.callHandlerName") + convey.So(err.Error(), convey.ShouldNotBeNil) + }) + + convey.Convey("success", func() { + defer ApplyFunc(plugin.Open, func(path string) (*plugin.Plugin, error) { + return &plugin.Plugin{}, nil + }).Reset() + defer ApplyMethod(reflect.TypeOf(&plugin.Plugin{}), "Lookup", + func(_ *plugin.Plugin, symName string) (plugin.Symbol, error) { + var symbol func([]byte, context.RuntimeContext) (interface{}, error) + symbol = func(i []byte, runtimeContext context.RuntimeContext) (interface{}, error) { + return "success", nil + } + return symbol, nil + }).Reset() + handler := &EventHandler{libMap: make(map[string]*plugin.Plugin)} + err := handler.SetHandler("/xxx", "test.callHandlerName") + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestCheckPointHandler(t *testing.T) { + handler := NewEventHandler(newFuncSpec(), nil) + pointHandler, err := handler.CheckPointHandler("checkPointID") + assert.Nil(t, pointHandler) + assert.Nil(t, err) +} + +func TestRecoverHandler(t *testing.T) { + handler := NewEventHandler(newFuncSpec(), nil) + err := handler.RecoverHandler([]byte{}) + assert.Nil(t, err) +} + +// create args and client, to invoke InvokeHandlerService +func getInitArgs(param string) ([]api.Arg, api.LibruntimeAPI) { + err := os.Setenv(config.DelegateDownloadPath, userCodePath) + if err != nil { + return nil, nil + } + args := make([]api.Arg, 2, 2) + args[0] = api.Arg{ + Type: 0, + Data: []byte(param), + } + client := &mockLibruntimeClient{} + return args, client +} + +// create args to invoke CallHandler +func getCallArgs() []api.Arg { + args := make([]api.Arg, 2) + argsData0 := "This is payload argsData0 for userHandler." + argsData1 := make(map[string]interface{}, 2) + argsData1["headers"] = make(map[string]string) + argsData1["body"] = []byte("This is payload argsData1 for userHandler.") + args1, _ := json.Marshal(argsData1) + args[0] = api.Arg{ + Type: 0, + Data: []byte(argsData0), + } + args[1] = api.Arg{ + Type: 0, + Data: args1, + } + return args +} + +func Test_initContext(t *testing.T) { + convey.Convey("initContext", t, func() { + convey.Convey("init user function log error", func() { + defer ApplyFunc(log.GetUserLogger, func() *log.UserFunctionLogger { + return nil + }).Reset() + err := initContext() + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("", func() { + defer ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error { + return errors.New("json unmarshal error") + }).Reset() + err := initContext() + convey.So(err, convey.ShouldBeError) + }) + convey.Convey("success", func() { + os.Setenv("RUNTIME_TIMEOUT", "10") + os.Setenv("RUNTIME_PROJECT_ID", "123456789") + os.Setenv("RUNTIME_PACKAGE", "service") + os.Setenv("RUNTIME_FUNC_NAME", "test-go-runtime") + os.Setenv("RUNTIME_FUNC_VERSION", "latest") + os.Setenv("RUNTIME_MEMORY", "500") + os.Setenv("RUNTIME_CPU", "500") + os.Setenv("RUNTIME_USERDATA", `{"name":"runtime-go"}`) + log.SetupUserLogger(userLogLevel) + err := initContext() + convey.So(err, convey.ShouldBeNil) + convey.So(defaultContext.GetRunningTimeInSeconds(), convey.ShouldEqual, 10) + convey.So(defaultContext.GetRemainingTimeInMilliSeconds() <= int(10*time.Millisecond), convey.ShouldEqual, true) + convey.So(defaultContext.GetProjectID(), convey.ShouldEqual, "123456789") + convey.So(defaultContext.GetPackage(), convey.ShouldEqual, "service") + convey.So(defaultContext.GetFunctionName(), convey.ShouldEqual, "test-go-runtime") + convey.So(defaultContext.GetVersion(), convey.ShouldEqual, "latest") + convey.So(defaultContext.GetCPUNumber(), convey.ShouldEqual, 500) + convey.So(defaultContext.GetMemorySize(), convey.ShouldEqual, 500) + convey.So(defaultContext.GetUserData("name"), convey.ShouldEqual, "runtime-go") + convey.So(defaultContext.GetAccessKey(), convey.ShouldEqual, "") + convey.So(defaultContext.GetSecretKey(), convey.ShouldEqual, "") + convey.So(defaultContext.GetToken(), convey.ShouldEqual, "") + convey.So(defaultContext.GetRequestID(), convey.ShouldEqual, "") + defaultContext.SetState("new state") + convey.So(defaultContext.GetState(), convey.ShouldEqual, "new state") + convey.So(defaultContext.GetInvokeProperty(), convey.ShouldEqual, "") + convey.So(defaultContext.GetTraceID(), convey.ShouldEqual, "") + convey.So(defaultContext.GetInvokeID(), convey.ShouldEqual, "") + convey.So(defaultContext.GetAlias(), convey.ShouldEqual, "") + convey.So(defaultContext.GetSecurityToken(), convey.ShouldEqual, "") + defaultContext.GetLogger().Infof("info log") + defaultContext.GetLogger().Warnf("warn log") + defaultContext.GetLogger().Debugf("debug log") + defaultContext.GetLogger().Errorf("error log") + }) + }) +} + +func TestSetUserEntry(t *testing.T) { + cleanFile("user-function") + convey.Convey( + "Test SetUserEntry", t, func() { + convey.Convey("SetUserInitEntry success", func() { + initE := func(ctx context.RuntimeContext) {} + convey.So(func() { + SetUserInitEntry(initE) + }, convey.ShouldNotPanic) + }) + convey.Convey("SetUserCallEntry success", func() { + callE := func([]byte, context.RuntimeContext) (interface{}, error) { + return nil, nil + } + convey.So(func() { + SetUserCallEntry(callE) + }, convey.ShouldNotPanic) + }) + }, + ) +} + +func TestEventHandler(t *testing.T) { + convey.Convey( + "Test EventHandler", t, func() { + eh := NewEventHandler(newFuncSpec(), nil) + convey.Convey("ShutDownHandler success", func() { + err := eh.ShutDownHandler(1) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("SignalHandler success", func() { + err := eh.SignalHandler(1, []byte{0}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("HealthCheckHandler success", func() { + ht, err := eh.HealthCheckHandler() + convey.So(ht, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("setHandlerByEntry success", func() { + err := eh.(*EventHandler).setHandlerByEntry("symbol", "handler") + convey.So(err, convey.ShouldNotBeNil) + symbol1 := func(ctx context.RuntimeContext) {} + eh.(*EventHandler).userInitEntry = nil + err = eh.(*EventHandler).setHandlerByEntry(symbol1, "handler") + convey.So(err, convey.ShouldBeNil) + symbol2 := func([]byte, context.RuntimeContext) (interface{}, error) { + return nil, nil + } + eh.(*EventHandler).userCallEntry = nil + err = eh.(*EventHandler).setHandlerByEntry(symbol2, "handler") + convey.So(err, convey.ShouldBeNil) + }) + }, + ) +} + +func cleanFile(fileName string) { + files, _ := os.ReadDir("./") + for _, file := range files { + flag := strings.HasPrefix(file.Name(), fileName) + if flag { + os.Remove(file.Name()) + } + } +} diff --git a/api/go/faassdk/handler/handler.go b/api/go/faassdk/handler/handler.go new file mode 100644 index 0000000..228609f --- /dev/null +++ b/api/go/faassdk/handler/handler.go @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handler - +package handler + +import ( + "fmt" + "os" + + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/libruntime/api" + log "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +// ExecutorHandler - +type ExecutorHandler interface { + InitHandler(args []api.Arg, rt api.LibruntimeAPI) ([]byte, error) + CallHandler(args []api.Arg, context map[string]string) ([]byte, error) + CheckPointHandler(checkPointId string) ([]byte, error) + RecoverHandler(state []byte) error + ShutDownHandler(gracePeriodSecond uint64) error + SignalHandler(signal int32, payload []byte) error + HealthCheckHandler() (api.HealthType, error) +} + +// GetUserCodePath will get user code downloaded in delegate download path +func GetUserCodePath() (string, error) { + userCodePath := os.Getenv(config.DelegateDownloadPath) + if userCodePath == "" { + log.GetLogger().Errorf("%s not found", config.DelegateDownloadPath) + return "", fmt.Errorf("%s not found", config.DelegateDownloadPath) + } + return userCodePath, nil +} diff --git a/api/go/faassdk/handler/handler_test.go b/api/go/faassdk/handler/handler_test.go new file mode 100644 index 0000000..3a192b5 --- /dev/null +++ b/api/go/faassdk/handler/handler_test.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handler for faas executor api +package handler + +import ( + "encoding/json" + "errors" + "os" + "reflect" + "testing" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/faassdk/utils" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" +) + +func TestGetUserCodePath(t *testing.T) { + convey.Convey("GetUserCodePath", t, func() { + convey.Convey("not found", func() { + path, err := GetUserCodePath() + convey.So(path, convey.ShouldEqual, "") + convey.So(err, convey.ShouldBeError) + }) + convey.Convey("success", func() { + os.Setenv(config.DelegateDownloadPath, "userCodePath") + path, err := GetUserCodePath() + convey.So(path, convey.ShouldEqual, "userCodePath") + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestDealEnv(t *testing.T) { + environment := map[string]string{"test1": "abc"} + environmentStr, _ := json.Marshal(environment) + encryptedUserData := map[string]string{"test2": "def"} + encryptedUserDataStr, _ := json.Marshal(encryptedUserData) + env1 := map[string]string{"environment": string(environmentStr), "encrypted_user_data": string(encryptedUserDataStr)} + env1Str, _ := json.Marshal(env1) + + tests := []struct { + name string + want error + want1 map[string]string + patchesFunc PatchesFunc + }{ + {"case1 env is nil", nil, nil, func() PatchSlice { + patches := InitPatchSlice() + patches.Append(PatchSlice{ + ApplyFunc(os.Getenv, func(key string) string { + return "" + }), + }) + return patches + }}, + {"case2 env has environment and encrypted_user_data", nil, + map[string]string{"test1": "abc", "test2": "def"}, + func() PatchSlice { + patches := InitPatchSlice() + patches.Append(PatchSlice{ + ApplyFunc(os.Getenv, func(key string) string { + return string(env1Str) + }), + ApplyFunc(os.Environ, func() []string { + return []string{"test1=abc"} + }), + }) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + got, got1 := utils.DealEnv() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("dealEnv() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(got1, tt.want1) { + t.Errorf("dealEnv() got1 = %v, want %v", got1, tt.want1) + } + patches.ResetAll() + }) + } + + convey.Convey("failed to Unmarshal environment", t, func() { + patches := InitPatchSlice() + patches.Append(PatchSlice{ + ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error { + return errors.New("json unmarshal error") + }), + }) + os.Setenv("ENV_DELEGATE_DECRYPT", "{}") + err, _ := utils.DealEnv() + convey.So(err, convey.ShouldBeError) + patches.ResetAll() + }) +} diff --git a/api/go/faassdk/handler/http/apig.go b/api/go/faassdk/handler/http/apig.go new file mode 100644 index 0000000..6390ff6 --- /dev/null +++ b/api/go/faassdk/handler/http/apig.go @@ -0,0 +1,175 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http - +package http + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + + log "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +// APIGTriggerEvent extern interface of web request +type APIGTriggerEvent struct { + IsBase64Encoded bool `json:"isBase64Encoded"` + HTTPMethod string `json:"httpMethod"` + Path string `json:"path"` + Body string `json:"body"` + PathParameters map[string]string `json:"pathParameters"` + RequestContext APIGRequestContext `json:"requestContext"` + Headers map[string]interface{} `json:"headers"` + QueryStringParameters map[string]interface{} `json:"queryStringParameters"` + UserData string `json:"user_data"` +} + +// APIGRequestContext - +type APIGRequestContext struct { + APIID string `json:"apiId"` + RequestID string `json:"requestId"` + Stage string `json:"stage"` + SourceIP string `json:"sourceIp"` +} + +// APIGTriggerResponse extern interface of web response +type APIGTriggerResponse struct { + Body string `json:"body"` + Headers map[string][]string `json:"headers"` + StatusCode int `json:"statusCode"` + IsBase64Encoded bool `json:"isBase64Encoded"` +} + +func parseAPIGEvent(ctx context.Context, serializedEvent []byte, header map[string]string, + baseURLPath, callRoute string) (*http.Request, error) { + apigEvent := &APIGTriggerEvent{} + if err := json.Unmarshal(serializedEvent, apigEvent); err != nil { + log.GetLogger().Errorf("failed to unmarshal event to APIG event, error: %s", err) + return nil, err + } + if apigEvent.Headers == nil { + apigEvent.Headers = make(map[string]interface{}) + } + for key, val := range header { + apigEvent.Headers[key] = val + } + apigEvent.HTTPMethod = http.MethodPost + return apigToHTTPRequest(ctx, apigEvent, baseURLPath, callRoute) +} + +func constructAPIGResponse(resp *http.Response, body *bytes.Buffer) ([]byte, error) { + apigResponse, err := fromHTTPResponse(resp, body) + if err != nil { + return nil, err + } + + serializedResponse, err := json.Marshal(apigResponse) + if err != nil { + log.GetLogger().Errorf("failed to marshal APIG response, error: %s", err) + return nil, err + } + return serializedResponse, nil +} + +func apigToHTTPRequest(ctx context.Context, event *APIGTriggerEvent, baseURLPath, + callRoute string) (*http.Request, error) { + var ( + requestBody []byte + err error + requestURI string + ) + if event.IsBase64Encoded { + requestBody, err = base64.StdEncoding.DecodeString(event.Body) + if err != nil { + log.GetLogger().Errorf("failed to decode body string, error: %s", err) + return nil, err + } + } else { + requestBody = []byte(event.Body) + } + if event.Path != "" { + requestURI = fmt.Sprintf("%s/%s", baseURLPath, event.Path) + } else { + requestURI = fmt.Sprintf("%s/%s", baseURLPath, callRoute) + } + request, err := http.NewRequestWithContext(ctx, event.HTTPMethod, requestURI, bytes.NewBuffer(requestBody)) + if err != nil { + log.GetLogger().Errorf("failed to construct HTTP request, error: %s", err.Error()) + return nil, err + } + queries := request.URL.Query() + for k, v := range event.QueryStringParameters { + switch valueType := v.(type) { + case string: + queries.Set(k, v.(string)) + case []string: + for _, param := range v.([]string) { + queries.Add(k, param) + } + default: + log.GetLogger().Warnf("invalid type in query parameters: %T", valueType) + } + } + request.URL.RawQuery = queries.Encode() + + initRequestHeader(request, event) + return request, nil +} + +func initRequestHeader(request *http.Request, event *APIGTriggerEvent) { + for k, v := range event.Headers { + if k == "Transfer-Encoding" { + continue + } + switch valueType := v.(type) { + case string: + request.Header.Set(k, v.(string)) + case []string: + for _, param := range v.([]string) { + request.Header.Add(k, param) + } + default: + log.GetLogger().Warnf("invalid type in headers: %T", valueType) + } + } + + request.Header.Set("X-APIG-Api-Id", event.RequestContext.APIID) + request.Header.Set("X-APIG-Request-Id", event.RequestContext.RequestID) + request.Header.Set("X-APIG-Source-Ip", event.RequestContext.SourceIP) + request.Header.Set("X-APIG-Source-Stage", event.RequestContext.Stage) + if host := request.Header.Get("Host"); host != "" { + request.Host = host + } +} + +func fromHTTPResponse(response *http.Response, body *bytes.Buffer) (*APIGTriggerResponse, error) { + apigResponse := &APIGTriggerResponse{ + Body: base64.StdEncoding.EncodeToString(body.Bytes()), + StatusCode: response.StatusCode, + IsBase64Encoded: true, + } + if response.Header != nil { + apigResponse.Headers = make(map[string][]string) + for k, v := range response.Header { + apigResponse.Headers[k] = v + } + } + return apigResponse, nil +} diff --git a/api/go/faassdk/handler/http/apig_test.go b/api/go/faassdk/handler/http/apig_test.go new file mode 100644 index 0000000..f3ebcd7 --- /dev/null +++ b/api/go/faassdk/handler/http/apig_test.go @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http - +package http + +import ( + "bytes" + "context" + "net/http" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func Test_parseAPIGEvent(t *testing.T) { + convey.Convey("parseAPIGEvent", t, func() { + event, err := parseAPIGEvent(context.TODO(), []byte(`{`), nil, "/baseURL", "invoke") + convey.So(err, convey.ShouldBeError) + convey.So(event, convey.ShouldBeNil) + }) +} + +func Test_initRequestHeader(t *testing.T) { + convey.Convey("initRequestHeader", t, func() { + req := &http.Request{ + Header: make(map[string][]string), + } + + event := &APIGTriggerEvent{ + Headers: make(map[string]interface{}), + } + event.Headers["test-header"] = []string{"testHeader"} + initRequestHeader(req, event) + convey.So(req.Header["Test-Header"][0], convey.ShouldEqual, "testHeader") + }) +} + +func Test_fromHTTPResponse(t *testing.T) { + convey.Convey("fromHTTPResponse", t, func() { + resp := &http.Response{StatusCode: 200, Header: make(map[string][]string)} + resp.Header["mockHeader"] = []string{"test-header"} + body := &bytes.Buffer{} + + apigResponse, err := fromHTTPResponse(resp, body) + convey.So(err, convey.ShouldBeNil) + convey.So(apigResponse.StatusCode, convey.ShouldEqual, 200) + convey.So(apigResponse.Headers["mockHeader"][0], convey.ShouldEqual, "test-header") + }) +} diff --git a/api/go/faassdk/handler/http/basic_handler.go b/api/go/faassdk/handler/http/basic_handler.go new file mode 100644 index 0000000..ab36128 --- /dev/null +++ b/api/go/faassdk/handler/http/basic_handler.go @@ -0,0 +1,676 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http handler +package http + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/faassdk/common/aliasroute" + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/common/faasscheduler" + "yuanrong.org/kernel/runtime/faassdk/common/functionlog" + "yuanrong.org/kernel/runtime/faassdk/common/monitor" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" + runtimeLog "yuanrong.org/kernel/runtime/libruntime/common/logger/log" +) + +const ( + listenIP = "127.0.0.1" + mebibyte = 1024 * 1024 + maxResponseSize = 6 * mebibyte + defaultTimeout = 3 + healthRoute = "healthz" + killSignalAliasUpdate = 64 + schedulerArgIndex = 2 + shutdownCheckInterval = 100 * time.Millisecond + checkReadyInterval = 10 * time.Millisecond + argsMinLength = 2 +) + +var ( + // ErrExecuteTimeout is the error of execute function timeout + ErrExecuteTimeout = errors.New("execute function timeout") + + stopCh = make(chan struct{}) + closeOnce sync.Once +) + +type basicHandler struct { + funcSpec *types.FuncSpec + httpClient *http.Client + sdkClient api.LibruntimeAPI + functionLogger *functionlog.FunctionLogger + logger *RuntimeContainerLogger + monitor *monitor.FunctionMonitorManager + config *config.Configuration + bootstrapFunc func() error + requestNum int32 + baseURL string + initRoute string + callRoute string + healthRoute string + disableAPIGFormat bool + port int + initTimeout int + callTimeout int + gracefulShutdown bool + circuitBreaker bool + circuitLock sync.RWMutex + once sync.Once + sync.RWMutex +} + +func newBasicHandler(funcSpec *types.FuncSpec, client api.LibruntimeAPI) *basicHandler { + rtmConf, err := config.GetConfig(funcSpec) + if err != nil { + logger.GetLogger().Errorf("failed to get runtime config for function %s error %s", + funcSpec.FuncMetaData.FunctionName, err.Error()) + return nil + } + return &basicHandler{ + funcSpec: funcSpec, + sdkClient: client, + monitor: &monitor.FunctionMonitorManager{}, + config: rtmConf, + gracefulShutdown: false, + disableAPIGFormat: os.Getenv("DISABLE_APIG_FORMAT") == "true", + } +} + +func (bh *basicHandler) getHTTPClient(timeout time.Duration) *http.Client { + bh.once.Do(func() { + defaultTransport := http.DefaultTransport + transport, ok := defaultTransport.(*http.Transport) + if !ok { + logger.GetLogger().Warnf("not a http transport type") + transport = &http.Transport{} + } + // The server can close a previous "keep-alive" TCP connection (mainly because of a read timeout) at any time. + // If we send a request while the connection is closing, the HTTP client will return an EOF error. As a result, + // we disable keep-alive completely. + transport.DisableKeepAlives = true + bh.httpClient = &http.Client{ + Transport: transport, + } + }) + bh.httpClient.Timeout = timeout + return bh.httpClient +} + +func (bh *basicHandler) parseCreateParams(args []api.Arg) error { + if len(args) != constants.ValidCustomImageCreateParamSize && + len(args) != constants.ValidBasicCreateParamSize { + return errors.New("invalid create params number") + } + funcSpec := &types.FuncSpec{} + err := json.Unmarshal(args[0].Data, funcSpec) + if err != nil { + logger.GetLogger().Errorf("failed to unmarshal funcSpec from %s", string(args[0].Data)) + return err + } + bh.initTimeout = funcSpec.ExtendedMetaData.Initializer.Timeout + if bh.initTimeout <= 0 { + bh.initTimeout = defaultTimeout + } + bh.callTimeout = funcSpec.FuncMetaData.Timeout + if bh.callTimeout <= 0 { + bh.callTimeout = defaultTimeout + } + createParams := &types.HttpCreateParams{} + err = json.Unmarshal(args[1].Data, createParams) + if err != nil { + logger.GetLogger().Errorf("failed to unmarshal create params from %s", string(args[1].Data)) + return err + } + err = faasscheduler.ParseSchedulerData(args[schedulerArgIndex]) + if err != nil { + return err + } + bh.port = createParams.Port + bh.baseURL = fmt.Sprintf("http://%s:%d", listenIP, createParams.Port) + bh.initRoute = createParams.InitRoute + bh.callRoute = createParams.CallRoute + bh.healthRoute = healthRoute + return nil +} + +func (bh *basicHandler) setBootstrapFunc(bootstrapFunc func() error) { + bh.bootstrapFunc = bootstrapFunc +} + +func (bh *basicHandler) setHTTPHeader(header http.Header) { + header.Set("X-CFF-Memory", strconv.Itoa(bh.funcSpec.ResourceMetaData.Memory)) + header.Set("X-CFF-Timeout", strconv.Itoa(bh.funcSpec.FuncMetaData.Timeout)) + header.Set("X-CFF-Func-Version", bh.funcSpec.FuncMetaData.Version) + header.Set("X-CFF-Func-Name", bh.funcSpec.FuncMetaData.FunctionName) + header.Set("X-CFF-Project-Id", bh.funcSpec.FuncMetaData.TenantId) + header.Set("X-CFF-Package", "") + header.Set("X-CFF-Region", "") + header.Set("X-CFF-Access-Key", "") + header.Set("X-CFF-Secret-Key", "") + header.Set("X-CFF-Security-Access-Key", "") + header.Set("X-CFF-Security-Secret-Key", "") + header.Set("X-CFF-Auth-Token", "") +} + +func (bh *basicHandler) setEnvContext() { + var err error + // deal with env + err, envMap := utils.DealEnv() + if err != nil { + logger.GetLogger().Errorf("deal env from createOpt failed, err: %s", err) + } + // http handler set environment and encrypted_user_data all to env + for key, value := range envMap { + if key != constants.LDLibraryPath { + err = os.Setenv(key, value) + } else { + err = os.Setenv(key, os.Getenv(constants.LDLibraryPath)+fmt.Sprintf(":%s", value)) + } + } + if err != nil { + logger.GetLogger().Errorf("set env from envMap failed, err: %s", err) + } + userDataStr, err := json.Marshal(envMap) + if err != nil { + logger.GetLogger().Errorf("setEnvContext failed to marshal Userdata, error: %s", err) + } + err = os.Setenv("RUNTIME_USERDATA", string(userDataStr)) + err = os.Setenv("RUNTIME_PROJECT_ID", bh.funcSpec.FuncMetaData.TenantId) + err = os.Setenv("RUNTIME_FUNC_NAME", bh.funcSpec.FuncMetaData.FunctionName) + err = os.Setenv("RUNTIME_FUNC_VERSION", bh.funcSpec.FuncMetaData.Version) + err = os.Setenv("RUNTIME_HANDLER", bh.funcSpec.FuncMetaData.Handler) + + err = os.Setenv("RUNTIME_TIMEOUT", strconv.Itoa(bh.funcSpec.FuncMetaData.Timeout)) + nameSplit := strings.Split(bh.funcSpec.FuncMetaData.FunctionName, "@") + if len(nameSplit) >= constants.RuntimePkgNameSplit { + err = os.Setenv("RUNTIME_PACKAGE", nameSplit[1]) + } + err = os.Setenv("RUNTIME_CPU", strconv.Itoa(bh.funcSpec.ResourceMetaData.Cpu)) + err = os.Setenv("RUNTIME_MEMORY", strconv.Itoa(bh.funcSpec.ResourceMetaData.Memory)) + err = os.Setenv("RUNTIME_MAX_RESP_BODY_SIZE", strconv.Itoa(constants.RuntimeMaxRespBodySize)) + err = os.Setenv("RUNTIME_INITIALIZER_HANDLER", bh.funcSpec.ExtendedMetaData.Initializer.Handler) + err = os.Setenv("RUNTIME_INITIALIZER_TIMEOUT", + strconv.FormatInt(int64(bh.funcSpec.ExtendedMetaData.Initializer.Timeout), constants.INT64ToINT)) + err = os.Setenv("RUNTIME_ROOT", constants.RuntimeRoot) + err = os.Setenv("RUNTIME_CODE_ROOT", constants.RuntimeCodeRoot) + err = os.Setenv("RUNTIME_LOG_DIR", constants.RuntimeLogDir) + if err != nil { + logger.GetLogger().Errorf("set env from createOpt failed, err: %s", err) + } +} + +func (bh *basicHandler) sendRequest(request *http.Request, timeout time.Duration) (*http.Response, error) { + bh.setHTTPHeader(request.Header) + response, err := bh.getHTTPClient(timeout).Do(request) + if err != nil { + logger.GetLogger().Errorf("failed to send request with timeout %d s, error %s", timeout, err.Error()) + return nil, err + } + return response, nil +} + +func (bh *basicHandler) awaitReady(timeout time.Duration) error { + timer := time.NewTimer(timeout) + for { + select { + case <-timer.C: + logger.GetLogger().Errorf("timeout waiting for http server running") + return errors.New("timeout waiting for http server running") + default: + } + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", listenIP, bh.port)) + if err != nil { + time.Sleep(checkReadyInterval) + continue + } + err = conn.Close() + if err != nil { + logger.GetLogger().Warnf("failed to close connection for checking, err: %s", err.Error()) + } + if bh.isCustomHealthCheckReady() { + return nil + } + time.Sleep(checkReadyInterval) + } +} + +func (bh *basicHandler) processInitRequest(timeout time.Duration) ([]byte, error) { + traceID := utils.UniqueID() + logRecorder := bh.functionLogger.NewLogRecorder(traceID, traceID, constants.InitializeStage) + logRecorder.StartSync() + defer func() { + logRecorder.FinishSync() + logRecorder.Finish() + }() + var initURL string + if bh.funcSpec != nil && bh.funcSpec.ExtendedMetaData.Initializer.Handler != "" { + initURL = fmt.Sprintf("%s/%s", bh.baseURL, "init") + } else { + initURL = bh.baseURL + } + requestBody := []byte("{}") + request, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, initURL, bytes.NewBuffer(requestBody)) + if err != nil { + logger.GetLogger().Errorf("failed to create request to init route error %s", err.Error()) + return utils.HandleInitResponse( + fmt.Sprintf("failed to create request to init route error %s", err.Error()), + constants.ExecutorErrCodeInitFail) + } + queries := request.URL.Query() + request.URL.RawQuery = queries.Encode() + request.Header.Add("Content-Type", "application/json") + if (bh.funcSpec != nil && bh.funcSpec.ExtendedMetaData.Initializer.Handler != "") || bh.initRoute != "" { + _, err, _ = executeWithTimeout(func() (interface{}, error) { + return bh.sendRequest(request, timeout) + }, timeout, bh.monitor.ErrChan) + } + if err != nil { + logger.GetLogger().Errorf("init failed request error %s", err.Error()) + switch err { + case monitor.ErrExecuteDiskLimit: + return utils.HandleInitResponse(fmt.Sprintf("init failed request disk limit"), constants.DiskUsageExceed) + case monitor.ErrExecuteOOM: + return utils.HandleInitResponse(fmt.Sprintf("init failed reqeust oom"), constants.MemoryLimitExceeded) + case ErrExecuteTimeout: + return utils.HandleInitResponse(fmt.Sprintf("init failed request timed out after %ds", + int(timeout.Seconds())), constants.InitFunctionTimeout) + default: + return utils.HandleInitResponse(fmt.Sprintf("init failed request error %s", err.Error()), + constants.ExecutorErrCodeInitFail) + } + } + logger.GetLogger().Infof("succeed to process init request for function %s", bh.funcSpec.FuncMetaData.FunctionName) + return []byte{}, nil +} + +func (bh *basicHandler) processCallRequest(args []api.Arg, traceID string, + timeout time.Duration, totalTime utils.ExecutionDuration) ([]byte, error) { + if len(args) < argsMinLength { + return nil, errors.New("args.length should not less than 2") + } + logRecorder := bh.functionLogger.NewLogRecorder(traceID, traceID, constants.InvokeStage) + logRecorder.StartSync() + logger := logger.GetLogger().With(zap.Any("traceID", traceID)) + defer func() { + logRecorder.FinishSync() + logRecorder.Finish() + }() + userCallRequest := &types.CallRequest{} + if err := json.Unmarshal(args[1].Data, userCallRequest); err != nil { + return utils.HandleCallResponse(fmt.Sprintf("unmarshal invoke call request data: %s, err: %s", + string(args[1].Data), err), constants.FaaSError, "", totalTime, nil) + } + request, err := bh.parseRequest(context.TODO(), userCallRequest.Body, userCallRequest.Header) + if err != nil { + logger.Errorf("failed to create request to call route,error: %s", err.Error()) + return utils.HandleCallResponse(fmt.Sprintf("failed to parse request, err:%s", err.Error()), + constants.FaaSError, "", totalTime, nil) + } + totalTime.UserFuncBeginTime = time.Now() + result, err, _ := executeWithTimeout(func() (interface{}, error) { + request.Header.Set("Content-Type", "application/json") + return bh.sendRequest(request, timeout) + }, timeout, bh.monitor.ErrChan) + totalTime.UserFuncTotalTime = time.Since(totalTime.UserFuncBeginTime) + if bh.logger != nil { + bh.logger.syncTo(time.Now()) + } + if err != nil { + logger.Errorf("call request failed error %s", err.Error()) + switch err { + case monitor.ErrExecuteDiskLimit: + return utils.HandleCallResponse(fmt.Sprintf("call request disk limit"), + constants.DiskUsageExceed, "", totalTime, nil) + case monitor.ErrExecuteOOM: + return utils.HandleCallResponse(fmt.Sprintf("call request oom"), + constants.MemoryLimitExceeded, "", totalTime, nil) + case ErrExecuteTimeout: + return utils.HandleCallResponse(fmt.Sprintf("call request timed out after %ds", bh.callTimeout), + constants.InvokeFunctionTimeout, "", totalTime, nil) + default: + return utils.HandleCallResponse(err.Error(), constants.FunctionRunError, + "", totalTime, nil) + } + } + response, ok := result.(*http.Response) + if !ok { + logger.Errorf("call response type error") + return utils.HandleCallResponse("call response type error", constants.InvokeFunctionTimeout, + "", totalTime, nil) + } + responseData, err := buildResponseData(response, traceID) + if err != nil { + return utils.HandleCallResponse(fmt.Sprintf("failed to build response data, err: %s", err.Error()), + constants.FaaSError, "", totalTime, nil) + } + if int32(responseData.Len()) > config.MaxReturnSize { + return utils.HandleCallResponse(fmt.Sprintf("response body size %d exceeds the limit of 6291456", + responseData.Len()), constants.ResponseExceedLimit, "", totalTime, nil) + } + + resp, err := bh.constructResponse(response, responseData) + if err != nil { + return utils.HandleCallResponse(fmt.Sprintf("failed to build apig response, err: %s", err.Error()), + constants.FaaSError, "", totalTime, nil) + } + + innerCode := constants.NoneError + if response.StatusCode/100 != 2 { // 非200,202,则认为失败 + innerCode = constants.FunctionRunError + } + + logger.Infof("check call response data %s, StatusCode %d, innerCode: %d", + string(resp), response.StatusCode, innerCode) + logRecorder.FinishSync() + logResult := logRecorder.MarshalAll() + return utils.HandleCallResponse(resp, innerCode, logResult, totalTime, response.Header) +} + +func (bh *basicHandler) parseRequest(ctx context.Context, serializedEvent []byte, + header map[string]string) (*http.Request, error) { + if bh.disableAPIGFormat { + return parseNormalEvent(ctx, serializedEvent, header, bh.baseURL) + } + return parseAPIGEvent(ctx, serializedEvent, header, bh.baseURL, bh.callRoute) +} + +func (bh *basicHandler) constructResponse(resp *http.Response, body *bytes.Buffer) ([]byte, error) { + if bh.disableAPIGFormat { + return body.Bytes(), nil + } + return constructAPIGResponse(resp, body) +} + +func parseNormalEvent(ctx context.Context, serializedEvent []byte, header map[string]string, + baseURLPath string) (*http.Request, error) { + requestURI := fmt.Sprintf("%s/%s", baseURLPath, customContainerCallPath) + request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURI, bytes.NewBuffer(serializedEvent)) + if err != nil { + logger.GetLogger().Errorf("failed to construct HTTP request, error: %s", err.Error()) + return nil, err + } + for key, val := range header { + request.Header.Set(key, val) + } + return request, nil +} + +func handleBootstrapError(initTimeout int, initErr error) ([]byte, error) { + if initErr == ErrExecuteTimeout { + logger.GetLogger().Errorf("init failed bootstrap timed out after %ds", initTimeout) + return utils.HandleInitResponse(fmt.Sprintf("init failed bootstrap timed out after %ds", initTimeout), + constants.InitFunctionTimeout) + } + logger.GetLogger().Errorf("init failed bootstrap failed error %s", initErr.Error()) + return utils.HandleInitResponse(fmt.Sprintf("init failed bootstrap failed error %s", + initErr.Error()), constants.InitFunctionFail) +} + +// InitHandler will send init request to user's http server +// args[0]: function specification +// args[1]: create params +// args[2]: scheduler info +func (bh *basicHandler) InitHandler(args []api.Arg, apiClient api.LibruntimeAPI) ([]byte, error) { + logger.GetLogger().Infof("start to init function %s", bh.funcSpec.FuncMetaData.FunctionName) + err := bh.parseCreateParams(args) + if err != nil { + return utils.HandleInitResponse(err.Error(), constants.ExecutorErrCodeInitFail) + } + bh.setEnvContext() + bh.functionLogger, err = functionlog.GetFunctionLogger(bh.config) + if err != nil { + return utils.HandleInitResponse(err.Error(), constants.FaaSError) + } + initTimeout := time.Duration(bh.initTimeout) * time.Second + var ( + bootErr error + costTime time.Duration + ) + if bh.bootstrapFunc != nil { + _, bootErr, costTime = executeWithTimeout(func() (interface{}, error) { + return nil, bh.bootstrapFunc() + }, initTimeout, bh.monitor.ErrChan) + initTimeout -= costTime + } + if bootErr != nil { + return handleBootstrapError(bh.initTimeout, bootErr) + } + _, bootErr, costTime = executeWithTimeout(func() (interface{}, error) { + return nil, bh.awaitReady(initTimeout) + }, initTimeout, bh.monitor.ErrChan) + initTimeout -= costTime + if bootErr != nil { + logger.GetLogger().Errorf("failed to bootstrap function %s", bh.funcSpec.FuncMetaData.FunctionName) + return handleBootstrapError(bh.initTimeout, bootErr) + } + logger.GetLogger().Infof("succeed to bootstrap function %s", bh.funcSpec.FuncMetaData.FunctionName) + return bh.processInitRequest(initTimeout) +} + +// CallHandler parses invoke request and convert to http request which will be sent to user's http server +// args[0]: reserved +// args[1]: invoke payload to the target function +func (bh *basicHandler) CallHandler(args []api.Arg, ctx map[string]string) ([]byte, error) { + traceID := ctx["traceID"] + loggerWith := logger.GetLogger().With(zap.Any("traceID", traceID)) + loggerWith.Infof("start to call function %s", bh.funcSpec.FuncMetaData.FunctionVersionURN) + totalTime := utils.ExecutionDuration{ + ExecutorBeginTime: time.Now(), + } + bh.RLock() + isGracefulShutdown := bh.gracefulShutdown + bh.RUnlock() + if isGracefulShutdown { + loggerWith.Infof("instances are gracefully exiting, function: %s", bh.funcSpec.FuncMetaData.FunctionName) + return utils.HandleCallResponse("graceful shutdown, stop processing", constants.FaaSError, + "", totalTime, nil) + } + bh.Lock() + bh.circuitLock.RLock() + if bh.circuitBreaker { + bh.circuitLock.RUnlock() + bh.Unlock() + loggerWith.Infof("instance is circuit breaker, no need processing call request, function: %s", + bh.funcSpec.FuncMetaData.FunctionName) + return utils.HandleCallResponse("function circuit break, stop processing", + constants.InstanceCircuitBreakError, "", totalTime, nil) + } + bh.circuitLock.RUnlock() + atomic.AddInt32(&bh.requestNum, 1) + bh.Unlock() + defer func() { + atomic.AddInt32(&bh.requestNum, -1) + }() + if len(args) != constants.ValidInvokeArgumentSize { + return utils.HandleCallResponse("invalid invoke argument", constants.FaaSError, + "", totalTime, nil) + } + return bh.processCallRequest(args, traceID, time.Duration(bh.callTimeout)*time.Second, totalTime) +} + +// CheckPointHandler handles checkpoint +func (bh *basicHandler) CheckPointHandler(checkPointId string) ([]byte, error) { + logger.GetLogger().Infof("start to checkpoint function %s", bh.funcSpec.FuncMetaData.FunctionName) + return nil, nil +} + +// RecoverHandler handles recover +func (bh *basicHandler) RecoverHandler(state []byte) error { + logger.GetLogger().Infof("start to recover function %s", bh.funcSpec.FuncMetaData.FunctionName) + return nil +} + +// ShutDownHandler handles shutdown +func (bh *basicHandler) ShutDownHandler(gracePeriodSecond uint64) error { + logger.GetLogger().Infof("start to shutdown function %s", bh.funcSpec.FuncMetaData.FunctionName) + bh.RLock() + isGracefulShutdown := bh.gracefulShutdown + bh.RUnlock() + if isGracefulShutdown { + time.Sleep(time.Duration(gracePeriodSecond) * time.Second) + logger.GetLogger().Warnf("start to shutdown function %s second times", bh.funcSpec.FuncMetaData.FunctionName) + return nil + } + bh.Lock() + bh.gracefulShutdown = true + bh.Unlock() + timer := time.NewTimer(time.Duration(gracePeriodSecond) * time.Second) + exitLoop := false + logger.GetLogger().Infof("wait all request done") + for !exitLoop { + select { + case <-timer.C: + logger.GetLogger().Warnf("reach grace period second %d, kill http server now", gracePeriodSecond) + exitLoop = true + break + default: + bh.RLock() + runningRequests := bh.requestNum + bh.RUnlock() + if runningRequests == 0 { + exitLoop = true + logger.GetLogger().Infof("all requests finished") + closeOnce.Do(func() { + close(stopCh) + }) + break + } + time.Sleep(shutdownCheckInterval) + } + } + logger.GetLogger().Infof("http server for function %s is killed", bh.funcSpec.FuncMetaData.FunctionName) + logger.GetLogger().Sync() + runtimeLog.GetLogger().Sync() + return nil +} + +// SignalHandler api for go-runtime, management function instance by semaphore. +// semaphore: 1-63 Yuanrong kernel reservation +// semaphore: 64-1024 User defined +func (bh *basicHandler) SignalHandler(signal int32, payload []byte) error { + logger.GetLogger().Infof("start to signal function %s", bh.funcSpec.FuncMetaData.FunctionName) + var aliasList []*aliasroute.AliasElement + if signal == killSignalAliasUpdate { + err := json.Unmarshal(payload, &aliasList) + if err != nil { + return err + } + aliasroute.UpdateAliasesMap(aliasList) + } + return nil +} + +func executeWithTimeout(function func() (interface{}, error), timeout time.Duration, monitorErrChan <-chan error) ( + interface{}, error, time.Duration) { + startTime := time.Now() + type asyncResult struct { + res interface{} + err error + } + resCh := make(chan asyncResult, 1) + timer := time.NewTimer(timeout) + go func() { + result, err := function() + resCh <- asyncResult{ + res: result, + err: err, + } + }() + select { + case result := <-resCh: + return result.res, result.err, time.Now().Sub(startTime) + case <-timer.C: + return nil, ErrExecuteTimeout, time.Now().Sub(startTime) + case err := <-monitorErrChan: + return nil, err, time.Now().Sub(startTime) + } +} + +func buildResponseData(response *http.Response, traceID string) (*bytes.Buffer, error) { + const base = 1024 + buffer := &bytes.Buffer{} + buf := make([]byte, base) + logger := logger.GetLogger().With(zap.Any("traceID", traceID)) + defer response.Body.Close() + for { + readLength, err := response.Body.Read(buf) + if err != nil && err != io.EOF { + logger.Errorf("found error in response body,error: %s", err.Error()) + if strings.Contains(err.Error(), "context deadline exceeded") { + return nil, errors.New("http function request timeout") + } + return nil, errors.New("http function request error") + } + if readLength > 0 { + buffer.Write(buf[:readLength]) + } else { + break + } + } + return buffer, nil +} + +// HealthCheckHandler custom health check handler +func (bh *basicHandler) HealthCheckHandler() (api.HealthType, error) { + return api.Healthy, nil +} + +func (bh *basicHandler) isCustomHealthCheckReady() bool { + check := bh.funcSpec.ExtendedMetaData.CustomHealthCheck + if check.TimeoutSeconds == 0 || check.PeriodSeconds == 0 || check.FailureThreshold == 0 { + logger.GetLogger().Infof("custom health check is disabled, no need check") + return true + } + logger.GetLogger().Debugf("custom health check is available, waiting for health check success") + healthURL := fmt.Sprintf("%s/%s", bh.baseURL, bh.healthRoute) + request, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, healthURL, nil) + if err != nil { + logger.GetLogger().Errorf("failed to create request to health route error %s", err.Error()) + return false + } + queries := request.URL.Query() + request.URL.RawQuery = queries.Encode() + request.Header.Add("Content-Type", "application/json") + res, err := bh.sendRequest(request, + time.Duration(check.TimeoutSeconds)*time.Second) + if err == nil && res != nil && res.StatusCode == http.StatusOK { + logger.GetLogger().Infof("health check successfully, custom runtime is ready") + return true + } + return false +} diff --git a/api/go/faassdk/handler/http/basic_handler_test.go b/api/go/faassdk/handler/http/basic_handler_test.go new file mode 100644 index 0000000..b40d1f7 --- /dev/null +++ b/api/go/faassdk/handler/http/basic_handler_test.go @@ -0,0 +1,435 @@ +package http + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + urlpkg "net/url" + "reflect" + "strings" + "testing" + "time" + "yuanrong.org/kernel/runtime/faassdk/common/functionlog" + "yuanrong.org/kernel/runtime/faassdk/common/monitor" + "yuanrong.org/kernel/runtime/faassdk/config" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/faassdk/common/aliasroute" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/libruntime/api" +) + +type mockReadCloser struct { +} + +func (m *mockReadCloser) Read(p []byte) (n int, err error) { + bytes := []byte("hello world") + if p[0] == bytes[0] { + return 0, io.EOF + } + for i, b := range bytes { + p[i] = b + } + return len(bytes), nil +} + +func (m *mockReadCloser) Close() error { + return nil +} + +func Test_basicHandler_CallHandler(t *testing.T) { + config1, _ := config.GetConfig(newFuncSpec()) + fl, _ := functionlog.GetFunctionLogger(config1) + convey.Convey("CallHandler", t, func() { + convey.Convey("graceful shutdown, stop processing", func() { + handler := basicHandler{funcSpec: newFuncSpec(), gracefulShutdown: true} + res, err := handler.CallHandler(nil, nil) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(string(res), "graceful shutdown, stop processing"), convey.ShouldEqual, true) + }) + + convey.Convey("invalid invoke argument", func() { + handler := basicHandler{funcSpec: newFuncSpec()} + args := []api.Arg{ + { + Type: 0, + Data: nil, + }, + } + ctx := make(map[string]string) + res, err := handler.CallHandler(args, ctx) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(string(res), "invalid invoke argument"), convey.ShouldEqual, true) + }) + + convey.Convey("unmarshal invoke call request data", func() { + handler := basicHandler{ + funcSpec: newFuncSpec(), + baseURL: fmt.Sprintf("http://%s:%d", listenIP, 8080), + port: 8080, + callRoute: "/callRoute", + functionLogger: fl, + } + args := []api.Arg{ + { + Type: 0, + Data: nil, + }, + { + Type: 0, + Data: []byte(`{`), + }, + } + ctx := make(map[string]string) + res, err := handler.CallHandler(args, ctx) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(string(res), "unmarshal invoke call request data"), convey.ShouldEqual, true) + }) + + convey.Convey("call request timed out", func() { + defer gomonkey.ApplyFunc(http.NewRequestWithContext, + func(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + u, _ := urlpkg.Parse(url) + return &http.Request{URL: u, Header: make(map[string][]string)}, nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Do", + func(_ *http.Client, req *http.Request) (*http.Response, error) { + time.Sleep(3*time.Second + 10*time.Millisecond) + return nil, errors.New("timeout 3s") + }).Reset() + handler := basicHandler{ + funcSpec: newFuncSpec(), + baseURL: fmt.Sprintf("http://%s:%d", listenIP, 8080), + port: 8080, + callRoute: "callRoute", + callTimeout: 3, + functionLogger: fl, + monitor: &monitor.FunctionMonitorManager{}, + } + apigEvent := &APIGTriggerEvent{ + Body: "hello", + Path: "test", + } + apigEventBytes, _ := json.Marshal(apigEvent) + userCallRequest := &types.CallRequest{ + Body: apigEventBytes, + } + userCallRequestBytes, _ := json.Marshal(userCallRequest) + args := []api.Arg{ + { + Type: 0, + Data: nil, + }, + { + Type: 0, + Data: userCallRequestBytes, + }, + } + ctx := make(map[string]string) + res, err := handler.CallHandler(args, ctx) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(string(res), "call request timed out"), convey.ShouldEqual, true) + }) + + convey.Convey("call request failed error", func() { + defer gomonkey.ApplyFunc(http.NewRequestWithContext, + func(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + u, _ := urlpkg.Parse(url) + return &http.Request{URL: u, Header: make(map[string][]string)}, nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Do", + func(_ *http.Client, req *http.Request) (*http.Response, error) { + return nil, errors.New("user function error") + }).Reset() + handler := basicHandler{ + funcSpec: newFuncSpec(), + baseURL: fmt.Sprintf("http://%s:%d", listenIP, 8080), + port: 8080, + callRoute: "callRoute", + callTimeout: 3, + functionLogger: fl, + monitor: &monitor.FunctionMonitorManager{}, + } + apigEvent := &APIGTriggerEvent{ + Body: "hello", + Path: "test", + } + apigEventBytes, _ := json.Marshal(apigEvent) + userCallRequest := &types.CallRequest{ + Body: apigEventBytes, + } + userCallRequestBytes, _ := json.Marshal(userCallRequest) + args := []api.Arg{ + { + Type: 0, + Data: nil, + }, + { + Type: 0, + Data: userCallRequestBytes, + }, + } + ctx := make(map[string]string) + res, err := handler.CallHandler(args, ctx) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(string(res), "user function error"), convey.ShouldEqual, true) + }) + + convey.Convey("success", func() { + defer gomonkey.ApplyFunc(http.NewRequestWithContext, + func(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + u, _ := urlpkg.Parse(url) + return &http.Request{URL: u, Header: make(map[string][]string)}, nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Do", + func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{Body: &mockReadCloser{}}, nil + }).Reset() + handler := basicHandler{ + funcSpec: newFuncSpec(), + baseURL: fmt.Sprintf("http://%s:%d", listenIP, 8080), + port: 8080, + callRoute: "callRoute", + callTimeout: 3, + functionLogger: fl, + monitor: &monitor.FunctionMonitorManager{}, + } + apigEvent := &APIGTriggerEvent{ + Body: "hello", + Path: "test", + } + apigEventBytes, _ := json.Marshal(apigEvent) + userCallRequest := &types.CallRequest{ + Body: apigEventBytes, + } + userCallRequestBytes, _ := json.Marshal(userCallRequest) + args := []api.Arg{ + { + Type: 0, + Data: nil, + }, + { + Type: 0, + Data: userCallRequestBytes, + }, + } + ctx := make(map[string]string) + res, err := handler.CallHandler(args, ctx) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(string(res), base64.StdEncoding.EncodeToString([]byte("hello world"))), + convey.ShouldEqual, true) + }) + + convey.Convey("circuitBreaker is true", func() { + handler := basicHandler{funcSpec: newFuncSpec()} + handler.circuitBreaker = true + args := []api.Arg{ + { + Type: 0, + Data: nil, + }, + } + ctx := make(map[string]string) + res, err := handler.CallHandler(args, ctx) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(string(res), "function circuit break"), convey.ShouldEqual, true) + }) + }) +} + +func TestCheckPointHandler(t *testing.T) { + convey.Convey("CheckPointHandler", t, func() { + handler := basicHandler{ + funcSpec: newFuncSpec(), + baseURL: fmt.Sprintf("http://%s:%d", listenIP, 8080), + port: 8080, + callRoute: "callRoute", + callTimeout: 3, + } + bytes, err := handler.CheckPointHandler("checkpointID") + convey.So(err, convey.ShouldBeNil) + convey.So(bytes, convey.ShouldBeNil) + }) +} + +func TestRecoverHandler(t *testing.T) { + convey.Convey("RecoverHandler", t, func() { + handler := basicHandler{ + funcSpec: newFuncSpec(), + baseURL: fmt.Sprintf("http://%s:%d", listenIP, 8080), + port: 8080, + callRoute: "callRoute", + callTimeout: 3, + } + err := handler.RecoverHandler([]byte{}) + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestSignalHandler(t *testing.T) { + convey.Convey("SignalHandler", t, func() { + handler := basicHandler{ + funcSpec: newFuncSpec(), + baseURL: fmt.Sprintf("http://%s:%d", listenIP, 8080), + port: 8080, + callRoute: "callRoute", + callTimeout: 3, + } + convey.Convey("handle signal success", func() { + aliasList := []*aliasroute.AliasElement{{AliasURN: "aaa"}} + payload, _ := json.Marshal(aliasList) + err := handler.SignalHandler(64, payload) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("handle signal error", func() { + err := handler.SignalHandler(64, []byte("aaa")) + convey.So(err, convey.ShouldNotBeNil) + }) + + }) +} + +func TestParseCreateParams(t *testing.T) { + convey.Convey("TestParseCreateParams", t, func() { + handler := basicHandler{ + funcSpec: newFuncSpec(), + baseURL: fmt.Sprintf("http://%s:%d", listenIP, 8080), + port: 8080, + callRoute: "callRoute", + callTimeout: 3, + } + funcSpecBytes, _ := json.Marshal(newFuncSpecWithoutTimeout()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + convey.Convey("unmarshal args[0].Data failed", func() { + args := []api.Arg{ + { + Type: 0, + Data: []byte(""), + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: 0, + Data: []byte(""), + }, + } + err := handler.parseCreateParams(args) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("unmarshal args[1].Data failed", func() { + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: []byte(""), + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: 0, + Data: []byte(""), + }, + } + err := handler.parseCreateParams(args) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("unmarshal args[2].Data failed", func() { + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: []byte(""), + }, + { + Type: 0, + Data: []byte(""), + }, + } + err := handler.parseCreateParams(args) + convey.So(err, convey.ShouldNotBeNil) + }) + + }) +} + +func Test_basicHandler_isCustomHealthCheckReady(t *testing.T) { + convey.Convey("test isCustomHealthCheckReady", t, func() { + convey.Convey("baseline", func() { + handler := &basicHandler{ + funcSpec: &types.FuncSpec{ExtendedMetaData: types.ExtendedMetaData{ + CustomHealthCheck: types.CustomHealthCheck{ + TimeoutSeconds: 1, + PeriodSeconds: 1, + FailureThreshold: 10, + }, + }}, + baseURL: "127.0.0.1", + healthRoute: healthRoute, + } + p := gomonkey.ApplyFunc((*basicHandler).sendRequest, func(_ *basicHandler, request *http.Request, + timeout time.Duration) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK}, nil + }) + defer p.Reset() + convey.So(handler.isCustomHealthCheckReady(), convey.ShouldBeTrue) + }) + convey.Convey("check failed", func() { + handler := &basicHandler{ + funcSpec: &types.FuncSpec{ExtendedMetaData: types.ExtendedMetaData{ + CustomHealthCheck: types.CustomHealthCheck{ + TimeoutSeconds: 1, + PeriodSeconds: 1, + FailureThreshold: 10, + }, + }}, + baseURL: "127.0.0.1", + healthRoute: healthRoute, + } + p := gomonkey.ApplyFunc((*basicHandler).sendRequest, func(_ *basicHandler, request *http.Request, + timeout time.Duration) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusInternalServerError}, nil + }) + defer p.Reset() + convey.So(handler.isCustomHealthCheckReady(), convey.ShouldBeFalse) + }) + convey.Convey("custom health check disabled", func() { + handler := &basicHandler{ + funcSpec: &types.FuncSpec{ExtendedMetaData: types.ExtendedMetaData{ + CustomHealthCheck: types.CustomHealthCheck{ + TimeoutSeconds: 0, + PeriodSeconds: 0, + FailureThreshold: 0, + }, + }}, + baseURL: "127.0.0.1", + healthRoute: healthRoute, + } + convey.So(handler.isCustomHealthCheckReady(), convey.ShouldBeTrue) + }) + }) +} diff --git a/api/go/faassdk/handler/http/crossclusterinvoke/httpclient.go b/api/go/faassdk/handler/http/crossclusterinvoke/httpclient.go new file mode 100644 index 0000000..e828686 --- /dev/null +++ b/api/go/faassdk/handler/http/crossclusterinvoke/httpclient.go @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crossclusterinvoke - +package crossclusterinvoke + +import ( + "sync" + "sync/atomic" + + "github.com/valyala/fasthttp" +) + +var httpClients []*fasthttp.Client +var httpClientsIndex uint32 +var httpClientsOnce sync.Once +var httpClientNum uint32 = 5 + +// GetHttpClient - +func GetHttpClient() *fasthttp.Client { + httpClientsOnce.Do(func() { + httpClients = make([]*fasthttp.Client, httpClientNum, httpClientNum) + for i := 0; i < int(httpClientNum); i++ { + httpClients[i] = &fasthttp.Client{} + } + httpClientsIndex = 0 + }) + index := atomic.AddUint32(&httpClientsIndex, 1) % httpClientNum + return httpClients[index] +} diff --git a/api/go/faassdk/handler/http/crossclusterinvoke/invoker.go b/api/go/faassdk/handler/http/crossclusterinvoke/invoker.go new file mode 100644 index 0000000..7a13cf7 --- /dev/null +++ b/api/go/faassdk/handler/http/crossclusterinvoke/invoker.go @@ -0,0 +1,370 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crossclusterinvoke - +package crossclusterinvoke + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/valyala/fasthttp" + "go.uber.org/zap" + + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/faassdk/sts" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils" + "yuanrong.org/kernel/runtime/faassdk/utils/signer" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + crossHeaderKeyTraceID = "X-Caas-Trace-Id" + crossHeaderKeyCrossCluster = "X-System-Cross-Cluster" + crossHeaderKeyClusterID = "X-System-Clusterid" + crossHeaderKeyTimestamp = "X-System-Timestamp" + crossHeaderKeySignature = "X-System-Signature" +) + +var ( + stsInitOnce sync.Once + stsInitErr error +) + +// InvokeConfig - +type InvokeConfig struct { + Enable bool `json:"enable" valid:"optional"` + CrossClusterAddr string `json:"crossClusterAddr" valid:"optional"` + ErrorCodes string `json:"errorCodes" valid:"optional"` + ErrCodeMap map[int]struct{} + AcquireTimeout int `json:"acquireTimeout" valid:"optional"` +} + +// AuthConfig - +type AuthConfig struct { + AccessKey string `json:"accessKey" valid:"optional"` + SecretKey string `json:"secretKey" valid:"optional"` +} + +// Invoker - +type Invoker struct { + StsServerConfig types.StsServerConfig + FuncInfo urnutils.BaseURN + initErr error + InvokeConfig + AuthConfig +} + +// NewInvoker - +func NewInvoker(urn string) *Invoker { + funcInfo, err := urnutils.GetFunctionInfo(urn) + if err != nil { + logger.GetLogger().Warnf("new cross cluster invoker error, err: %s", err.Error()) + } + return &Invoker{ + InvokeConfig: getCrossClusterInvokeConfig(), + AuthConfig: getCrossClusterAuthConfig(), + FuncInfo: funcInfo, + initErr: err, + } +} + +// NeedCrossClusterInvoke - +func (invoker *Invoker) NeedCrossClusterInvoke(err error) bool { + if !invoker.InvokeConfig.Enable { + return false + } + snErr, ok := err.(api.ErrorInfo) + if !ok { + return false + } + if _, ok := invoker.ErrCodeMap[snErr.Code]; !ok { + return false + } + return true +} + +func (invoker *Invoker) getCalleeUrn(name, version string) string { + calleeInfo := invoker.FuncInfo + calleeInfo.Name = buildCalleeFullFuncName(invoker.FuncInfo.Name, name) + calleeInfo.Version = version + return calleeInfo.String() +} + +// DoInvoke - +func (invoker *Invoker) DoInvoke(request types.InvokeRequest, response *types.GetFutureResponse, timeout time.Duration, + logger api.FormatLogger) bool { + if invoker.initErr != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = "crossClusterInvoker not ready, err: " + invoker.initErr.Error() + return false + } + defer func() { + if response.StatusCode != constants.NoneError { + response.ErrorMessage = "do cross cluster invoke failed, " + response.GetErrorMessage() + } + }() + if timeout <= 0 { + response.StatusCode = constants.FaaSError + response.ErrorMessage = "no time left" + return false + } + calleeUrn := request.FuncUrn + if calleeUrn == "" { + calleeUrn = invoker.getCalleeUrn(request.FuncName, request.FuncVersion) + } + logger = logger.With(zap.Any("calleeUrn", calleeUrn), zap.Any("timeout", timeout.Milliseconds()), + zap.Any("host", invoker.CrossClusterAddr)) + invokeUrl := fmt.Sprintf("/serverless/v1/functions/%s/invocations", calleeUrn) + + httpReq := fasthttp.AcquireRequest() + httpRsp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(httpReq) + defer fasthttp.ReleaseResponse(httpRsp) + + err := invoker.setHeader(httpReq, request, logger) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("setHeader failed, err: %s", err) + return false + } + httpReq.SetRequestURI(invokeUrl) + httpReq.URI().SetScheme("http") + httpReq.SetBody([]byte(request.Payload)) + httpClient := GetHttpClient() + err = httpClient.DoTimeout(httpReq, httpRsp, timeout) + + if needTryLocalCluster(err, httpRsp, logger) { + logger.Infof("cross cluster is not ready or upgrading") + return true + } + + if err != nil { + logger.Errorf("do invoke failed, err: %s", err) + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("do invoke failed, err: %s", err) + return false + } + if httpRsp.StatusCode()/100 != 2 { // 2xx is ok + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("do invoke failed, statuscode:%d is not accepted", httpRsp.StatusCode()) + return false + } + handleHttpResponse(httpRsp, response) + return false +} + +func needTryLocalCluster(err error, resp *fasthttp.Response, logger api.FormatLogger) bool { + if err != nil { + if strings.Contains(err.Error(), fasthttp.ErrDialTimeout.Error()) || utils.ContainsConnRefusedErr(err) { + logger.Errorf("do cross invoke failed, cross cluster is in error: %s", err.Error()) + return true + } + return false + } + if resp == nil { + return false + } + if resp.StatusCode() == fasthttp.StatusNotFound { + return true + } + if resp.StatusCode()/100 != 2 { // 2xx is ok + return false + } + callResponse := &struct { + Code int `json:"code"` + Message string `json:"message"` + UserResponse json.RawMessage `json:"userResponse"` + }{} + err = json.Unmarshal(resp.Body(), callResponse) + if err != nil { + return false + } + if callResponse.Code == constants.ClusterIsUpgrading { + logger.Errorf("do cross invoke failed, cross cluster is in upgrading") + return true + } + return false +} + +func handleHttpResponse(httpResp *fasthttp.Response, response *types.GetFutureResponse) { + callResponse := &struct { + Code int `json:"code"` + Message string `json:"message"` + UserResponse json.RawMessage `json:"userResponse"` + }{} + err := json.Unmarshal(httpResp.Body(), callResponse) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("call response unmarshal error %s", err.Error()) + return + } + if callResponse.Code != constants.NoneError { + response.StatusCode = callResponse.Code + response.ErrorMessage = callResponse.Message + return + } + contentBytes, err := callResponse.UserResponse.MarshalJSON() + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("call response content unmarshal error %s", err.Error()) + return + } + response.StatusCode = constants.NoneError + + response.Content = string(contentBytes) +} + +func (invoker *Invoker) setHeader(httpReq *fasthttp.Request, request types.InvokeRequest, + logger api.FormatLogger) error { + httpReq.Header.ResetConnectionClose() + httpReq.Header.SetHost(invoker.CrossClusterAddr) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set(crossHeaderKeyTraceID, request.TraceID) + httpReq.Header.Set(crossHeaderKeyCrossCluster, "true") + httpReq.Header.Set(crossHeaderKeyClusterID, os.Getenv("CLUSTER_ID")) + httpReq.Header.SetMethod(fasthttp.MethodPost) + timeStamp := strconv.FormatInt(time.Now().Unix(), 10) // int64 -> string + + httpReq.Header.Set(crossHeaderKeyTimestamp, timeStamp) + ak, sk, err := invoker.parseAuthConfig() + if err != nil { + logger.Errorf("parse secret error %s", err.Error()) + return err + } + signature := buildSignature(timeStamp, []byte(request.Payload), string(ak)) + sign := signer.Sign(sk, signature) + signStr := signer.EncodeHex(sign) + + buildAuth := signer.BuildAuthorization(string(ak), timeStamp, signStr) + httpReq.Header.Set(crossHeaderKeySignature, buildAuth) + return nil +} + +func (invoker *Invoker) parseAuthConfig() ([]byte, []byte, error) { + err := initSTS(invoker.StsServerConfig) + if err != nil { + return []byte{}, []byte{}, fmt.Errorf("init STS error,err:%s", err.Error()) + } + + AKey := invoker.AccessKey + SKey := invoker.SecretKey + + if AKey == "" || SKey == "" { + return []byte{}, []byte{}, fmt.Errorf("AK or SK is nil") + } + accessKey, err := stsgoapi.DecryptSensitiveConfig(AKey) + if err != nil { + return []byte{}, []byte{}, fmt.Errorf("decrypt accessKey failed, err: %s", err) + } + secretKey, err := stsgoapi.DecryptSensitiveConfig(SKey) + if err != nil { + return []byte{}, []byte{}, fmt.Errorf("decrypt secretKey failed , err: %s", err) + } + decodeKey, err := base64.StdEncoding.DecodeString(string(secretKey)) + if err != nil { + return []byte{}, []byte{}, fmt.Errorf("decode secretKey failed , err: %s", err) + } + return accessKey, decodeKey, nil +} + +func buildSignature(timeStamp string, body []byte, tenantId string) []byte { + signValues := [][]byte{ + []byte(tenantId), + []byte(timeStamp), + body, + } + return bytes.Join(signValues, []byte("&")) +} + +func buildCalleeFullFuncName(callerFullFuncName string, calleeFuncName string) string { + defaultPrefix := "0@default@" + splits := strings.Split(callerFullFuncName, "@") // "@" separator + if len(splits) != 3 { // example: 0@default@hello + return defaultPrefix + calleeFuncName + } + splits[2] = calleeFuncName // callerFuncName -> calleeFuncName + return strings.Join(splits, "@") // separator +} + +func initSTS(serverCfg types.StsServerConfig) error { + stsInitOnce.Do(func() { + stsInitErr = sts.InitStsSDK(serverCfg) + if stsInitErr != nil { + logger.GetLogger().Errorf("failed to init sts sdk, err: %s", stsInitErr.Error()) + } else { + logger.GetLogger().Infof("succeeded in initing sts sdk") + } + }) + return stsInitErr +} + +func getCrossClusterInvokeConfig() InvokeConfig { + var crossClusterInvokeConfig InvokeConfig + content, err := os.ReadFile(config.DefaultRuntimeJsonFilePath) + if err != nil { + logger.GetLogger().Errorf("read crossClusterInvokeConfig failed, err: %s, "+ + "filePath: %s", err.Error(), config.DefaultRuntimeJsonFilePath) + return crossClusterInvokeConfig + } + configStruct := struct { + InvokeConfig `json:"crossClusterInvokeConfig"` + }{} + err = json.Unmarshal(content, &configStruct) + if err != nil { + logger.GetLogger().Errorf("unmarshal crossClusterInvokeConfig env failed, "+ + "content: %s, err: %v", string(content), err.Error()) + return crossClusterInvokeConfig + } + crossClusterInvokeConfig = configStruct.InvokeConfig + splits := strings.Split(crossClusterInvokeConfig.ErrorCodes, ",") + crossClusterInvokeConfig.ErrCodeMap = make(map[int]struct{}) + for _, v := range splits { + errCode, err := strconv.Atoi(v) + if err != nil { + logger.GetLogger().Errorf("parse errCode failed, v: %s, err: %v", v, err) + continue + } + crossClusterInvokeConfig.ErrCodeMap[errCode] = struct{}{} + } + logger.GetLogger().Infof("show crossclusterinvokeconfig: %v", crossClusterInvokeConfig) + return crossClusterInvokeConfig +} + +func getCrossClusterAuthConfig() AuthConfig { + var crossClusterAuthConfig AuthConfig + csac := os.Getenv("CROSS_CLUSTER_AUTH_CONFIG") + err := json.Unmarshal([]byte(csac), &crossClusterAuthConfig) + if err != nil { + logger.GetLogger().Errorf("unmarshal crossClusterAuthConfig failed, config:%s err: %s", + csac, err.Error()) + return crossClusterAuthConfig + } + logger.GetLogger().Infof("get crossclusterAuthConfig success") + return crossClusterAuthConfig +} diff --git a/api/go/faassdk/handler/http/crossclusterinvoke/invoker_test.go b/api/go/faassdk/handler/http/crossclusterinvoke/invoker_test.go new file mode 100644 index 0000000..7cbc075 --- /dev/null +++ b/api/go/faassdk/handler/http/crossclusterinvoke/invoker_test.go @@ -0,0 +1,370 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crossclusterinvoke - +package crossclusterinvoke + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "os" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/valyala/fasthttp" + + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/sts" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils/signer" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" + "yuanrong.org/kernel/runtime/libruntime/api" + log "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +func Test_NewInvoker(t *testing.T) { + convey.Convey("init Invoker", t, func() { + convey.Convey("init invoker failed", func() { + invoker := NewInvoker("") + convey.So(invoker, convey.ShouldNotBeNil) + convey.So(invoker.initErr, convey.ShouldNotBeNil) + }) + convey.Convey("init invoker ok", func() { + invoker := NewInvoker("sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest") + convey.So(invoker, convey.ShouldNotBeNil) + convey.So(invoker.initErr, convey.ShouldBeNil) + }) + }) +} + +func getMockInvoker() *Invoker { + funcInfo, err := urnutils.GetFunctionInfo("sn:cn:yrk:12345678901234561234567890123456:" + + "function:helloworld:$latest") + return &Invoker{ + InvokeConfig: InvokeConfig{ + Enable: true, + CrossClusterAddr: "127.0.0.1:31222", + ErrorCodes: "1223,12345,345", + ErrCodeMap: map[int]struct{}{ + 1223: {}, + 12345: {}, + 345: {}, + }, + AcquireTimeout: 0, + }, + AuthConfig: AuthConfig{ + AccessKey: "testAK", + SecretKey: "testSK", + }, + FuncInfo: funcInfo, + initErr: err, + } +} + +func Test_NeedCrossClusterInvoke(t *testing.T) { + convey.Convey("test NeedCrossClusterInvoke", t, func() { + convey.Convey("test NeedCrossClusterInvoke failed", func() { + invoker := getMockInvoker() + invoker.InvokeConfig.Enable = false + convey.So(invoker.NeedCrossClusterInvoke(fmt.Errorf("fdafadsfda")), convey.ShouldBeFalse) + }) + convey.Convey("test NeedCrossClusterInvoke failed1", func() { + invoker := getMockInvoker() + invoker.InvokeConfig.Enable = false + convey.So(invoker.NeedCrossClusterInvoke(api.ErrorInfo{ + Code: 1111, + Err: errors.New(""), + }), convey.ShouldBeFalse) + }) + convey.Convey("test NeedCrossClusterInvoke ok", func() { + invoker := getMockInvoker() + invoker.InvokeConfig.Enable = true + convey.So(invoker.NeedCrossClusterInvoke(api.ErrorInfo{ + Code: 12345, + Err: errors.New(""), + }), convey.ShouldBeTrue) + }) + }) +} + +func Test_buildCalleeFullFuncName(t *testing.T) { + tt := []struct { + name string + callerFullFuncName string + calleeFuncName string + expectCalleeFullFuncName string + }{ + { + name: "test buildCalleeFullFuncName case0", + callerFullFuncName: "0@deafult@hello", + calleeFuncName: "world", + expectCalleeFullFuncName: "0@deafult@world", + }, + { + name: "test buildCalleeFullFuncName case1", + callerFullFuncName: "0@service@hello", + calleeFuncName: "world", + expectCalleeFullFuncName: "0@service@world", + }, + { + name: "test buildCalleeFullFuncName case2", + callerFullFuncName: "0@servicehello", + calleeFuncName: "world", + expectCalleeFullFuncName: "0@default@world", + }, + { + name: "test buildCalleeFullFuncName case3", + callerFullFuncName: "1@service@hello", + calleeFuncName: "world", + expectCalleeFullFuncName: "1@service@world", + }, + } + + for _, ttt := range tt { + convey.Convey(ttt.name, t, func() { + actualCalleeFullFuncName := buildCalleeFullFuncName(ttt.callerFullFuncName, ttt.calleeFuncName) + convey.So(actualCalleeFullFuncName, convey.ShouldEqual, ttt.expectCalleeFullFuncName) + }) + } +} + +func TestInvoker_DoInvoke(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(&stsInitOnce), "Do", func(_ *sync.Once, f func()) { + f() + }).Reset() + convey.Convey("invoker init err", t, func() { + invoker := getMockInvoker() + invoker.initErr = fmt.Errorf("error is error") + resp := &types.GetFutureResponse{} + invoker.DoInvoke(types.InvokeRequest{}, resp, 0, log.GetLogger()) + convey.So(resp.StatusCode, convey.ShouldEqual, constants.FaaSError) + convey.So(resp.ErrorMessage, convey.ShouldEqual, "crossClusterInvoker not ready, err: "+invoker.initErr.Error()) + }) + + convey.Convey("timeout < 0 err", t, func() { + invoker := getMockInvoker() + resp := &types.GetFutureResponse{} + invoker.DoInvoke(types.InvokeRequest{}, resp, -1, log.GetLogger()) + convey.So(resp.StatusCode, convey.ShouldEqual, constants.FaaSError) + convey.So(resp.ErrorMessage, convey.ShouldEqual, "do cross cluster invoke failed, no time left") + }) + + convey.Convey("init sts failed", t, func() { + defer gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg types.StsServerConfig) error { + return errors.New("mock sts init error") + }).Reset() + invoker := getMockInvoker() + resp := &types.GetFutureResponse{} + + invoker.DoInvoke(types.InvokeRequest{}, resp, 100, log.GetLogger()) + + convey.So(resp.StatusCode, convey.ShouldEqual, constants.FaaSError) + convey.So(strings.Contains(resp.ErrorMessage, "mock sts init error"), convey.ShouldBeTrue) + }) + + convey.Convey("ak/sk is empty", t, func() { + defer gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg types.StsServerConfig) error { + return nil + }).Reset() + defer gomonkey.ApplyFunc(stsgoapi.DecryptSensitiveConfig, + func(rawConfigValue string) (plainBytes []byte, err error) { + return []byte{}, errors.New("mock sts error 1111") + }).Reset() + invoker := getMockInvoker() + invoker.AccessKey = "" + invoker.SecretKey = "" + resp := &types.GetFutureResponse{} + invoker.DoInvoke(types.InvokeRequest{}, resp, 100, log.GetLogger()) + convey.So(resp.StatusCode, convey.ShouldEqual, constants.FaaSError) + convey.So(strings.Contains(resp.ErrorMessage, "AK or SK is nil"), convey.ShouldBeTrue) + }) + + convey.Convey("decrypt ak/sk failed", t, func() { + defer gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg types.StsServerConfig) error { + return nil + }).Reset() + defer gomonkey.ApplyFunc(stsgoapi.DecryptSensitiveConfig, + func(rawConfigValue string) (plainBytes []byte, err error) { + return []byte{}, errors.New("mock sts error") + }).Reset() + invoker := getMockInvoker() + resp := &types.GetFutureResponse{} + invoker.DoInvoke(types.InvokeRequest{}, resp, 100, log.GetLogger()) + convey.So(resp.StatusCode, convey.ShouldEqual, constants.FaaSError) + convey.So(strings.Contains(resp.ErrorMessage, "mock sts error"), convey.ShouldBeTrue) + }) + + convey.Convey("do invoke cases", t, func() { + invoker := getMockInvoker() + type bodyStruct struct { + Code int `json:"code"` + Message string `json:"message"` + UserResponse json.RawMessage `json:"userResponse"` + } + mockResponse := struct { + response bodyStruct + err error + }{} + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&fasthttp.Client{}), "DoTimeout", + func(_ *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { + body, _ := json.Marshal(mockResponse.response) + resp.SetBody(body) + return mockResponse.err + }), + gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg types.StsServerConfig) error { + return nil + }), + gomonkey.ApplyFunc(stsgoapi.DecryptSensitiveConfig, func(rawConfigValue string) (plainBytes []byte, err error) { + return []byte{}, nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + resp := &types.GetFutureResponse{} + convey.Convey("do invoke ok", func() { + mockResponse.response = bodyStruct{ + UserResponse: json.RawMessage("{\"key\":\"hello world\"}"), + Code: 0, + Message: "", + } + mockResponse.err = nil + invoker.DoInvoke(types.InvokeRequest{ + FuncUrn: "aaa", + }, resp, 100*time.Second, log.GetLogger()) + convey.So(resp.StatusCode, convey.ShouldEqual, constants.NoneError) + convey.So(resp.Content, convey.ShouldEqual, "{\"key\":\"hello world\"}") + }) + convey.Convey("do invoke failed 0", func() { + mockResponse.response = bodyStruct{ + UserResponse: json.RawMessage("{\"key\": \"hello world\"}"), + Code: 0, + Message: "", + } + mockResponse.err = fmt.Errorf("error is error") + invoker.DoInvoke(types.InvokeRequest{}, resp, 100*time.Second, log.GetLogger()) + convey.So(resp.StatusCode, convey.ShouldEqual, constants.FaaSError) + convey.So(strings.Contains(resp.ErrorMessage, "error is error"), convey.ShouldBeTrue) + }) + convey.Convey("do invoke failed 1", func() { + mockResponse.response = bodyStruct{ + UserResponse: json.RawMessage("{\"key\":\"hello world\"}"), + Code: constants.FunctionRunError, + Message: "errorMsg is errorMsg", + } + mockResponse.err = nil + invoker.DoInvoke(types.InvokeRequest{}, resp, 100*time.Second, log.GetLogger()) + convey.So(resp.StatusCode, convey.ShouldEqual, constants.FunctionRunError) + convey.So(strings.Contains(resp.ErrorMessage, "errorMsg is errorMsg"), convey.ShouldBeTrue) + }) + }) +} + +func TestNeedTryLocalCluster(t *testing.T) { + convey.Convey( + "Test needTryLocalCluster", t, func() { + convey.Convey("needTryLocalCluster success", func() { + err := errors.New("") + resp := fasthttp.AcquireResponse() + flag := needTryLocalCluster(err, resp, nil) + convey.So(flag, convey.ShouldBeFalse) + flag = needTryLocalCluster(nil, nil, nil) + convey.So(flag, convey.ShouldBeFalse) + flag = needTryLocalCluster(nil, resp, nil) + convey.So(flag, convey.ShouldBeFalse) + }) + }, + ) +} + +func TestHandleHttpResponse(t *testing.T) { + convey.Convey( + "Test handleHttpResponse", t, func() { + convey.Convey("handleHttpResponse success", func() { + resp := fasthttp.AcquireResponse() + convey.So(func() { + handleHttpResponse(resp, &types.GetFutureResponse{}) + }, convey.ShouldNotPanic) + }) + }, + ) +} + +func Test_getCrossClusterInvokeConfig(t *testing.T) { + convey.Convey("getCrossClusterInvokeConfig", t, func() { + defer gomonkey.ApplyFunc(os.ReadFile, func(name string) ([]byte, error) { + return []byte(`{"crossClusterInvokeConfig": {"crossClusterAddr": "127.0.0.1:8080", "errorCodes": "150420,150421"}}`), nil + }).Reset() + config := getCrossClusterInvokeConfig() + _, exist := config.ErrCodeMap[150420] + convey.So(exist, convey.ShouldBeTrue) + }) +} + +func Test_getCrossClusterAuthConfig(t *testing.T) { + convey.Convey("getCrossClusterInvokeConfig", t, func() { + defer gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return `{"accessKey":"testAk","secretKey":"testSK"}` + }).Reset() + config := getCrossClusterAuthConfig() + convey.So(config.AccessKey, convey.ShouldEqual, "testAk") + convey.So(config.SecretKey, convey.ShouldEqual, "testSK") + }) +} + +func Test_SetHeader(t *testing.T) { + convey.Convey("Test SetHeader", t, func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg types.StsServerConfig) error { + return nil + }), + gomonkey.ApplyFunc(stsgoapi.DecryptSensitiveConfig, + func(rawConfigValue string) (plainBytes []byte, err error) { + return []byte(base64.StdEncoding.EncodeToString([]byte("aaa"))), nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + invoker := getMockInvoker() + httpReq := &fasthttp.Request{} + req := types.InvokeRequest{ + Payload: "hello", + } + httpReq.SetBody([]byte(req.Payload)) + err := invoker.setHeader(httpReq, req, log.GetLogger()) + timeStamp := "1736864093" + signature := buildSignature(timeStamp, []byte(req.Payload), invoker.AccessKey) + sign := signer.Sign([]byte(invoker.SecretKey), signature) + signStr := signer.EncodeHex(sign) + buildAuth := signer.BuildAuthorization(invoker.AccessKey, timeStamp, signStr) + convey.So(err, convey.ShouldBeNil) + convey.So(buildAuth, convey.ShouldEqual, "SDK-HMAC-SHA256 accessId=testAK,timestamp="+timeStamp+",signature=0b3558036eaf229573c9390016e2adaaa4b39403bde85e6f075bb0b224658b7f") + }) +} diff --git a/api/go/faassdk/handler/http/custom_container_handler.go b/api/go/faassdk/handler/http/custom_container_handler.go new file mode 100644 index 0000000..eb40271 --- /dev/null +++ b/api/go/faassdk/handler/http/custom_container_handler.go @@ -0,0 +1,1150 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http - +package http + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/uuid" + + "yuanrong.org/kernel/runtime/faassdk/common/alarm" + "yuanrong.org/kernel/runtime/faassdk/common/aliasroute" + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/common/faasscheduler" + "yuanrong.org/kernel/runtime/faassdk/common/functionlog" + "yuanrong.org/kernel/runtime/faassdk/common/monitor" + "yuanrong.org/kernel/runtime/faassdk/common/tokentosecret" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/faassdk/handler" + "yuanrong.org/kernel/runtime/faassdk/handler/http/crossclusterinvoke" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" + "yuanrong.org/kernel/runtime/libruntime/common" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + invokeServerAddr = "127.0.0.1:31538" + funcNameSeparator = "@" + defaultInvokeTimeout = 900 * time.Second + customContainerCallPath = "invoke" + // CustomGracefulShutdownClientTimeout - + CustomGracefulShutdownClientTimeout = time.Duration(3) * time.Second + // CustomGracefulShutdownPeriod - + CustomGracefulShutdownPeriod = 5 + maxInvokeRetries = 5 + customContainerSystemTenantID = "12345678901234561234567890123456" // same with systemTenantID in rpc pkg +) + +const ( + invokeRoute = "/invoke" + getFutureRoute = "/future" + stateNewRoute = "/state/new" + stateGetRoute = "/state/get" + stateDeleteRoute = "/state/delete" + circuitBreakRoute = "/circuitbreak" + exitRoute = "/exit" + credentialRequireRoute = "/serverless/v1/credential/require" +) + +var ( + stsInitOnce sync.Once + stsInitErr error +) + +type status int + +const ( + uninitialized status = iota + completed + failed +) + +// CustomContainerHandler - +type CustomContainerHandler struct { + *basicHandler + invokeServer *http.Server + invokeClient *http.Client + logger *RuntimeContainerLogger + futureMap map[string]chan types.GetFutureResponse + alarmConfig types.AlarmConfig + customUserArgs types.CustomUserArgs + remoteDsClient api.KvClient + stateMgr *stateManager + stateMgrState status + stateMgrLock sync.RWMutex + crossClusterInvoker *crossclusterinvoke.Invoker +} + +// NewCustomContainerHandler creates CustomContainerHandler +func NewCustomContainerHandler(funcSpec *types.FuncSpec, client api.LibruntimeAPI) handler.ExecutorHandler { + handler := &CustomContainerHandler{ + basicHandler: newBasicHandler(funcSpec, client), + invokeServer: &http.Server{ + Addr: invokeServerAddr, + }, + invokeClient: &http.Client{ + Transport: http.DefaultTransport, + Timeout: defaultInvokeTimeout, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + } + handler.crossClusterInvoker = crossclusterinvoke.NewInvoker(funcSpec.FuncMetaData.FunctionVersionURN) + return handler +} + +func (ch *CustomContainerHandler) initDsClient(apiClient api.LibruntimeAPI) error { + cr := apiClient.GetCredential() + funcInfo, err := urnutils.GetFunctionInfo(ch.funcSpec.FuncMetaData.FunctionVersionURN) + if err != nil { + return err + } + createConfig := api.ConnectArguments{ + Host: os.Getenv("HOST_IP"), + Port: 31501, + TimeoutMs: 60 * 1000, // 60 seconds + TenantID: funcInfo.TenantID, + AccessKey: cr.AccessKey, + SecretKey: cr.SecretKey, + EnableCrossNodeConnection: true, + } + + client, err := apiClient.CreateClient(createConfig) + if err != nil { + return fmt.Errorf("creating ds client failed, err: %s", err) + } + ch.remoteDsClient = client + return nil +} + +func (ch *CustomContainerHandler) setStateMgr(mgr *stateManager, status status) { + ch.stateMgrLock.Lock() + ch.stateMgr = mgr + ch.stateMgrState = status + ch.stateMgrLock.Unlock() +} + +func (ch *CustomContainerHandler) getStateMgr() *stateManager { + if ch.stateMgrState == failed { + logger.GetLogger().Errorf("stateMgrState is failed!") + return nil + } else if ch.stateMgrState == completed { + return ch.stateMgr + } + for i := 0; i < 60; i++ { // retry times 60 + ch.stateMgrLock.RLock() + if ch.stateMgr != nil { + ch.stateMgrLock.RUnlock() + break + } + ch.stateMgrLock.RUnlock() + logger.GetLogger().Warnf("get state mgr = nil, times %d", i) + time.Sleep(1 * time.Second) // sleep 1 second + } + return ch.stateMgr +} + +// InitHandler will send init request to custom container http server +// there are some features (oom, log) will to be added to this method, so keep this InitHandler for now +func (ch *CustomContainerHandler) InitHandler(args []api.Arg, apiClient api.LibruntimeAPI) ([]byte, error) { + http.HandleFunc(invokeRoute, ch.handleInvoke) + http.HandleFunc(getFutureRoute, ch.handleGetFuture) + http.HandleFunc(stateNewRoute, ch.handleStateNew) + http.HandleFunc(stateGetRoute, ch.handleStateGet) + http.HandleFunc(stateDeleteRoute, ch.handleStateDel) + http.HandleFunc(circuitBreakRoute, ch.handleCircuitBreak) + http.HandleFunc(exitRoute, ch.handleExit) + http.HandleFunc(credentialRequireRoute, ch.handleCredentialRequire) + go func() { + logger.GetLogger().Infof("start invoke server") + err := ch.invokeServer.ListenAndServe() + if err != nil { + logger.GetLogger().Errorf("invoke server exit error %s", err.Error()) + } + }() + + if err := ch.setCustomUserArgs(args); err != nil { + logger.GetLogger().Errorf("set custom user args error:%s", err.Error()) + return utils.HandleInitResponse(fmt.Sprintf("set custom user args error: %s", err.Error()), constants.FaaSError) + } + ch.crossClusterInvoker.StsServerConfig = ch.customUserArgs.StsServerConfig + funcMonitor, err := monitor.CreateFunctionMonitor(ch.funcSpec, stopCh) + if err != nil { + logger.GetLogger().Errorf("create function monitor error:", err.Error()) + return utils.HandleInitResponse(fmt.Sprintf("create function monitor error: %s", err.Error()), constants.FaaSError) + } + ch.monitor = funcMonitor + + err = ch.setAlarmInfo() + if err != nil { + logger.GetLogger().Errorf("set alarm info error:", err.Error()) + } + common.GetTokenMgr().SetCallback(tokentosecret.GetSecretMgr().SetAuthContext) + go func() { + if err = ch.initDsClient(apiClient); err != nil { + logger.GetLogger().Errorf("create ds client fail: %s", err.Error()) + ch.setStateMgr(nil, failed) + } else { + logger.GetLogger().Infof("create ds client successfully") + ch.setStateMgr(GetStateManager(ch.remoteDsClient), completed) + } + }() + resp, err := ch.basicHandler.InitHandler(args, apiClient) + if err != nil { + return resp, err + } + + go ch.checkHealth() + return resp, err +} + +func (ch *CustomContainerHandler) setCustomUserArgs(args []api.Arg) error { + if len(args) != constants.ValidCustomImageCreateParamSize { + return errors.New("invalid create params number") + } + customUserArgs := types.CustomUserArgs{} + if err := json.Unmarshal(args[constants.CustomImageUserArgIndex].Data, &customUserArgs); err != nil { + return err + } + ch.customUserArgs = customUserArgs + return nil +} + +func (ch *CustomContainerHandler) setAlarmInfo() error { + logger.GetLogger().Infof("start to setAlarmInfo") + if &ch.customUserArgs == nil { + return errors.New("custom user args is empty") + } + if !ch.customUserArgs.AlarmConfig.EnableAlarm { + logger.GetLogger().Infof("enableAlarm is false") + return nil + } + ch.alarmConfig = ch.customUserArgs.AlarmConfig + alarm.SetClusterNameEnv(ch.customUserArgs.ClusterName) + alarm.SetAlarmEnv(ch.alarmConfig.AlarmLogConfig) + alarm.SetXiangYunFourConfigEnv(ch.alarmConfig.XiangYunFourConfig) + + if err := alarm.SetPodIP(); err != nil { + return err + } + logger.GetLogger().Infof("set alarm info ok") + return nil +} + +func (ch *CustomContainerHandler) checkHealth() { + if ch.funcSpec.ExtendedMetaData.CustomHealthCheck.TimeoutSeconds == 0 || + ch.funcSpec.ExtendedMetaData.CustomHealthCheck.PeriodSeconds == 0 || + ch.funcSpec.ExtendedMetaData.CustomHealthCheck.FailureThreshold == 0 { + logger.GetLogger().Info("user has not configured custom health check and will not enable it") + return + } + healthURL := fmt.Sprintf("%s/%s", ch.baseURL, ch.healthRoute) + request, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, healthURL, nil) + if err != nil { + logger.GetLogger().Errorf("failed to create request to health route error %s", err.Error()) + return + } + queries := request.URL.Query() + request.URL.RawQuery = queries.Encode() + request.Header.Add("Content-Type", "application/json") + if err := ch.checkRuntime(request, ch.funcSpec.ExtendedMetaData.CustomHealthCheck); err != nil { + logger.GetLogger().Errorf("check runtime health failed, err: %s", err) + ch.sdkClient.Exit(0, "") + } +} + +func (ch *CustomContainerHandler) checkRuntime(request *http.Request, check types.CustomHealthCheck) error { + log := logger.GetLogger().With(zap.Any("timeoutSecond", check.TimeoutSeconds), + zap.Any("periodSecond", check.PeriodSeconds), zap.Any("failureThreshold", check.FailureThreshold)) + log.Infof("start custom health check") + defer log.Infof("end custom health check") + timer := time.NewTimer(time.Duration(check.PeriodSeconds) * time.Second) + defer timer.Stop() + failCount := 0 + for { + res, err := ch.sendRequest(request, time.Duration(check.TimeoutSeconds)*time.Second) + if err != nil { + failCount++ + } else if res != nil && res.StatusCode != http.StatusOK { + err = fmt.Errorf("check health res status code is %d", res.StatusCode) + failCount++ + } + if res != nil && res.StatusCode == http.StatusOK { + failCount = 0 + } + + if failCount >= check.FailureThreshold { + log.Errorf("do custom health check failed, err: %s, failCount: %d", err.Error(), failCount) + return err + } + select { + case <-stopCh: + logger.GetLogger().Warnf("stop channel closed") + return nil + case <-timer.C: + timer.Reset(time.Duration(check.PeriodSeconds) * time.Second) + continue + } + } +} + +// ShutDownHandler handles shutdown +func (ch *CustomContainerHandler) ShutDownHandler(gracePeriodSecond uint64) error { + conf, err := config.GetConfig(ch.funcSpec) + if err != nil { + return err + } + functionlog.Sync(conf) + startGracefulShutdown := time.Now() + gracefulTime := uint64(ch.funcSpec.ExtendedMetaData.CustomGracefulShutdown.MaxShutdownTimeout) + if gracefulTime <= 0 { + gracefulTime = gracePeriodSecond + } + err = ch.basicHandler.ShutDownHandler(gracefulTime) + if err != nil { + logger.GetLogger().Errorf("failed to shutdown handler error %s", err.Error()) + } + leftTime := time.Duration(ch.funcSpec.ExtendedMetaData.CustomGracefulShutdown.MaxShutdownTimeout)*time.Second - + time.Since(startGracefulShutdown) + // custom container's runtime graceful shutdown + if ch.funcSpec.ExtendedMetaData.CustomGracefulShutdown.MaxShutdownTimeout > 0 && leftTime > 0 { + logger.GetLogger().Infof("wait runtime to shutdown gracefully") + ch.customGracefulShutdownNewVer(leftTime) + } + // close invoke server after custom graceful shutdown + // timeout set to 1 second because gracefulshutdown timeout is already handled above + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) // wait for 1 second + defer cancel() + err = ch.invokeServer.Shutdown(ctx) + if err != nil { + logger.GetLogger().Errorf("failed to close invoke server error %s", err.Error()) + } + logger.GetLogger().Infof("exit custom runtime") + return err +} + +func checkGracefulRsp(rsp *http.Response, err error) bool { + if err != nil { + if urlErr, ok := err.(*url.Error); ok { + if urlErr.Err.Error() == "EOF" { + logger.GetLogger().Infof("server is shut down") + return true + } + } + return false + } + logger.GetLogger().Infof("graceful http rsp code is %d", rsp.StatusCode) + if rsp.StatusCode == http.StatusOK || rsp.StatusCode == http.StatusNotFound { + return true + } + return false +} + +func (ch *CustomContainerHandler) customGracefulShutdownNewVer(duration time.Duration) { + t := time.NewTicker(duration) + defer t.Stop() + rspChan := make(chan struct{}) + stopChan := make(chan struct{}) + go func() { + client := http.Client{ + Timeout: time.Duration(duration) * time.Second, + } + for { + select { + case <-stopChan: + return + default: + rsp, err := client.Get("http://127.0.0.1:8000/shutdown") + if checkGracefulRsp(rsp, err) { + rspChan <- struct{}{} + return + } + time.Sleep(CustomGracefulShutdownPeriod * time.Second) + } + } + }() + + select { + case <-t.C: + logger.GetLogger().Infof("Time to graceful shutdown") + close(stopChan) + return + case <-rspChan: + logger.GetLogger().Infof("stop by http") + return + } +} + +// CallHandler handles call +func (ch *CustomContainerHandler) CallHandler(args []api.Arg, ctx map[string]string) ([]byte, error) { + ch.basicHandler.logger = ch.logger + return ch.basicHandler.CallHandler(args, ctx) +} + +func (ch *CustomContainerHandler) handleInvoke(w http.ResponseWriter, r *http.Request) { + logger.GetLogger().Infof("handle function invoke from function %s", ch.funcSpec.FuncMetaData.FunctionName) + response := types.InvokeResponse{ + StatusCode: constants.NoneError, + } + defer handleResponse(&response, w, invokeRoute) + data, err := ioutil.ReadAll(r.Body) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to read request body error %s", err.Error()) + return + } + request := types.InvokeRequest{} + err = json.Unmarshal(data, &request) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to unmarshal request body error %s", err.Error()) + return + } + objectID := uuid.New().String() + ch.Lock() + ch.futureMap[objectID] = make(chan types.GetFutureResponse, 1) + ch.Unlock() + if request.StateID == "" { + err = ch.processInvoke(request, objectID) + } else { + err = ch.processInvokeByState(request, objectID) + } + + if err != nil { + if snErr, ok := err.(api.ErrorInfo); ok { + logger.GetLogger().Errorf("invoke funcKey %s err type is SNError, code is %d, msg is %s, traceid: %s", + request.FuncName, snErr.Code, snErr.Error(), request.TraceID) + response.StatusCode = snErr.Code + response.ErrorMessage = snErr.Error() + } else { + response.StatusCode = constants.FaaSError + } + ch.Lock() + delete(ch.futureMap, objectID) + ch.Unlock() + return + } + response.ObjectID = objectID +} + +func (ch *CustomContainerHandler) processInvoke(request types.InvokeRequest, objectID string) error { + logWith := logger.GetLogger().With(zap.Any("function", ch.funcSpec.FuncMetaData.FunctionName), + zap.Any("objectID", objectID), zap.Any("traceID", request.TraceID)) + logWith.Infof("processing function invoke, request %+v", request) + funcKey, funcUrn, err := ch.getInvokeFuncKey(request.FuncName, request.FuncVersion, request.Params) + if err != nil { + return err + } + request.FuncUrn = funcUrn + logWith = logWith.With(zap.Any("funcKey", funcKey)) + logWith.Infof("main function start to call function:%s", funcKey) + start := time.Now() + acquireTimeout := int(request.AcquireTimeout) + defaultTimeout := 120 // 默认120s + oldAcquireTimeout := acquireTimeout + crossClusterIsUpgrading := false + if acquireTimeout == 0 && ch.crossClusterInvoker.InvokeConfig.Enable { + acquireTimeout = ch.crossClusterInvoker.AcquireTimeout + } else if acquireTimeout == 0 { + acquireTimeout = defaultTimeout + } + functionMeta := api.FunctionMeta{FuncID: funcKey, Api: api.FaaSApi} + arg, err := prepareInvokeArg(request, ch.disableAPIGFormat) + if err != nil { + return err + } + go func() { + wait := make(chan struct{}, 1) + response := types.GetFutureResponse{ + StatusCode: constants.NoneError, + } + defer func() { + ch.RLock() + futureCh, exist := ch.futureMap[objectID] + ch.RUnlock() + if !exist { + logWith.Errorf("future channel doesn't exist") + return + } + if futureCh != nil { + futureCh <- response + } + }() + leftRetryTimes := maxInvokeRetries + for { + trafficLimit := false + shouldRetry := false + shouldCrossClusterInvoke := false + invokeOptions := api.InvokeOptions{ + SchedulerFunctionID: faasscheduler.Proxy.GetSchedulerFuncKey(), + RetryTimes: leftRetryTimes, + TraceID: request.TraceID, + Timeout: int(request.Timeout), + AcquireTimeout: acquireTimeout, + } + returnObjectID, InvokeErr := ch.sdkClient.InvokeByFunctionName(functionMeta, arg, invokeOptions) + if InvokeErr != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("invoke funcKey %s,return error %s", funcKey, InvokeErr.Error()) + return + } + ch.sdkClient.GetAsync(returnObjectID, func(result []byte, err error) { + defer func() { + wait <- struct{}{} + }() + if _, decreaseErr := ch.sdkClient.GDecreaseRef([]string{returnObjectID}); decreaseErr != nil { + fmt.Printf("failed to decrease object ref,err: %s", decreaseErr.Error()) + } + if checkTrafficLimitResp(result) { + trafficLimit = true + return + } + if err != nil { + logWith.Errorf("invoke function err, %s", err.Error()) + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("invoke funcKey %s,return error %s traceID: %s", + funcKey, err.Error(), request.TraceID) + if err != nil && ch.crossClusterInvoker.NeedCrossClusterInvoke(err) && !crossClusterIsUpgrading { + shouldCrossClusterInvoke = true + return + } + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("invoke funcKey %s,return error %s", funcKey, err.Error()) + if snErr, ok := err.(api.ErrorInfo); ok { + logWith.Errorf("invoke err type is SNError, code is %d, msg is %s", + snErr.Code, snErr.Error()) + response.StatusCode = snErr.Code + response.ErrorMessage = snErr.Error() + } else { + logWith.Debugf("invoke err type is not SNError") + } + return + } + handleGetFutureResponse(result, &response, ch.disableAPIGFormat) + }) + <-wait + if trafficLimit && request.StateID == "" { // invoking by stateID don't do trafficlimit + continue + } + if shouldCrossClusterInvoke { + timeout := time.Duration(request.Timeout)*time.Second - time.Now().Sub(start) + crossClusterIsUpgrading = ch.crossClusterInvoker.DoInvoke(request, &response, timeout, logWith) + if !crossClusterIsUpgrading { + break + } + invokeOptions.Timeout = oldAcquireTimeout + shouldRetry = true + leftRetryTimes-- + } + if shouldRetry && leftRetryTimes > 0 { + continue + } + break + } + }() + return nil +} + +func isFaaSSchedulerStateErrorCode(errCode int) bool { + faasschedulerErrorCodes := []int{StateInstanceNotExistedErrCode, StateInstanceNoLease, FaaSSchedulerInternalErrCode} + for _, code := range faasschedulerErrorCodes { + if errCode == code { + return true + } + } + return false +} + +func (ch *CustomContainerHandler) processInvokeByState(request types.InvokeRequest, objectID string) error { + loggerWith := logger.GetLogger().With(zap.Any("traceID", request.TraceID), zap.Any("stateID", request.StateID), + zap.Any("funcName", request.FuncName), zap.Any("objectID", objectID)) + loggerWith.Infof("processing function state invoke from function %s, request %v", + ch.funcSpec.FuncMetaData.FunctionName, request) + funcKey, _, err := ch.getInvokeFuncKey(request.FuncName, request.FuncVersion, request.Params) + if err != nil { + return err + } + arg, err := prepareInvokeArg(request, ch.disableAPIGFormat) + if err != nil { + return err + } + schedulerID, err := faasscheduler.Proxy.Get(funcKey) + if err != nil { + return err + } + functionMeta := api.FunctionMeta{FuncID: funcKey, Api: api.FaaSApi} + option := api.InvokeOptions{ + SchedulerFunctionID: faasscheduler.Proxy.GetSchedulerFuncKey(), + SchedulerInstanceIDs: []string{schedulerID}, + TraceID: request.TraceID, + Timeout: 120, // 120 seconds + AcquireTimeout: 120, + } + stateMgr := ch.getStateMgr() + if stateMgr == nil { + loggerWith.Errorf("stateMgr is nil!!") + return errors.New("stateMgr is nil") + } + lease, err := ch.sdkClient.AcquireInstance(request.StateID, functionMeta, option) + loggerWith.Debugf("get lease err = %v %T", err, err) + if snErr, ok := err.(api.ErrorInfo); ok { + loggerWith.Errorf("invoke funcKey %s err type is SNError, code is %d, msg is %s", + funcKey, snErr.Code, snErr.Error()) + if isFaaSSchedulerStateErrorCode(snErr.Code) { + stateMgr.delInstance(funcKey, request.StateID) + return err + } + } + if err != nil { + leasePtr := stateMgr.getInstance(funcKey, request.StateID) + if leasePtr == nil { + loggerWith.Errorf("failed to get lease, err: %s", err.Error()) + return err + } + lease = *leasePtr + loggerWith.Warnf("failed to get lease, err: %s, to use cached lease %s", err.Error(), lease.InstanceID) + } else { + stateMgr.addInstance(funcKey, request.StateID, &lease) + loggerWith.Infof("aquire lease ok: %s", lease.InstanceID) + } + + go func() { + wait := make(chan struct{}, 1) + response := types.GetFutureResponse{ + StatusCode: constants.NoneError, + } + defer func() { + ch.RLock() + futureCh, exist := ch.futureMap[objectID] + ch.RUnlock() + if !exist { + loggerWith.Errorf("future channel for objectID %s doesn't exist for function %s traceID %s", + ch.funcSpec.FuncMetaData.FunctionName, request.TraceID) + return + } + if futureCh != nil { + futureCh <- response + } else { + loggerWith.Errorf("futureCh is nil") + } + }() + option.Timeout = int(request.Timeout) + returnObjectID, invokeErr := ch.sdkClient.InvokeByInstanceId(functionMeta, lease.InstanceID, arg, option) + if invokeErr != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("invoke funcKey %s,return error %s", funcKey, invokeErr.Error()) + return + } + ch.sdkClient.GetAsync(returnObjectID, func(result []byte, cbErr error) { + defer func() { + wait <- struct{}{} + }() + loggerWith.Infof("finish invoke, err: %v, result len = %d", cbErr, len(result)) + if cbErr != nil { + if snErr, ok := cbErr.(api.ErrorInfo); ok { + loggerWith.Errorf("invoke funcKey %s err type is SNError, code is %d, msg is %s", + funcKey, snErr.Code, snErr.Error()) + response.StatusCode = snErr.Code + response.ErrorMessage = snErr.Error() + } else { + loggerWith.Errorf("invoke funcKey %s err type is not SNError: %s", funcKey, err.Error()) + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("invoke funcKey %s,return error %s", funcKey, err.Error()) + } + return + } + loggerWith.Debugf("inovke result: %v", result) + handleGetFutureResponse(result, &response, ch.disableAPIGFormat) + }) + <-wait + ch.sdkClient.ReleaseInstance(lease, "", false, option) + }() + return nil +} + +func (ch *CustomContainerHandler) handleStateNew(w http.ResponseWriter, r *http.Request) { + logger.GetLogger().Infof("handle state new from function %s", ch.funcSpec.FuncMetaData.FunctionName) + response := types.StateResponse{ + StatusCode: constants.NoneError, + } + defer handleResponse(&response, w, stateNewRoute) + data, err := ioutil.ReadAll(r.Body) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to read request body error %s", err.Error()) + return + } + logger.GetLogger().Debugf("state new req data is %s", data) + request := types.StateRequest{} + err = json.Unmarshal(data, &request) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to unmarshal request body error %s", err.Error()) + return + } + logger := logger.GetLogger().With(zap.Any("traceID", request.TraceID)) + err = ch.processState(request, "new") + if err != nil { + if snErr, ok := err.(api.ErrorInfo); ok { + logger.Errorf("state new rsp err type is SNError, code is %d, msg is %s", + snErr.Code, snErr.Error()) + response.StatusCode = snErr.Code + response.ErrorMessage = snErr.Error() + } else { + logger.Errorf("state new req %v rsp err type is not SNError", request) + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed process state new error %s", err.Error()) + } + return + } + response.StateID = request.StateID +} + +func (ch *CustomContainerHandler) handleStateGet(w http.ResponseWriter, r *http.Request) { + logger.GetLogger().Infof("handle state get from function %s", ch.funcSpec.FuncMetaData.FunctionName) + response := types.StateResponse{ + StatusCode: constants.NoneError, + } + defer handleResponse(&response, w, stateGetRoute) + data, err := ioutil.ReadAll(r.Body) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to read request body error %s", err.Error()) + return + } + logger.GetLogger().Debugf("state get req data is %s", data) + request := types.StateRequest{} + err = json.Unmarshal(data, &request) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to unmarshal request body error %s", err.Error()) + return + } + logger := logger.GetLogger().With(zap.Any("traceID", request.TraceID)) + err = ch.processState(request, "get") + if err != nil { + if snErr, ok := err.(api.ErrorInfo); ok { + logger.Errorf("state get rsp err type is SNError, code is %d, msg is %s", + snErr.Code, snErr.Error()) + response.StatusCode = snErr.Code + response.ErrorMessage = snErr.Error() + } else { + logger.Errorf("state get req %v rsp err type is not SNError", request) + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed process state new error %s", err.Error()) + } + return + } + response.StateID = request.StateID +} + +func (ch *CustomContainerHandler) handleStateDel(w http.ResponseWriter, r *http.Request) { + logger.GetLogger().Infof("handle function invoke from function %s", ch.funcSpec.FuncMetaData.FunctionName) + response := types.TerminateResponse{ + StatusCode: constants.NoneError, + } + defer handleResponse(&response, w, stateDeleteRoute) + data, err := ioutil.ReadAll(r.Body) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to read request body error %s", err.Error()) + return + } + logger.GetLogger().Infof("state del req data is %s", data) + request := types.StateRequest{} + err = json.Unmarshal(data, &request) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to unmarshal request body error %s", err.Error()) + return + } + logger := logger.GetLogger().With(zap.Any("traceID", request.TraceID)) + err = ch.processState(request, "del") + logger.Infof("processState del err: %v, %T", err, err) + if err != nil { + if snErr, ok := err.(api.ErrorInfo); ok { + logger.Errorf("state del rsp err type is SNError, code is %d, msg is %s", + snErr.Code, snErr.Error()) + response.StatusCode = snErr.Code + response.ErrorMessage = snErr.Error() + } else { + logger.Errorf("state del rsp err type is not SNError") + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed process state new error %s", err.Error()) + } + return + } + objectID := uuid.New().String() + furtureRspChan := make(chan types.GetFutureResponse, 1) + ch.Lock() + ch.futureMap[objectID] = furtureRspChan + ch.Unlock() + logger.Infof("set object %s to futureMap", objectID) + response.ObjectID = objectID + furtureRspChan <- types.GetFutureResponse{ + StatusCode: constants.NoneError, + Content: `{"result": "DELETE SUCCESSFULLY"}`, + } +} + +func (ch *CustomContainerHandler) handleExit(w http.ResponseWriter, r *http.Request) { + logger.GetLogger().Infof("handle exit") + body, err := io.ReadAll(r.Body) + if err != nil { + logger.GetLogger().Errorf("read body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + logger.GetLogger().Infof("get exit body %s", string(body)) + req := &types.ExitRequest{} + if len(body) > 0 { + err = json.Unmarshal(body, req) + if err != nil { + logger.GetLogger().Errorf("unmarshal exit body failed: %v", err) + } + } + response := types.ExitResponse{ + StatusCode: constants.NoneError, + } + handleResponse(&response, w, exitRoute) + go ch.sdkClient.Exit(req.Code, req.Message) +} + +func (ch *CustomContainerHandler) handleCredentialRequire(w http.ResponseWriter, r *http.Request) { + logger.GetLogger().Infof("handle credential require") + response := types.CredentialResponse{ + StatusCode: constants.NoneError, + } + credential := ch.sdkClient.GetCredential() + response.Credential = credential + handleResponse(&response, w, credentialRequireRoute) + +} + +func (ch *CustomContainerHandler) handleCircuitBreak(w http.ResponseWriter, r *http.Request) { + logger.GetLogger().Infof("handle circuit break") + body, err := io.ReadAll(r.Body) + if err != nil { + logger.GetLogger().Errorf("read body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + req := &types.CircuitBreakRequest{} + err = json.Unmarshal(body, req) + if err != nil { + logger.GetLogger().Errorf("unmarshal body failed: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + response := types.ExitResponse{ + StatusCode: constants.NoneError, + } + ch.circuitLock.Lock() + ch.circuitBreaker = req.Switch + ch.circuitLock.Unlock() + logger.GetLogger().Infof("circuit break flag set to %v", req.Switch) + handleResponse(&response, w, circuitBreakRoute) +} + +func (ch *CustomContainerHandler) releaseState(funcKey, stateID, traceID string) error { + schedulerID, err := faasscheduler.Proxy.Get(funcKey) + if err != nil { + return err + } + option := api.InvokeOptions{ + SchedulerFunctionID: faasscheduler.Proxy.GetSchedulerFuncKey(), + SchedulerInstanceIDs: []string{schedulerID}, + TraceID: traceID, + } + instanceAllocation := api.InstanceAllocation{ + FuncKey: funcKey, + } + ch.sdkClient.ReleaseInstance(instanceAllocation, fmt.Sprintf("%s;%s", funcKey, stateID), + false, option) + return nil +} + +func (ch *CustomContainerHandler) processState(request types.StateRequest, opType string) error { + funcKey, _, err := ch.getInvokeFuncKey(request.FuncName, request.FuncVersion, request.Params) + if err != nil { + return err + } + if request.StateID == "" { + return api.ErrorInfo{ + Code: InvalidState, + Err: fmt.Errorf(InvalidStateErrMsg), + } + } + stateMgr := ch.getStateMgr() + if stateMgr == nil { + logger.GetLogger().Errorf("stateMgr is nil!!") + return errors.New("stateMgr is nil") + } + switch opType { + case "new": + err = stateMgr.newState(funcKey, request.StateID, request.TraceID) + case "get": + err = stateMgr.getState(funcKey, request.StateID, request.TraceID) + case "del": + err = stateMgr.delState(funcKey, request.StateID, request.TraceID) + if err == nil { + logger.GetLogger().Infof("releaseState to faasscheduler") + err = ch.releaseState(funcKey, request.StateID, request.TraceID) + if err != nil { + alarmDetail := &alarm.Detail{ + SourceTag: os.Getenv(constants.PodNameEnvKey) + "|" + os.Getenv(constants.PodIPEnvKey) + + "|" + os.Getenv(constants.ClusterName) + "|" + os.Getenv(constants.HostIPEnvKey), + OpType: alarm.GenerateAlarmLog, + Details: fmt.Sprintf("terminate State failed, faasscheduler error: %v, "+ + "stateKey: %s, statefuncKey: %s", err, request.StateID, funcKey), + StartTimestamp: int(time.Now().Unix()), + EndTimestamp: 0, + } + alarmInfo := &alarm.LogAlarmInfo{ + AlarmID: "TerminateStateForFaasScheduler00001", + AlarmName: "TerminateStateForFaasScheduler", + AlarmLevel: alarm.Level2, + } + alarm.ReportOrClearAlarm(alarmInfo, alarmDetail) + } + logger.GetLogger().Debugf("releaseState to faasscheduler, err: %v, %T", err, err) + } + default: + err = errors.New(fmt.Sprintf("unknow opType %s", opType)) + } + return err +} + +func checkTrafficLimitResp(notifyMsg []byte) bool { + if notifyMsg != nil && len(notifyMsg) != 0 { + var insResponse struct { + ErrorCode int `json:"errorCode"` + ErrorMessage string `json:"errorMessage"` + } + if unMarshalErr := json.Unmarshal(notifyMsg, &insResponse); unMarshalErr != nil { + logger.GetLogger().Errorf("unmarshal notifyMsg error : %s", unMarshalErr.Error()) + } + // current faasscheduler has reached instance limit, should retry and chose another faasscheduler + return insResponse.ErrorCode == constants.AcquireLeaseTrafficLimitErrorCode + } + return false +} + +func handleGetFutureResponse(result []byte, response *types.GetFutureResponse, disableAPIGFormat bool) { + callResponse := &types.CallResponse{} + err := json.Unmarshal(result, callResponse) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("call response unmarshal error %s", err.Error()) + return + } + if len(callResponse.InnerCode) != 0 { + innerCode, err := strconv.Atoi(callResponse.InnerCode) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to get the innerCode, err: %s", err) + return + } + if innerCode != constants.NoneError { + response.StatusCode = innerCode + response.ErrorMessage = string(callResponse.Body) + return + } + } + if disableAPIGFormat { + response.Content = string(callResponse.Body) + logger.GetLogger().Infof("set rsp content: %s", response.Content) + return + } + apigResponse := APIGTriggerResponse{} + err = json.Unmarshal(callResponse.Body, &apigResponse) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to unmarshal http response error %s", err.Error()) + return + } + content := apigResponse.Body + if apigResponse.IsBase64Encoded { + decodeContent, err := base64.StdEncoding.DecodeString(apigResponse.Body) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to decode http response error %s", err.Error()) + return + } + content = string(decodeContent) + } + if apigResponse.StatusCode != http.StatusOK { + response.StatusCode = constants.FunctionRunError + response.ErrorMessage = content + } else { + response.Content = content + logger.GetLogger().Infof("set rsp content: %s", response.Content) + } +} + +func prepareInvokeArg(request types.InvokeRequest, disableAPIGFormat bool) ([]api.Arg, error) { + requestData, err := getCallRequestBody(request, disableAPIGFormat) + if err != nil { + return nil, err + } + callRequest := &types.CallRequest{ + Body: requestData, + Header: map[string]string{ + constants.CaaSTraceIDHeaderKey: request.TraceID, + constants.CffRequestIDHeaderKey: request.TraceID, + "Content-Type": "application/json", + }, + } + data, err := json.Marshal(callRequest) + if err != nil { + return nil, err + } + arg := []api.Arg{ + { + Type: api.Value, + Data: []byte(request.TraceID), + }, + { + Type: api.Value, + Data: data, + }, + } + return arg, nil +} + +func getCallRequestBody(request types.InvokeRequest, disableAPIGFormat bool) ([]byte, error) { + if disableAPIGFormat { + return []byte(request.Payload), nil + } + apigRequestData, err := buildAPIGRequestData(request.Payload) + if err != nil { + return nil, errors.New(fmt.Sprintf("failed to build APIG request error %s", err.Error())) + } + return apigRequestData, nil +} + +func (ch *CustomContainerHandler) getInvokeFuncKey(funcName, funcVer string, + params map[string]string) (string, string, error) { + funcUrnPrefixIndex := strings.LastIndex(ch.funcSpec.FuncMetaData.FunctionVersionURN, funcNameSeparator) + funcUrnPathPrefix := ch.funcSpec.FuncMetaData.FunctionVersionURN[:funcUrnPrefixIndex] + aliasUrn := fmt.Sprintf("%s@%s:%s", funcUrnPathPrefix, funcName, funcVer) + funcUrn := aliasroute.GetAliases().GetFuncVersionURNWithParams(aliasUrn, params) + functionInfo, err := urnutils.GetFunctionInfo(funcUrn) + if err != nil { + return "", "", err + } + funcKey := urnutils.CombineFunctionKey(functionInfo.TenantID, functionInfo.Name, functionInfo.Version) + return funcKey, funcUrn, err +} + +func (ch *CustomContainerHandler) handleGetFuture(w http.ResponseWriter, r *http.Request) { + response := types.GetFutureResponse{} + defer handleResponse(&response, w, getFutureRoute) + data, err := ioutil.ReadAll(r.Body) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to read future request body error %s", err.Error()) + return + } + logger.GetLogger().Infof("state getfuture req data is %s", data) + request := types.GetFutureRequest{} + err = json.Unmarshal(data, &request) + if err != nil { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("failed to read future request body error %s", err.Error()) + return + } + logger.GetLogger().Infof("getfuture objID is %s", request.ObjectID) + ch.RLock() + futureCh, exist := ch.futureMap[request.ObjectID] + ch.RUnlock() + if !exist { + response.StatusCode = constants.FaaSError + response.ErrorMessage = fmt.Sprintf("objectID %s doesn't exit", request.ObjectID) + return + } + if futureCh != nil { + response = <-futureCh + } + ch.Lock() + delete(ch.futureMap, request.ObjectID) + ch.Unlock() +} + +func handleResponse(response types.Response, w http.ResponseWriter, handleType string) { + logger.GetLogger().Infof("handle %s response code %d message %s", handleType, response.GetStatusCode(), + response.GetErrorMessage()) + if response.GetStatusCode() != constants.NoneError { + logger.GetLogger().Errorf("handle %s error %s", handleType, response.GetErrorMessage()) + } + data, err := json.Marshal(response) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + logger.GetLogger().Errorf("failed to marshal %s response error %s", handleType, err.Error()) + if _, err = w.Write([]byte("failed to marshal response")); err != nil { + logger.GetLogger().Errorf("failed to write marshal error %s", err.Error()) + } + return + } + if _, err = w.Write(data); err != nil { + logger.GetLogger().Errorf("failed to write %s response data error %s", handleType, err.Error()) + } +} + +func buildAPIGRequestData(payload string) ([]byte, error) { + apigRequest := APIGTriggerEvent{ + Path: customContainerCallPath, + Body: payload, + IsBase64Encoded: false, + } + data, err := json.Marshal(apigRequest) + if err != nil { + return nil, err + } + return data, nil +} + +// HealthCheckHandler health check +func (ch *CustomContainerHandler) HealthCheckHandler() (api.HealthType, + error) { + code := api.Healthy + ch.circuitLock.RLock() + defer ch.circuitLock.RUnlock() + if ch.circuitBreaker { + code = api.SubHealth + return code, nil + } + return code, nil +} diff --git a/api/go/faassdk/handler/http/custom_container_handler_test.go b/api/go/faassdk/handler/http/custom_container_handler_test.go new file mode 100644 index 0000000..c575622 --- /dev/null +++ b/api/go/faassdk/handler/http/custom_container_handler_test.go @@ -0,0 +1,1937 @@ +package http + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "reflect" + "strings" + "testing" + "time" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + FakeHTTP "github.com/stretchr/testify/http" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong.org/kernel/runtime/faassdk/common/alarm" + "yuanrong.org/kernel/runtime/faassdk/common/faasscheduler" + "yuanrong.org/kernel/runtime/faassdk/common/monitor" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/faassdk/handler/http/crossclusterinvoke" + "yuanrong.org/kernel/runtime/faassdk/sts" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/faassdk/utils/urnutils" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/uuid" +) + +type fakeSnError struct { + ErrorCode int + ErrorMessage string +} + +// Code Returned error code +func (s *fakeSnError) Code() int { + return s.ErrorCode +} + +// Error Implement the native error interface. +func (s *fakeSnError) Error() string { + return s.ErrorMessage +} + +type fakeSDKClient struct { + returnErr bool +} + +func (f *fakeSDKClient) CreateInstance(funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + InstanceID := uuid.New().String() + return InstanceID, nil +} + +func (f *fakeSDKClient) InvokeByInstanceId(funcMeta api.FunctionMeta, instanceID string, args []api.Arg, invokeOpt api.InvokeOptions) (returnObjectID string, err error) { + //TODO implement me + return "", nil +} + +func (f *fakeSDKClient) InvokeByFunctionName(funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + return "success", nil +} + +func (f *fakeSDKClient) AcquireInstance(state string, funcMeta api.FunctionMeta, acquireOpt api.InvokeOptions) (api.InstanceAllocation, error) { + return api.InstanceAllocation{}, nil +} + +func (f *fakeSDKClient) ReleaseInstance(allocation api.InstanceAllocation, stateID string, abnormal bool, option api.InvokeOptions) { + //TODO implement me + return +} + +func (f *fakeSDKClient) Kill(instanceID string, signal int, payload []byte) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) CreateInstanceRaw(createReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) InvokeByInstanceIdRaw(invokeReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KillRaw(killReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) Finalize() { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KVSet(key string, value []byte, param api.SetParam) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KVSetWithoutKey(value []byte, param api.SetParam) (string, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KVMSetTx(keys []string, values [][]byte, param api.MSetParam) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KVGet(key string, timeoutms uint) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KVGetMulti(keys []string, timeoutms uint) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KVQuerySize(keys []string) ([]uint64, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KVDel(key string) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) KVDelMulti(keys []string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) CreateProducer(streamName string, producerConf api.ProducerConf) (api.StreamProducer, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) Subscribe(streamName string, config api.SubscriptionConfig) (api.StreamConsumer, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) DeleteStream(streamName string) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) QueryGlobalProducersNum(streamName string) (uint64, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) QueryGlobalConsumersNum(streamName string) (uint64, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) SetTraceID(traceID string) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) SetTenantID(tenantID string) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) Put(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) PutRaw(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) Get(objectIDs []string, timeoutMs int) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) GetRaw(objectIDs []string, timeoutMs int) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) Wait(objectIDs []string, waitNum uint64, timeoutMs int) ([]string, []string, map[string]error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) GIncreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) GIncreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) GDecreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return nil, nil +} + +func (f *fakeSDKClient) GDecreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + panic("implement me") +} + +func (f *fakeSDKClient) GetAsync(objectID string, cb api.GetAsyncCallback) { + cb([]byte("success"), nil) +} + +func (f *fakeSDKClient) GetFormatLogger() api.FormatLogger { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) CreateClient(config api.ConnectArguments) (api.KvClient, error) { + return &FakeDataSystemClinet{}, nil +} + +func (f *fakeSDKClient) ReleaseGRefs(remoteClientID string) error { + //TODO implement me + panic("implement me") +} + +func (f *fakeSDKClient) SaveState(state []byte) (string, error) { + return "", nil +} + +func (f *fakeSDKClient) LoadState(checkpointID string) ([]byte, error) { + return nil, nil +} + +func (f *fakeSDKClient) Exit(code int, message string) { + return +} + +func (f *fakeSDKClient) GetCredential() api.Credential { + return api.Credential{} +} + +func (f *fakeSDKClient) UpdateSchdulerInfo(schedulerName string, schedulerId string, option string) { + return +} + +func (f *fakeSDKClient) IsHealth() bool { + return true +} + +func (f *fakeSDKClient) IsDsHealth() bool { + return true +} + +func newFuncSpec() *types.FuncSpec { + return &types.FuncSpec{ + FuncMetaData: types.FuncMetaData{ + FunctionName: "test-future-fuction", + Runtime: "go1.x", + TenantId: "123456789", + Version: "$latest", + Timeout: 3, + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:0@yrservice@test-image-env", + }, + ResourceMetaData: types.ResourceMetaData{}, + ExtendedMetaData: types.ExtendedMetaData{ + Initializer: types.Initializer{Timeout: 3}, + LogTankService: types.LogTankService{ + GroupID: "gid", + StreamID: "sid", + }, + CustomGracefulShutdown: types.CustomGracefulShutdown{ + MaxShutdownTimeout: 5, + }, + }, + } +} + +func newFuncSpecWithHealthCheck() *types.FuncSpec { + return &types.FuncSpec{ + FuncMetaData: types.FuncMetaData{ + FunctionName: "test-future-fuction", + Runtime: "go1.x", + TenantId: "123456789", + Version: "$latest", + Timeout: 3, + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:0@yrservice@test-image-env", + }, + ResourceMetaData: types.ResourceMetaData{}, + ExtendedMetaData: types.ExtendedMetaData{ + Initializer: types.Initializer{Timeout: 3}, + CustomHealthCheck: types.CustomHealthCheck{ + TimeoutSeconds: 30, + PeriodSeconds: 5, + FailureThreshold: 1, + }, + }, + } +} + +func newFuncSpecWithoutTimeout() *types.FuncSpec { + return &types.FuncSpec{ + FuncMetaData: types.FuncMetaData{ + FunctionName: "test-future-fuction", + Runtime: "go1.x", + TenantId: "123456789", + Version: "$latest", + Timeout: 0, + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:0@yrservice@test-image-env", + }, + ResourceMetaData: types.ResourceMetaData{}, + ExtendedMetaData: types.ExtendedMetaData{ + Initializer: types.Initializer{Timeout: 0}, + LogTankService: types.LogTankService{ + GroupID: "gid", + StreamID: "sid", + }, + CustomGracefulShutdown: types.CustomGracefulShutdown{ + MaxShutdownTimeout: 5, + }, + }, + } +} + +func newHttpCreateParams() *types.HttpCreateParams { + return &types.HttpCreateParams{ + Port: 33333, + InitRoute: "initRoute", + CallRoute: "callRote", + } +} + +func newHttpSchedulerParams() *faasscheduler.SchedulerInfo { + return &faasscheduler.SchedulerInfo{ + SchedulerFuncKey: "scheduler func key", + SchedulerIDList: []string{"1111"}, + } +} + +func TestNewCustomContainerHandler(t *testing.T) { + convey.Convey("NewCustomContainerHandler", t, func() { + handler := NewCustomContainerHandler(newFuncSpec(), nil) + convey.So(handler, convey.ShouldNotBeNil) + }) +} + +func TestHandleGetFutureResponse(t *testing.T) { + convey.Convey("Test Handle GetFutureResponse", t, func() { + convey.Convey("call response unmarshal error", func() { + response := &types.GetFutureResponse{} + handleGetFutureResponse([]byte("result"), response, false) + convey.So(response.StatusCode, convey.ShouldEqual, constants.FaaSError) + }) + convey.Convey("failed to get the innerCode", func() { + callResponse := &types.CallResponse{ + InnerCode: "aaa", + } + result, _ := json.Marshal(callResponse) + response := &types.GetFutureResponse{} + handleGetFutureResponse(result, response, false) + convey.So(response.StatusCode, convey.ShouldEqual, constants.FaaSError) + }) + convey.Convey("inner code not zero", func() { + callResponse := &types.CallResponse{ + InnerCode: "1", + } + result, _ := json.Marshal(callResponse) + response := &types.GetFutureResponse{} + handleGetFutureResponse(result, response, false) + convey.So(response.StatusCode, convey.ShouldEqual, 1) + }) + convey.Convey("unmarshal http response error", func() { + callResponse := &types.CallResponse{ + InnerCode: "0", + Body: []byte("aaa"), + } + result, _ := json.Marshal(callResponse) + response := &types.GetFutureResponse{} + handleGetFutureResponse(result, response, false) + convey.So(response.StatusCode, convey.ShouldEqual, constants.FaaSError) + }) + convey.Convey("response ok", func() { + apigResponse := APIGTriggerResponse{ + IsBase64Encoded: true, + Body: "b2s=", + StatusCode: http.StatusOK, + } + body, _ := json.Marshal(apigResponse) + callResponse := &types.CallResponse{ + InnerCode: "0", + Body: body, + } + result, _ := json.Marshal(callResponse) + response := &types.GetFutureResponse{} + handleGetFutureResponse(result, response, false) + convey.So(response.Content, convey.ShouldEqual, "ok") + }) + convey.Convey("response not ok", func() { + apigResponse := APIGTriggerResponse{ + IsBase64Encoded: true, + Body: "b2s=", + StatusCode: http.StatusBadRequest, + } + body, _ := json.Marshal(apigResponse) + callResponse := &types.CallResponse{ + InnerCode: "0", + Body: body, + } + result, _ := json.Marshal(callResponse) + response := &types.GetFutureResponse{} + handleGetFutureResponse(result, response, false) + convey.So(response.StatusCode, convey.ShouldEqual, constants.FunctionRunError) + }) + convey.Convey("response base64 decode failed", func() { + apigResponse := APIGTriggerResponse{ + IsBase64Encoded: true, + Body: "????", + StatusCode: http.StatusOK, + } + body, _ := json.Marshal(apigResponse) + callResponse := &types.CallResponse{ + InnerCode: "0", + Body: body, + } + result, _ := json.Marshal(callResponse) + response := &types.GetFutureResponse{} + handleGetFutureResponse(result, response, false) + convey.So(response.StatusCode, convey.ShouldEqual, constants.FaaSError) + }) + }) +} + +func TestHandleInvoke(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + config: &config.Configuration{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + convey.Convey("Test Handle Invoke", t, func() { + convey.Convey("read request body error", func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte{}, errors.New("read error") + }) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(w.StatusCode, convey.ShouldEqual, http.StatusOK) + }) + convey.Convey("unmarshal request body error", func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte("aaa"), nil + }) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(w.StatusCode, convey.ShouldEqual, http.StatusOK) + }) + convey.Convey("process invoke error", func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + invokeRequest := types.InvokeRequest{} + data, _ := json.Marshal(invokeRequest) + return data, nil + }) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(w.StatusCode, convey.ShouldEqual, http.StatusOK) + }) + convey.Convey("process invoke success", func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + invokeRequest := types.InvokeRequest{} + data, _ := json.Marshal(invokeRequest) + return data, nil + }) + faasscheduler.Proxy.Add("faasScheduler1") + defer patch.Reset() + defer faasscheduler.Proxy.Remove("faasScheduler1") + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(w.StatusCode, convey.ShouldEqual, http.StatusOK) + }) + convey.Convey("process invoke kernel error", func() { + ch.sdkClient = &fakeSDKClient{returnErr: true} + defer gomonkey.ApplyMethod(reflect.TypeOf(ch.sdkClient), "InvokeByFunctionName", + func(_ *fakeSDKClient, funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + return "123", errors.New("kernel error") + }).Reset() + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + invokeRequest := types.InvokeRequest{} + data, _ := json.Marshal(invokeRequest) + return data, nil + }) + faasscheduler.Proxy.Add("faasScheduler1") + defer patch.Reset() + defer faasscheduler.Proxy.Remove("faasScheduler1") + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + for id, _ := range ch.futureMap { + rep := <-ch.futureMap[id] + convey.So(rep, convey.ShouldNotBeNil) + } + }) + }) +} + +func TestHandleGetFuture(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + convey.Convey("Test Handle Get Future", t, func() { + convey.Convey("read request body error", func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte{}, errors.New("read error") + }) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleGetFuture(w, r) + convey.So(w.StatusCode, convey.ShouldEqual, http.StatusOK) + }) + convey.Convey("unmarshal request body error", func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte("aaa"), nil + }) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleGetFuture(w, r) + convey.So(w.StatusCode, convey.ShouldEqual, http.StatusOK) + }) + convey.Convey("process futureCh not exit", func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + invokeRequest := types.InvokeRequest{} + data, _ := json.Marshal(invokeRequest) + return data, nil + }) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleGetFuture(w, r) + convey.So(w.StatusCode, convey.ShouldEqual, http.StatusOK) + }) + }) +} + +func TestInitHandler(t *testing.T) { + convey.Convey("InitHandler", t, func() { + handler := NewCustomContainerHandler(newFuncSpec(), nil) + convey.So(handler, convey.ShouldNotBeNil) + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(monitor.CreateFunctionMonitor, func(funcSpec *types.FuncSpec, stopCh chan struct{}) (*monitor.FunctionMonitorManager, error) { + return &monitor.FunctionMonitorManager{}, nil + }), + gomonkey.ApplyFunc(http.HandleFunc, func(pattern string, handler func(http.ResponseWriter, *http.Request)) { + return + }), + gomonkey.ApplyMethod(reflect.TypeOf(&http.Server{}), "ListenAndServe", func(_ *http.Server) error { + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + convey.Convey("failed to parse create params: invalid create params number", func() { + args := []api.Arg{} + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldEqual, []byte{}) + }) + + convey.Convey("failed to parse create params: failed to unmarshal funcSpec from", func() { + args := []api.Arg{ + { + Type: 0, + Data: []byte("invalid json"), + }, + { + Type: 0, + Data: []byte{}, + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldEqual, []byte{}) + }) + convey.Convey("test initExecuteWithTimeout", func() { + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: 0, + Data: []byte(""), + }, + { + Type: 0, + Data: []byte(""), + }, + } + convey.Convey("failed to test ExecuteTimeout cause timeout", func() { + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldEqual, []byte{}) + }) + convey.Convey("succeed to test ExecuteTimeout", func() { + defer gomonkey.ApplyFunc(net.Dial, func(_, _ string) (net.Conn, error) { + n := tls.Conn{} + return &n, nil + }).Reset() + defer gomonkey.ApplyFunc((*tls.Conn).Close, func(_ *tls.Conn) error { + return nil + }).Reset() + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldBeEmpty) + }) + }) + + convey.Convey("failed to create request to init route error", func() { + defer gomonkey.ApplyFunc(http.NewRequestWithContext, + func(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + return nil, errors.New("NewRequestWithContext error") + }).Reset() + defer gomonkey.ApplyFunc((*CustomContainerHandler).awaitReady, func(_ *CustomContainerHandler, timeout time.Duration) error { + return nil + }).Reset() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: 0, + Data: []byte(""), + }, + { + Type: 0, + Data: []byte(""), + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldBeEmpty) + }) + + convey.Convey("init failed request timed out after 3s", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Do", + func(_ *http.Client, req *http.Request) (*http.Response, error) { + time.Sleep(3*time.Second + 10*time.Millisecond) + return nil, errors.New("timeout 3s") + }).Reset() + defer gomonkey.ApplyFunc((*CustomContainerHandler).awaitReady, func(_ *CustomContainerHandler, timeout time.Duration) error { + return nil + }).Reset() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: 0, + Data: []byte(""), + }, + { + Type: 0, + Data: []byte(""), + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldBeEmpty) + }) + + convey.Convey("init failed request error", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Do", + func(_ *http.Client, req *http.Request) (*http.Response, error) { + return nil, errors.New("mock init request error") + }).Reset() + defer gomonkey.ApplyFunc((*CustomContainerHandler).awaitReady, func(_ *CustomContainerHandler, timeout time.Duration) error { + return nil + }).Reset() + handlerWithHealthCheck := NewCustomContainerHandler(newFuncSpecWithHealthCheck(), &fakeSDKClient{}) + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: 0, + Data: []byte(""), + }, + { + Type: 0, + Data: []byte(""), + }, + } + res, err := handlerWithHealthCheck.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldBeEmpty) + }) + + convey.Convey("success", func() { + dsClient := &fakeSDKClient{} + patches := gomonkey.NewPatches() + patches.ApplyFunc(sts.InitStsSDK, func(serverCfg types.StsServerConfig) error { + return nil + }) + patches.ApplyFunc(stsgoapi.DecryptSensitiveConfig, func(rawConfigValue string) (plainBytes []byte, err error) { + return []byte{}, nil + }) + patches.ApplyMethod(reflect.TypeOf(dsClient), "CreateClient", func(_ *fakeSDKClient, config api.ConnectArguments) (api.KvClient, error) { + return &FakeDataSystemClinet{}, nil + }) + patches.ApplyMethod(reflect.TypeOf(&basicHandler{}), "InitHandler", func(_ *basicHandler, args []api.Arg, apiClient api.LibruntimeAPI) ([]byte, error) { + return nil, nil + }) + defer patches.Reset() + + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Do", + func(_ *http.Client, req *http.Request) (*http.Response, error) { + return nil, nil + }).Reset() + defer gomonkey.ApplyFunc((*CustomContainerHandler).awaitReady, func(_ *CustomContainerHandler, timeout time.Duration) error { + return nil + }).Reset() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: 0, + Data: []byte("{}"), + }, + { + Type: 0, + Data: []byte("{}"), + }, + } + res, _ := handler.InitHandler(args, dsClient) + convey.So(res, convey.ShouldBeEmpty) + }) + + convey.Convey("DataSystem client init exception", func() { + dsClient := &fakeSDKClient{} + patches := gomonkey.NewPatches() + patches.ApplyMethod(reflect.TypeOf(&basicHandler{}), "InitHandler", func(_ *basicHandler, args []api.Arg, apiClient api.LibruntimeAPI) ([]byte, error) { + return nil, nil + }) + defer patches.Reset() + + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Do", + func(_ *http.Client, req *http.Request) (*http.Response, error) { + return nil, nil + }).Reset() + defer gomonkey.ApplyFunc((*CustomContainerHandler).awaitReady, func(_ *CustomContainerHandler, timeout time.Duration) error { + return nil + }).Reset() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: 0, + Data: []byte("{}"), + }, + { + Type: 0, + Data: []byte("{}"), + }, + } + patch0 := gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg types.StsServerConfig) error { + return errors.New("mock error for sts.InitStsSDK") + }) + defer patch0.Reset() + res, _ := handler.InitHandler(args, dsClient) + convey.So(res, convey.ShouldBeEmpty) + + patch0.Reset() + patch0.ApplyFunc(sts.InitStsSDK, func(serverCfg types.StsServerConfig) error { + return nil + }) + + patch1 := gomonkey.ApplyFunc(stsgoapi.DecryptSensitiveConfig, func(rawConfigValue string) (plainBytes []byte, err error) { + return []byte{}, errors.New("mock error for stsgoapi.DecryptSensitiveConfig") + }) + defer patch1.Reset() + + res, _ = handler.InitHandler(args, dsClient) + convey.So(res, convey.ShouldBeEmpty) + + patch1.Reset() + patch1.ApplyFunc(stsgoapi.DecryptSensitiveConfig, func(rawConfigValue string) (plainBytes []byte, err error) { + return []byte{}, nil + }) + + defer gomonkey.ApplyMethod(reflect.TypeOf(&fakeSDKClient{}), "CreateClient", func(_ *fakeSDKClient, config api.ConnectArguments) (api.KvClient, error) { + return &FakeDataSystemClinet{}, errors.New("mock error for FakeDataSystemClinet.CreateClient") + }).Reset() + + res, _ = handler.InitHandler(args, dsClient) + convey.So(res, convey.ShouldBeEmpty) + }) + }) +} + +func TestShutDownHander(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + invokeServer: &http.Server{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + convey.Convey("Test ShutDown Hander", t, func() { + convey.Convey("shutdown success", func() { + h := &basicHandler{} + patch := gomonkey.ApplyMethod(reflect.TypeOf(h), + "ShutDownHandler", func(b *basicHandler, gracePeriodSecond uint64) error { + return nil + }) + defer patch.Reset() + err := ch.ShutDownHandler(30) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestGetCallRequestBody(t *testing.T) { + request := types.InvokeRequest{ + Payload: "hello", + } + convey.Convey("Test GetCallRequestBody", t, func() { + convey.Convey("disableAPIGFormat is ture", func() { + requestData, err := getCallRequestBody(request, true) + convey.So(err, convey.ShouldBeNil) + convey.So(string(requestData), convey.ShouldEqual, "hello") + }) + convey.Convey("disableAPIGFormat is false", func() { + requestData, err := getCallRequestBody(request, false) + APIGData := APIGTriggerEvent{} + _ = json.Unmarshal(requestData, &APIGData) + convey.So(err, convey.ShouldBeNil) + convey.So(APIGData.Body, convey.ShouldEqual, "hello") + }) + }) +} + +func TestHandleResponse(t *testing.T) { + convey.Convey("Test HandleResponse", t, func() { + convey.Convey("json marshal error", func() { + patch := gomonkey.ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return []byte{}, errors.New("marshal error") + }) + defer patch.Reset() + writer := &fakeWriter{} + handleResponse(&types.GetFutureResponse{}, writer, "test") + convey.So(writer.StatusCode, convey.ShouldEqual, http.StatusInternalServerError) + }) + convey.Convey("Write error", func() { + writer := &fakeWriter{} + response := &types.GetFutureResponse{ + StatusCode: constants.NoneError, + } + handleResponse(response, writer, "test") + convey.So(response.GetStatusCode(), convey.ShouldEqual, constants.NoneError) + }) + }) +} + +type fakeWriter struct { + StatusCode int +} + +func (w *fakeWriter) Header() http.Header { + return nil +} + +func (w *fakeWriter) Write([]byte) (int, error) { + return 0, errors.New("write error") +} + +func (w *fakeWriter) WriteHeader(statusCode int) { + w.StatusCode = statusCode +} + +func TestCustomGracefulShutdownNewVer(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + invokeServer: &http.Server{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + convey.Convey("test customGracefulShutdownNewVer", t, func() { + convey.Convey("test stop by cust-image", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Get", func(_ *http.Client, url string) (resp *http.Response, err error) { + return &http.Response{ + StatusCode: 200, + }, nil + }).Reset() + convey.So(func() { + ch.customGracefulShutdownNewVer(1 * time.Second) + }, convey.ShouldNotPanic) + }) + + convey.Convey("test stop by timeout", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Get", func(_ *http.Client, url string) (resp *http.Response, err error) { + t.Logf("sleep 2s") + time.Sleep(2 * time.Second) + return &http.Response{ + StatusCode: 200, + }, nil + }).Reset() + convey.So(func() { + ch.customGracefulShutdownNewVer(1 * time.Second) + }, convey.ShouldNotPanic) + }) + }) +} + +func TestProcessState(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + invokeServer: &http.Server{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + + req := types.StateRequest{ + FuncName: "test-func-name", + FuncVersion: "test-func-version", + Params: nil, + StateID: "test-state-id" + uuid.New().String(), + TraceID: "test-trace-id-1", + } + + convey.Convey("func urn is invalid", t, func() { + defer gomonkey.ApplyFunc(urnutils.GetFunctionInfo, func(urn string) (urnutils.BaseURN, error) { + return urnutils.BaseURN{}, errors.New("mock urn error") + }).Reset() + err := ch.processState(req, "new") + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("state id is empty", t, func() { + req.StateID = "" + err := ch.processState(req, "new") + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("stateMgr cannot be obtained", t, func() { + ch.stateMgrState = failed + err := ch.processState(req, "new") + convey.So(err, convey.ShouldNotBeNil) + }) + + ch.stateMgrState = completed + ch.stateMgr = &stateManager{} + req.StateID = "test-state-id" + + convey.Convey("new state success", t, func() { + defer gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch.stateMgr), "newState", func(_ *stateManager, funcKey string, stateID string, traceID string) error { + return nil + }).Reset() + err := ch.processState(req, "new") + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("new state fail due to stateMgr failure", t, func() { + defer gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch.stateMgr), "newState", func(_ *stateManager, funcKey, stateID, traceID string) error { + return errors.New("mock state mgr err") + }).Reset() + err := ch.processState(req, "new") + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("get state success", t, func() { + defer gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch.stateMgr), "getState", func(_ *stateManager, funcKey, stateID, traceID string) error { + return nil + }).Reset() + err := ch.processState(req, "get") + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("get state fail due to stateMgr failure", t, func() { + defer gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch.stateMgr), "getState", func(_ *stateManager, funcKey, stateID, traceID string) error { + return errors.New("mock state mgr err") + }).Reset() + err := ch.processState(req, "get") + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("del state fail due to stateMgr failure", t, func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch.stateMgr), "delState", func(_ *stateManager, funcKey, stateID, traceID string) error { + return errors.New("mock state mgr err") + }) + defer patch.Reset() + err := ch.processState(req, "del") + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("del state fail due to faasscheduler failure", t, func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch.stateMgr), "delState", func(_ *stateManager, funcKey, stateID, traceID string) error { + return nil + }) + defer patch.Reset() + err := ch.processState(req, "del") + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("del state success", t, func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch.stateMgr), "delState", func(_ *stateManager, funcKey, stateID, traceID string) error { + return nil + }) + faasscheduler.Proxy.Add("faasScheduler1") + defer faasscheduler.Proxy.Remove("faasScheduler1") + defer patch.Reset() + err := ch.processState(req, "del") + convey.So(err, convey.ShouldBeNil) + }) +} + +func mockHttpWriter(rspCode *int, patch *gomonkey.Patches) { + patch.ApplyMethod(reflect.TypeOf(&FakeHTTP.TestResponseWriter{}), "Write", func(_ *FakeHTTP.TestResponseWriter, data []byte) (int, error) { + response := types.StateResponse{} + err := json.Unmarshal(data, &response) + if err != nil { + *rspCode = -1 + } + *rspCode = response.StatusCode + return 0, nil + }) +} + +func TestHandleStateNew(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + invokeServer: &http.Server{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + + convey.Convey("read req body error", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte{}, errors.New("read error") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateNew(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("unmarshal body failed", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte("aaa"), nil + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateNew(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("read req body successfully", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + invokeRequest := types.InvokeRequest{} + data, _ := json.Marshal(invokeRequest) + return data, nil + }) + defer patch.Reset() + convey.Convey("processState return snErr", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return api.ErrorInfo{ + Code: StateInstanceNotExistedErrCode, + Err: fmt.Errorf("mock snerror"), + } + }) + + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateNew(w, r) + convey.So(rspCode, convey.ShouldEqual, StateInstanceNotExistedErrCode) + }) + + convey.Convey("processState return stateError", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return api.ErrorInfo{ + Code: InvalidState, + Err: fmt.Errorf("error"), + } + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateNew(w, r) + convey.So(rspCode, convey.ShouldEqual, InvalidState) + }) + + convey.Convey("processState return normal error", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return errors.New("mock normal error") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateNew(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("processState success", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return nil + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateNew(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.NoneError) + }) + }) +} + +func TestHandleStateGet(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + invokeServer: &http.Server{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + + convey.Convey("read req body error", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte{}, errors.New("read error") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateGet(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("unmarshal body error", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte("aaa"), nil + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateGet(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("read req body successfully", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + invokeRequest := types.InvokeRequest{} + data, _ := json.Marshal(invokeRequest) + return data, nil + }) + defer patch.Reset() + convey.Convey("processState return snErr", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return api.ErrorInfo{ + Code: StateInstanceNotExistedErrCode, + Err: fmt.Errorf("mock snerror"), + } + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateGet(w, r) + convey.So(rspCode, convey.ShouldEqual, StateInstanceNotExistedErrCode) + }) + + convey.Convey("processState return stateError", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return api.ErrorInfo{ + Code: InvalidState, + Err: fmt.Errorf("mock snerror"), + } + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateGet(w, r) + convey.So(rspCode, convey.ShouldEqual, InvalidState) + }) + + convey.Convey("processState return normal error", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return errors.New("mock normal error") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateGet(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("processState success", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return nil + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateGet(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.NoneError) + }) + }) +} + +func TestHandleStateDel(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + invokeServer: &http.Server{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + + convey.Convey("read req body error", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte{}, errors.New("read error") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateDel(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("unmarshal body error", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte("aaa"), nil + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateDel(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("read req body successfully", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + invokeRequest := types.InvokeRequest{} + data, _ := json.Marshal(invokeRequest) + return data, nil + }) + defer patch.Reset() + convey.Convey("processState return snErr", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return api.ErrorInfo{ + Code: StateInstanceNotExistedErrCode, + Err: fmt.Errorf("mock snerror"), + } + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateDel(w, r) + convey.So(rspCode, convey.ShouldEqual, StateInstanceNotExistedErrCode) + }) + + convey.Convey("processState return stateError", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return api.ErrorInfo{ + Code: InvalidState, + Err: fmt.Errorf("mock snerror"), + } + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateDel(w, r) + convey.So(rspCode, convey.ShouldEqual, InvalidState) + }) + + convey.Convey("processState return normal error", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return errors.New("mock normal error") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateDel(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("processState success", func() { + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(ch), "processState", func(_ *CustomContainerHandler, request types.StateRequest, opType string) error { + return nil + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleStateDel(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.NoneError) + }) + }) +} + +func TestInitDsClinet(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + invokeServer: &http.Server{}, + remoteDsClient: &FakeDataSystemClinet{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + + convey.Convey("init sts success", t, func() { + convey.Convey("create client fail", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&fakeSDKClient{}), "CreateClient", + func(_ *fakeSDKClient, config api.ConnectArguments) (api.KvClient, error) { + return &FakeDataSystemClinet{}, errors.New("mock create client error") + }).Reset() + err := ch.initDsClient(&fakeSDKClient{}) + t.Logf("error is %v", err) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("create client successufully", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&FakeDataSystemClinet{}), "CreateClient", + func(_ *FakeDataSystemClinet, config api.ConnectArguments) (api.KvClient, error) { + return &FakeDataSystemClinet{}, nil + }).Reset() + err := ch.initDsClient(&fakeSDKClient{}) + t.Logf("error is %v", err) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestGetStateMgr(t *testing.T) { + ch := &CustomContainerHandler{} + convey.Convey("stateMgrState is failed", t, func() { + ch.stateMgrState = failed + sm := ch.getStateMgr() + convey.So(sm, convey.ShouldBeNil) + }) + convey.Convey("stateMgrState is uninitialized", t, func() { + ch.stateMgrState = uninitialized + defer gomonkey.ApplyFunc(time.Sleep, func(d time.Duration) { + }).Reset() + convey.Convey("stateMgr is nil", func() { + ch.stateMgr = nil + sm := ch.getStateMgr() + convey.So(sm, convey.ShouldBeNil) + }) + convey.Convey("stateMgr is not nil", func() { + ch.stateMgr = &stateManager{} + sm := ch.getStateMgr() + convey.So(sm, convey.ShouldNotBeNil) + }) + }) + +} + +func TestHandleInvokeStatePart(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + remoteDsClient: &FakeDataSystemClinet{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + ch.setStateMgr(GetStateManager(ch.remoteDsClient), completed) + faasscheduler.Proxy.Add("faasScheduler1") + defer faasscheduler.Proxy.Remove("faasScheduler1") + + convey.Convey("Test Handle Invoke by StateID", t, func() { + patch := gomonkey.ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + invokeRequest := types.InvokeRequest{ + StateID: "test-state-id", + } + data, _ := json.Marshal(invokeRequest) + return data, nil + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + convey.Convey("process invoke success", func() { + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.NoneError) + }) + + convey.Convey("func urn is invalid", func() { + patch := gomonkey.ApplyFunc(urnutils.GetFunctionInfo, func(urn string) (urnutils.BaseURN, error) { + return urnutils.BaseURN{}, errors.New("mock urn error") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("stateMgr is nil", func() { + ch.stateMgrState = failed + defer func() { ch.stateMgrState = completed }() + patch := gomonkey.NewPatches() + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.FaaSError) + }) + + convey.Convey("invoke fail return normal error", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&fakeSDKClient{}), "InvokeByInstanceId", func(_ *fakeSDKClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, invokeOpt api.InvokeOptions) (returnObjectID string, err error) { + return "", errors.New("mock invoke err") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.NoneError) + }) + + convey.Convey("invoke fail return SNError", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&fakeSDKClient{}), "InvokeByInstanceId", func(_ *fakeSDKClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (returnObjectID string, err error) { + return "", api.ErrorInfo{Code: 4000, Err: fmt.Errorf("mock snerror")} + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.NoneError) + }) + + convey.Convey("acquire instance fail return normal error", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&fakeSDKClient{}), "AcquireInstance", func(_ *fakeSDKClient, state string, funcMeta api.FunctionMeta, + acquireOpt api.InvokeOptions) (api.InstanceAllocation, error) { + return api.InstanceAllocation{}, errors.New("mock AcquireInstance normal error") + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.NoneError) + }) + + convey.Convey("acquire instance fail return SNError but no faasscheduler error", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&fakeSDKClient{}), "AcquireInstance", func(_ *fakeSDKClient, + state string, funcMeta api.FunctionMeta, acquireOpt api.InvokeOptions) (api.InstanceAllocation, error) { + return api.InstanceAllocation{}, &fakeSnError{ + ErrorCode: 4999, // 4999 is not faasscheduler errorCodes + ErrorMessage: "mock snerror", + } + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(rspCode, convey.ShouldEqual, constants.NoneError) + }) + + convey.Convey("acquire instance fail return faas error", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&fakeSDKClient{}), "AcquireInstance", func(_ *fakeSDKClient, + state string, funcMeta api.FunctionMeta, acquireOpt api.InvokeOptions) (api.InstanceAllocation, error) { + return api.InstanceAllocation{}, api.ErrorInfo{ + Code: StateInstanceNotExistedErrCode, + Err: fmt.Errorf("mock snerror"), + } + }) + rspCode := 0 + mockHttpWriter(&rspCode, patch) + defer patch.Reset() + w := &FakeHTTP.TestResponseWriter{} + r := &http.Request{} + ch.handleInvoke(w, r) + convey.So(rspCode, convey.ShouldEqual, StateInstanceNotExistedErrCode) + }) + }) +} + +func TestCustomContainerHandler_checkRuntime(t *testing.T) { + convey.Convey("check runtime test", t, func() { + + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + remoteDsClient: &FakeDataSystemClinet{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + check := types.CustomHealthCheck{ + TimeoutSeconds: 1, + PeriodSeconds: 5, + FailureThreshold: 1, + } + convey.Convey("res is 500", func() { + p := gomonkey.ApplyFunc((*basicHandler).sendRequest, func(_ *basicHandler, + request *http.Request, timeout time.Duration) (*http.Response, error) { + res := &http.Response{} + res.StatusCode = http.StatusOK + return res, nil + }) + go func() { + err := ch.checkRuntime(nil, check) + assert.Equal(t, err.Error(), "check health res status code is 500") + }() + time.Sleep(2 * time.Second) + p.Reset() + p = gomonkey.ApplyFunc((*basicHandler).sendRequest, func(_ *basicHandler, + request *http.Request, timeout time.Duration) (*http.Response, error) { + res := &http.Response{} + res.StatusCode = http.StatusInternalServerError + return res, nil + }) + time.Sleep(time.Duration(check.PeriodSeconds) * time.Second) + p.Reset() + }) + + convey.Convey("err not nil", func() { + p := gomonkey.ApplyFunc((*basicHandler).sendRequest, func(_ *basicHandler, + request *http.Request, timeout time.Duration) (*http.Response, error) { + return nil, fmt.Errorf("error") + }) + defer p.Reset() + err := ch.checkRuntime(nil, check) + convey.So(err.Error(), convey.ShouldEqual, "error") + }) + convey.Convey("fail fail success", func() { + failCount := 2 + check.FailureThreshold = failCount + 1 + count := 0 + p := gomonkey.ApplyFunc((*basicHandler).sendRequest, func(_ *basicHandler, + request *http.Request, timeout time.Duration) (*http.Response, error) { + res := &http.Response{} + if count < failCount { + count++ + return nil, fmt.Errorf("error") + } + res.StatusCode = http.StatusOK + return res, nil + }) + defer func() { + close(stopCh) + time.Sleep(100 * time.Millisecond) + stopCh = make(chan struct{}) + p.Reset() + }() + go func() { + _ = ch.checkRuntime(nil, check) + }() + time.Sleep(20 * time.Second) + }) + }) +} + +func TestCustomContainerHandler_checkHealth(t *testing.T) { + convey.Convey("check health", t, func() { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + remoteDsClient: &FakeDataSystemClinet{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + convey.Convey("check runtime health failed", func() { + ch.funcSpec.ExtendedMetaData.CustomHealthCheck.TimeoutSeconds = 1 + ch.funcSpec.ExtendedMetaData.CustomHealthCheck.PeriodSeconds = 1 + ch.funcSpec.ExtendedMetaData.CustomHealthCheck.FailureThreshold = 1 + p := gomonkey.ApplyFunc((*CustomContainerHandler).checkRuntime, func(_ *CustomContainerHandler, + request *http.Request, check types.CustomHealthCheck) error { + return fmt.Errorf("error") + }) + defer p.Reset() + ch.checkHealth() + }) + }) +} + +func TestCustomContainerHandler_handleExit(t *testing.T) { + convey.Convey("handle exit test", t, func() { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + remoteDsClient: &FakeDataSystemClinet{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + cnt := 0 + gomonkey.ApplyMethod(ch.sdkClient, "Exit", func(_ *fakeSDKClient, code int, message string) { + if code != 0 { + cnt++ + } + return + }) + convey.Convey("baseline", func() { + p := gomonkey.ApplyFunc(handleResponse, func(response types.Response, + w http.ResponseWriter, handleType string) { + return + }) + defer p.Reset() + ch.handleExit(&FakeHTTP.TestResponseWriter{}, &http.Request{Body: io.NopCloser(strings.NewReader(""))}) + }) + convey.Convey("withexit", func() { + p := gomonkey.ApplyFunc(handleResponse, func(response types.Response, + w http.ResponseWriter, handleType string) { + return + }) + defer p.Reset() + ch.handleExit(&FakeHTTP.TestResponseWriter{}, &http.Request{Body: io.NopCloser(strings.NewReader("{\"code\":3015, \"message\":\"\"}"))}) + time.Sleep(5 * time.Millisecond) + convey.So(cnt, convey.ShouldEqual, 1) + }) + }) +} + +type GetCredentialSuccessClient struct { + fakeSDKClient +} + +func (f *GetCredentialSuccessClient) GetCredential() api.Credential { + return api.Credential{ + AccessKey: "ak", + SecretKey: []byte("sk"), + DataKey: []byte("dk"), + } +} + +func TestCustomContainerHandler_handleCredentialRequire(t *testing.T) { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{}, + } + convey.Convey("handle credential require test", t, func() { + convey.Convey("success", func() { + ch.basicHandler.sdkClient = &GetCredentialSuccessClient{} + w := &FakeHTTP.TestResponseWriter{} + ch.handleCredentialRequire(w, nil) + resp := &types.CredentialResponse{} + _ = json.Unmarshal([]byte(w.Output), resp) + convey.So(w.StatusCode, convey.ShouldEqual, http.StatusOK) + convey.So(resp.StatusCode, convey.ShouldEqual, constants.NoneError) + convey.So(resp.AccessKey, convey.ShouldEqual, "ak") + convey.So(string(resp.SecretKey), convey.ShouldEqual, "sk") + convey.So(string(resp.DataKey), convey.ShouldEqual, "dk") + }) + }) +} + +func TestCustomContainerHandler_handleCircuitBreak(t *testing.T) { + convey.Convey("handle circuit break test", t, func() { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + remoteDsClient: &FakeDataSystemClinet{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + convey.Convey("baseline", func() { + p := gomonkey.ApplyFunc(io.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte("{\"switch\":true}"), nil + }) + p2 := gomonkey.ApplyFunc(handleResponse, func(response types.Response, + w http.ResponseWriter, handleType string) { + return + }) + defer func() { + p.Reset() + p2.Reset() + }() + ch.handleCircuitBreak(nil, &http.Request{}) + convey.So(ch.circuitBreaker, convey.ShouldBeTrue) + }) + convey.Convey("io read failed", func() { + p := gomonkey.ApplyFunc(io.ReadAll, func(r io.Reader) ([]byte, error) { + return nil, fmt.Errorf("error") + }) + defer p.Reset() + ch.circuitBreaker = false + ch.handleCircuitBreak(&FakeHTTP.TestResponseWriter{}, &http.Request{}) + convey.So(ch.circuitBreaker, convey.ShouldBeFalse) + }) + convey.Convey("json unmarshal failed", func() { + p := gomonkey.ApplyFunc(io.ReadAll, func(r io.Reader) ([]byte, error) { + return []byte("aaa"), nil + }) + defer p.Reset() + ch.circuitBreaker = true + ch.handleCircuitBreak(&FakeHTTP.TestResponseWriter{}, &http.Request{}) + convey.So(ch.circuitBreaker, convey.ShouldBeTrue) + }) + }) +} + +func TestCustomContainerHandler_HealthCheckHandler(t *testing.T) { + convey.Convey("health check handler test", t, func() { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + remoteDsClient: &FakeDataSystemClinet{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + convey.Convey("circuitBreaker is true", func() { + ch.circuitBreaker = true + res, err := ch.HealthCheckHandler() + convey.So(err, convey.ShouldBeNil) + convey.So(res, convey.ShouldEqual, api.SubHealth) + }) + convey.Convey("circuitBreaker is false", func() { + ch.circuitBreaker = false + res, err := ch.HealthCheckHandler() + convey.So(err, convey.ShouldBeNil) + convey.So(res, convey.ShouldEqual, api.Healthy) + }) + }) +} + +func TestCustomContainerHandler_setAlarmInfo(t *testing.T) { + convey.Convey("set alarmInfo test", t, func() { + ch := &CustomContainerHandler{ + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + sdkClient: &fakeSDKClient{}, + }, + futureMap: make(map[string]chan types.GetFutureResponse, constants.DefaultMapSize), + remoteDsClient: &FakeDataSystemClinet{}, + crossClusterInvoker: &crossclusterinvoke.Invoker{}, + } + convey.Convey("baseline", func() { + ch.customUserArgs.AlarmConfig.EnableAlarm = true + p := gomonkey.ApplyFunc(alarm.SetPodIP, func() error { + return nil + }) + defer p.Reset() + err := ch.setAlarmInfo() + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("set podIp failed", func() { + ch.customUserArgs.AlarmConfig.EnableAlarm = true + p := gomonkey.ApplyFunc(alarm.SetPodIP, func() error { + return fmt.Errorf("error") + }) + defer p.Reset() + err := ch.setAlarmInfo() + convey.So(err.Error(), convey.ShouldEqual, "error") + }) + }) +} + +func Test_checkTrafficLimitResp(t *testing.T) { + convey.Convey("check traffic limit resp test", t, func() { + convey.So(checkTrafficLimitResp([]byte("{\"errorCode\":0, \"errorMessage\": \"\"}")), convey.ShouldBeFalse) + convey.So(checkTrafficLimitResp([]byte("{\"errorCode\":6037, \"errorMessage\": \"\"}")), convey.ShouldBeTrue) + convey.So(checkTrafficLimitResp([]byte("aaa")), convey.ShouldBeFalse) + }) +} diff --git a/api/go/faassdk/handler/http/custom_container_log.go b/api/go/faassdk/handler/http/custom_container_log.go new file mode 100644 index 0000000..4c0099e --- /dev/null +++ b/api/go/faassdk/handler/http/custom_container_log.go @@ -0,0 +1,677 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/fsnotify/fsnotify" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/common/functionlog" + "yuanrong.org/kernel/runtime/faassdk/config" + log "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + syncStream = "sync" +) + +var errWatcherStopped = errors.New("watcher is stopped") + +func collectRuntimeContainerLog(conf *config.Configuration) *RuntimeContainerLogger { + runtimeContainerID := conf.RuntimeContainerID + logFileName := runtimeContainerID + "-json.log" + logFilePath := filepath.Join("/var/lib/docker/containers/", runtimeContainerID, logFileName) + + processCB := func(msg RuntimeContainerLog) { + logger, err := functionlog.GetFunctionLogger(conf) + if err != nil { + log.GetLogger().Warnf("failed to get function logger, %s", err.Error()) + return + } + if logger == nil { + log.GetLogger().Warnf("failed to get function logger") + return + } + + logger.WriteStdLog(msg.log, msg.t.UTC().Format(functionlog.NanoLogLayout), + false, + msg.t.UTC().Format(functionlog.NanoLogLayout)) + } + logger, err := NewRuntimeStdLogger(logFilePath, processCB) + if err != nil { + log.GetLogger().Errorf("failed to collect runtime container log: %s", err.Error()) + } + + go func() { + if err := logger.Run(); err != nil { + log.GetLogger().Errorf("failed to collect runtime container log: %s", err.Error()) + } + }() + + return logger +} + +// RuntimeContainerLogProcessCB - +type RuntimeContainerLogProcessCB func(RuntimeContainerLog) + +// RuntimeContainerLog - +type RuntimeContainerLog struct { + log string + level string + t time.Time +} + +type dockerLog struct { + Log string `json:"log"` + Stream string `json:"stream"` + Time string `json:"time"` +} + +type syncPoint struct { + t time.Time + done chan struct{} +} + +// RuntimeContainerLogger - +type RuntimeContainerLogger struct { + t *tail + syncPointCh chan *syncPoint + syncPoints []*syncPoint + logFilePath string + processCB RuntimeContainerLogProcessCB + tick time.Time +} + +// NewRuntimeStdLogger - +func NewRuntimeStdLogger(logFilePath string, processCB RuntimeContainerLogProcessCB) (*RuntimeContainerLogger, error) { + logger := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + processCB: processCB, + logFilePath: logFilePath, + } + + err := tailFile(logger) + if err != nil { + log.GetLogger().Errorf("failed to start tailing file: %s, %s", logFilePath, err.Error()) + return nil, err + } + + return logger, nil +} + +// Run - +func (l *RuntimeContainerLogger) Run() error { + if l.t == nil { + log.GetLogger().Errorf("tail is nil") + return errors.New("tail is nil") + } + + log.GetLogger().Infof("start tailing file: %s", l.logFilePath) + + timeOfNoSync := time.NewTicker(time.Second) + defer timeOfNoSync.Stop() + for { + select { + case line, ok := <-l.t.Lines: + if !ok { + log.GetLogger().Errorf("tail lines chan closed unexpectedly") + return errors.New("tail lines chan closed unexpectedly") + } + l.processLine(line) + l.notifySyncPoints() + case syncPoint, ok := <-l.syncPointCh: + if !ok { + log.GetLogger().Errorf("sync times chan closed unexpectedly") + return errors.New("sync times chan closed unexpectedly") + } + l.addSyncPoint(syncPoint) + l.notifySyncPoints() + case err, ok := <-l.t.Errors: + if !ok || err == nil { + log.GetLogger().Errorf("tail error chan closed unexpectedly") + return errors.New("tail error chan closed unexpectedly") + } + log.GetLogger().Errorf("failed to tail file: %s, %s", l.logFilePath, err.Error()) + return err + case <-timeOfNoSync.C: + l.tick = time.Now() + l.notifySyncPoints() + } + } +} + +// syncTo waits until logger has processed to "t" +func (l *RuntimeContainerLogger) syncTo(t time.Time) error { + // Wait for docker log to catch up. + time.Sleep(1 * time.Millisecond) + + content, err := makeDockerLog("", syncStream, t.UTC()) + if err != nil { + log.GetLogger().Errorf("failed to make docker log: %s", err.Error()) + return err + } + if err := appendFile(l.logFilePath, content); err != nil { + log.GetLogger().Errorf("failed to append file: %s, %s", l.logFilePath, err.Error()) + return err + } + + syncPoint := &syncPoint{ + t: t, + done: make(chan struct{}), + } + + l.syncPointCh <- syncPoint + + <-syncPoint.done + + return nil +} + +func (l *RuntimeContainerLogger) addSyncPoint(syncPoint *syncPoint) { + if !syncPoint.t.After(l.tick) { + close(syncPoint.done) + } else { + l.syncPoints = append(l.syncPoints, syncPoint) + sort.Slice(l.syncPoints, func(i, j int) bool { + if i < 0 || i >= len(l.syncPoints) || j < 0 || j >= len(l.syncPoints) { + return false + } + return l.syncPoints[i].t.Before(l.syncPoints[j].t) + }) + } +} + +func (l *RuntimeContainerLogger) notifySyncPoints() { + idx := -1 + for i, syncPoint := range l.syncPoints { + if syncPoint.t.After(l.tick) { + break + } + idx = i + } + if idx != -1 { + for i := 0; i < idx+1; i++ { + close(l.syncPoints[i].done) + } + l.syncPoints = l.syncPoints[idx+1:] + } +} + +func (l *RuntimeContainerLogger) processLine(line string) { + msg := dockerLog{} + if err := json.Unmarshal([]byte(line), &msg); err != nil { + log.GetLogger().Warnf("failed to parse docker log: %s, %s", line, err.Error()) + return + } + + log.GetLogger().Debugf("process line: %+v", msg) + + var ( + level string + isSync bool + ) + switch msg.Stream { + case "stderr": + level = constants.FuncLogLevelWarn + case syncStream: + isSync = true + default: + level = constants.FuncLogLevelInfo + } + + t, err := time.Parse(time.RFC3339Nano, msg.Time) + if err != nil { + log.GetLogger().Warnf("failed to parse time: %s to RFC3339Nano, %s", msg.Time, err.Error()) + return + } + + if !isSync { + l.processCB(RuntimeContainerLog{log: msg.Log, level: level, t: t}) + } + + if t.After(l.tick) { + l.tick = t + } +} + +type tail struct { + file *os.File + reader *bufio.Reader + watcher *watcher + changeNotifier *changeNotifier + Lines chan string + Errors chan error + fileName string +} + +func tailFile(logger *RuntimeContainerLogger) error { + t := &tail{ + Lines: make(chan string), + Errors: make(chan error), + fileName: logger.logFilePath, + } + + var err error + t.watcher, err = newWatcher() + if err != nil { + return err + } + + go t.sync(logger) + logger.t = t + return nil +} + +func (t *tail) sync(logger *RuntimeContainerLogger) { + defer t.clean() + + if err := t.reopen(); err != nil { + log.GetLogger().Errorf("failed to wait open: %s", err.Error()) + t.Errors <- err + return + } + t.reader = bufio.NewReader(t.file) + + var ( + offset int64 + err error + ) + for { + offset, err = t.tell() + if err != nil { + log.GetLogger().Errorf("failed to tell file: %s, %s", t.fileName, err.Error()) + t.Errors <- err + return + } + + line, err := t.readLine() + if err != nil && errors.Is(err, io.EOF) { + if err := t.handleReadLineEOF(logger, offset, line); err != nil { + log.GetLogger().Errorf("failed to handle eof for file: %s, %s", t.fileName, err.Error()) + t.Errors <- err + return + } + continue + } + if err != nil { + log.GetLogger().Errorf("failed to read file: %s, %s", t.fileName, err.Error()) + t.Errors <- err + return + } + t.Lines <- line + } +} + +func (t *tail) clean() { + if t.watcher != nil { + t.watcher.close() + } + if t.file != nil { + if err := t.file.Close(); err != nil { + log.GetLogger().Warnf("failed to close file: %s, %s", t.fileName, err.Error()) + } + } +} + +func (t *tail) tell() (int64, error) { + offset, err := t.file.Seek(0, os.SEEK_CUR) + if err != nil { + return 0, err + } + offset -= int64(t.reader.Buffered()) + return offset, nil +} + +func (t *tail) readLine() (string, error) { + line, err := t.reader.ReadString('\n') + if err != nil { + return line, err + } + line = strings.TrimRight(line, "\n") + return line, nil +} + +func (t *tail) handleReadLineEOF(logger *RuntimeContainerLogger, offset int64, line string) error { + if line != "" { + if err := t.seekTo(offset, 0); err != nil { + log.GetLogger().Errorf("failed to seek to offset: %s, %s", offset, err.Error()) + return err + } + } + + if t.changeNotifier == nil { + pos, err := t.file.Seek(0, os.SEEK_CUR) + if err != nil { + log.GetLogger().Errorf("failed to seek to current position: %s", err.Error()) + return err + } + changeNotifier, err := t.watcher.watch(t.fileName, pos) + if err != nil { + log.GetLogger().Errorf("failed to watch file: %s, %s", t.fileName, err.Error()) + return err + } + t.changeNotifier = changeNotifier + } + + reopen := func() error { + if err := t.reopen(); err != nil { + log.GetLogger().Errorf("failed to reopen: %s", err.Error()) + return err + } + t.reader = bufio.NewReader(t.file) + return nil + } + + select { + case <-t.changeNotifier.modifyCh: + return nil + case <-t.changeNotifier.deleteCh: + log.GetLogger().Infof("reopen a deleted file...") + return reopen() + case <-t.changeNotifier.truncateCh: + log.GetLogger().Infof("reopen a truncated file...") + close(t.changeNotifier.closeCh) + <-t.changeNotifier.closeDone + go logger.syncTo(time.Now()) + return reopen() + case err, ok := <-t.changeNotifier.errCh: + if !ok || err == nil { + log.GetLogger().Errorf("change notifier error chan closed unexpectedly") + return errors.New("change notifier error chan closed unexpectedly") + } + return err + } +} + +func (t *tail) seekTo(offset int64, whence int) error { + _, err := t.file.Seek(offset, whence) + if err != nil { + return err + } + t.reader.Reset(t.file) + return nil +} + +func (t *tail) reopen() error { + if t.file != nil { + if err := t.file.Close(); err != nil { + log.GetLogger().Warnf("failed to close file: %s, %s", t.fileName, err.Error()) + } + } + t.changeNotifier = nil + + var err error + for { + t.file, err = os.Open(t.fileName) + if err == nil { + break + } + if !os.IsNotExist(err) { + return fmt.Errorf("unable to open file %s: %s", t.fileName, err.Error()) + } + if err := t.watcher.waitUntilCreate(t.fileName); err != nil { + if errors.Is(err, errWatcherStopped) { + return err + } + return fmt.Errorf("unable to wait creation of file: %s, %s", t.fileName, err.Error()) + } + } + return nil +} + +type watcher struct { + watcher *fsnotify.Watcher + stopCh chan struct{} +} + +func newWatcher() (*watcher, error) { + wa, err := fsnotify.NewWatcher() + if err != nil { + log.GetLogger().Errorf("failed to new fsnotify watcher: %s", err.Error()) + return nil, err + } + + w := &watcher{ + watcher: wa, + stopCh: make(chan struct{}), + } + + return w, nil +} + +func (w *watcher) close() { + close(w.stopCh) + if err := w.watcher.Close(); err != nil { + log.GetLogger().Warnf("failed to close watcher: %s", err.Error()) + } +} + +func (w *watcher) waitUntilCreate(fileName string) error { + dir := filepath.Dir(fileName) + if err := w.watcher.Add(dir); err != nil { + log.GetLogger().Errorf("failed to watch dir: %s, %s", dir, err.Error()) + return err + } + defer func() { + if err := w.watcher.Remove(dir); err != nil { + log.GetLogger().Warnf("failed to remove watching dir: %s, %s", dir, err.Error()) + } + }() + + _, err := os.Stat(fileName) + if err == nil { + return nil // file already exists. + } + if !os.IsNotExist(err) { + log.GetLogger().Errorf("failed to stat file: %s, %s", fileName, err.Error()) + return err + } + + for { + select { + case event, ok := <-w.watcher.Events: + if !ok { + log.GetLogger().Errorf("watcher event chan closed unexpectedly") + return errors.New("inotify watcher event chan closed unexpectedly") + } + + same, err := isSameFileName(fileName, event.Name) + if err != nil { + log.GetLogger().Errorf("failed to check is same file name for file %s and %s", + fileName, event.Name, err.Error()) + return err + } + if same { + log.GetLogger().Infof("file %s is created", fileName) + return nil + } + case err, ok := <-w.watcher.Errors: + if !ok || err == nil { + log.GetLogger().Errorf("watcher errors chan closed unexpectedly") + return errors.New("inotify watcher errors chan closed unexpectedly") + } + log.GetLogger().Errorf("failed to watch file, watcher returns error: %s", err.Error()) + return err + case <-w.stopCh: + return errWatcherStopped + } + } +} + +func (w *watcher) watch(fileName string, pos int64) (*changeNotifier, error) { + if err := w.watcher.Add(fileName); err != nil { + log.GetLogger().Errorf("failed to watch file: %s, %s", fileName, err.Error()) + return nil, err + } + + c := newChangeNotifier() + go func() { + for { + if !w.handle(c, fileName, &pos) { + return + } + } + }() + + return c, nil +} + +func (w *watcher) handle(c *changeNotifier, fileName string, size *int64) bool { + prevSize := *size + + select { + case event, ok := <-w.watcher.Events: + if !ok { + log.GetLogger().Errorf("watcher event chan closed unexpectedly") + c.errCh <- errors.New("watcher event chan closed unexpectedly") + return false + } + + switch { + case event.Op&fsnotify.Remove > 0 || event.Op&fsnotify.Rename > 0: + w.handleDelete(c, fileName) + return false + case event.Op&fsnotify.Chmod > 0 || event.Op&fsnotify.Write > 0: + fi, err := os.Stat(fileName) + if err != nil && os.IsNotExist(err) { + w.handleDelete(c, fileName) + return false + } + if err != nil { + log.GetLogger().Errorf("failed to stat file: %s, %s", fileName, err.Error()) + c.errCh <- fmt.Errorf("failed to stat file: %s, %s", fileName, err.Error()) + return false + } + *size = fi.Size() + // if file size is less or prevSize == *size == 0, we think file is truncate + if (prevSize > *size) || (prevSize == 0 && *size == 0) { + w.notify(c.truncateCh) + } else { + w.notify(c.modifyCh) + } + // nothing to do + default: + } + case err, ok := <-w.watcher.Errors: + if !ok || err == nil { + log.GetLogger().Errorf("watcher errors chan closed unexpectedly") + c.errCh <- errors.New("watcher errors chan closed unexpectedly") + return false + } + log.GetLogger().Errorf("failed to watch file, watcher returns error: %s", err.Error()) + c.errCh <- fmt.Errorf("failed to watch file, watcher returns error: %s", err.Error()) + return false + case <-w.stopCh: + return false + case <-c.closeCh: + close(c.closeDone) + return false + } + return true +} + +func (w *watcher) handleDelete(c *changeNotifier, fileName string) { + if err := w.watcher.Remove(fileName); err != nil { + log.GetLogger().Warnf("failed to remove watching file: %s, %s", fileName, err.Error()) + } + w.notify(c.deleteCh) +} + +func (w *watcher) notify(ch chan struct{}) { + if ch == nil { + log.GetLogger().Warnf("nil chan") + return + } + + select { + case ch <- struct{}{}: + default: + } +} + +type changeNotifier struct { + modifyCh chan struct{} + deleteCh chan struct{} + truncateCh chan struct{} + errCh chan error + closeCh chan struct{} + closeDone chan struct{} +} + +func newChangeNotifier() *changeNotifier { + return &changeNotifier{ + modifyCh: make(chan struct{}, 1), + deleteCh: make(chan struct{}, 1), + truncateCh: make(chan struct{}, 1), + errCh: make(chan error), + closeCh: make(chan struct{}), + closeDone: make(chan struct{}), + } +} + +func isSameFileName(a, b string) (bool, error) { + var err error + a, err = filepath.Abs(a) + if err != nil { + return false, err + } + b, err = filepath.Abs(b) + if err != nil { + return false, err + } + return a == b, nil +} + +func makeDockerLog(content, stream string, t time.Time) ([]byte, error) { + strTime := t.Format(time.RFC3339Nano) + d := &dockerLog{ + Log: content, + Stream: stream, + Time: strTime, + } + b, err := json.Marshal(d) + if err != nil { + log.GetLogger().Errorf("failed to marshal docker log: %v %s", d, err.Error()) + return nil, err + } + return append(b, '\n'), nil +} + +func appendFile(fileName string, b []byte) error { + fi, err := os.OpenFile(fileName, os.O_WRONLY|os.O_APPEND, 0) + if err != nil { + return err + } + defer fi.Close() + + if _, err := fi.Write(b); err != nil { + return err + } + + return nil +} diff --git a/api/go/faassdk/handler/http/custom_container_log_test.go b/api/go/faassdk/handler/http/custom_container_log_test.go new file mode 100644 index 0000000..3093adf --- /dev/null +++ b/api/go/faassdk/handler/http/custom_container_log_test.go @@ -0,0 +1,664 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/fsnotify/fsnotify" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "yuanrong.org/kernel/runtime/faassdk/common/functionlog" + "yuanrong.org/kernel/runtime/faassdk/config" +) + +func TestMain(m *testing.M) { + m.Run() +} + +func TestCollect(t *testing.T) { + p := gomonkey.NewPatches() + p.ApplyFunc(functionlog.GetFunctionLogger, func(cfg *config.Configuration) (*functionlog.FunctionLogger, error) { + return functionlog.NewFunctionLogger(nil, "", ""), nil + }) + logger := &RuntimeContainerLogger{} + p.ApplyFunc(NewRuntimeStdLogger, func(logFilePath string, processCB RuntimeContainerLogProcessCB) (*RuntimeContainerLogger, error) { + logger.processCB = processCB + return logger, nil + }) + defer p.Reset() + + l := collectRuntimeContainerLog(&config.Configuration{}) + require.NotNil(t, l) + l.processCB(RuntimeContainerLog{log: "abc", level: "info", t: time.Now()}) +} + +func TestRuntimeContainerLogger(t *testing.T) { + p := gomonkey.NewPatches() + events := make(chan fsnotify.Event) + errors := make(chan error) + p.ApplyFunc(fsnotify.NewWatcher, func() (*fsnotify.Watcher, error) { + return &fsnotify.Watcher{ + Events: events, + Errors: errors, + }, nil + }) + + var watcher *fsnotify.Watcher + p.ApplyMethod(reflect.TypeOf(watcher), "Add", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(watcher), "Remove", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(watcher), "Close", func(_ *fsnotify.Watcher) error { return nil }) + + dir, err := ioutil.TempDir(os.TempDir(), "test-runtime-container-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + fileName := filepath.Join(dir, "log") + + processCB := func(msg RuntimeContainerLog) { + fmt.Println(msg) + } + logger, err := NewRuntimeStdLogger(fileName, processCB) + if err != nil { + t.Fatal(err) + } + + go func() { + if err := logger.Run(); err != nil { + fmt.Println(err) + } + }() + + before := time.Now() + content, err := makeDockerLog("abc", "stdin", time.Now()) + if err != nil { + t.Fatal(err) + } + ioutil.WriteFile(fileName, content, 0755) + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Write} + logger.syncTo(before) + + os.Remove(fileName) + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Remove} + + pp := gomonkey.NewPatches() + pp.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { return nil, os.ErrNotExist }) + pp.ApplyFunc(os.Open, func(name string) (*os.File, error) { return nil, os.ErrNotExist }) + + content, err = makeDockerLog("abcd", "stdin", time.Now()) + if err != nil { + pp.Reset() + t.Fatal(err) + } + + ioutil.WriteFile(fileName, content, 0755) + + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Create} + + pp.Reset() + + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Write} + + content, err = makeDockerLog("a", "stdin", time.Now()) + if err != nil { + pp.Reset() + t.Fatal(err) + } + ioutil.WriteFile(fileName, content, 0755) + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Write} +} + +func TestRuntimeContainerLogger2(t *testing.T) { + p := gomonkey.NewPatches() + events := make(chan fsnotify.Event) + errorCh := make(chan error) + p.ApplyFunc(fsnotify.NewWatcher, func() (*fsnotify.Watcher, error) { + return &fsnotify.Watcher{ + Events: events, + Errors: errorCh, + }, nil + }) + + var watcher *fsnotify.Watcher + p.ApplyMethod(reflect.TypeOf(watcher), "Add", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(watcher), "Remove", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(watcher), "Close", func(_ *fsnotify.Watcher) error { return nil }) + + dir, err := ioutil.TempDir(os.TempDir(), "test-runtime-container-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + fileName := filepath.Join(dir, "log") + + processCB := func(msg RuntimeContainerLog) { + fmt.Println(msg) + } + logger, err := NewRuntimeStdLogger(fileName, processCB) + if err != nil { + t.Fatal(err) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := logger.Run(); err != nil { + fmt.Println(err) + } + }() + + before := time.Now() + content, err := makeDockerLog("abc", "stdin", time.Now()) + if err != nil { + t.Fatal(err) + } + ioutil.WriteFile(fileName, content, 0755) + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Write} + logger.syncTo(before) + + os.Remove(fileName) + errorCh <- errors.New("my error") + + wg.Wait() +} + +func TestRuntimeContainerLogger3(t *testing.T) { + p := gomonkey.NewPatches() + events := make(chan fsnotify.Event) + errorCh := make(chan error) + p.ApplyFunc(fsnotify.NewWatcher, func() (*fsnotify.Watcher, error) { + return &fsnotify.Watcher{ + Events: events, + Errors: errorCh, + }, nil + }) + + var watcher *fsnotify.Watcher + p.ApplyMethod(reflect.TypeOf(watcher), "Add", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(watcher), "Remove", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(watcher), "Close", func(_ *fsnotify.Watcher) error { return nil }) + + dir, err := ioutil.TempDir(os.TempDir(), "test-runtime-container-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + fileName := filepath.Join(dir, "log") + + processCB := func(msg RuntimeContainerLog) { + fmt.Println(msg) + } + logger, err := NewRuntimeStdLogger(fileName, processCB) + if err != nil { + t.Fatal(err) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := logger.Run(); err != nil { + fmt.Println(err) + } + }() + + before := time.Now() + content, err := makeDockerLog("abc", "stdin", time.Now()) + if err != nil { + t.Fatal(err) + } + ioutil.WriteFile(fileName, content, 0755) + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Write} + logger.syncTo(before) + + os.Remove(fileName) + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Remove} + + pp := gomonkey.NewPatches() + pp.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { return nil, os.ErrNotExist }) + pp.ApplyFunc(os.Open, func(name string) (*os.File, error) { return nil, os.ErrNotExist }) + + content, err = makeDockerLog("abcd", "stdin", time.Now()) + if err != nil { + pp.Reset() + t.Fatal(err) + } + + ioutil.WriteFile(fileName, content, 0755) + + events <- fsnotify.Event{Name: fileName, Op: fsnotify.Create} + + pp.Reset() + + errorCh <- errors.New("my error") + + content, err = makeDockerLog("a", "stdin", time.Now()) + if err != nil { + pp.Reset() + t.Fatal(err) + } + ioutil.WriteFile(fileName, content, 0755) + + wg.Wait() +} + +func TestLoggerRun(t *testing.T) { + { + l := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + t: &tail{ + Lines: make(chan string), + Errors: make(chan error), + }, + } + close(l.t.Lines) + err := l.Run() + assert.NotNil(t, err) + } + + { + l := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + t: &tail{ + Lines: make(chan string), + Errors: make(chan error), + }, + } + close(l.t.Errors) + err := l.Run() + assert.NotNil(t, err) + } + + { + l := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + t: &tail{ + Lines: make(chan string), + Errors: make(chan error), + }, + } + close(l.syncPointCh) + err := l.Run() + assert.NotNil(t, err) + } + + { + finish := make(chan struct{}) + l := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + t: &tail{ + Lines: make(chan string), + Errors: make(chan error), + }, + } + // close of runtime.ProcessExitSignal should not affect logger + close(finish) + go func() { l.Run() }() + l.syncTo(time.Now()) + } +} + +func TestSyncPoint(t *testing.T) { + before := time.Now() + time.Sleep(10 * time.Millisecond) + after := time.Now() + l := &RuntimeContainerLogger{} + l.addSyncPoint(&syncPoint{t: before, done: make(chan struct{})}) + assert.Equal(t, 1, len(l.syncPoints)) + + l.tick = after + l.notifySyncPoints() +} + +func Test_processLine(t *testing.T) { + tt := time.Now() + l := &RuntimeContainerLogger{ + processCB: func(RuntimeContainerLog) {}, + } + l.processLine("abc") + + content, err := makeDockerLog("abc", "stdout", tt) + if err != nil { + t.Fatal(err) + } + l.processLine(string(content)) + + content, err = makeDockerLog("abc", "stderr", tt) + if err != nil { + t.Fatal(err) + } + l.processLine(string(content)) + + content, err = makeDockerLog("abc", "sync", tt) + if err != nil { + t.Fatal(err) + } + l.processLine(string(content)) + + s := `{"log":"abc","stream":"sync","time":"wrong time"}` + l.processLine(s) + + assert.True(t, tt.Equal(l.tick)) +} + +func Test_sync(t *testing.T) { + p := gomonkey.NewPatches() + var f *os.File + p.ApplyMethod(reflect.TypeOf(f), "Close", func(_ *os.File) error { return errors.New("my error") }) + p.ApplyFunc(os.Open, func(name string) (*os.File, error) { return nil, errors.New("my error") }) + defer p.Reset() + + tt := &tail{ + file: &os.File{}, + Errors: make(chan error, 1), + } + + logger := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + processCB: func(RuntimeContainerLog) {}, + logFilePath: "", + } + + tt.sync(logger) + + e := <-tt.Errors + assert.NotNil(t, e) +} + +func Test_sync2(t *testing.T) { + p := gomonkey.NewPatches() + var f *os.File + p.ApplyMethod(reflect.TypeOf(f), "Close", func(_ *os.File) error { return errors.New("my error") }) + p.ApplyMethod(reflect.TypeOf(f), "Seek", func(_ *os.File, offset int64, whence int) (ret int64, err error) { + return 0, errors.New("my error") + }) + p.ApplyFunc(os.Open, func(name string) (*os.File, error) { return &os.File{}, nil }) + defer p.Reset() + + tt := &tail{ + file: &os.File{}, + Errors: make(chan error, 1), + } + logger := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + processCB: func(RuntimeContainerLog) {}, + logFilePath: "", + } + tt.sync(logger) + + e := <-tt.Errors + assert.NotNil(t, e) +} + +func Test_sync3(t *testing.T) { + p := gomonkey.NewPatches() + var f *os.File + p.ApplyMethod(reflect.TypeOf(f), "Close", func(_ *os.File) error { return errors.New("my error") }) + p.ApplyMethod(reflect.TypeOf(f), "Seek", func(_ *os.File, offset int64, whence int) (ret int64, err error) { + return 0, nil + }) + p.ApplyFunc(os.Open, func(name string) (*os.File, error) { return &os.File{}, nil }) + var r *bufio.Reader + p.ApplyMethod(reflect.TypeOf(r), "ReadString", func(_ *bufio.Reader, delim byte) (string, error) { + return "", errors.New("my error") + }) + defer p.Reset() + + tt := &tail{ + file: &os.File{}, + Errors: make(chan error, 1), + } + logger := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + processCB: func(RuntimeContainerLog) {}, + logFilePath: "", + } + + tt.sync(logger) + + e := <-tt.Errors + assert.NotNil(t, e) +} + +func Test_handleReadLineEOF(t *testing.T) { + p := gomonkey.NewPatches() + var f *os.File + p.ApplyMethod(reflect.TypeOf(f), "Seek", func(_ *os.File, offset int64, whence int) (ret int64, err error) { + return 0, nil + }) + var r *bufio.Reader + p.ApplyMethod(reflect.TypeOf(r), "Reset", func(_ *bufio.Reader, r io.Reader) {}) + defer p.Reset() + + changeNotifier := newChangeNotifier() + close(changeNotifier.errCh) + tt := &tail{ + reader: &bufio.Reader{}, + changeNotifier: changeNotifier, + } + logger := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + processCB: func(RuntimeContainerLog) {}, + logFilePath: "", + } + + err := tt.handleReadLineEOF(logger, 0, "abc") + assert.NotNil(t, err) +} + +func Test_handleReadLineEOF2(t *testing.T) { + p := gomonkey.NewPatches() + var f *os.File + p.ApplyMethod(reflect.TypeOf(f), "Seek", func(_ *os.File, offset int64, whence int) (ret int64, err error) { + return 0, nil + }) + p.ApplyFunc(os.Open, func(name string) (*os.File, error) { return &os.File{}, nil }) + var r *bufio.Reader + p.ApplyMethod(reflect.TypeOf(r), "Reset", func(_ *bufio.Reader, r io.Reader) {}) + defer p.Reset() + + changeNotifier := newChangeNotifier() + close(changeNotifier.truncateCh) + tt := &tail{ + reader: &bufio.Reader{}, + changeNotifier: changeNotifier, + } + logger := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + processCB: func(RuntimeContainerLog) {}, + logFilePath: "", + } + go func(tt *tail) { + <-tt.changeNotifier.closeCh + close(tt.changeNotifier.closeDone) + }(tt) + err := tt.handleReadLineEOF(logger, 0, "abc") + assert.Nil(t, err) +} + +func Test_handleReadLineEOF3(t *testing.T) { + p := gomonkey.NewPatches() + var f *os.File + p.ApplyMethod(reflect.TypeOf(f), "Seek", func(_ *os.File, offset int64, whence int) (ret int64, err error) { + return 0, errors.New("my error") + }) + defer p.Reset() + + tt := &tail{ + reader: &bufio.Reader{}, + } + logger := &RuntimeContainerLogger{ + syncPointCh: make(chan *syncPoint), + processCB: func(RuntimeContainerLog) {}, + logFilePath: "", + } + err := tt.handleReadLineEOF(logger, 0, "abc") + assert.NotNil(t, err) +} + +func Test_waitUntilCreate(t *testing.T) { + p := gomonkey.NewPatches() + wat := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), + Errors: make(chan error), + } + p.ApplyMethod(reflect.TypeOf(wat), "Add", func(_ *fsnotify.Watcher, name string) error { return errors.New("my error") }) + defer p.Reset() + + w := &watcher{ + watcher: wat, + } + err := w.waitUntilCreate("file") + assert.NotNil(t, err) +} + +func Test_waitUntilCreate2(t *testing.T) { + p := gomonkey.NewPatches() + wat := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), + Errors: make(chan error), + } + p.ApplyMethod(reflect.TypeOf(wat), "Add", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(wat), "Remove", func(_ *fsnotify.Watcher, name string) error { return errors.New("my error") }) + p.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { return nil, errors.New("my error") }) + defer p.Reset() + + w := &watcher{ + watcher: wat, + } + err := w.waitUntilCreate("file") + assert.NotNil(t, err) +} + +func Test_waitUntilCreate3(t *testing.T) { + p := gomonkey.NewPatches() + wat := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), + Errors: make(chan error), + } + close(wat.Events) + p.ApplyMethod(reflect.TypeOf(wat), "Add", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(wat), "Remove", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { return nil, os.ErrNotExist }) + defer p.Reset() + + w := &watcher{ + watcher: wat, + } + err := w.waitUntilCreate("file") + assert.NotNil(t, err) +} + +func Test_waitUntilCreate4(t *testing.T) { + p := gomonkey.NewPatches() + wat := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), + Errors: make(chan error), + } + close(wat.Errors) + p.ApplyMethod(reflect.TypeOf(wat), "Add", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(wat), "Remove", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { return nil, os.ErrNotExist }) + defer p.Reset() + + w := &watcher{ + watcher: wat, + } + err := w.waitUntilCreate("file") + assert.NotNil(t, err) +} + +func Test_waitUntilCreate5(t *testing.T) { + p := gomonkey.NewPatches() + wat := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event, 1), + Errors: make(chan error), + } + wat.Events <- fsnotify.Event{Name: "abc"} + p.ApplyMethod(reflect.TypeOf(wat), "Add", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyMethod(reflect.TypeOf(wat), "Remove", func(_ *fsnotify.Watcher, name string) error { return nil }) + p.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { return nil, os.ErrNotExist }) + p.ApplyFunc(filepath.Abs, func(string) (string, error) { return "", errors.New("my error") }) + defer p.Reset() + + w := &watcher{ + watcher: wat, + } + err := w.waitUntilCreate("file") + assert.NotNil(t, err) +} + +func Test_handle(t *testing.T) { + wat := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event, 1), + Errors: make(chan error), + } + close(wat.Events) + + w := &watcher{ + watcher: wat, + } + c := newChangeNotifier() + size := int64(1) + go w.handle(c, "abc", &size) + err := c.errCh + assert.NotNil(t, err) +} + +func Test_handle2(t *testing.T) { + wat := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event, 1), + Errors: make(chan error), + } + close(wat.Errors) + + w := &watcher{ + watcher: wat, + } + c := newChangeNotifier() + size := int64(1) + go w.handle(c, "abc", &size) + err := c.errCh + assert.NotNil(t, err) +} + +func Test_makeDockerLog(t *testing.T) { + p := gomonkey.NewPatches() + p.ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { return nil, errors.New("my error") }) + defer p.Reset() + + _, err := makeDockerLog("abc", "sync", time.Now()) + assert.NotNil(t, err) +} diff --git a/api/go/faassdk/handler/http/http_handler.go b/api/go/faassdk/handler/http/http_handler.go new file mode 100644 index 0000000..5a773e1 --- /dev/null +++ b/api/go/faassdk/handler/http/http_handler.go @@ -0,0 +1,149 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http - +package http + +import ( + "errors" + "fmt" + "os" + "os/exec" + "path" + "syscall" + + "yuanrong.org/kernel/runtime/faassdk/handler" + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/libruntime/api" + log "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + bootstrapCommand = "/bin/bash" + bootstrapFilename = "bootstrap" +) + +type processState struct { + exitCode int + message string +} + +// HttpHandler handles http function initialization and invocation +type HttpHandler struct { + *basicHandler + process *os.Process + waitChan chan processState +} + +// NewHttpHandler creates HttpHandler +func NewHttpHandler(funcSpec *types.FuncSpec, client api.LibruntimeAPI) handler.ExecutorHandler { + return &HttpHandler{ + basicHandler: newBasicHandler(funcSpec, client), + waitChan: make(chan processState, 1), + } +} + +func (hh *HttpHandler) bootstrap() error { + delegateDownloadPath, err := handler.GetUserCodePath() + if err != nil { + return err + } + bootstrapPath := path.Join(delegateDownloadPath, bootstrapFilename) + _, err = os.Stat(bootstrapPath) + if err != nil { + if os.IsNotExist(err) { + log.GetLogger().Errorf("bootstrap file %s not exist for http function", bootstrapPath) + return errors.New("bootstrap file not exist") + } + log.GetLogger().Errorf("failed to check stat of bootstrap file %s for http function", bootstrapPath) + return errors.New("failed to check stat of bootstrap") + } + cmd := exec.Command(bootstrapCommand, bootstrapPath) + // make suer bootstrap can find http server binary + cmd.Dir = delegateDownloadPath + // make sure subprocess can be killed + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err = cmd.Start() + if err != nil { + log.GetLogger().Errorf("failed to execute bootstrap for http function error %s", err.Error()) + return err + } + hh.process = cmd.Process + go hh.waitProcess() + return nil +} + +func (hh *HttpHandler) waitProcess() { + defer close(hh.waitChan) + if hh.process == nil { + log.GetLogger().Errorf("failed to wait process, process is nil") + return + } + state, err := hh.process.Wait() + if err != nil { + log.GetLogger().Errorf("failed to wait process error %s", err.Error()) + return + } + message := fmt.Sprintf("process exits with %d", state.ExitCode()) + if state.ExitCode() != 0 { + message += fmt.Sprintf(" error: %s", state.String()) + } + log.GetLogger().Warnf(message) + hh.monitor.ErrChan <- errors.New(message) +} + +func (hh *HttpHandler) killAllProcesses() { + log.GetLogger().Warnf("killing all processes of http function %s", hh.funcSpec.FuncMetaData.FunctionName) + if hh.process == nil { + log.GetLogger().Errorf("failed to kill process, process is nil") + return + } + gid, err := syscall.Getpgid(hh.process.Pid) + if err != nil { + log.GetLogger().Errorf("failed to get gid of process %d", hh.process.Pid) + gid = hh.process.Pid + } + err = syscall.Kill(-gid, syscall.SIGKILL) + if err != nil { + log.GetLogger().Errorf("failed to kill http function process error %s", err.Error()) + } +} + +// InitHandler will bring up user's http server and wait until its ready then send init request +// args[0]: function specification +// args[1]: create params +func (hh *HttpHandler) InitHandler(args []api.Arg, dsClient api.LibruntimeAPI) ([]byte, error) { + hh.setBootstrapFunc(hh.bootstrap) + rsp, err := hh.basicHandler.InitHandler(args, dsClient) + if err != nil { + hh.killAllProcesses() + } + return rsp, err +} + +// ShutDownHandler handles shutdown +func (hh *HttpHandler) ShutDownHandler(gracePeriodSecond uint64) error { + err := hh.basicHandler.ShutDownHandler(gracePeriodSecond) + hh.killAllProcesses() + return err +} + +// HealthCheckHandler handles health check +func (hh *HttpHandler) HealthCheckHandler() (api.HealthType, error) { + return api.Healthy, nil +} diff --git a/api/go/faassdk/handler/http/http_handler_test.go b/api/go/faassdk/handler/http/http_handler_test.go new file mode 100644 index 0000000..e09262b --- /dev/null +++ b/api/go/faassdk/handler/http/http_handler_test.go @@ -0,0 +1,386 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http - +package http + +import ( + "encoding/json" + "errors" + "net" + "net/http" + "os" + "os/exec" + "reflect" + "syscall" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/faassdk/common/functionlog" + "yuanrong.org/kernel/runtime/faassdk/common/monitor" + "yuanrong.org/kernel/runtime/faassdk/config" + "yuanrong.org/kernel/runtime/libruntime/api" +) + +var userCodePath = "../test/main/user_code_test.so" + +func TestHttpHandler_InitHandler(t *testing.T) { + convey.Convey("HttpHandler_InitHandler", t, func() { + handler := NewHttpHandler(newFuncSpec(), nil) + os.Setenv(config.DelegateDownloadPath, userCodePath) + defer gomonkey.ApplyFunc(syscall.Kill, func(gid int, signal syscall.Signal) error { + return nil + }).Reset() + convey.Convey("invalid create params number", func() { + res, err := handler.InitHandler([]api.Arg{}, nil) + convey.So(err.Error(), convey.ShouldEqual, "{\"errorCode\":\"6001\",\"message\":\"invalid create params number\"}") + convey.So(res, convey.ShouldEqual, []byte{}) + }) + + convey.Convey("init failed bootstrap timed out after 3s", func() { + defer gomonkey.ApplyFunc(functionlog.GetFunctionLogger, func(cfg *config.Configuration) (*functionlog.FunctionLogger, error) { + return &functionlog.FunctionLogger{}, nil + }).Reset() + defer gomonkey.ApplyFunc((*HttpHandler).bootstrap, func(_ *HttpHandler) error { + return nil + }).Reset() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldContainSubstring, `4211`) + convey.So(res, convey.ShouldEqual, []byte{}) + healthType, err := handler.HealthCheckHandler() + convey.So(healthType, convey.ShouldEqual, api.Healthy) + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("bootstrap file not exist", func() { + defer gomonkey.ApplyFunc(functionlog.GetFunctionLogger, func(cfg *config.Configuration) (*functionlog.FunctionLogger, error) { + return &functionlog.FunctionLogger{}, nil + }).Reset() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldEqual, []byte{}) + }) + + convey.Convey("failed to execute bootstrap for http function", func() { + defer gomonkey.ApplyFunc(functionlog.GetFunctionLogger, func(cfg *config.Configuration) (*functionlog.FunctionLogger, error) { + return &functionlog.FunctionLogger{}, nil + }).Reset() + defer gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&exec.Cmd{}), "Start", func(_ *exec.Cmd) error { + return errors.New("run bootstrap error") + }).Reset() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldEqual, "{\"errorCode\":\"4201\",\"message\":\"init failed bootstrap failed error run bootstrap error\"}") + convey.So(res, convey.ShouldEqual, []byte{}) + }) + + convey.Convey("timeout waiting for http server running", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(functionlog.GetFunctionLogger, func(cfg *config.Configuration) (*functionlog.FunctionLogger, error) { + return &functionlog.FunctionLogger{}, nil + }), + gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&exec.Cmd{}), "Start", func(c *exec.Cmd) error { + c.Process = &os.Process{} + return nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&os.Process{}), "Wait", func(_ *os.Process) (*os.ProcessState, error) { + // mock process running + time.Sleep(5 * time.Second) + return nil, nil + }), + gomonkey.ApplyFunc(net.Dial, func(network, address string) (net.Conn, error) { + time.Sleep(3*time.Second + 10*time.Millisecond) + return nil, errors.New("timeout") + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldEqual, []byte{}) + }) + + convey.Convey("http server process exited", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(functionlog.GetFunctionLogger, func(cfg *config.Configuration) (*functionlog.FunctionLogger, error) { + return &functionlog.FunctionLogger{}, nil + }), + gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&exec.Cmd{}), "Start", func(c *exec.Cmd) error { + c.Process = &os.Process{} + return nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&os.Process{}), "Wait", func(_ *os.Process) (*os.ProcessState, error) { + return &os.ProcessState{}, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&os.ProcessState{}), "ExitCode", func(_ *os.ProcessState) int { + return 0 + }), + gomonkey.ApplyFunc(net.Dial, func(network, address string) (net.Conn, error) { + return nil, errors.New("process has done") + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err.Error(), convey.ShouldNotBeEmpty) + convey.So(res, convey.ShouldEqual, []byte{}) + }) + + convey.Convey("success", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(functionlog.GetFunctionLogger, func(cfg *config.Configuration) (*functionlog.FunctionLogger, error) { + return &functionlog.FunctionLogger{}, nil + }), + gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&exec.Cmd{}), "Start", func(c *exec.Cmd) error { + c.Process = &os.Process{} + return nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&os.Process{}), "Wait", func(_ *os.Process) (*os.ProcessState, error) { + // mock process running + time.Sleep(5 * time.Second) + return nil, nil + }), + gomonkey.ApplyFunc(net.Dial, func(network, address string) (net.Conn, error) { + return &net.TCPConn{}, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&http.Client{}), "Do", + func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{}, nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + funcSpecBytes, _ := json.Marshal(newFuncSpec()) + createParamsBytes, _ := json.Marshal(newHttpCreateParams()) + schedulerParamsBytes, _ := json.Marshal(newHttpSchedulerParams()) + args := []api.Arg{ + { + Type: 0, + Data: funcSpecBytes, + }, + { + Type: 0, + Data: createParamsBytes, + }, + { + Type: 0, + Data: schedulerParamsBytes, + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + res, err := handler.InitHandler(args, nil) + convey.So(err, convey.ShouldBeNil) + convey.So(res, convey.ShouldEqual, []byte{}) + }) + }) +} + +func TestShutDownHandler(t *testing.T) { + convey.Convey("ShutDownHandler", t, func() { + defer gomonkey.ApplyFunc(syscall.Getpgid, func(pid int) (int, error) { + return 0, errors.New("get gid error") + }).Reset() + defer gomonkey.ApplyFunc(syscall.Kill, func(gid int, signal syscall.Signal) error { + return errors.New("kill error") + }).Reset() + handler := NewHttpHandler(newFuncSpec(), nil) + var err1 error + go func() { + time.Sleep(1 * time.Second) + err1 = handler.ShutDownHandler(30) + }() + err := handler.ShutDownHandler(30) + convey.So(err, convey.ShouldBeNil) + convey.So(err1, convey.ShouldBeNil) + }) +} + +func TestKillAllProcesses(t *testing.T) { + convey.Convey("TestKillAllProcesses", t, func() { + patch1 := gomonkey.ApplyFunc(syscall.Getpgid, func(pid int) (int, error) { + return 0, errors.New("get gid error") + }) + defer patch1.Reset() + patch2 := gomonkey.ApplyFunc(syscall.Kill, func(gid int, signal syscall.Signal) error { + return errors.New("kill error") + }) + defer patch2.Reset() + handler := HttpHandler{ + process: &os.Process{ + Pid: 100, + }, + basicHandler: &basicHandler{ + funcSpec: newFuncSpec(), + }, + } + handler.killAllProcesses() + convey.So(handler.process, convey.ShouldNotBeNil) + }) +} + +func TestWaitProcess(t *testing.T) { + convey.Convey("TestWaitProcess", t, func() { + convey.Convey("precess is nil", func() { + handler := HttpHandler{ + waitChan: make(chan processState, 1), + } + handler.waitProcess() + convey.So(handler.process, convey.ShouldBeNil) + }) + convey.Convey("precess wait error", func() { + d := &os.Process{} + patch := gomonkey.ApplyMethod(reflect.TypeOf(d), + "Wait", func(p *os.Process) (*os.ProcessState, error) { + return nil, errors.New("error") + }) + defer patch.Reset() + handler := HttpHandler{ + process: &os.Process{ + Pid: 100, + }, + waitChan: make(chan processState, 1), + } + handler.waitProcess() + convey.So(handler.process, convey.ShouldNotBeNil) + }) + convey.Convey("wait process success ", func() { + d := &os.Process{} + patch := gomonkey.ApplyMethod(reflect.TypeOf(d), + "Wait", func(p *os.Process) (*os.ProcessState, error) { + return &os.ProcessState{}, nil + }) + defer patch.Reset() + handler := HttpHandler{ + process: &os.Process{ + Pid: 100, + }, + basicHandler: &basicHandler{ + monitor: &monitor.FunctionMonitorManager{ + ErrChan: make(chan error, 1), + }, + }, + waitChan: make(chan processState, 1), + } + handler.waitProcess() + _, ok := <-handler.basicHandler.monitor.ErrChan + convey.So(ok, convey.ShouldBeTrue) + }) + }) +} diff --git a/api/go/faassdk/handler/http/state.go b/api/go/faassdk/handler/http/state.go new file mode 100644 index 0000000..147f54e --- /dev/null +++ b/api/go/faassdk/handler/http/state.go @@ -0,0 +1,232 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http - +package http + +import ( + "fmt" + "os" + "runtime" + "sync" + "time" + + "yuanrong.org/kernel/runtime/faassdk/common/alarm" + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/logger/log" +) + +const ( + defaultStateMapCapacity = 10 + defStateVal = "1" + cOK = 0 // same with datasystem K_OK + cNotFound = 3 // same with datasystem K_NOT_FOUND +) + +const ( + // StateExistedErrCode - + StateExistedErrCode = 4027 + // StateExistedErrMsg - + StateExistedErrMsg = "state cannot be created repeatedly" + // StateNotExistedErrCode - + StateNotExistedErrCode = 4026 + // StateNotExistedErrMsg - + StateNotExistedErrMsg = "state does not exist" + // StateInstanceNotExistedErrCode - + StateInstanceNotExistedErrCode = 4028 + // StateInstanceNotExistedErrMsg - + StateInstanceNotExistedErrMsg = "state instance not existed" + // DataSystemInternalErrCode - + DataSystemInternalErrCode = 4030 + // DataSystemInternalErrMsg - + DataSystemInternalErrMsg = "internal system error" + // StateInstanceNoLease - + StateInstanceNoLease = 4025 + // StateInstanceNoLeaseMsg - + StateInstanceNoLeaseMsg = "maximum number of leases reached" + // FaaSSchedulerInternalErrCode - + FaaSSchedulerInternalErrCode = 4029 + // FaaSSchedulerInternalErrMsg - + FaaSSchedulerInternalErrMsg = "internal system error" + // InvalidState - + InvalidState = 4040 + // InvalidStateErrMsg - + InvalidStateErrMsg = "invalid state, expect not blank" +) + +var ( + once sync.Once + instance *stateManager +) + +type stateManager struct { + stateMap map[string]*api.InstanceAllocation // key stateKey, value lease + dsClinet api.KvClient + muteMap sync.RWMutex +} + +// GetStateManager - +func GetStateManager(dsClient api.KvClient) *stateManager { + once.Do(func() { + instance = &stateManager{ + stateMap: make(map[string]*api.InstanceAllocation, defaultStateMapCapacity), + dsClinet: dsClient, + muteMap: sync.RWMutex{}, + } + }) + return instance +} + +func (sm *stateManager) genStateKey(funcKey, stateID string) string { + return fmt.Sprintf("/sn/state/function/%s/state/%s", funcKey, stateID) +} + +func (sm *stateManager) newState(funcKey, stateID, traceID string) error { + err := sm.getState(funcKey, stateID, traceID) + if err == nil { + return api.ErrorInfo{ + Code: StateExistedErrCode, + Err: fmt.Errorf(StateExistedErrMsg), + } + } + + stateKey := sm.genStateKey(funcKey, stateID) + param := api.SetParam{ + WriteMode: 1, // WRITE_THROUGH_L2_CACHE + TTLSecond: 0, + Existence: 0, // should be 1 (NX) after datasystem issue is resolved + CacheType: 0, + } + runtime.LockOSThread() + sm.dsClinet.SetTraceID(traceID) + ret := sm.dsClinet.KVSet(stateKey, []byte(defStateVal), param) + runtime.UnlockOSThread() + log.GetLogger().Infof("set state to datasystem %d, %v, traceid(%s)", ret.Code, ret.Err, traceID) + switch ret.Code { + case cOK: + return nil + case cNotFound: + return api.ErrorInfo{ + Code: StateExistedErrCode, + Err: fmt.Errorf(StateExistedErrMsg), + } + default: + alarmDetail := &alarm.Detail{ + SourceTag: os.Getenv(constants.PodNameEnvKey) + "|" + os.Getenv(constants.PodIPEnvKey) + + "|" + os.Getenv(constants.ClusterName) + "|" + os.Getenv(constants.HostIPEnvKey), + OpType: alarm.GenerateAlarmLog, + Details: fmt.Sprintf("new State failed, datasystem error: %v, stateKey: %s, statefuncKey: %s", + err, stateID, funcKey), + StartTimestamp: int(time.Now().Unix()), + EndTimestamp: 0, + } + alarmInfo := &alarm.LogAlarmInfo{ + AlarmID: "NewStateFailedInDataSystem00001", + AlarmName: "NewStateFailedInDataSystem", + AlarmLevel: alarm.Level2, + } + alarm.ReportOrClearAlarm(alarmInfo, alarmDetail) + return api.ErrorInfo{ + Code: DataSystemInternalErrCode, + Err: fmt.Errorf(DataSystemInternalErrMsg), + } + } +} + +func (sm *stateManager) getState(funcKey, stateID, traceID string) error { + stateKey := sm.genStateKey(funcKey, stateID) + runtime.LockOSThread() + sm.dsClinet.SetTraceID(traceID) + _, ret := sm.dsClinet.KVGet(stateKey) + runtime.UnlockOSThread() + log.GetLogger().Infof("get state from datasystem %d, %v, traceid(%s)", ret.Code, ret.Err, traceID) + switch ret.Code { + case cOK: + return nil + case cNotFound: + return api.ErrorInfo{ + Code: StateNotExistedErrCode, + Err: fmt.Errorf(StateNotExistedErrMsg), + } + default: + return api.ErrorInfo{ + Code: DataSystemInternalErrCode, + Err: fmt.Errorf(DataSystemInternalErrMsg), + } + } +} + +func (sm *stateManager) delState(funcKey, stateID, traceID string) error { + err := sm.getState(funcKey, stateID, traceID) + if err != nil { + return api.ErrorInfo{ + Code: StateNotExistedErrCode, + Err: fmt.Errorf(StateNotExistedErrMsg), + } + } + stateKey := sm.genStateKey(funcKey, stateID) + runtime.LockOSThread() + sm.dsClinet.SetTraceID(traceID) + ret := sm.dsClinet.KVDel(stateKey) + runtime.UnlockOSThread() + log.GetLogger().Infof("del state from datasystem %d, %v, traceid(%s)", ret.Code, ret.Err, traceID) + switch ret.Code { + case cOK: + return nil + default: + alarmDetail := &alarm.Detail{ + SourceTag: os.Getenv(constants.PodNameEnvKey) + "|" + os.Getenv(constants.PodIPEnvKey) + + "|" + os.Getenv(constants.ClusterName) + "|" + os.Getenv(constants.HostIPEnvKey), + OpType: alarm.GenerateAlarmLog, + Details: fmt.Sprintf("terminate State failed, datasystem error: %v, stateKey: %s, statefuncKey: %s", + err, stateID, funcKey), + StartTimestamp: int(time.Now().Unix()), + EndTimestamp: 0, + } + alarmInfo := &alarm.LogAlarmInfo{ + AlarmID: "TerminateStateFailedInDataSystem00001", + AlarmName: "TerminateStateFailedInDataSystem", + AlarmLevel: alarm.Level2, + } + alarm.ReportOrClearAlarm(alarmInfo, alarmDetail) + return api.ErrorInfo{ + Code: DataSystemInternalErrCode, + Err: fmt.Errorf(DataSystemInternalErrMsg), + } + } +} + +func (sm *stateManager) getInstance(funcKey, stateID string) *api.InstanceAllocation { + sm.muteMap.RLock() + defer sm.muteMap.RUnlock() + stateKey := sm.genStateKey(funcKey, stateID) + return sm.stateMap[stateKey] +} + +func (sm *stateManager) addInstance(funcKey, stateID string, lease *api.InstanceAllocation) { + sm.muteMap.Lock() + defer sm.muteMap.Unlock() + stateKey := sm.genStateKey(funcKey, stateID) + sm.stateMap[stateKey] = lease +} + +func (sm *stateManager) delInstance(funcKey, stateID string) { + sm.muteMap.Lock() + defer sm.muteMap.Unlock() + stateKey := sm.genStateKey(funcKey, stateID) + delete(sm.stateMap, stateKey) +} diff --git a/api/go/faassdk/handler/http/state_test.go b/api/go/faassdk/handler/http/state_test.go new file mode 100644 index 0000000..e1eac17 --- /dev/null +++ b/api/go/faassdk/handler/http/state_test.go @@ -0,0 +1,290 @@ +package http + +import ( + "github.com/agiledragon/gomonkey/v2" + "reflect" + "testing" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/uuid" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" +) + +type FakeDataSystemClinet struct{} + +func (f *FakeDataSystemClinet) CreateClient(config api.ConnectArguments) (api.KvClient, error) { + return &FakeDataSystemClinet{}, nil +} + +func (f *FakeDataSystemClinet) KVSet(key string, value []byte, param api.SetParam) api.ErrorInfo { + return api.ErrorInfo{ + Code: 0, + Err: nil, + } +} + +func (f *FakeDataSystemClinet) KVSetWithoutKey(value []byte, param api.SetParam) (string, api.ErrorInfo) { + return string(value), api.ErrorInfo{ + Code: 0, + Err: nil, + } +} + +func (f *FakeDataSystemClinet) KVGet(key string, timeoutms ...uint32) ([]byte, api.ErrorInfo) { + return []byte{}, api.ErrorInfo{ + Code: 0, + Err: nil, + } +} + +func (f *FakeDataSystemClinet) KVGetMulti(keys []string, timeoutms ...uint32) ([][]byte, api.ErrorInfo) { + return [][]byte{}, api.ErrorInfo{ + Code: 0, + Err: nil, + } +} + +func (f *FakeDataSystemClinet) KVQuerySize(keys []string) ([]uint64, api.ErrorInfo) { + return []uint64{}, api.ErrorInfo{ + Code: 0, + Err: nil, + } +} + +func (f *FakeDataSystemClinet) KVDel(key string) api.ErrorInfo { + return api.ErrorInfo{ + Code: 0, + Err: nil, + } +} + +func (f *FakeDataSystemClinet) KVDelMulti(keys []string) ([]string, api.ErrorInfo) { + return []string{}, api.ErrorInfo{ + Code: 0, + Err: nil, + } +} + +func (f *FakeDataSystemClinet) GenerateKey() string { + return "" +} + +func (f *FakeDataSystemClinet) SetTraceID(traceID string) { +} + +func (f *FakeDataSystemClinet) DestroyClient() { +} + +func (f *FakeDataSystemClinet) Put(objectId string, value []byte, param api.PutParam, + nestedObjectIds ...[]string) error { + return nil +} + +func (f *FakeDataSystemClinet) Get(objectIds []string, timeoutMs ...int64) ([][]byte, error) { + return [][]byte{}, nil +} + +func (f *FakeDataSystemClinet) GIncreaseRef(objectIds []string, remoteClientId ...string) ([]string, error) { + return []string{}, nil +} + +func (f *FakeDataSystemClinet) GDecreaseRef(objectIds []string, remoteClientId ...string) ([]string, error) { + return []string{}, nil +} + +func (f *FakeDataSystemClinet) ReleaseGRefs(remoteClientId string) error { + return nil +} + +func TestGetStateManager(t *testing.T) { + dsClient := &FakeDataSystemClinet{} + + sm1 := GetStateManager(dsClient) + assert.NotNil(t, sm1, "Expected non-nil stateManager instance") + + sm2 := GetStateManager(dsClient) + assert.Equal(t, sm1, sm2, "Expected the same instance of stateManager") +} + +func TestNewState(t *testing.T) { + convey.Convey("state is already existed", t, func() { + dsClient := &FakeDataSystemClinet{} + sm := GetStateManager(dsClient) + patches := gomonkey.NewPatches() + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVGet", func(_ *FakeDataSystemClinet, key string, timeoutms ...uint32) (string, api.ErrorInfo) { + return "", api.ErrorInfo{ + Code: cOK, + Err: nil, + } + }) + defer patches.Reset() + err := sm.newState("testFunc", "testStateID", "testTraceID") + if stateErr, ok := err.(api.ErrorInfo); !ok { + t.Errorf("Expected error type stateError") + } else { + convey.So(stateErr.Code, convey.ShouldEqual, StateExistedErrCode) + } + }) + + convey.Convey("standard flow", t, func() { + dsClient := &FakeDataSystemClinet{} + sm := GetStateManager(dsClient) + patches := gomonkey.NewPatches() + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVGet", func(_ *FakeDataSystemClinet, key string, timeoutms ...uint32) (string, api.ErrorInfo) { + return "", api.ErrorInfo{ + Code: cNotFound, + Err: nil, + } + }) + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVSet", func(_ *FakeDataSystemClinet, key string, value []byte, param api.SetParam) api.ErrorInfo { + return api.ErrorInfo{ + Code: cOK, + Err: nil, + } + }) + defer patches.Reset() + err := sm.newState("testFunc", "testStateID", "testTraceID") + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("datasystem put failed", t, func() { + dsClient := &FakeDataSystemClinet{} + sm := GetStateManager(dsClient) + patches := gomonkey.NewPatches() + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVGet", func(_ *FakeDataSystemClinet, key string, timeoutms ...uint32) (string, api.ErrorInfo) { + return "", api.ErrorInfo{ + Code: cNotFound, + Err: nil, + } + }) + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVSet", func(_ *FakeDataSystemClinet, key string, value []byte, param api.SetParam) api.ErrorInfo { + return api.ErrorInfo{ + Code: 4, // todo + Err: nil, + } + }) + defer patches.Reset() + err := sm.newState("testFunc", "testStateID", "testTraceID") + if stateErr, ok := err.(api.ErrorInfo); !ok { + t.Errorf("Expected error type stateError") + } else { + convey.So(stateErr.Code, convey.ShouldEqual, DataSystemInternalErrCode) + } + }) +} + +func TestGetState(t *testing.T) { + convey.Convey("datesystem InternalErrCode", t, func() { + dsClient := &FakeDataSystemClinet{} + sm := GetStateManager(dsClient) + patches := gomonkey.NewPatches() + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVGet", func(_ *FakeDataSystemClinet, key string, timeoutms ...uint32) (string, api.ErrorInfo) { + return "", api.ErrorInfo{ + Code: 100, + Err: nil, + } + }) + defer patches.Reset() + err := sm.getState("testFunc", "testStateID", "testTraceID") + if stateErr, ok := err.(api.ErrorInfo); !ok { + t.Errorf("Expected error type stateError") + } else { + convey.So(stateErr.Code, convey.ShouldEqual, DataSystemInternalErrCode) + } + }) +} + +func TestDelState(t *testing.T) { + convey.Convey("state is not existed", t, func() { + dsClient := &FakeDataSystemClinet{} + sm := GetStateManager(dsClient) + patches := gomonkey.NewPatches() + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVGet", func(_ *FakeDataSystemClinet, key string, timeoutms ...uint32) (string, api.ErrorInfo) { + return "", api.ErrorInfo{ + Code: cNotFound, + Err: nil, + } + }) + defer patches.Reset() + err := sm.delState("testFunc", "testStateID", "testTraceID") + if stateErr, ok := err.(api.ErrorInfo); !ok { + t.Errorf("Expected error type stateError") + } else { + convey.So(stateErr.Code, convey.ShouldEqual, StateNotExistedErrCode) + } + }) + + convey.Convey("standard flow", t, func() { + dsClient := &FakeDataSystemClinet{} + sm := GetStateManager(dsClient) + patches := gomonkey.NewPatches() + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVGet", func(_ *FakeDataSystemClinet, key string, timeoutms ...uint32) (string, api.ErrorInfo) { + return "", api.ErrorInfo{ + Code: cOK, + Err: nil, + } + }) + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVDel", func(_ *FakeDataSystemClinet, key string) api.ErrorInfo { + return api.ErrorInfo{ + Code: cOK, + Err: nil, + } + }) + defer patches.Reset() + err := sm.delState("testFunc", "testStateID", "testTraceID") + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("datesystem InternalErrCode", t, func() { + dsClient := &FakeDataSystemClinet{} + sm := GetStateManager(dsClient) + patches := gomonkey.NewPatches() + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVGet", func(_ *FakeDataSystemClinet, key string, timeoutms ...uint32) (string, api.ErrorInfo) { + return "", api.ErrorInfo{ + Code: cOK, + Err: nil, + } + }) + patches.ApplyMethod(reflect.TypeOf(dsClient), "KVDel", func(_ *FakeDataSystemClinet, key string) api.ErrorInfo { + return api.ErrorInfo{ + Code: 100, + Err: nil, + } + }) + defer patches.Reset() + err := sm.delState("testFunc", "testStateID", "testTraceID") + if stateErr, ok := err.(api.ErrorInfo); !ok { + t.Errorf("Expected error type stateError") + } else { + convey.So(stateErr.Code, convey.ShouldEqual, DataSystemInternalErrCode) + } + }) +} + +func TestInstanceFuncs(t *testing.T) { + dsClient := &FakeDataSystemClinet{} + sm := GetStateManager(dsClient) + testFuncKey := "test-func-key" + testStateID := "test-state-id" + uuid.New().String() + testLease := api.InstanceAllocation{ + LeaseID: "test-lease-id", + } + getLease := sm.getInstance(testFuncKey, testStateID) + if getLease != nil { + t.Errorf("Expected getLease to be nil, but it was not") + } + + sm.addInstance(testFuncKey, testStateID, &testLease) + getLease = sm.getInstance(testFuncKey, testStateID) + + if testLease != *getLease { + t.Errorf("Expected getLease and testLease are same") + } + + sm.delInstance(testFuncKey, testStateID) + getLease = sm.getInstance(testFuncKey, testStateID) + if getLease != nil { + t.Errorf("Expected getLease to be nil, but it was not") + } +} diff --git a/api/go/faassdk/handler/mock_utils_test.go b/api/go/faassdk/handler/mock_utils_test.go new file mode 100644 index 0000000..fa30eff --- /dev/null +++ b/api/go/faassdk/handler/mock_utils_test.go @@ -0,0 +1,302 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package event - +package handler + +import ( + "github.com/agiledragon/gomonkey/v2" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// PatchSlice - +type PatchSlice []*gomonkey.Patches + +// PatchesFunc - +type PatchesFunc func() PatchSlice + +// InitPatchSlice - +func InitPatchSlice() PatchSlice { + return make([]*gomonkey.Patches, 0) +} + +type mockLibruntimeClient struct { +} + +func (m mockLibruntimeClient) CreateInstance(funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) InvokeByInstanceId(funcMeta api.FunctionMeta, instanceID string, args []api.Arg, invokeOpt api.InvokeOptions) (returnObjectID string, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) InvokeByFunctionName(funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) AcquireInstance(state string, funcMeta api.FunctionMeta, acquireOpt api.InvokeOptions) (api.InstanceAllocation, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) ReleaseInstance(allocation api.InstanceAllocation, stateID string, abnormal bool, option api.InvokeOptions) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Kill(instanceID string, signal int, payload []byte) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) CreateInstanceRaw(createReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) InvokeByInstanceIdRaw(invokeReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KillRaw(killReqRaw []byte) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) SaveState(state []byte) (string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) LoadState(checkpointID string) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Exit(code int, message string) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Finalize() { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVSet(key string, value []byte, param api.SetParam) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVSetWithoutKey(value []byte, param api.SetParam) (string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVGet(key string, timeoutms uint) ([]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVGetMulti(keys []string, timeoutms uint) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVDel(key string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) KVDelMulti(keys []string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) CreateProducer(streamName string, producerConf api.ProducerConf) (api.StreamProducer, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Subscribe(streamName string, config api.SubscriptionConfig) (api.StreamConsumer, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) DeleteStream(streamName string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) QueryGlobalProducersNum(streamName string) (uint64, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) QueryGlobalConsumersNum(streamName string) (uint64, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) SetTraceID(traceID string) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) SetTenantID(tenantID string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Put(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) PutRaw(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Get(objectIDs []string, timeoutMs int) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GetRaw(objectIDs []string, timeoutMs int) ([][]byte, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) Wait(objectIDs []string, waitNum uint64, timeoutMs int) ([]string, []string, map[string]error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GIncreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GIncreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GDecreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GDecreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GetAsync(objectID string, cb api.GetAsyncCallback) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) GetFormatLogger() api.FormatLogger { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) CreateClient(config api.ConnectArguments) (api.KvClient, error) { + //TODO implement me + panic("implement me") +} + +func (m mockLibruntimeClient) ReleaseGRefs(remoteClientID string) error { + //TODO implement me + panic("implement me") +} + +func (m *mockLibruntimeClient) UpdateSchdulerInfo(schedulerName string, option string) { + //TODO implement me + panic("implement me") +} + +func (m *mockLibruntimeClient) IsHealth() bool { + return true +} + +func (m *mockLibruntimeClient) IsDsHealth() bool { + return true +} + +// Append - +func (p *PatchSlice) Append(patches PatchSlice) { + if len(patches) > 0 { + *p = append(*p, patches...) + } +} + +// ResetAll - +func (p PatchSlice) ResetAll() { + for _, item := range p { + item.Reset() + } +} + +// FakeLogger - +type FakeLogger struct{} + +// With - +func (f *FakeLogger) With(fields ...zapcore.Field) api.FormatLogger { + return f +} + +// Infof - +func (f *FakeLogger) Infof(format string, paras ...interface{}) {} + +// Errorf - +func (f *FakeLogger) Errorf(format string, paras ...interface{}) {} + +// Warnf - +func (f *FakeLogger) Warnf(format string, paras ...interface{}) {} + +// Debugf - +func (f *FakeLogger) Debugf(format string, paras ...interface{}) {} + +// Fatalf - +func (f *FakeLogger) Fatalf(format string, paras ...interface{}) {} + +// Info - +func (f *FakeLogger) Info(msg string, fields ...zap.Field) {} + +// Error - +func (f *FakeLogger) Error(msg string, fields ...zap.Field) {} + +// Warn - +func (f *FakeLogger) Warn(msg string, fields ...zap.Field) {} + +// Debug - +func (f *FakeLogger) Debug(msg string, fields ...zap.Field) {} + +// Fatal - +func (f *FakeLogger) Fatal(msg string, fields ...zap.Field) {} + +// Sync - +func (f *FakeLogger) Sync() {} diff --git a/api/go/faassdk/runtime.go b/api/go/faassdk/runtime.go new file mode 100644 index 0000000..ad933b3 --- /dev/null +++ b/api/go/faassdk/runtime.go @@ -0,0 +1,86 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faassdk for init and start +package faassdk + +import ( + "fmt" + "os" + + "yuanrong.org/kernel/runtime/libruntime" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common" + "yuanrong.org/kernel/runtime/libruntime/config" + "yuanrong.org/kernel/runtime/libruntime/pool" +) + +// Run begins loop processing the received request. +func Run() { + libruntime.ReceiveRequestLoop() +} + +// InitRuntime init runtime +func InitRuntime() error { + conf := common.GetConfig() + intfs := newFaaSFuncExecutionIntfs() + runtimeConf := config.Config{ + GrpcAddress: conf.GrpcAddress, + FunctionSystemAddress: conf.FSAddress, + DataSystemAddress: os.Getenv("DATASYSTEM_ADDR"), + JobID: conf.JobID, + RuntimeID: conf.RuntimeID, + InstanceID: conf.InstanceID, + FunctionName: conf.FunctionName, + LogDir: conf.LogPath, + LogLevel: conf.LogLevel, + InCluster: true, + IsDriver: conf.DriverMode, + EnableMTLS: conf.EnableMTLS, + PrivateKeyPath: conf.PrivateKeyPath, + CertificateFilePath: conf.CertificateFilePath, + VerifyFilePath: conf.VerifyFilePath, + PrivateKeyPaaswd: conf.PrivateKeyPaaswd, + Api: api.FaaSApi, + Hooks: config.HookIntfs{ + LoadFunctionCb: intfs.LoadFunction, + FunctionExecutionCb: intfs.FunctionExecute, + CheckpointCb: intfs.Checkpoint, + RecoverCb: intfs.Recover, + ShutdownCb: intfs.Shutdown, + SignalCb: intfs.Signal, + HealthCheckCb: intfs.HealthCheck, + }, + FunctionExectionPool: pool.NewPool(pool.DefaultFuncExecPoolSize), + SystemAuthAccessKey: conf.SystemAuthAccessKey, + SystemAuthSecretKey: conf.SystemAuthSecretKey, + EncryptPrivateKeyPasswd: conf.EncryptPrivateKeyPasswd, + PrimaryKeyStoreFile: conf.PrimaryKeyStoreFile, + StandbyKeyStoreFile: conf.StandbyKeyStoreFile, + EnableDsEncrypt: conf.EnableDsEncrypt, + RuntimePublicKeyContext: conf.RuntimePublicKeyContext, + RuntimePrivateKeyContext: conf.RuntimePrivateKeyContext, + DsPublicKeyContext: conf.DsPublicKeyContext, + EncryptRuntimePublicKeyContext: conf.EncryptRuntimePublicKeyContext, + EncryptRuntimePrivateKeyContext: conf.EncryptRuntimePrivateKeyContext, + EncryptDsPublicKeyContext: conf.EncryptDsPublicKeyContext, + } + if err := libruntime.Init(runtimeConf); err != nil { + fmt.Printf("failed to init libruntime, error %s\n", err.Error()) + return err + } + return nil +} diff --git a/api/go/faassdk/runtime_test.go b/api/go/faassdk/runtime_test.go new file mode 100644 index 0000000..12f4295 --- /dev/null +++ b/api/go/faassdk/runtime_test.go @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faassdk for init and start +package faassdk + +import ( + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/common" +) + +func TestInitRuntimeAndRun(t *testing.T) { + convey.Convey( + "Test InitRuntimeAndRun", t, func() { + cfg := common.GetConfig() + cfg.DriverMode = true + os.Create("./test.json") + defer os.Remove("./test.json") + os.Setenv("FUNCTION_LIB_PATH", "/tmp") + convey.Convey("Test InitRuntime Failed", func() { + err := InitRuntime() + Run() + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/api/go/faassdk/sts/sts.go b/api/go/faassdk/sts/sts.go new file mode 100644 index 0000000..30650a7 --- /dev/null +++ b/api/go/faassdk/sts/sts.go @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package sts used for init sts +package sts + +import ( + "os" + + "github.com/magiconair/properties" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong.org/kernel/runtime/faassdk/types" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +// EnvSTSEnable flag +const EnvSTSEnable = "STS_ENABLE" +const fileMode = 0640 + +// InitStsSDK - Configure sts go sdk +func InitStsSDK(serverCfg types.StsServerConfig) error { + initStsSdkLog() + logger.GetLogger().Infof("finished to init sts sdk log") + stsProperties := properties.LoadMap( + map[string]string{ + "sts.server.domain": serverCfg.Domain, + "sts.config.path": serverCfg.Path, + "sts.connect.timeout": "20000", + "sts.handshake.timeout": "20000", + }, + ) + err := stsgoapi.InitWith(*stsProperties) + return err +} + +func initStsSdkLog() { + coreInfo, err := config.GetCoreInfoFromEnv() + if err != nil { + coreInfo = config.GetDefaultCoreInfo() + } + stsSdkLogFilePath := coreInfo.FilePath + "/sts.sdk.log" + stsgoapi.SetLogFile(stsSdkLogFilePath) + file, err := os.OpenFile(stsSdkLogFilePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileMode) + if err != nil { + logger.GetLogger().Errorf("failed to open stsSdkLogFile") + return + } + defer file.Close() + return +} diff --git a/api/go/faassdk/sts/sts_test.go b/api/go/faassdk/sts/sts_test.go new file mode 100644 index 0000000..b3fae7f --- /dev/null +++ b/api/go/faassdk/sts/sts_test.go @@ -0,0 +1,22 @@ +package sts + +import ( + "github.com/agiledragon/gomonkey/v2" + "github.com/magiconair/properties" + "github.com/smartystreets/goconvey/convey" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + "testing" + "yuanrong.org/kernel/runtime/faassdk/types" +) + +func TestInitStsSDK(t *testing.T) { + convey.Convey("InitStsSDK", t, func() { + convey.Convey("success", func() { + defer gomonkey.ApplyFunc(stsgoapi.InitWith, func(property properties.Properties) error { + return nil + }).Reset() + err := InitStsSDK(types.StsServerConfig{}) + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/api/go/faassdk/types/types.go b/api/go/faassdk/types/types.go new file mode 100644 index 0000000..c6ff1d6 --- /dev/null +++ b/api/go/faassdk/types/types.go @@ -0,0 +1,321 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +import ( + "encoding/json" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +// CreateParams for RequestArgs +type CreateParams struct { + InitEntry string `json:"userInitEntry"` + CallEntry string `json:"userCallEntry"` +} + +// HttpCreateParams is used for http function creation +type HttpCreateParams struct { + Port int `json:"port"` + InitRoute string `json:"initRoute"` + CallRoute string `json:"callRoute"` +} + +// CreateOptions context +type CreateOptions struct { + FuncMetaDataContext string + ResourceMetaDataContext string +} + +// FuncSpec - +type FuncSpec struct { + FuncMetaData FuncMetaData `json:"funcMetaData"` + ResourceMetaData ResourceMetaData `json:"resourceMetaData"` + ExtendedMetaData ExtendedMetaData `json:"extendedMetaData"` +} + +// FuncMetaData - +type FuncMetaData struct { + FunctionName string `json:"name"` + Service string `json:"service"` + Runtime string `json:"runtime"` + TenantId string `json:"tenantId"` + Version string `json:"version"` + Timeout int `json:"timeout"` + Handler string `json:"handler"` + FunctionVersionURN string `json:"functionVersionUrn"` +} + +// ResourceMetaData - +type ResourceMetaData struct { + Cpu int `json:"cpu"` + Memory int `json:"memory"` + GpuMemory int64 `json:"gpu_memory"` + EnableDynamicMemory bool `json:"enable_dynamic_memory" valid:",optional"` + CustomResources string `json:"customResources" valid:",optional"` + EnableTmpExpansion bool `json:"enable_tmp_expansion" valid:",optional"` + EphemeralStorage int `json:"ephemeral_storage" valid:"int,optional"` +} + +// ExtendedMetaData - +type ExtendedMetaData struct { + Initializer Initializer `json:"initializer" valid:",optional"` + LogTankService LogTankService `json:"log_tank_service" valid:",optional"` + CustomHealthCheck CustomHealthCheck `json:"custom_health_check" valid:",optional"` + CustomGracefulShutdown CustomGracefulShutdown `json:"runtime_graceful_shutdown"` +} + +// LogTankService - +type LogTankService struct { + GroupID string `json:"logGroupId" valid:",optional"` + StreamID string `json:"logStreamId" valid:",optional"` +} + +// Initializer - +type Initializer struct { + Handler string `json:"initializer_handler" valid:",optional"` + Timeout int `json:"initializer_timeout" valid:",optional"` +} + +// CustomHealthCheck - +type CustomHealthCheck struct { + TimeoutSeconds int `json:"timeoutSeconds" valid:",optional"` + PeriodSeconds int `json:"periodSeconds" valid:",optional"` + FailureThreshold int `json:"failureThreshold" valid:",optional"` +} + +// CustomGracefulShutdown define the option of custom container's runtime graceful shutdown +type CustomGracefulShutdown struct { + MaxShutdownTimeout int `json:"maxShutdownTimeout"` +} + +// CallRequest - +type CallRequest struct { + Header map[string]string `json:"header"` + Path string `json:"path"` + Method string `json:"method"` + Query string `json:"query"` + Body json.RawMessage `json:"body"` +} + +// CallResponse - +type CallResponse struct { + Headers map[string]string `json:"headers"` + Body json.RawMessage `json:"body"` + BillingDuration string `json:"billingDuration"` + InnerCode string `json:"innerCode"` + InvokerSummary string `json:"invokerSummary"` + LogResult string `json:"logResult"` + UserFuncTime float64 `json:"userFuncTime"` + ExecutorTime float64 `json:"executorTime"` +} + +// InitResponse - +type InitResponse struct { + ErrorCode string `json:"errorCode"` + Message json.RawMessage `json:"message"` +} + +// InvokeRequest - +type InvokeRequest struct { + FuncName string `json:"funcName"` + FuncVersion string `json:"funcVersion"` + Payload string `json:"payload"` + TraceID string `json:"traceID"` + Timeout int64 `json:"timeout"` + AcquireTimeout int64 `json:"acquireTimeout"` + StateID string `json:"stateID"` + Params map[string]string `json:"params"` + FuncUrn string +} + +// InvokeResponse - +type InvokeResponse struct { + ObjectID string `json:"objectID"` + StatusCode int `json:"statusCode"` + ErrorMessage string `json:"errorMessage"` +} + +// GetStatusCode InvokeResponse get status code +func (i *InvokeResponse) GetStatusCode() int { + return i.StatusCode + +} + +// GetErrorMessage InvokeResponse get error message +func (i *InvokeResponse) GetErrorMessage() string { + return i.ErrorMessage + +} + +// ExitRequest - +type ExitRequest struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// GetFutureRequest - +type GetFutureRequest struct { + ObjectID string `json:"objectID"` + TraceID string `json:"traceID"` +} + +// CircuitBreakRequest - +type CircuitBreakRequest struct { + Switch bool `json:"switch"` +} + +// GetFutureResponse - +type GetFutureResponse struct { + ObjectID string `json:"objectID"` + Content string `json:"content"` + StatusCode int `json:"statusCode"` + ErrorMessage string `json:"errorMessage"` +} + +// GetStatusCode GetFutureResponse get status code +func (f *GetFutureResponse) GetStatusCode() int { + return f.StatusCode +} + +// GetErrorMessage GetFutureResponse get error message +func (f *GetFutureResponse) GetErrorMessage() string { + return f.ErrorMessage +} + +// StateRequest - +type StateRequest struct { + FuncName string `json:"funcName"` + FuncVersion string `json:"funcVersion"` + Params map[string]string `json:"params"` + StateID string `json:"stateID"` + TraceID string `json:"traceID"` +} + +// StateResponse - +type StateResponse struct { + StateID string `json:"stateID"` + StatusCode int `json:"statusCode"` + ErrorMessage string `json:"errorMessage"` +} + +// GetStatusCode StateNewResponse get status code +func (i *StateResponse) GetStatusCode() int { + return i.StatusCode +} + +// GetErrorMessage StateNewResponse get error message +func (f *StateResponse) GetErrorMessage() string { + return f.ErrorMessage +} + +// TerminateResponse - +type TerminateResponse struct { + ObjectID string `json:"objectID"` + StatusCode int `json:"statusCode"` + ErrorMessage string `json:"errorMessage"` +} + +// GetStatusCode StateNewResponse get status code +func (i *TerminateResponse) GetStatusCode() int { + return i.StatusCode +} + +// GetErrorMessage StateNewResponse get error message +func (f *TerminateResponse) GetErrorMessage() string { + return f.ErrorMessage +} + +// ExitResponse - +type ExitResponse struct { + StatusCode int `json:"statusCode"` + ErrorMessage string `json:"errorMessage"` +} + +// GetStatusCode ExitResponse get status code +func (i *ExitResponse) GetStatusCode() int { + return i.StatusCode +} + +// GetErrorMessage ExitResponse get error message +func (f *ExitResponse) GetErrorMessage() string { + return f.ErrorMessage +} + +// Response interface +type Response interface { + GetStatusCode() int + GetErrorMessage() string +} + +// AuthConfig represents configurations of local auth +type AuthConfig struct { + AKey string `json:"aKey" yaml:"aKey" valid:"optional"` + SKey string `json:"sKey" yaml:"sKey" valid:"optional"` + Duration int `json:"duration" yaml:"duration" valid:"optional"` +} + +// CustomUserArgs - +type CustomUserArgs struct { + AlarmConfig AlarmConfig `json:"alarmConfig" valid:"optional"` + StsServerConfig StsServerConfig `json:"stsServerConfig"` + ClusterName string `json:"clusterName"` + DiskMonitorEnable bool `json:"diskMonitorEnable"` + LocalAuth AuthConfig `json:"localAuth"` +} + +// StsServerConfig - +type StsServerConfig struct { + Domain string `json:"domain,omitempty" validate:"max=255"` + Path string `json:"path,omitempty" validate:"max=255"` +} + +// AlarmConfig - +type AlarmConfig struct { + EnableAlarm bool `json:"enableAlarm"` + AlarmLogConfig config.CoreInfo `json:"alarmLogConfig" valid:"optional"` + XiangYunFourConfig XiangYunFourConfig `json:"xiangYunFourConfig" valid:"optional"` + MinInsStartInterval int `json:"minInsStartInterval"` + MinInsCheckInterval int `json:"minInsCheckInterval"` +} + +// XiangYunFourConfig - +type XiangYunFourConfig struct { + Site string `json:"site"` + TenantID string `json:"tenantID"` + ApplicationID string `json:"applicationID"` + ServiceID string `json:"serviceID"` +} + +// CredentialResponse - +type CredentialResponse struct { + StatusCode int `json:"statusCode"` + ErrorMessage string `json:"errorMessage"` + api.Credential +} + +// GetStatusCode CredentialResponse get status code +func (c *CredentialResponse) GetStatusCode() int { + return c.StatusCode +} + +// GetErrorMessage CredentialResponse get error message +func (c *CredentialResponse) GetErrorMessage() string { + return c.ErrorMessage +} diff --git a/api/go/faassdk/types/types_test.go b/api/go/faassdk/types/types_test.go new file mode 100644 index 0000000..82b10cc --- /dev/null +++ b/api/go/faassdk/types/types_test.go @@ -0,0 +1,94 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestInvokeResponse(t *testing.T) { + convey.Convey("Test InvokeResponse", t, func() { + i := &InvokeResponse{} + convey.Convey("GetStatusCode success", func() { + code := i.GetStatusCode() + convey.So(code, convey.ShouldBeZeroValue) + }) + convey.Convey("GetErrorMessage success", func() { + msg := i.GetErrorMessage() + convey.So(msg, convey.ShouldBeEmpty) + }) + }) +} + +func TestGetFutureResponse(t *testing.T) { + convey.Convey("Test GetFutureResponse", t, func() { + f := &GetFutureResponse{} + convey.Convey("GetStatusCode success", func() { + code := f.GetStatusCode() + convey.So(code, convey.ShouldBeZeroValue) + }) + convey.Convey("GetErrorMessage success", func() { + msg := f.GetErrorMessage() + convey.So(msg, convey.ShouldBeEmpty) + }) + }) +} + +func TestStateResponse(t *testing.T) { + convey.Convey("Test StateResponse", t, func() { + s := &StateResponse{} + convey.Convey("GetStatusCode success", func() { + code := s.GetStatusCode() + convey.So(code, convey.ShouldBeZeroValue) + }) + convey.Convey("GetErrorMessage success", func() { + msg := s.GetErrorMessage() + convey.So(msg, convey.ShouldBeEmpty) + }) + }) +} + +func TestTerminateResponse(t *testing.T) { + convey.Convey("Test TerminateResponse", t, func() { + t := &TerminateResponse{} + convey.Convey("GetStatusCode success", func() { + code := t.GetStatusCode() + convey.So(code, convey.ShouldBeZeroValue) + }) + convey.Convey("GetErrorMessage success", func() { + msg := t.GetErrorMessage() + convey.So(msg, convey.ShouldBeEmpty) + }) + }) +} + +func TestExitResponse(t *testing.T) { + convey.Convey("Test ExitResponse", t, func() { + e := &ExitResponse{} + convey.Convey("GetStatusCode success", func() { + code := e.GetStatusCode() + convey.So(code, convey.ShouldBeZeroValue) + }) + convey.Convey("GetErrorMessage success", func() { + msg := e.GetErrorMessage() + convey.So(msg, convey.ShouldBeEmpty) + }) + }) +} diff --git a/api/go/faassdk/utils/handle_response.go b/api/go/faassdk/utils/handle_response.go new file mode 100644 index 0000000..bdad456 --- /dev/null +++ b/api/go/faassdk/utils/handle_response.go @@ -0,0 +1,120 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils +package utils + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/faassdk/types" +) + +// ExecutionDuration records the execution time of Executor and user functions. +type ExecutionDuration struct { + ExecutorBeginTime time.Time + UserFuncBeginTime time.Time + UserFuncTotalTime time.Duration +} + +const responseHeaderSize = 2 + +// HandleInitResponse format init response processing error return. +func HandleInitResponse(body interface{}, statusCode int) ([]byte, error) { + var err error + initResponse := &types.InitResponse{} + + initResponse.ErrorCode = strconv.Itoa(statusCode) + if initResponse.Message, err = HandleResponseBody(body); err != nil { + errStr := fmt.Sprintf(`{"errorCode":"%d","message":"%s"}`, + constants.ExecutorErrCodeInitFail, err.Error()) + return []byte{}, errors.New(errStr) + } + + result, err := json.Marshal(initResponse) + if err != nil { + errStr := fmt.Sprintf(`{"errorCode":"%d","message":"json marsh response error:%s"}`, + constants.ExecutorErrCodeInitFail, err.Error()) + return []byte{}, errors.New(errStr) + } + return []byte{}, errors.New(string(result)) +} + +// HandleCallResponse format call response +func HandleCallResponse(body interface{}, statusCode int, logResult string, + totalTime ExecutionDuration, headers map[string][]string) ([]byte, error) { + callResponse := &types.CallResponse{} + var err error + callResponse.Headers = HandleCallResponseHeaders(headers) + if callResponse.Body, err = HandleResponseBody(body); err != nil { + return []byte{}, fmt.Errorf("handle response header error: %s", err) + } + callResponse.BillingDuration = "this is billing duration TODO" + callResponse.InnerCode = strconv.Itoa(statusCode) + callResponse.InvokerSummary = "this is summary TODO" + callResponse.LogResult = encodeBase64([]byte("this is user log TODO")) + if !totalTime.ExecutorBeginTime.IsZero() { + callResponse.ExecutorTime = float64(time.Since(totalTime.ExecutorBeginTime).Milliseconds()) + } + callResponse.UserFuncTime = float64(totalTime.UserFuncTotalTime.Milliseconds()) + if logResult != "" { + callResponse.LogResult = logResult + } + result, err := json.Marshal(callResponse) + if err != nil { + return []byte{}, fmt.Errorf("json marsh response error:%s", err) + } + + return result, nil +} + +// HandleResponseBody convert interface body to []byte or handle error message +func HandleResponseBody(body interface{}) ([]byte, error) { + if _, ok := body.([]byte); ok { + return body.([]byte), nil + } + if result, err := json.Marshal(body); err != nil { + return []byte{}, fmt.Errorf("json marshal response error:%s", err) + } else { + return result, nil + } +} + +// HandleCallResponseHeaders deal with response header +func HandleCallResponseHeaders(rawHeaders map[string][]string) map[string]string { + headers := make(map[string]string, responseHeaderSize) + headers["Content-Type"] = "application/json" + headers["X-Log-Type"] = "base64" + if rawHeaders != nil { + for k, v := range rawHeaders { + if len(v) > 0 { + headers[k] = v[0] + } + } + } + return headers +} + +func encodeBase64(data []byte) string { + base64Result := base64.StdEncoding.EncodeToString(data) + return base64Result +} diff --git a/api/go/faassdk/utils/handle_response_test.go b/api/go/faassdk/utils/handle_response_test.go new file mode 100644 index 0000000..1cad7c3 --- /dev/null +++ b/api/go/faassdk/utils/handle_response_test.go @@ -0,0 +1,71 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils +package utils + +import ( + "encoding/json" + "errors" + "strings" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" +) + +func TestHandleInitResponse(t *testing.T) { + convey.Convey("HandleInitResponse", t, func() { + convey.Convey("ExecutorErrCodeInitFail", func() { + defer gomonkey.ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, errors.New("json marshal error") + }).Reset() + resp, err := HandleInitResponse("hello world", 0) + convey.So(resp, convey.ShouldEqual, "") + convey.So(err, convey.ShouldBeError) + }) + }) +} + +func TestHandleCallResponse(t *testing.T) { + convey.Convey("HandleCallResponse", t, func() { + totalTime := ExecutionDuration{ + ExecutorBeginTime: time.Now(), + UserFuncBeginTime: time.Now(), + UserFuncTotalTime: 0, + } + response, err := HandleCallResponse(createLargeStr(), 200, "", totalTime, nil) + convey.So(err, convey.ShouldBeNil) + convey.So(len(response) > 6291456, convey.ShouldEqual, true) + }) +} + +func createLargeStr() string { + builder := strings.Builder{} + for i := 0; i < 1024*1024; i++ { + builder.WriteString("aaaaaa") + } + return builder.String() +} + +func TestHandleResponseBody(t *testing.T) { + convey.Convey("HandleResponseBody", t, func() { + body, err := HandleResponseBody([]byte{}) + convey.So(body, convey.ShouldEqual, []byte{}) + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/api/go/faassdk/utils/signer/sksigner.go b/api/go/faassdk/utils/signer/sksigner.go new file mode 100644 index 0000000..2138342 --- /dev/null +++ b/api/go/faassdk/utils/signer/sksigner.go @@ -0,0 +1,86 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package signer - +package signer + +import ( + "crypto/hmac" + "crypto/sha256" + "strings" +) + +const ( + // TimestampKey timestamp + TimestampKey = "timestamp=" + + // AccessIDKey accessId + AccessIDKey = "accessId=" + + // SignatureKey signature + SignatureKey = "signature=" + + // Algorithm Signature algorithm + Algorithm = "SDK-HMAC-SHA256" +) + +var hexCodes = []rune("0123456789abcdef") + +// Sign hmac sha256 signature +func Sign(sk, content []byte) []byte { + h := hmac.New(sha256.New, sk) + _, err := h.Write(content) + if err != nil { + return nil + } + return h.Sum(nil) +} + +// EncodeHex encode hex +func EncodeHex(data []byte) string { + if data == nil || len(data) == 0 { + return "" + } + l := len(data) + out := make([]rune, l<<1) + j := 0 + for i := 0; i < l; i++ { + if j >= l<<1 { + return "" + } + out[j] = hexCodes[(data[i]>>4)&0xF] // magic number + j++ + if j >= l<<1 { + return "" + } + out[j] = hexCodes[(data[i] & 0xF)] // magic number + j++ + } + return string(out) +} + +// BuildAuthorization Splicing Signature +func BuildAuthorization(ak, ts, sign string) string { + var builder strings.Builder + builder.WriteString(Algorithm) + builder.WriteString(" ") + builder.WriteString(AccessIDKey + ak) + builder.WriteString(",") + builder.WriteString(TimestampKey + ts) + builder.WriteString(",") + builder.WriteString(SignatureKey + sign) + return builder.String() +} diff --git a/api/go/faassdk/utils/signer/sksigner_test.go b/api/go/faassdk/utils/signer/sksigner_test.go new file mode 100644 index 0000000..4f9aac3 --- /dev/null +++ b/api/go/faassdk/utils/signer/sksigner_test.go @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package signer + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestSign(t *testing.T) { + convey.Convey("Sign", t, func() { + value := Sign([]byte("abcdefgh"), []byte("afdhkajhfkjasdhfkjasdhfkjasdfh")) + convey.So(value, convey.ShouldEqual, []byte{233, 43, 186, 45, 115, 41, 252, 5, 25, 20, 102, 213, 47, 18, 30, 161, 100, 248, 235, 212, 108, 19, 103, 58, 242, 244, 24, 54, 20, 176, 247, 230}) + }) +} + +func TestEncodeHex(t *testing.T) { + convey.Convey("EncodeHex", t, func() { + tests := []struct { + Input string + Expect string + }{{ + "afasdfadsfdasf", + "6166617364666164736664617366", + }, { + "1234567890-=", + "313233343536373839302d3d", + }, { + "1234567890-\n\t\r \t", + "313233343536373839302d0a090d20202009", + }, + } + for _, tt := range tests { + convey.So(EncodeHex([]byte(tt.Input)), convey.ShouldEqual, tt.Expect) + } + }) +} + +func TestBuildAuthorization(t *testing.T) { + convey.Convey("BuildAuthorization", t, func() { + tests := []struct { + Ak string + Ts string + Sign string + Expect string + }{ + { + Ak: "tenatnId", + Ts: "123456789", + Sign: "adfasfasdfadsfadf", + Expect: "SDK-HMAC-SHA256 accessId=tenatnId,timestamp=123456789,signature=adfasfasdfadsfadf", + }, { + Ak: "", + Ts: "123456789", + Sign: "qerqreqwrqewr", + Expect: "SDK-HMAC-SHA256 accessId=,timestamp=123456789,signature=qerqreqwrqewr", + }, { + Ak: "tenatnId", + Ts: "", + Sign: "zvvzcvzcvc", + Expect: "SDK-HMAC-SHA256 accessId=tenatnId,timestamp=,signature=zvvzcvzcvc", + }, + { + Ak: "tenantId", + Ts: "", + Sign: "", + Expect: "SDK-HMAC-SHA256 accessId=tenantId,timestamp=,signature=", + }, + } + + for _, tt := range tests { + convey.So(BuildAuthorization(tt.Ak, tt.Ts, tt.Sign), convey.ShouldEqual, tt.Expect) + } + }) +} diff --git a/api/go/faassdk/utils/urnutils/gadgets.go b/api/go/faassdk/utils/urnutils/gadgets.go new file mode 100644 index 0000000..56467f7 --- /dev/null +++ b/api/go/faassdk/utils/urnutils/gadgets.go @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package urnutils + +import ( + "strings" +) + +// Separator is the current system special +var Separator = "-" + +const ( + // ServiceIDPrefix is the prefix of the function with serviceID. + ServiceIDPrefix = "0" + + // DefaultSeparator is a character that separates functions and services. + DefaultSeparator = "-" + + // ServicePrefix is the prefix of the function with serviceID. + ServicePrefix = "0@" + + // TenantProductSplitStr separator between a tenant and a product + TenantProductSplitStr = "@" + + minEleSize = 3 +) + +// ComplexFuncName contains service ID and raw function name +type ComplexFuncName struct { + prefix string + ServiceID string + FuncName string +} + +// NewComplexFuncName - +func NewComplexFuncName(svcID, funcName string) *ComplexFuncName { + return &ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: svcID, + FuncName: funcName, + } +} + +// ParseFrom parse ComplexFuncName from string +func (c *ComplexFuncName) ParseFrom(name string) *ComplexFuncName { + fields := strings.Split(name, Separator) + if len(fields) < minEleSize || fields[0] != ServiceIDPrefix { + c.prefix = "" + c.ServiceID = "" + c.FuncName = name + return c + } + idx := 0 + c.prefix = fields[idx] + idx++ + c.ServiceID = fields[idx] + // $prefix$separator$ServiceID$separator$FuncName equals name + c.FuncName = name[(len(c.prefix) + len(Separator) + len(c.ServiceID) + len(Separator)):] + return c +} + +// String - +func (c *ComplexFuncName) String() string { + return strings.Join([]string{c.prefix, c.ServiceID, c.FuncName}, Separator) +} + +// GetSvcIDWithPrefix get serviceID with prefix from function name +func (c *ComplexFuncName) GetSvcIDWithPrefix() string { + return c.prefix + Separator + c.ServiceID +} + +// SetSeparator - +func SetSeparator(separator string) { + if separator != "" { + Separator = separator + } +} diff --git a/api/go/faassdk/utils/urnutils/gadgets_test.go b/api/go/faassdk/utils/urnutils/gadgets_test.go new file mode 100644 index 0000000..133cdd3 --- /dev/null +++ b/api/go/faassdk/utils/urnutils/gadgets_test.go @@ -0,0 +1,159 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package urnutils + +import ( + "reflect" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestComplexFuncName_GetSvcIDWithPrefix(t *testing.T) { + tests := []struct { + name string + fields ComplexFuncName + want string + }{ + { + name: "normal", + fields: ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: "absserviceid", + FuncName: "absFuncName", + }, + want: "0-absserviceid", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ComplexFuncName{ + prefix: tt.fields.prefix, + ServiceID: tt.fields.ServiceID, + FuncName: tt.fields.FuncName, + } + if got := c.GetSvcIDWithPrefix(); got != tt.want { + t.Errorf("GetSvcIDWithPrefix() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestComplexFuncName_ParseFrom(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + want *ComplexFuncName + }{ + { + name: "normal", + args: args{ + name: "0-absserviceid-absFunc-Name", + }, + want: &ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: "absserviceid", + FuncName: "absFunc-Name", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ComplexFuncName{} + if got := c.ParseFrom(tt.args.name); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseFrom() = %v, want %v", got, tt.want) + } + if got := c.ParseFrom("name-0"); !reflect.DeepEqual(got.FuncName, "name-0") { + t.Errorf("ParseFrom() = %v, want %v", got.FuncName, "name-0") + } + }) + } +} + +func TestComplexFuncName_String(t *testing.T) { + tests := []struct { + name string + fields ComplexFuncName + want string + }{ + { + name: "normal", + fields: ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: "absserviceid", + FuncName: "absFunc-Name", + }, + want: "0-absserviceid-absFunc-Name", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ComplexFuncName{ + prefix: tt.fields.prefix, + ServiceID: tt.fields.ServiceID, + FuncName: tt.fields.FuncName, + } + if got := c.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewComplexFuncName(t *testing.T) { + type args struct { + svcID string + funcName string + } + tests := []struct { + name string + args args + want *ComplexFuncName + }{ + { + name: "normal", + args: args{ + svcID: "absserviceid", + funcName: "absFunc-Name", + }, + want: &ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: "absserviceid", + FuncName: "absFunc-Name", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewComplexFuncName(tt.args.svcID, tt.args.funcName); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewComplexFuncName() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSetSeparator(t *testing.T) { + convey.Convey("Test SetSeparator", t, func() { + convey.Convey("SetSeparator success", func() { + SetSeparator("-") + convey.So(Separator, convey.ShouldEqual, "-") + }) + }) +} diff --git a/api/go/faassdk/utils/urnutils/urn_utils.go b/api/go/faassdk/utils/urnutils/urn_utils.go new file mode 100644 index 0000000..cd4b9a5 --- /dev/null +++ b/api/go/faassdk/utils/urnutils/urn_utils.go @@ -0,0 +1,124 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package urnutils +package urnutils + +import ( + "fmt" + "strings" + + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +// An example of a function URN: :::::: +// Indices of elements in BaseURN +const ( + // ProductIDIndex is the index of the product ID in a URN + ProductIDIndex = iota + // RegionIDIndex is the index of the region ID in a URN + RegionIDIndex + // BusinessIDIndex is the index of the business ID in a URN + BusinessIDIndex + // TenantIDIndex is the index of the tenant ID in a URN + TenantIDIndex + // FunctionSignIndex is the index of the product ID in a URN + FunctionSignIndex + // FunctionNameIndex is the index of the product name in a URN + FunctionNameIndex + // VersionIndex is the index of the version in a URN + VersionIndex + // URNLenWithVersion is the normal URN length with a version + URNLenWithVersion +) + +const ( + urnLenWithoutVersion = URNLenWithVersion - 1 + // URNSep is a URN separator of functions + URNSep = ":" + // DefaultURNProductID is the default product ID of a URN + DefaultURNProductID = "sn" + // DefaultURNRegion is the default region of a URN + DefaultURNRegion = "cn" + // DefaultURNFuncSign is the default function sign of a URN + DefaultURNFuncSign = "function" +) + +var ( + // LocalFuncURN is URN of local function + LocalFuncURN = &BaseURN{} +) + +// BaseURN contains elements of a product URN. It can expand to FunctionURN, LayerURN and WorkerURN +type BaseURN struct { + ProductID string + RegionID string + BusinessID string + TenantID string + TypeSign string + Name string + Version string +} + +// String serializes elements of function URN struct to string +func (p *BaseURN) String() string { + urn := fmt.Sprintf("%s:%s:%s:%s:%s:%s", p.ProductID, p.RegionID, + p.BusinessID, p.TenantID, p.TypeSign, p.Name) + if p.Version != "" { + return fmt.Sprintf("%s:%s", urn, p.Version) + } + return urn +} + +// ParseFrom parses elements from a function URN +func (p *BaseURN) ParseFrom(urn string) error { + elements := strings.Split(urn, URNSep) + urnLen := len(elements) + if urnLen < urnLenWithoutVersion || urnLen > URNLenWithVersion { + return fmt.Errorf("failed to parse urn from %s, invalid length", urn) + } + p.ProductID = elements[ProductIDIndex] + p.RegionID = elements[RegionIDIndex] + p.BusinessID = elements[BusinessIDIndex] + p.TenantID = elements[TenantIDIndex] + p.TypeSign = elements[FunctionSignIndex] + p.Name = elements[FunctionNameIndex] + if urnLen == URNLenWithVersion { + p.Version = elements[VersionIndex] + } + return nil +} + +// StringWithoutVersion return string without version +func (p *BaseURN) StringWithoutVersion() string { + return fmt.Sprintf("%s:%s:%s:%s:%s:%s", p.ProductID, p.RegionID, + p.BusinessID, p.TenantID, p.TypeSign, p.Name) +} + +// GetFunctionInfo collects function information from a URN +func GetFunctionInfo(urn string) (BaseURN, error) { + var parsedURN BaseURN + if err := parsedURN.ParseFrom(urn); err != nil { + logger.GetLogger().Errorf("error while parsing an URN: %s", err.Error()) + return BaseURN{}, fmt.Errorf("parsing an URN error: %s", err) + } + return parsedURN, nil +} + +// CombineFunctionKey will generate funcKey from three IDs +func CombineFunctionKey(tenantID, funcName, version string) string { + return fmt.Sprintf("%s/%s/%s", tenantID, funcName, version) +} diff --git a/api/go/faassdk/utils/urnutils/urn_utils_test.go b/api/go/faassdk/utils/urnutils/urn_utils_test.go new file mode 100644 index 0000000..aff409a --- /dev/null +++ b/api/go/faassdk/utils/urnutils/urn_utils_test.go @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package urnutils + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestProductUrn_String(t *testing.T) { + tests := []struct { + name string + fields BaseURN + want string + }{ + { + "stringify with version", + BaseURN{ + "absPrefix", + "absZone", + "absBusinessID", + "absTenantID", + "absProductID", + "absName", + "latest", + }, + "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName:latest", + }, + { + "stringify without version", + BaseURN{ + ProductID: "absPrefix", + RegionID: "absZone", + BusinessID: "absBusinessID", + TenantID: "absTenantID", + TypeSign: "absProductID", + Name: "absName", + }, + "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &BaseURN{ + ProductID: tt.fields.ProductID, + RegionID: tt.fields.RegionID, + BusinessID: tt.fields.BusinessID, + TenantID: tt.fields.TenantID, + TypeSign: tt.fields.TypeSign, + Name: tt.fields.Name, + Version: tt.fields.Version, + } + if got := p.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProductUrn_StringWithoutVersion(t *testing.T) { + tests := []struct { + name string + fields BaseURN + want string + }{ + { + "stringify without version", + BaseURN{ + "absPrefix", + "absZone", + "absBusinessID", + "absTenantID", + "absProductID", + "absName", + "latest", + }, + "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &BaseURN{ + ProductID: tt.fields.ProductID, + RegionID: tt.fields.RegionID, + BusinessID: tt.fields.BusinessID, + TenantID: tt.fields.TenantID, + TypeSign: tt.fields.TypeSign, + Name: tt.fields.Name, + Version: tt.fields.Version, + } + if got := p.StringWithoutVersion(); got != tt.want { + t.Errorf("StringWithoutVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetFunctionInfo(t *testing.T) { + convey.Convey("Test GetFunctionInfo", t, func() { + convey.Convey("GetFunctionInfo when parsedURN error", func() { + baseURN, err := GetFunctionInfo("urn") + convey.So(baseURN, convey.ShouldEqual, BaseURN{}) + convey.So(err, convey.ShouldNotBeNil) + }) + absURN := BaseURN{ + "absPrefix", + "absZone", + "absBusinessID", + "absTenantID", + "absProductID", + "absName", + "latest", + } + absURNStr := "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName:latest" + convey.Convey("GetFunctionInfo success", func() { + baseURN, err := GetFunctionInfo(absURNStr) + convey.So(baseURN, convey.ShouldEqual, absURN) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestCombineFunctionKey(t *testing.T) { + convey.Convey("Test CombineFunctionKey", t, func() { + convey.Convey("CombineFunctionKey success", func() { + str := CombineFunctionKey("tenantID1", "funcName1", "version1") + convey.So(str, convey.ShouldEqual, "tenantID1/funcName1/version1") + }) + }) +} diff --git a/api/go/faassdk/utils/utils.go b/api/go/faassdk/utils/utils.go new file mode 100644 index 0000000..a093759 --- /dev/null +++ b/api/go/faassdk/utils/utils.go @@ -0,0 +1,149 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils +package utils + +import ( + "encoding/json" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "sync" + "unsafe" + + "yuanrong.org/kernel/runtime/libruntime/common/uuid" + + "yuanrong.org/kernel/runtime/faassdk/common/constants" + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" +) + +const ( + minPathPartLength = 2 +) + +var ( + once sync.Once + serverIP = "" +) + +// UniqueID get unique ID +func UniqueID() string { + return uuid.New().String() +} + +// BytesToString convert []byte to string without memory alloc +func BytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// GetServerIP - +func GetServerIP() (string, error) { + var err error + once.Do(func() { + addr, errMsg := GetHostAddr() + if errMsg != nil { + err = errMsg + return + } + serverIP = addr[0] + }) + return serverIP, err +} + +// GetHostAddr - +func GetHostAddr() ([]string, error) { + name, err := os.Hostname() + if err != nil { + logger.GetLogger().Errorf("get hostname failed: %v", err) + return nil, err + } + + addrs, err := net.LookupHost(name) + if err != nil || len(addrs) == 0 { + logger.GetLogger().Errorf("look up host by name failed") + return nil, fmt.Errorf("look up host by name failed") + } + return addrs, nil +} + +// GetLibInfo will parse handler info, example: +// libPath="/tmp" libName="example.init" --> fileName="/tmp/example.so" handlerName="init" +// libPath="/tmp" libName="test.example.init" --> fileName="/tmp/test/example.so" handlerName="init" +func GetLibInfo(libPath, libName string) (string, string) { + path := libPath + parts := strings.Split(libName, ".") + length := len(parts) + handlerName := parts[length-1] + if length < minPathPartLength { + return "", "" + } else if length > minPathPartLength { + tmpPath := filepath.Join(parts[:length-minPathPartLength]...) + path = filepath.Join(path, tmpPath) + } + fileName := filepath.Join(path, parts[length-minPathPartLength]+".so") + return fileName, handlerName +} + +// DealEnv deal with environment and encrypted_user_data +func DealEnv() (error, map[string]string) { + var err error + delegateDecryptMap := make(map[string]string) + environmentMap := make(map[string]string) + rtUserDataMap := make(map[string]string) + envMap := make(map[string]string) + delegateDecrypt := os.Getenv("ENV_DELEGATE_DECRYPT") + if delegateDecrypt == "" { + return nil, nil + } + if err = json.Unmarshal([]byte(delegateDecrypt), &delegateDecryptMap); err != nil { + return err, nil + } + environment, ok := delegateDecryptMap["environment"] + if ok && environment != "" { + if err = json.Unmarshal([]byte(environment), &environmentMap); err != nil { + return err, nil + } + } + for key, value := range environmentMap { + if key != constants.LDLibraryPath { + err = os.Setenv(key, value) + } + envMap[key] = value + } + if err != nil { + return err, nil + } + rtUserData, ok := delegateDecryptMap["encrypted_user_data"] + if ok && rtUserData != "" { + if err = json.Unmarshal([]byte(rtUserData), &rtUserDataMap); err != nil { + return err, nil + } + } + + for key, value := range rtUserDataMap { + envMap[key] = value + } + return err, envMap +} + +// ContainsConnRefusedErr - +func ContainsConnRefusedErr(err error) bool { + const connRefusedStr = "connection refused" + return strings.Contains(err.Error(), connRefusedStr) +} diff --git a/api/go/faassdk/utils/utils_test.go b/api/go/faassdk/utils/utils_test.go new file mode 100644 index 0000000..444936b --- /dev/null +++ b/api/go/faassdk/utils/utils_test.go @@ -0,0 +1,117 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils +package utils + +import ( + "errors" + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestUniqueID(t *testing.T) { + convey.Convey("Test UniqueID", t, func() { + convey.Convey("UniqueID success", func() { + convey.So(UniqueID(), convey.ShouldNotBeEmpty) + }) + }) +} + +func TestBytesToString(t *testing.T) { + convey.Convey("Test BytesToString", t, func() { + convey.Convey("BytesToString success", func() { + convey.So(BytesToString([]byte("str")), convey.ShouldEqual, "str") + }) + }) +} + +func TestGetServerIP(t *testing.T) { + convey.Convey("Test GetServerIP", t, func() { + convey.Convey("GetServerIP success", func() { + ip, err := GetServerIP() + convey.So(ip, convey.ShouldNotBeEmpty) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestGetLibInfo(t *testing.T) { + convey.Convey("Test GetLibInfo", t, func() { + convey.Convey("GetLibInfo when length < minPathPartLength", func() { + fileName, handlerName := GetLibInfo("", "") + convey.So(fileName, convey.ShouldBeEmpty) + convey.So(handlerName, convey.ShouldBeEmpty) + }) + convey.Convey("GetLibInfo success", func() { + fileName, handlerName := GetLibInfo("/tmp", "example.init") + convey.So(fileName, convey.ShouldEqual, "/tmp/example.so") + convey.So(handlerName, convey.ShouldEqual, "init") + }) + convey.Convey("GetLibInfo when length > minPathPartLength", func() { + fileName, handlerName := GetLibInfo("/tmp", "test.example.init") + convey.So(fileName, convey.ShouldEqual, "/tmp/test/example.so") + convey.So(handlerName, convey.ShouldEqual, "init") + }) + }) +} + +func TestDealEnv(t *testing.T) { + convey.Convey("Test DealEnv", t, func() { + convey.Convey("DealEnv when delegateDecrypt == \"\"", func() { + err, envMap := DealEnv() + convey.So(err, convey.ShouldBeNil) + convey.So(len(envMap), convey.ShouldBeZeroValue) + }) + os.Setenv("ENV_DELEGATE_DECRYPT", "delegateDecrypt1") + convey.Convey("DealEnv when json.Unmarshal delegateDecrypt error", func() { + err, envMap := DealEnv() + convey.So(err, convey.ShouldNotBeNil) + convey.So(envMap, convey.ShouldBeNil) + }) + os.Setenv("ENV_DELEGATE_DECRYPT", `{"environment":"env1"}`) + convey.Convey("DealEnv when json.Unmarshal environment error", func() { + err, envMap := DealEnv() + convey.So(err, convey.ShouldNotBeNil) + convey.So(envMap, convey.ShouldBeNil) + }) + str := `{"environment":"{\"env1\":\"value1\"}","encrypted_user_data":"data1"}` + os.Setenv("ENV_DELEGATE_DECRYPT", str) + convey.Convey("DealEnv when json.Unmarshal encrypted_user_data error", func() { + err, envMap := DealEnv() + convey.So(err, convey.ShouldNotBeNil) + convey.So(envMap, convey.ShouldBeNil) + }) + str = `{"environment":"{\"env1\":\"value1\"}","encrypted_user_data":"{\"env2\":\"value2\"}"}` + os.Setenv("ENV_DELEGATE_DECRYPT", str) + convey.Convey("DealEnv success", func() { + err, envMap := DealEnv() + convey.So(err, convey.ShouldBeNil) + convey.So(envMap, convey.ShouldNotBeNil) + }) + }) +} + +func TestContainsConnRefusedErr(t *testing.T) { + convey.Convey("Test ContainsConnRefusedErr", t, func() { + convey.Convey("ContainsConnRefusedErr success", func() { + flag := ContainsConnRefusedErr(errors.New("connection refused")) + convey.So(flag, convey.ShouldBeTrue) + }) + }) +} diff --git a/api/go/go.mod b/api/go/go.mod new file mode 100644 index 0000000..50084fb --- /dev/null +++ b/api/go/go.mod @@ -0,0 +1,57 @@ +module yuanrong.org/kernel/runtime + +go 1.24.1 + +require ( + github.com/agiledragon/gomonkey/v2 v2.11.0 + github.com/asaskevich/govalidator/v11 v11.0.1-0.20250122183457-e11347878e23 + github.com/fsnotify/fsnotify v1.7.0 + github.com/magiconair/properties v1.8.7 + github.com/panjf2000/ants/v2 v2.10.0 + github.com/smartystreets/goconvey v1.8.1 + github.com/stretchr/testify v1.8.3 + github.com/valyala/fasthttp v1.58.0 + github.com/vmihailenco/msgpack v4.0.4+incompatible + github.com/vmihailenco/msgpack/v5 v5.4.1 + go.uber.org/zap v1.27.0 + golang.org/x/crypto v0.24.0 + huawei.com/wisesecurity/sts-sdk v1.0.1-20250319171100-c6b279f3bac +) + +require ( + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v4 v4.4.3 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/compress v1.17.9 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/sirupsen/logrus v1.9.0 // indirect + github.com/smarty/assertions v1.15.0 // indirect + github.com/stretchr/objx v0.1.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/sync v0.3.0 // indirect + golang.org/x/sys v0.21.0 // indirect + google.golang.org/appengine v1.6.8 // indirect + google.golang.org/protobuf v1.36.6 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace ( + github.com/asaskevich/govalidator/v11 => github.com/asaskevich/govalidator/v11 v11.0.1-0.20250122183457-e11347878e23 + github.com/fsnotify/fsnotify v1.4.9 => github.com/fsnotify/fsnotify v1.7.0 + github.com/golang/mock => github.com/golang/mock v1.3.1 + github.com/stretchr/testify => github.com/stretchr/testify v1.7.1 + github.com/valyala/fasthttp => github.com/valyala/fasthttp v1.58.0 + go.uber.org/zap => go.uber.org/zap v1.27.0 + golang.org/x/crypto => golang.org/x/crypto v0.24.0 + golang.org/x/net => golang.org/x/net v0.26.0 + golang.org/x/sync => golang.org/x/sync v0.0.0-20190423024810-112230192c58 + golang.org/x/sys => golang.org/x/sys v0.21.0 + google.golang.org/protobuf => google.golang.org/protobuf v1.36.6 +) diff --git a/api/go/libruntime/api/api.go b/api/go/libruntime/api/api.go new file mode 100644 index 0000000..2a00255 --- /dev/null +++ b/api/go/libruntime/api/api.go @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package api for libruntime +package api + +// LibruntimeAPI defines libruntime's interfaces, supporting function-system data-system and bridge-system operations +type LibruntimeAPI interface { + CreateInstance(funcMeta FunctionMeta, args []Arg, invokeOpt InvokeOptions) (string, error) + InvokeByInstanceId( + funcMeta FunctionMeta, instanceID string, args []Arg, invokeOpt InvokeOptions, + ) (returnObjectID string, err error) + InvokeByFunctionName(funcMeta FunctionMeta, args []Arg, invokeOpt InvokeOptions) (string, error) + AcquireInstance(state string, funcMeta FunctionMeta, acquireOpt InvokeOptions) (InstanceAllocation, error) + ReleaseInstance(allocation InstanceAllocation, stateID string, abnormal bool, option InvokeOptions) + Kill(instanceID string, signal int, payload []byte) error + + CreateInstanceRaw(createReqRaw []byte) ([]byte, error) + InvokeByInstanceIdRaw(invokeReqRaw []byte) ([]byte, error) + KillRaw(killReqRaw []byte) ([]byte, error) + + SaveState(state []byte) (string, error) + LoadState(checkpointID string) ([]byte, error) + + Exit(code int, message string) + Finalize() + + KVSet(key string, value []byte, param SetParam) error + KVSetWithoutKey(value []byte, param SetParam) (string, error) + KVMSetTx(keys []string, values [][]byte, param MSetParam) error + KVGet(key string, timeoutms uint) ([]byte, error) + KVGetMulti(keys []string, timeoutms uint) ([][]byte, error) + KVDel(key string) error + KVDelMulti(keys []string) ([]string, error) + + CreateProducer(streamName string, producerConf ProducerConf) (StreamProducer, error) + Subscribe(streamName string, config SubscriptionConfig) (StreamConsumer, error) + DeleteStream(streamName string) error + QueryGlobalProducersNum(streamName string) (uint64, error) + QueryGlobalConsumersNum(streamName string) (uint64, error) + + SetTraceID(traceID string) + SetTenantID(tenantID string) error + + Put(objectID string, value []byte, param PutParam, nestedObjectIDs ...string) error + PutRaw(objectID string, value []byte, param PutParam, nestedObjectIDs ...string) error + Get(objectIDs []string, timeoutMs int) ([][]byte, error) + GetRaw(objectIDs []string, timeoutMs int) ([][]byte, error) + Wait(objectIDs []string, waitNum uint64, timeoutMs int) ([]string, []string, map[string]error) + GIncreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) + GIncreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) + GDecreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) + GDecreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) + GetAsync(objectID string, cb GetAsyncCallback) + + GetFormatLogger() FormatLogger + + CreateClient(config ConnectArguments) (KvClient, error) + ReleaseGRefs(remoteClientID string) error + GetCredential() Credential + UpdateSchdulerInfo(schedulerName string, schedulerId string, option string) + IsHealth() bool + IsDsHealth() bool +} + +// KvClient - +type KvClient interface { + KVSet(key string, value []byte, param SetParam) ErrorInfo + + KVSetWithoutKey(value []byte, param SetParam) (string, ErrorInfo) + + KVGet(key string, timeoutMs ...uint32) ([]byte, ErrorInfo) + + KVGetMulti(keys []string, timeoutMs ...uint32) ([][]byte, ErrorInfo) + + KVQuerySize(keys []string) ([]uint64, ErrorInfo) + + KVDel(key string) ErrorInfo + + KVDelMulti(keys []string) ([]string, ErrorInfo) + + GenerateKey() string + + SetTraceID(traceID string) + + DestroyClient() +} diff --git a/api/go/libruntime/api/types.go b/api/go/libruntime/api/types.go new file mode 100644 index 0000000..17e4c31 --- /dev/null +++ b/api/go/libruntime/api/types.go @@ -0,0 +1,480 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package api for libruntime +package api + +import ( + "fmt" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// ProducerConf represents configuration information for a producer. +type ProducerConf struct { + DelayFlushTime int64 // Trigger Flush after the maximum Delay duration after Send. + PageSize int64 // Size of the buffer page corresponding to the producer. + MaxStreamSize uint64 // Maximum size of the shared memory that can be used by a stream on a worker. + TraceId string // trace id of producer config +} + +// SubscriptionType subscribe type +type SubscriptionType int + +const ( + // Stream Single consumer consumption + Stream SubscriptionType = iota + // RoundRobin Multi-consumer load-balanced consumption round robin + RoundRobin + // KeyPartitions Multi-consumer load-balanced consumption by key partition + KeyPartitions +) + +// SubscriptionConfig represents configuration information for a consumer. +type SubscriptionConfig struct { + SubscriptionName string + SubscriptionType SubscriptionType + TraceId string +} + +// Element data to be sent +type Element struct { + Ptr *uint8 + Size uint64 + Id uint64 +} + +// StreamProducer producer interface +type StreamProducer interface { + Send(element Element) error + SendWithTimeout(element Element, timeoutMs int64) error + Flush() error + Close() error +} + +// StreamConsumer consumer interface +type StreamConsumer interface { + ReceiveExpectNum(expectNum uint32, timeoutMs uint32) ([]Element, error) + Receive(timeoutMs uint32) ([]Element, error) + Ack(elementId uint64) error + Close() error +} + +// ApiType api type, for example, actor +type ApiType int32 + +const ( + // ActorApi actor api + ActorApi ApiType = 0 + // FaaSApi faas api + FaaSApi ApiType = 1 + // PosixApi posix api + PosixApi ApiType = 2 + // ServeApi posix api + ServeApi ApiType = 3 +) + +// LanguageType language type, for example, cpp +type LanguageType int32 + +const ( + // Cpp cpp language + Cpp LanguageType = iota + // Python python language + Python + // Java java language + Java + // Golang go language + Golang + // NodeJS node js language + NodeJS + // CSharp c# language + CSharp + // Php php language + Php +) + +// FunctionMeta function meta +type FunctionMeta struct { + AppName string + FuncName string + FuncID string + Sig string + PoolLabel string + Name *string + Namespace *string + Api ApiType + Language LanguageType +} + +// StackTracesInfo stack info +type StackTracesInfo struct { + Code int + MCode int + Message string + StackTraces []StackTrace +} + +// StackTrace stack trace +type StackTrace struct { + ClassName string + MethodName string + FileName string + LineNumber int64 + ExtensionInfo map[string]string +} + +// ArgType value or objectRef +type ArgType int32 + +const ( + // Value value of object + Value ArgType = 0 + // ObjectRef ref of object + ObjectRef ArgType = 1 +) + +// Arg function arg +type Arg struct { + Type ArgType + Data []byte + ObjectID string + TenantID string + NestedObjectIDs []string +} + +// OperatorType operator type +type OperatorType int32 + +const ( + // LabelOpIn in + LabelOpIn OperatorType = 0 + // LabelOpNotIn not in + LabelOpNotIn OperatorType = 1 + // LabelOpExists exists + LabelOpExists OperatorType = 2 + // LabelOpNotExists not exists + LabelOpNotExists OperatorType = 3 +) + +// LabelOperator affinity label operator +type LabelOperator struct { + Type OperatorType + LabelKey string + LabelValues []string +} + +// AffinityKindType affinity type +type AffinityKindType int32 + +const ( + // AffinityKindResource resource + AffinityKindResource AffinityKindType = 0 + // AffinityKindInstance instance + AffinityKindInstance AffinityKindType = 1 +) + +// AffinityType affinity type +type AffinityType int32 + +const ( + // PreferredAffinity prefer + PreferredAffinity AffinityType = 0 + // PreferredAntiAffinity prefer anti + PreferredAntiAffinity AffinityType = 1 + // RequiredAffinity required + RequiredAffinity AffinityType = 2 + // RequiredAntiAffinity required anti + RequiredAntiAffinity AffinityType = 3 +) + +// Affinity - +type Affinity struct { + Kind AffinityKindType + Affinity AffinityType + PreferredPriority bool + PreferredAntiOtherLabels bool + LabelOps []LabelOperator +} + +// InstanceSessionConfig contains session config for instance +type InstanceSessionConfig struct { + SessionID string `json:"sessionID"` + SessionTTL int `json:"sessionTTL"` + Concurrency int `json:"concurrency"` +} + +// InvokeOptions invoke option +type InvokeOptions struct { + Cpu int + Memory int + InvokeLabels map[string]string + CustomResources map[string]float64 + CustomExtensions map[string]string + CreateOpt map[string]string + InstanceSession *InstanceSessionConfig + Labels []string + Affinities map[string]string // deprecated + ScheduleAffinities []Affinity + ScheduleTimeoutMs int64 + RetryTimes int + RecoverRetryTimes int + Priority int + CodePaths []string + SchedulerFunctionID string + SchedulerInstanceIDs []string + TraceID string + Timeout int + AcquireTimeout int + TrafficLimited bool +} + +// InstanceAllocation holds the allocation of instance +type InstanceAllocation struct { + FuncKey string + FuncSig string + InstanceID string + LeaseID string + LeaseInterval int64 +} + +// WriteModeEnum kv write mode +type WriteModeEnum int + +// constants of WriteModeEnum +const ( + NoneL2Cache WriteModeEnum = iota + WriteThroughL2Cache + WriteBackL2Cache + NoneL2CacheEvict +) + +// CacheTypeEnum kv cache type +type CacheTypeEnum int + +// constants of CacheTypeEnum +const ( + MEMORY CacheTypeEnum = iota + DISK +) + +// SetParam structure is used to transfer parameters during the set operation. +type SetParam struct { + WriteMode WriteModeEnum + TTLSecond uint32 + Existence int32 + CacheType CacheTypeEnum +} + +// MSetParam structure is used to transfer parameters during the mSetTx operation. +type MSetParam struct { + WriteMode WriteModeEnum + TTLSecond uint32 + Existence int32 + CacheType CacheTypeEnum +} + +// ConsistencyTypeEnum is the new type defined. +// Use const Param and Causal together to simulate the enumeration type of C++. +type ConsistencyTypeEnum int + +// The constant Pram and Causal is of the ConsistencyTypeEnum type. +// They is used as two enumeration constants of ConsistencyTypeEnum to simulates the enumeration type of C++. +const ( + PRAM ConsistencyTypeEnum = iota + CAUSAL +) + +// PutParam structure is used to transfer parameters during the Put operation. +// It takes parameters including WriteMode and ConsistencyType. +// The WriteMode is a WriteModeEnum "enumeration" type, +// The ConsistencyType is a ConsistencyTypeEnum "enumeration" type, +type PutParam struct { + WriteMode WriteModeEnum + ConsistencyType ConsistencyTypeEnum + CacheType CacheTypeEnum +} + +// HealthType - +type HealthType int32 + +// HealthCheckRequest - +type HealthCheckRequest struct { +} + +// HealthCheckResponse - +type HealthCheckResponse struct { + Code HealthType +} + +// constants of HealthType +const ( + Healthy HealthType = iota + HealthCheckFailed + SubHealth +) + +// Credential - +type Credential struct { + AccessKey string + SecretKey []byte + DataKey []byte +} + +// ErrorInfo error info +type ErrorInfo struct { + Code int + Err error + StackTracesInfo StackTracesInfo +} + +// IsOk - +func (e ErrorInfo) IsOk() bool { + return e.Code == Ok +} + +// IsError - +func (e ErrorInfo) IsError() bool { + return !e.IsOk() +} + +func (e ErrorInfo) Error() string { + if len(e.StackTracesInfo.Message) > 0 { + errMsg := e.StackTracesInfo.Message + stackTraces := e.getStackTracesInfo() + return fmt.Sprintf("%s\n%s", errMsg, stackTraces) + } + return e.Err.Error() +} + +// AddStack add stack to ErrorInfo +func AddStack(err error, stack StackTrace) error { + // if stack is empty, it is not appended to the stack. + if stack.ClassName == "" { + return err + } + errInfo := TurnErrInfo(err) + // if errInfo.StackTracesInfo.StackTraces is not empty, this stack has been collected + if len(errInfo.StackTracesInfo.StackTraces) == 0 { + errInfo.StackTracesInfo.StackTraces = []StackTrace{stack} + } + return errInfo +} + +// NewErrorInfoWithStackInfo add StackInfo to ErrorInfo +func NewErrorInfoWithStackInfo(err error, stacks []StackTrace) error { + if len(stacks) == 0 { + return err + } + errInfo := TurnErrInfo(err) + errInfo.StackTracesInfo.StackTraces = stacks + return errInfo +} + +// TurnErrInfo turn err To ErrorInfo +func TurnErrInfo(err error) ErrorInfo { + errInfo, ok := err.(ErrorInfo) + if ok { + return errInfo + } + errInfo = ErrorInfo{ + Err: err, + StackTracesInfo: StackTracesInfo{ + Message: err.Error(), + }, + } + return errInfo +} + +// NewErrInfo new ErrorInfo +func NewErrInfo(code int, message string, stackTracesInfo StackTracesInfo) ErrorInfo { + return ErrorInfo{ + Code: code, + Err: fmt.Errorf(message), + StackTracesInfo: stackTracesInfo, + } +} + +func (e ErrorInfo) getStackTracesInfo() string { + var info string + for _, v := range e.StackTracesInfo.StackTraces { + var funcInfo, fileInfo, parameters, offset string + if v.ExtensionInfo != nil { + parameters = v.ExtensionInfo["parameters"] + offset = v.ExtensionInfo["offset"] + } + funcInfo = fmt.Sprintf("%s.%s%s\n", v.ClassName, v.MethodName, parameters) + if v.LineNumber == 0 { + fileInfo = fmt.Sprintf(" %s\n", v.FileName) + } else { + fileInfo = fmt.Sprintf(" %s:%d %s\n", v.FileName, v.LineNumber, offset) + } + info = info + funcInfo + fileInfo + } + return info +} + +// GetAsyncCallback define the get async callback function type. +type GetAsyncCallback func(result []byte, err error) + +// FormatLogger format logger interface +type FormatLogger interface { + With(fields ...zapcore.Field) FormatLogger + + Infof(format string, paras ...interface{}) + Errorf(format string, paras ...interface{}) + Warnf(format string, paras ...interface{}) + Debugf(format string, paras ...interface{}) + Fatalf(format string, paras ...interface{}) + + Info(msg string, fields ...zap.Field) + Error(msg string, fields ...zap.Field) + Warn(msg string, fields ...zap.Field) + Debug(msg string, fields ...zap.Field) + Fatal(msg string, fields ...zap.Field) + + Sync() +} + +const ( + // Ok indicates the operation was successful. + Ok = 0 + + // InvalidParam indicates that an invalid parameter was provided. + InvalidParam = 2 + + // DsClientNilError indicates that dsclient is destructed. + DsClientNilError = 11001 +) + +// ConnectArguments - +type ConnectArguments struct { + Host string + Port int + TimeoutMs int + Token []byte + ClientPublicKey string + ClientPrivateKey []byte + ServerPublicKey string + AccessKey string + SecretKey []byte + AuthclientID string + AuthclientSecret []byte + AuthURL string + TenantID string + EnableCrossNodeConnection bool +} diff --git a/api/go/libruntime/api/types_test.go b/api/go/libruntime/api/types_test.go new file mode 100644 index 0000000..d02ba55 --- /dev/null +++ b/api/go/libruntime/api/types_test.go @@ -0,0 +1,158 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package api for libruntime +package api + +import ( + "errors" + "fmt" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestNewErrInfo(t *testing.T) { + info := StackTracesInfo{} + err := NewErrInfo(0, "", info) + if err.Err == nil { + t.Errorf("test NewErrInfo failed") + } +} + +func TestNewErrorInfoWithStackInfo(t *testing.T) { + stack := StackTrace{ + ClassName: "main", + MethodName: "(*Manager.AddValue)", + FileName: "a.go", + LineNumber: 12, + ExtensionInfo: nil, + } + err := errors.New("error info") + err = NewErrorInfoWithStackInfo(err, []StackTrace{stack}) + msg := fmt.Sprintf("%s\n%s", "error info", TurnErrInfo(err).getStackTracesInfo()) + if err.Error() != msg { + t.Errorf("failed to new err to error with stack, %s %s", err.Error(), msg) + } + + convey.Convey( + "Test NewErrorInfoWithStackInfo", t, func() { + convey.Convey( + "NewErrorInfoWithStackInfo success when len(stacks)==0", func() { + e := NewErrorInfoWithStackInfo(err, []StackTrace{}) + convey.So(e, convey.ShouldEqual, err) + }, + ) + }, + ) +} + +func TestGetStackTracesInfo(t *testing.T) { + className := "main" + methodName := "(*Manager.AddValue)" + fileName := "a.go" + linNumber := 12 + stack := StackTrace{ + ClassName: className, + MethodName: methodName, + FileName: "a.go", + LineNumber: 12, + ExtensionInfo: nil, + } + err := errors.New("error info") + errInfo := ErrorInfo{ + Code: 2002, + Err: err, + StackTracesInfo: StackTracesInfo{ + Code: 2002, + MCode: 0, + Message: err.Error(), + StackTraces: []StackTrace{stack}, + }, + } + msg := errInfo.getStackTracesInfo() + funcInfo := fmt.Sprintf("%s.%s%s\n", className, methodName, "") + fileInfo := fmt.Sprintf(" %s:%d %s\n", fileName, linNumber, "") + if msg != funcInfo+fileInfo { + t.Errorf("error message %s shoud be %s", msg, funcInfo+fileInfo) + } + + convey.Convey( + "Test getStackTracesInfo", t, func() { + convey.Convey( + "getStackTracesInfo success", func() { + st := StackTrace{LineNumber: 0} + st.ExtensionInfo = make(map[string]string) + var e ErrorInfo + e.StackTracesInfo.StackTraces = []StackTrace{st} + str := e.getStackTracesInfo() + convey.So(str, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestIsError(t *testing.T) { + convey.Convey( + "Test IsError", t, func() { + convey.Convey( + "IsError success", func() { + var e ErrorInfo + flag := e.IsError() + convey.So(flag, convey.ShouldBeFalse) + }, + ) + }, + ) +} + +func TestError(t *testing.T) { + convey.Convey( + "Test Error", t, func() { + convey.Convey( + "Error success", func() { + var e ErrorInfo + e.Err = errors.New("error info") + str := e.Error() + convey.So(str, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestAddStack(t *testing.T) { + convey.Convey( + "Test AddStack", t, func() { + var stack StackTrace + err := errors.New("error info") + convey.Convey( + "AddStack success when stack is empty", func() { + e := AddStack(err, stack) + convey.So(e, convey.ShouldEqual, err) + }, + ) + convey.Convey( + "AddStack success", func() { + stack.ClassName = "stackClassName" + e := AddStack(err, stack) + convey.So(e, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/clibruntime/clibruntime.go b/api/go/libruntime/clibruntime/clibruntime.go new file mode 100644 index 0000000..29d5414 --- /dev/null +++ b/api/go/libruntime/clibruntime/clibruntime.go @@ -0,0 +1,2539 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +This package clibruntime encapsulates all cgo invocations. +*/ +package clibruntime + +/* +#cgo CFLAGS: -I../cpplibruntime +#cgo LDFLAGS: -L../../../../build/output/runtime/service/go/bin -lcpplibruntime +#include +#include +#include "clibruntime.h" +*/ +import "C" +import ( + "errors" + "fmt" + "reflect" + "sync" + "time" + "unsafe" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/uuid" + "yuanrong.org/kernel/runtime/libruntime/config" +) + +// KvClientImpl - +type KvClientImpl struct { + stateStore *StateStore +} + +// StateStore - +type StateStore struct { + mutex *sync.RWMutex + stateStorePtr C.CStateStorePtr +} + +func kvClientCheckNil(ptr *KvClientImpl) api.ErrorInfo { + if ptr == nil { + return api.ErrorInfo{Code: api.DsClientNilError, Err: errors.New("client is nil")} + } + if ptr.stateStore == nil { + return api.ErrorInfo{Code: api.DsClientNilError, Err: errors.New("stateStore is nil")} + } + return api.ErrorInfo{Code: api.Ok, Err: nil} +} + +func stateStorePtrCheckNil(ptr *StateStore) api.ErrorInfo { + if ptr.stateStorePtr == nil { + return api.ErrorInfo{Code: api.DsClientNilError, Err: errors.New("cStateStorePtr is nil")} + } + return api.ErrorInfo{Code: api.Ok, Err: nil} +} + +// CreateClient - +func CreateClient(config api.ConnectArguments) (api.KvClient, error) { + arguments := C.CConnectArguments{ + host: CSafeString(config.Host), + port: C.int(config.Port), + timeoutMs: C.int(config.TimeoutMs), + token: CSafeBytes(config.Token), + tokenLen: C.int(len(config.Token)), + clientPublicKey: CSafeString(config.ClientPublicKey), + clientPublicKeyLen: C.int(len(config.ClientPublicKey)), + clientPrivateKey: CSafeBytes(config.ClientPrivateKey), + clientPrivateKeyLen: C.int(len(config.ClientPrivateKey)), + serverPublicKey: CSafeString(config.ServerPublicKey), + serverPublicKeyLen: C.int(len(config.ServerPublicKey)), + accessKey: CSafeString(config.AccessKey), + accessKeyLen: C.int(len(config.AccessKey)), + secretKey: CSafeBytes(config.SecretKey), + secretKeyLen: C.int(len(config.SecretKey)), + authClientID: CSafeString(config.AuthclientID), + authClientIDLen: C.int(len(config.AuthclientID)), + authClientSecret: CSafeBytes(config.AuthclientSecret), + authClientSecretLen: C.int(len(config.AuthclientSecret)), + authUrl: CSafeString(config.AuthURL), + authUrlLen: C.int(len(config.AuthURL)), + tenantID: CSafeString(config.TenantID), + tenantIDLen: C.int(len(config.TenantID)), + enableCrossNodeConnection: C.char(btoi(config.EnableCrossNodeConnection)), + } + defer freeCArguments(&arguments) + var s StateStore + cErr := C.CCreateStateStore(&arguments, &s.stateStorePtr) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroErr(code, cErr, "create kv client: ") + } + s.mutex = new(sync.RWMutex) + return &KvClientImpl{stateStore: &s}, nil +} + +func messageFree(cErr C.CErrorInfo) string { + msg := CSafeGoString(cErr.message) + CSafeFree(cErr.message) + return msg +} + +func codeNotZeroErr(code int, cErr C.CErrorInfo, str string) error { + msg := messageFree(cErr) + return api.ErrorInfo{Code: code, Err: fmt.Errorf(str+"%s", msg)} +} + +func codeNotZeroDsErr(code int, cErr C.CErrorInfo, str string) api.ErrorInfo { + msg := messageFree(cErr) + return api.ErrorInfo{ + Code: int(cErr.dsStatusCode), Err: api.ErrorInfo{Code: code, Err: fmt.Errorf(str+"%s", msg)}, + } +} + +// KVSet - +func (c *KvClientImpl) KVSet(key string, value []byte, param api.SetParam) api.ErrorInfo { + status := kvClientCheckNil(c) + if status.IsError() { + return status + } + cKey := CSafeString(key) + defer CSafeFree(cKey) + + cValue, cValueLen := ByteSliceToCBinaryDataNoCopy(value) + cBuf := C.CBuffer{ + buffer: cValue, + size_buffer: C.int64_t(cValueLen), + selfSharedPtrBuffer: nil, + } + cParam := cSetParam(param) + + c.stateStore.mutex.RLock() + defer c.stateStore.mutex.RUnlock() + status = stateStorePtrCheckNil(c.stateStore) + if status.IsError() { + return status + } + + cErr := C.CSetByStateStore(c.stateStore.stateStorePtr, cKey, cBuf, cParam) + code := int(cErr.code) + if code != 0 { + return codeNotZeroDsErr(code, cErr, "kv set: ") + } + return api.ErrorInfo{Code: api.Ok, Err: nil} +} + +// KVSetWithoutKey - +func (c *KvClientImpl) KVSetWithoutKey(value []byte, param api.SetParam) (string, api.ErrorInfo) { + status := kvClientCheckNil(c) + if status.IsError() { + return "", status + } + cValue, cValueLen := ByteSliceToCBinaryDataNoCopy(value) + cBuf := C.CBuffer{ + buffer: cValue, + size_buffer: C.int64_t(cValueLen), + selfSharedPtrBuffer: nil, + } + var cKey *C.char = nil + defer func() { + CSafeFree(cKey) + }() + var cKeyLen C.int + cParam := cSetParam(param) + + c.stateStore.mutex.RLock() + defer c.stateStore.mutex.RUnlock() + status = stateStorePtrCheckNil(c.stateStore) + if status.IsError() { + return "", status + } + + cErr := C.CSetValueByStateStore(c.stateStore.stateStorePtr, cBuf, cParam, &cKey, &cKeyLen) + code := int(cErr.code) + if code != 0 { + return "", codeNotZeroDsErr(code, cErr, "kv set value: ") + } + return CSafeGoStringN(cKey, cKeyLen), api.ErrorInfo{Code: api.Ok, Err: nil} +} + +// KVGet - +func (c *KvClientImpl) KVGet(key string, timeoutMs ...uint32) ([]byte, api.ErrorInfo) { + status := kvClientCheckNil(c) + if status.IsError() { + return nil, status + } + cTimeoutMs := C.int(0) + if len(timeoutMs) > 0 { + cTimeoutMs = C.int(timeoutMs[0]) + } + var cData C.CBuffer + cKey := CSafeString(key) + defer CSafeFree(cKey) + + c.stateStore.mutex.RLock() + defer c.stateStore.mutex.RUnlock() + status = stateStorePtrCheckNil(c.stateStore) + if status.IsError() { + return nil, status + } + + cErr := C.CGetByStateStore(c.stateStore.stateStorePtr, cKey, &cData, cTimeoutMs) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroDsErr(code, cErr, "kv get: ") + } + defer CSafeFree((*C.char)(cData.buffer)) + return C.GoBytes(cData.buffer, C.int(cData.size_buffer)), api.ErrorInfo{Code: api.Ok, Err: nil} +} + +// GetCredential - +func GetCredential() api.Credential { + cCredential := C.CGetCredential() + defer CSafeFree(cCredential.ak) + defer CSafeFree(cCredential.dk) + defer CSafeFree(cCredential.sk) + return GoCredential(cCredential) +} + +// KVGetMulti - +func (c *KvClientImpl) KVGetMulti(keys []string, timeoutMs ...uint32) ([][]byte, api.ErrorInfo) { + status := kvClientCheckNil(c) + if status.IsError() { + return nil, status + } + cTimeoutMs := C.int(0) + if len(timeoutMs) > 0 { + cTimeoutMs = C.int(timeoutMs[0]) + } + cKeys, cKeysLen := CStrings(keys) + defer freeCStrings(cKeys, cKeysLen) + cBuffersSlice := make([]C.CBuffer, len(keys)) + cBufferPtr := (*C.CBuffer)(unsafe.Pointer(&cBuffersSlice[0])) + + c.stateStore.mutex.RLock() + defer c.stateStore.mutex.RUnlock() + status = stateStorePtrCheckNil(c.stateStore) + if status.IsError() { + return nil, status + } + + cErr := C.CGetArrayByStateStore(c.stateStore.stateStorePtr, cKeys, cKeysLen, cBufferPtr, cTimeoutMs) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroDsErr(code, cErr, "kv get array: ") + } + values := make([][]byte, len(keys)) + for i, val := range cBuffersSlice { + values[i] = C.GoBytes(val.buffer, C.int(val.size_buffer)) + CSafeFree((*C.char)(val.buffer)) + } + return values, api.ErrorInfo{Code: api.Ok, Err: nil} +} + +// KVQuerySize - +func (c *KvClientImpl) KVQuerySize(keys []string) ([]uint64, api.ErrorInfo) { + status := kvClientCheckNil(c) + if status.IsError() { + return nil, status + } + cKeys, cKeysLen := CStrings(keys) + defer freeCStrings(cKeys, cKeysLen) + sizes := make([]uint64, len(keys)) + cSizes := (*C.uint64_t)(unsafe.Pointer(&sizes[0])) + + c.stateStore.mutex.RLock() + defer c.stateStore.mutex.RUnlock() + status = stateStorePtrCheckNil(c.stateStore) + if status.IsError() { + return nil, status + } + + cErr := C.CQuerySizeByStateStore(c.stateStore.stateStorePtr, cKeys, cKeysLen, cSizes) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroDsErr(code, cErr, "kv query size: ") + } + return sizes, api.ErrorInfo{Code: api.Ok, Err: nil} +} + +// KVDel - +func (c *KvClientImpl) KVDel(key string) api.ErrorInfo { + status := kvClientCheckNil(c) + if status.IsError() { + return status + } + cKey := CSafeString(key) + defer CSafeFree(cKey) + + c.stateStore.mutex.RLock() + defer c.stateStore.mutex.RUnlock() + status = stateStorePtrCheckNil(c.stateStore) + if status.IsError() { + return status + } + + cErr := C.CDelByStateStore(c.stateStore.stateStorePtr, cKey) + code := int(cErr.code) + if code != 0 { + return codeNotZeroDsErr(code, cErr, "kv del: ") + } + return api.ErrorInfo{Code: api.Ok, Err: nil} +} + +// KVDelMulti - +func (c *KvClientImpl) KVDelMulti(keys []string) ([]string, api.ErrorInfo) { + status := kvClientCheckNil(c) + if status.IsError() { + return keys, status + } + cKeys, cKeysLen := CStrings(keys) + defer freeCStrings(cKeys, cKeysLen) + var cFailedKeys **C.char + var cFailedKeysLen C.int + + c.stateStore.mutex.RLock() + defer c.stateStore.mutex.RUnlock() + status = stateStorePtrCheckNil(c.stateStore) // double check + if status.IsError() { + return keys, status + } + + cErr := C.CDelArrayByStateStore(c.stateStore.stateStorePtr, cKeys, cKeysLen, &cFailedKeys, &cFailedKeysLen) + var failedKeys []string + if int(cFailedKeysLen) > 0 { + failedKeys = GoStrings(cFailedKeysLen, cFailedKeys) + } + code := int(cErr.code) + if code != 0 { + return failedKeys, codeNotZeroDsErr(code, cErr, "kv del array: ") + } + return failedKeys, api.ErrorInfo{Code: api.Ok, Err: nil} +} + +// GenerateKey - +func (c *KvClientImpl) GenerateKey() string { + status := kvClientCheckNil(c) + if status.IsError() { + return "" + } + var cKey *C.char = nil + defer func() { + CSafeFree(cKey) + }() + var cKeyLen C.int + c.stateStore.mutex.RLock() + defer c.stateStore.mutex.RUnlock() + status = stateStorePtrCheckNil(c.stateStore) + if status.IsError() { + return "" + } + + cErr := C.CGenerateKey(c.stateStore.stateStorePtr, &cKey, &cKeyLen) + if code := int(cErr.code); code != 0 { + messageFree(cErr) + return "" + } + return CSafeGoStringN(cKey, cKeyLen) +} + +// SetTraceID - +func (c *KvClientImpl) SetTraceID(traceID string) { + cTraceID := CSafeString(traceID) + defer CSafeFree(cTraceID) + traceIDLen := C.int(len(traceID)) + cErr := C.CSetTraceId(cTraceID, traceIDLen) + code := int(cErr.code) + if code != 0 { + messageFree(cErr) + } +} + +// DestroyClient - +func (c *KvClientImpl) DestroyClient() { + status := kvClientCheckNil(c) + if status.IsError() { + return + } + c.stateStore.mutex.Lock() + defer c.stateStore.mutex.Unlock() + C.CDestroyStateStore(c.stateStore.stateStorePtr) + c.stateStore.stateStorePtr = nil +} + +func freeCArguments(arguments *C.CConnectArguments) { + CSafeFree(arguments.host) + CSafeFree(arguments.token) + CSafeFree(arguments.clientPublicKey) + CSafeFree(arguments.clientPrivateKey) + CSafeFree(arguments.serverPublicKey) + CSafeFree(arguments.accessKey) + CSafeFree(arguments.secretKey) + CSafeFree(arguments.authClientID) + CSafeFree(arguments.authClientSecret) + CSafeFree(arguments.authUrl) + CSafeFree(arguments.tenantID) +} + +// StreamProducerImpl struct represents a producer of streaming data. +type StreamProducerImpl struct { + producer C.Producer_p +} + +// StreamConsumerImpl struct represents a consumer of streaming data. +type StreamConsumerImpl struct { + consumer C.Consumer_p +} + +// CSafeString - +func CSafeString(s string) *C.char { + if len(s) == 0 { + return nil + } + return C.CString(s) +} + +// CSafeFree - +func CSafeFree(s *C.char) { + if s == nil { + return + } + C.free(unsafe.Pointer(s)) +} + +// CSafeBytes - +func CSafeBytes(b []byte) *C.char { + if len(b) == 0 { + return nil + } + return (*C.char)(C.CBytes(b)) +} + +// CSafeGoBytes - +func CSafeGoBytes(cStr *C.char, length C.int) []byte { + if cStr == nil { + return nil + } + byteArr := C.GoBytes(unsafe.Pointer(cStr), length) + return byteArr +} + +// CSafeGoString safely converts a *C.char to a Go string. +func CSafeGoString(message *C.char) string { + if message == nil { + return "" + } else { + return C.GoString(message) + } +} + +// CSafeGoStringN - +func CSafeGoStringN(message *C.char, length C.int) string { + if message == nil || length == 0 { + return "" + } else { + return C.GoStringN(message, length) + } +} + +// Send sends an element. +// This method can be used to send data to consumers. +func (p *StreamProducerImpl) Send(element api.Element) error { + cPtr := (*C.uint8_t)(unsafe.Pointer(element.Ptr)) + cErr := C.CProducerSend(p.producer, cPtr, C.uint64_t(element.Size), C.uint64_t(element.Id)) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "stream producer send: ") + } + return nil +} + +// SendWithTimeout sends an element with a specified timeout. +// If the element cannot be sent within the timeout duration, it will be discarded. +// Parameters: +// - element: the element to be sent. +// - timeoutMs: the duration to wait before discarding the element. +func (p *StreamProducerImpl) SendWithTimeout(element api.Element, timeoutMs int64) error { + cPtr := (*C.uint8_t)(unsafe.Pointer(element.Ptr)) + cErr := C.CProducerSendWithTimeout( + p.producer, cPtr, C.uint64_t(element.Size), C.uint64_t(element.Id), C.int64_t(timeoutMs), + ) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "stream producer send with timeout: ") + } + return nil +} + +// Flush ensure flush buffered data so that it is visible to the consumer. +func (p *StreamProducerImpl) Flush() error { + cErr := C.CProducerFlush(p.producer) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "stream producer flush: ") + } + return nil +} + +// Close signals the producer to stop accepting new data and automatically flushes +// any pending data in the buffer. Once closed, the producer is no longer available. +func (p *StreamProducerImpl) Close() error { + cErr := C.CProducerClose(p.producer) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "stream producer close: ") + } + return nil +} + +// Receive retrieves data from the consumer with an optional timeout. +// Parameters: +// - timeoutMs: Maximum time in milliseconds to wait for data before timing out. +// Returns: +// - []api.Element: The received data. +// - error: nil if data was received within the timeout, error otherwise. +func (c *StreamConsumerImpl) Receive(timeoutMs uint32) ([]api.Element, error) { + var count C.uint64_t + var pEles *C.CElement + cErr := C.CConsumerReceive(c.consumer, C.uint32_t(timeoutMs), &pEles, &count) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroErr(code, cErr, "stream consumer receive: ") + } + eles := GoEles(pEles, count) + return eles, nil +} + +// ReceiveExpectNum waits to receive the expected number of elements from the consumer, +// either until the timeout is reached or the expected number of elements are received. +// Parameters: +// - expectNum: The expected number of elements to receive. +// - timeoutMs: Maximum time in milliseconds to wait before timing out. +// Returns: +// - []api.Element: The received data. +// - error: nil if data was received within the timeout, error otherwise. +func (c *StreamConsumerImpl) ReceiveExpectNum(expectNum uint32, timeoutMs uint32) ([]api.Element, error) { + var count C.uint64_t + var pEles *C.CElement + cErr := C.CConsumerReceiveExpectNum(c.consumer, C.uint32_t(expectNum), C.uint32_t(timeoutMs), &pEles, &count) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroErr(code, cErr, fmt.Sprintf("stream consumer receive expect num: %d", expectNum)) + } + eles := GoEles(pEles, count) + return eles, nil +} + +// Ack confirms that the consumer has completed processing the element identified by elementID. +// This function signals to other workers whether the consumer has finished processing the element. +// If all consumers have acknowledged processing the element, it triggers internal memory reclamation +// for the corresponding page. +// Parameters: +// - elementID: The identifier of the element that has been consumed. +func (c *StreamConsumerImpl) Ack(elementId uint64) error { + cErr := C.CConsumerAck(c.consumer, C.uint64_t(elementId)) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "stream consumer ack: ") + } + return nil +} + +// Close closes the consumer, unsubscribing it from further data consumption. +// This method also acknowledges any unacknowledged elements on the consumer, +// ensuring that they are marked as processed before shutting down. +func (c *StreamConsumerImpl) Close() error { + cErr := C.CConsumerClose(c.consumer) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "stream close: ") + } + return nil +} + +var ( + cfg config.Config + getAsyncCallbacks sync.Map + rawCallbacks sync.Map +) + +func btoi(b bool) int { + if b { + return 1 + } + return 0 +} + +func itob(i int) bool { + if i > 0 { + return true + } + return false +} + +func checkIfRef(t api.ArgType) C.char { + if t == api.ObjectRef { + return C.char(1) + } + return C.char(0) +} + +func apiTypeToCApiType(apiType api.ApiType) C.CApiType { + switch apiType { + case api.ActorApi: + return C.ACTOR + case api.FaaSApi: + return C.FAAS + case api.PosixApi: + return C.POSIX + default: + return C.POSIX + } +} + +// Init Initialization entry, which is used to initialize the data system and function system. +func Init(conf config.Config) error { + cfg = conf + + cFunctionSystemAddress := C.CString(conf.FunctionSystemAddress) + defer C.free(unsafe.Pointer(cFunctionSystemAddress)) + cGrpcAddress := C.CString(conf.GrpcAddress) + defer C.free(unsafe.Pointer(cGrpcAddress)) + cDataSystemAddress := C.CString(conf.DataSystemAddress) + defer C.free(unsafe.Pointer(cDataSystemAddress)) + + cJobId := C.CString(conf.JobID) + defer C.free(unsafe.Pointer(cJobId)) + cRuntimeId := C.CString(conf.RuntimeID) + defer C.free(unsafe.Pointer(cRuntimeId)) + cInstanceId := C.CString(conf.InstanceID) + defer C.free(unsafe.Pointer(cInstanceId)) + cFunctionName := C.CString(conf.FunctionName) + defer C.free(unsafe.Pointer(cFunctionName)) + cLogLevel := C.CString(conf.LogLevel) + defer C.free(unsafe.Pointer(cLogLevel)) + cLogDir := C.CString(conf.LogDir) + defer C.free(unsafe.Pointer(cLogDir)) + cServerName := C.CString(conf.ServerName) + defer C.free(unsafe.Pointer(cServerName)) + cNs := C.CString(conf.Namespace) + defer C.free(unsafe.Pointer(cNs)) + cPrivateKeyPath := C.CString(conf.PrivateKeyPath) + defer C.free(unsafe.Pointer(cPrivateKeyPath)) + cCertificateFilePath := C.CString(conf.CertificateFilePath) + defer C.free(unsafe.Pointer(cCertificateFilePath)) + cVerifyFilePath := C.CString(conf.VerifyFilePath) + defer C.free(unsafe.Pointer(cVerifyFilePath)) + cPrivateKeyPaaswd := C.CString(conf.PrivateKeyPaaswd) + defer C.free(unsafe.Pointer(cPrivateKeyPaaswd)) + cSystemAuthAccessKey := C.CString(conf.SystemAuthAccessKey) + defer C.free(unsafe.Pointer(cSystemAuthAccessKey)) + cSystemAuthSecretKey := C.CString(conf.SystemAuthSecretKey) + defer C.free(unsafe.Pointer(cSystemAuthSecretKey)) + cSystemAuthSecretKeySize := C.int(len(conf.SystemAuthSecretKey)) + cEncryptPrivateKeyPasswd := C.CString(conf.EncryptPrivateKeyPasswd) + defer C.free(unsafe.Pointer(cEncryptPrivateKeyPasswd)) + cPrimaryKeyStoreFile := C.CString(conf.PrimaryKeyStoreFile) + defer C.free(unsafe.Pointer(cPrimaryKeyStoreFile)) + cStandbyKeyStoreFile := C.CString(conf.StandbyKeyStoreFile) + defer C.free(unsafe.Pointer(cStandbyKeyStoreFile)) + cRuntimePublicKeyContext := C.CString(conf.RuntimePublicKeyContext) + defer C.free(unsafe.Pointer(cRuntimePublicKeyContext)) + cRuntimePrivateKeyContext := C.CString(conf.RuntimePrivateKeyContext) + defer C.free(unsafe.Pointer(cRuntimePrivateKeyContext)) + cDsPublicKeyContext := C.CString(conf.DsPublicKeyContext) + defer C.free(unsafe.Pointer(cDsPublicKeyContext)) + cEncryptRuntimePublicKeyContext := C.CString(conf.EncryptRuntimePublicKeyContext) + defer C.free(unsafe.Pointer(cEncryptRuntimePublicKeyContext)) + cEncryptRuntimePrivateKeyContext := C.CString(conf.EncryptRuntimePrivateKeyContext) + defer C.free(unsafe.Pointer(cEncryptRuntimePrivateKeyContext)) + cEncryptDsPublicKeyContext := C.CString(conf.EncryptDsPublicKeyContext) + defer C.free(unsafe.Pointer(cEncryptDsPublicKeyContext)) + cMaxConcurrencyCreateNum := C.int(conf.MaxConcurrencyCreateNum) + + cFunctionId := C.CString(conf.FunctionId) + defer C.free(unsafe.Pointer(cFunctionId)) + cConf := C.CLibruntimeConfig{ + functionSystemAddress: cFunctionSystemAddress, + grpcAddress: cGrpcAddress, + dataSystemAddress: cDataSystemAddress, + jobId: cJobId, + runtimeId: cRuntimeId, + instanceId: cInstanceId, + functionName: cFunctionName, + logLevel: cLogLevel, + logDir: cLogDir, + functionId: cFunctionId, + apiType: apiTypeToCApiType(conf.Api), + inCluster: C.char(btoi(conf.InCluster)), + isDriver: C.char(btoi(conf.IsDriver)), + enableMTLS: C.char(btoi(conf.EnableMTLS)), + privateKeyPath: cPrivateKeyPath, + certificateFilePath: cCertificateFilePath, + verifyFilePath: cVerifyFilePath, + privateKeyPaaswd: cPrivateKeyPaaswd, + systemAuthAccessKey: cSystemAuthAccessKey, + systemAuthSecretKey: cSystemAuthSecretKey, + systemAuthSecretKeySize: cSystemAuthSecretKeySize, + encryptPrivateKeyPasswd: cEncryptPrivateKeyPasswd, + primaryKeyStoreFile: cPrimaryKeyStoreFile, + standbyKeyStoreFile: cStandbyKeyStoreFile, + enableDsEncrypt: C.char(btoi(conf.EnableDsEncrypt)), + runtimePublicKeyContext: cRuntimePublicKeyContext, + runtimePrivateKeyContext: cRuntimePrivateKeyContext, + dsPublicKeyContext: cDsPublicKeyContext, + encryptRuntimePublicKeyContext: cEncryptRuntimePublicKeyContext, + encryptRuntimePrivateKeyContext: cEncryptRuntimePrivateKeyContext, + encryptDsPublicKeyContext: cEncryptDsPublicKeyContext, + maxConcurrencyCreateNum: cMaxConcurrencyCreateNum, + enableSigaction: C.char(btoi(conf.EnableSigaction)), + } + cErr := C.CInit(&cConf) + code := int(cErr.code) + if code != 0 { + return fmt.Errorf("failed to init libruntime, code: %d, message: %s", code, messageFree(cErr)) + } + return nil +} + +// ReceiveRequestLoop begins loop processing the received request. +func ReceiveRequestLoop() { + C.CReceiveRequestLoop() +} + +// ExecShutdownHandler exec shutdown handler. +func ExecShutdownHandler(signum int) { + cSignum := C.int(signum) + C.CExecShutdownHandler(cSignum) +} + +// CreateInstance Golang posix function. +func CreateInstance(funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + cFuncMeta := cFunctionMeta(funcMeta) + defer freeCFunctionMeta(cFuncMeta) + + cArgs, cArgsLen := cArgs(args) + defer freeCArgs(cArgs, cArgsLen) + + cInvokeOpt := cInvokeOptions(invokeOpt) + defer freeCInvokeOptions(cInvokeOpt) + + var cInstanceID *C.char + cErr := C.CCreateInstance(cFuncMeta, cArgs, cArgsLen, cInvokeOpt, &cInstanceID) + code := int(cErr.code) + if code != 0 { + return "", codeNotZeroErr(code, cErr, "") + } + defer C.free(unsafe.Pointer(cInstanceID)) + + objectID := CSafeGoString(cInstanceID) + wait := make(chan error, 1) + WaitAsync( + objectID, func(result []byte, err error) { + wait <- err + }, + ) + defer func() { + if _, err := GDecreaseRef([]string{objectID}); err != nil { + fmt.Printf("failed to decrease object ref,err: %s", err.Error()) + } + }() + + var createErr error + timer := time.NewTimer(time.Duration(invokeOpt.Timeout) * time.Second) + select { + case <-timer.C: + createErr = api.ErrorInfo{Code: 3002, Err: fmt.Errorf("create instance timeout")} + case err, ok := <-wait: + if !ok { + createErr = api.ErrorInfo{Code: 3002, Err: fmt.Errorf("failed to create instance")} + } else { + createErr = err + } + } + var instanceID string + cRealInstanceID := C.CGetRealInstanceId(cInstanceID, C.int(invokeOpt.Timeout)) + instanceID = CSafeGoString(cRealInstanceID) + C.free(unsafe.Pointer(cRealInstanceID)) + if instanceID == "" || instanceID == objectID { + if createErr == nil { + return "", api.ErrorInfo{Code: 1003, Err: fmt.Errorf("real instance id not exist, get failed")} + } else { + return "", createErr // avoid sending an fake insId(i.e. objectId) to scheduler + } + } + return instanceID, createErr +} + +// InvokeByInstanceId Golang posix function +func InvokeByInstanceId( + funcMeta api.FunctionMeta, instanceID string, args []api.Arg, invokeOpt api.InvokeOptions, +) (string, error) { + cFuncMeta := cFunctionMeta(funcMeta) + defer freeCFunctionMeta(cFuncMeta) + + cInstanceID := C.CString(instanceID) + defer C.free(unsafe.Pointer(cInstanceID)) + + cArgs, cArgsLen := cArgs(args) + defer freeCArgs(cArgs, cArgsLen) + + cInvokeOpt := cInvokeOptions(invokeOpt) + defer freeCInvokeOptions(cInvokeOpt) + + var cRetObjID *C.char + cErr := C.CInvokeByInstanceId(cFuncMeta, cInstanceID, cArgs, cArgsLen, cInvokeOpt, &cRetObjID) + code := int(cErr.code) + if code != 0 { + return "", codeNotZeroErr(code, cErr, "invoke by instance id: ") + } + retObjID := CSafeGoString(cRetObjID) + C.free(unsafe.Pointer(cRetObjID)) + return retObjID, nil +} + +// InvokeByFunctionName Supports system functions and faas. +func InvokeByFunctionName(funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + cFuncMeta := cFunctionMeta(funcMeta) + + defer freeCFunctionMeta(cFuncMeta) + + cArgs, cArgsLen := cArgs(args) + defer freeCArgs(cArgs, cArgsLen) + + cInvokeOpt := cInvokeOptions(invokeOpt) + defer freeCInvokeOptions(cInvokeOpt) + + var cRetObjID *C.char + cErr := C.CInvokeByFunctionName(cFuncMeta, cArgs, cArgsLen, cInvokeOpt, &cRetObjID) + code := int(cErr.code) + if code != 0 { + return "", codeNotZeroErr(code, cErr, "") + } + retObjID := CSafeGoString(cRetObjID) + C.free(unsafe.Pointer(cRetObjID)) + return retObjID, nil +} + +// AcquireInstance Supports system functions and faas. +func AcquireInstance(stateID string, funcMeta api.FunctionMeta, acquireOpt api.InvokeOptions) (api.InstanceAllocation, + error) { + cFuncMeta := cFunctionMeta(funcMeta) + defer freeCFunctionMeta(cFuncMeta) + + cInvokeOpt := cAcquireOptions(acquireOpt) + defer freeCInvokeOptions(cInvokeOpt) + + cStateID := C.CString(stateID) + defer C.free(unsafe.Pointer(cStateID)) + instanceAllocation := new(C.CInstanceAllocation) + defer freeCInstanceAllocation(instanceAllocation) + cErr := C.CAcquireInstance(cStateID, cFuncMeta, cInvokeOpt, instanceAllocation) + code := int(cErr.code) + if code != 0 { + return api.InstanceAllocation{}, codeNotZeroErr(code, cErr, "") + } + return api.InstanceAllocation{ + FuncSig: CSafeGoString(instanceAllocation.funcSig), + FuncKey: CSafeGoString(instanceAllocation.functionId), + InstanceID: CSafeGoString(instanceAllocation.instanceId), + LeaseID: CSafeGoString(instanceAllocation.leaseId), + LeaseInterval: int64(instanceAllocation.tLeaseInterval), + }, nil +} + +// ReleaseInstance release lease +func ReleaseInstance(allocation api.InstanceAllocation, stateID string, abnormal bool, option api.InvokeOptions) { + cInstanceAllocation := cCInstanceAllocation(allocation) + cOptions := cInvokeOptions(option) + defer freeCInstanceAllocation(cInstanceAllocation) + defer freeCInvokeOptions(cOptions) + cStateID := C.CString(stateID) + defer C.free(unsafe.Pointer(cStateID)) + cErr := C.CReleaseInstance(cInstanceAllocation, cStateID, C.char(btoi(abnormal)), cOptions) + if code := int(cErr.code); code != 0 { + fmt.Printf("failed to release lease %s error: %s\n", allocation.LeaseID, messageFree(cErr)) + } +} + +// Kill instances +func Kill(instanceID string, signo int, data []byte) error { + cInstanceID := C.CString(instanceID) + defer C.free(unsafe.Pointer(cInstanceID)) + + cSigno := C.int(signo) + + cData, cDataLen := ByteSliceToCBinaryData(data) + if cData != nil { + defer C.free(cData) + } + + cBuf := C.CBuffer{ + buffer: cData, + size_buffer: C.int64_t(cDataLen), + } + cErr := C.CKill(cInstanceID, cSigno, cBuf) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "kill instance: ") + } + return nil +} + +// UpdateSchdulerInfo update libruntime scheduler info +func UpdateSchdulerInfo(schedulerName string, schedulerId string, option string) { + cSchedulerName := C.CString(schedulerName) + defer C.free(unsafe.Pointer(cSchedulerName)) + cSchedulerId := C.CString(schedulerId) + defer C.free(unsafe.Pointer(cSchedulerId)) + cOption := C.CString(option) + defer C.free(unsafe.Pointer(cOption)) + + C.CUpdateSchdulerInfo(cSchedulerName, cSchedulerId, cOption) +} + +// GetAsync with a callback +func GetAsync(objectID string, cb api.GetAsyncCallback) { + getAsyncCallbacks.Store(objectID, cb) + cObjectID := C.CString(objectID) + defer C.free(unsafe.Pointer(cObjectID)) + C.CGetAsync(cObjectID, nil) +} + +// WaitAsync with a callback +func WaitAsync(objectID string, cb api.GetAsyncCallback) { + getAsyncCallbacks.Store(objectID, cb) + cObjectID := C.CString(objectID) + defer C.free(unsafe.Pointer(cObjectID)) + C.CWaitAsync(cObjectID, nil) +} + +// CreateStreamProducer creates and returns a new Producer instance +func CreateStreamProducer(streamName string, producerConf api.ProducerConf) (api.StreamProducer, error) { + cStreamName := C.CString(streamName) + defer C.free(unsafe.Pointer(cStreamName)) + cProducerConf := cProducerConfig(producerConf) + defer C.free(unsafe.Pointer(cProducerConf.traceId)) + var streamProducer StreamProducerImpl + cErr := C.CCreateStreamProducer(cStreamName, cProducerConf, &streamProducer.producer) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroErr(code, cErr, "create stream producer: ") + } + return &streamProducer, nil +} + +// CreateStreamConsumer creates and returns a new Consumer instance +func CreateStreamConsumer(streamName string, config api.SubscriptionConfig) (api.StreamConsumer, error) { + cStreamName := C.CString(streamName) + defer C.free(unsafe.Pointer(cStreamName)) + cSubscriptConf := cSubscriptionConfig(config) + defer C.free(unsafe.Pointer(cSubscriptConf.subscriptionName)) + defer C.free(unsafe.Pointer(cSubscriptConf.traceId)) + var streamConsumer StreamConsumerImpl + cErr := C.CCreateStreamConsumer(cStreamName, cSubscriptConf, &streamConsumer.consumer) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroErr(code, cErr, "create stream consumer: ") + } + return &streamConsumer, nil +} + +// DeleteStream Delete a data flow. +// When the number of global producers and consumers is 0, +// the data flow is no longer used and the metadata related to the data flow is deleted from each worker and the +// master. This function can be invoked on any host node. +func DeleteStream(streamName string) error { + cStreamName := C.CString(streamName) + defer C.free(unsafe.Pointer(cStreamName)) + cErr := C.CDeleteStream(cStreamName) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "delete stream: ") + } + return nil +} + +// QueryGlobalProducersNum Specifies the flow name to query the number of all producers of the flow. +func QueryGlobalProducersNum(streamName string) (uint64, error) { + cStreamName := C.CString(streamName) + defer C.free(unsafe.Pointer(cStreamName)) + + var num C.uint64_t + cErr := C.CQueryGlobalProducersNum(cStreamName, &num) + code := int(cErr.code) + if code != 0 { + return 0, codeNotZeroErr(code, cErr, "qurey global producers num: ") + } + return uint64(num), nil +} + +// QueryGlobalConsumersNum Specifies the flow name to query the number of all consumers of the flow. +func QueryGlobalConsumersNum(streamName string) (uint64, error) { + cStreamName := C.CString(streamName) + defer C.free(unsafe.Pointer(cStreamName)) + + var num C.uint64_t + cErr := C.CQueryGlobalConsumersNum(cStreamName, &num) + code := int(cErr.code) + if code != 0 { + return 0, codeNotZeroErr(code, cErr, "qurey global consumers num: ") + } + return uint64(num), nil +} + +// GoGetAsyncCallback is exported as a C function for calling from C/C++ code. +// The purpose is to execute the get callback function of go. +// +//export GoGetAsyncCallback +func GoGetAsyncCallback(cObjectID *C.char, cBuf C.CBuffer, cErr *C.CErrorInfo, userData unsafe.Pointer) { + objectID := CSafeGoString((*C.char)(cObjectID)) + cb, ok := getGetAsyncCallback(objectID) + if !ok { + return + } + + code := int(cErr.code) + if code != 0 { + cb([]byte{}, codeNotZeroErr(code, *cErr, "")) + return + } + + cb(C.GoBytes(cBuf.buffer, C.int(cBuf.size_buffer)), nil) + C.free(cBuf.buffer) +} + +// GoWaitAsyncCallback is exported as a C function for calling from C/C++ code. +// The purpose is to execute the wait callback function of go. +// +//export GoWaitAsyncCallback +func GoWaitAsyncCallback(cObjectID *C.char, cErr *C.CErrorInfo, userData unsafe.Pointer) { + objectID := CSafeGoString((*C.char)(cObjectID)) + cb, ok := getGetAsyncCallback(objectID) + if !ok { + return + } + + if code := int(cErr.code); code != 0 { + cb([]byte{}, codeNotZeroErr(code, *cErr, "")) + return + } + cb([]byte{}, nil) +} + +func getGetAsyncCallback(objectID string) (api.GetAsyncCallback, bool) { + value, ok := getAsyncCallbacks.LoadAndDelete(objectID) + if ok { + cb, ok := value.(api.GetAsyncCallback) + return cb, ok + } + return nil, false +} + +// RawCallback define the raw callback function type. +type RawCallback func(result []byte, err error) + +// GoRawCallback is exported as a C function for calling from C/C++ code. +// The purpose is to execute the raw callback function of go. +// +//export GoRawCallback +func GoRawCallback(cKey *C.char, cErr C.CErrorInfo, cResultRaw C.CBuffer) { + key := CSafeGoString((*C.char)(cKey)) + cb, ok := getRawCallback(key) + if !ok { + return + } + + code := int(cErr.code) + if code != 0 { + cb([]byte{}, codeNotZeroErr(code, cErr, "raw callback error: ")) + return + } + + cb(C.GoBytes(cResultRaw.buffer, C.int(cResultRaw.size_buffer)), nil) +} + +func getRawCallback(key string) (RawCallback, bool) { + value, ok := rawCallbacks.LoadAndDelete(key) + if ok { + cb, ok := value.(RawCallback) + return cb, ok + } + return nil, false +} + +// CreateInstanceRaw Raw interface provided for the frontend. +func CreateInstanceRaw(createReqRaw []byte) ([]byte, error) { + createReqRawPtr, createReqRawLen := ByteSliceToCBinaryDataNoCopy(createReqRaw) + cCreateReqRaw := C.CBuffer{ + buffer: createReqRawPtr, + size_buffer: C.int64_t(createReqRawLen), + selfSharedPtrBuffer: nil, + } + + errChan := make(chan error, 1) + key := uuid.New().String() + var result []byte + var rawCallback RawCallback = func(resultRaw []byte, err error) { + if err == nil { + result = resultRaw + } + errChan <- err + } + rawCallbacks.Store(key, rawCallback) + + cKey := C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + C.CCreateInstanceRaw(cCreateReqRaw, cKey) + + resultErr, ok := <-errChan + if !ok { + return []byte{}, api.ErrorInfo{Code: -1, Err: fmt.Errorf("create instance raw: channel closed")} + } + if resultErr != nil { + return []byte{}, resultErr + } + + return result, nil +} + +// InvokeByInstanceIdRaw Raw interface provided for the frontend. +func InvokeByInstanceIdRaw(invokeReqRaw []byte) ([]byte, error) { + invokeReqRawPtr, invokeReqRawLen := ByteSliceToCBinaryDataNoCopy(invokeReqRaw) + cInvokeReqRaw := C.CBuffer{ + buffer: invokeReqRawPtr, + size_buffer: C.int64_t(invokeReqRawLen), + selfSharedPtrBuffer: nil, + } + + errChan := make(chan error, 1) + var result []byte + key := uuid.New().String() + var callback RawCallback = func(resultRaw []byte, err error) { + if err == nil { + result = resultRaw + } + errChan <- err + } + rawCallbacks.Store(key, callback) + + cKey := C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + C.CInvokeByInstanceIdRaw(cInvokeReqRaw, cKey) + + resultErr, ok := <-errChan + if !ok { + return []byte{}, api.ErrorInfo{Code: -1, Err: fmt.Errorf("invoke raw: channel closed")} + } + if resultErr != nil { + return []byte{}, resultErr + } + + return result, nil +} + +// KillRaw Raw interface provided for the frontend. +func KillRaw(killReqRaw []byte) ([]byte, error) { + killReqRawPtr, killReqRawLen := ByteSliceToCBinaryDataNoCopy(killReqRaw) + cKillReqRaw := C.CBuffer{ + buffer: killReqRawPtr, + size_buffer: C.int64_t(killReqRawLen), + selfSharedPtrBuffer: nil, + } + + errChan := make(chan error, 1) + var result []byte + key := uuid.New().String() + var rawCallback RawCallback = func(resultRaw []byte, e error) { + if e == nil { + result = resultRaw + } + errChan <- e + } + rawCallbacks.Store(key, rawCallback) + cKey := C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + C.CKillRaw(cKillReqRaw, cKey) + + resultErr, ok := <-errChan + if !ok { + return []byte{}, api.ErrorInfo{Code: -1, Err: fmt.Errorf("kill raw: channel closed")} + } + if resultErr != nil { + return []byte{}, resultErr + } + + return result, nil +} + +// Exit send exit request. +func Exit(code int, message string) { + cMessage := C.CString(message) + defer C.free(unsafe.Pointer(cMessage)) + C.CExit(C.int(code), cMessage) +} + +// Finalize func for go sdk +// This API is used to release resources, such as created function instances and data objects, +// to prevent residual resources. +func Finalize() { + C.CFinalize() +} + +// KVSet save binary data to the data system. +func KVSet(key string, value []byte, param api.SetParam) error { + cKey := C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + + cValue, cValueLen := ByteSliceToCBinaryDataNoCopy(value) + cBuf := C.CBuffer{ + buffer: cValue, + size_buffer: C.int64_t(cValueLen), + selfSharedPtrBuffer: nil, + } + cParam := cSetParam(param) + cErr := C.CKVWrite(cKey, cBuf, cParam) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "kv set: ") + } + return nil +} + +// KVMSetTx save binary datas to the data system. +func KVMSetTx(keys []string, values [][]byte, param api.MSetParam) error { + cKeys, cKeysLen := CStrings(keys) + defer freeCStrings(cKeys, cKeysLen) + + cBuffersSlice := make([]C.CBuffer, len(values)) + for idx, val := range values { + cValue, cValueLen := ByteSliceToCBinaryDataNoCopy(val) + cBuf := C.CBuffer{ + buffer: cValue, + size_buffer: C.int64_t(cValueLen), + selfSharedPtrBuffer: nil, + } + cBuffersSlice[idx] = cBuf + } + cBufferPtr := (*C.CBuffer)(unsafe.Pointer(&cBuffersSlice[0])) + cParam := cMSetParam(param) + cErr := C.CKVMSetTx(cKeys, cKeysLen, cBufferPtr, cParam) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "kv multi set tx: ") + } + return nil +} + +// KVGet get binary data from data system. +func KVGet(key string, timeoutms uint) ([]byte, error) { + cKey := C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + + var cData C.CBuffer + cErr := C.CKVRead(cKey, C.int(timeoutms), &cData) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroErr(code, cErr, "kv get: ") + } + value := C.GoBytes(cData.buffer, C.int(cData.size_buffer)) + C.free(cData.buffer) + return value, nil +} + +// KVGetMulti get multi binary data from data system. +func KVGetMulti(keys []string, timeoutms uint) ([][]byte, error) { + cKeys, cKeysLen := CStrings(keys) + defer freeCStrings(cKeys, cKeysLen) + + cBuffersSlice := make([]C.CBuffer, len(keys)) + cBufferPtr := (*C.CBuffer)(unsafe.Pointer(&cBuffersSlice[0])) + cErr := C.CKVMultiRead(cKeys, cKeysLen, C.int(timeoutms), 1, cBufferPtr) + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroErr(code, cErr, "kv get multi: ") + } + + values := make([][]byte, len(keys)) + for idx, val := range cBuffersSlice { + values[idx] = C.GoBytes(val.buffer, C.int(val.size_buffer)) + C.free(val.buffer) + } + return values, nil +} + +// KVDel del data from data system. +func KVDel(key string) error { + cKey := C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + cErr := C.CKVDel(cKey) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "kv del: ") + } + return nil +} + +// KVDelMulti del multi data from data system. +func KVDelMulti(keys []string) ([]string, error) { + cKeys, cKeysLen := CStrings(keys) + defer freeCStrings(cKeys, cKeysLen) + + var cFailedKeys **C.char + var cFailedKeysLen C.int + cErr := C.CKMultiVDel(cKeys, cKeysLen, &cFailedKeys, &cFailedKeysLen) + var failedKeys []string + if int(cFailedKeysLen) > 0 { + failedKeys = GoStrings(cFailedKeysLen, cFailedKeys) + } + code := int(cErr.code) + if code != 0 { + return failedKeys, codeNotZeroErr(code, cErr, "kv del multi: ") + } + return failedKeys, nil +} + +// SetTenantID - +func SetTenantID(tenantId string) error { + cTenantId := CSafeString(tenantId) + defer CSafeFree(cTenantId) + tenantIdLen := C.int(len(tenantId)) + cErr := C.CSetTenantId(cTenantId, tenantIdLen) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "set tenantID failed: ") + } + return nil +} + +// PutCommon put obj data to data system, +// Parameters: +// - objectID: object id. +// - value: object binary data. +// - param: put param. +// - nestedObjectIDs: Indicates that the current object will be used by objects in nestedIds. +// If the object in nestedIds is not deleted, +// the current object will be held by the system until all objects in nestedIds are deleted. +// +// Returns: +// - error: nil if data put success, error otherwise. +func PutCommon(objectID string, value []byte, param api.PutParam, isRaw bool, nestedObjectIDs ...string) error { + cObjID := C.CString(objectID) + defer C.free(unsafe.Pointer(cObjID)) + cValuePtr, cValueLen := ByteSliceToCBinaryDataNoCopy(value) + cData := C.CBuffer{ + buffer: cValuePtr, + size_buffer: C.int64_t(cValueLen), + selfSharedPtrBuffer: nil, + } + cNestedIDs, cNestedIDsLen := CStrings(nestedObjectIDs) + defer freeCStrings(cNestedIDs, cNestedIDsLen) + + var cErr C.CErrorInfo + cIsRaw := C.char(0) + if isRaw { + cIsRaw = C.char(1) + } + createParam := cCreateParam(param) + cErr = C.CPutCommon(cObjID, cData, cNestedIDs, cNestedIDsLen, cIsRaw, createParam) + + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "put: ") + } + return nil +} + +// Put put obj data to data system. +func Put(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + return PutCommon(objectID, value, param, false, nestedObjectIDs...) +} + +// PutRaw put obj data to data system. +func PutRaw(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + return PutCommon(objectID, value, param, true, nestedObjectIDs...) +} + +// GetCommon get objs from data system. +func GetCommon(objectIDs []string, timeoutMs int, isRaw bool) ([][]byte, error) { + cObjIDs, cObjIDsLen := CStrings(objectIDs) + defer freeCStrings(cObjIDs, cObjIDsLen) + + cBuffersSlice := make([]C.CBuffer, len(objectIDs)) + cBufferPtr := (*C.CBuffer)(unsafe.Pointer(&cBuffersSlice[0])) + + var cErr C.CErrorInfo + cIsRaw := C.char(0) + if isRaw { + cIsRaw = C.char(1) + } + cErr = C.CGetMultiCommon(cObjIDs, cObjIDsLen, C.int(timeoutMs), 1, cBufferPtr, cIsRaw) + + code := int(cErr.code) + if code != 0 { + return nil, codeNotZeroErr(code, cErr, "") + } + + values := make([][]byte, len(objectIDs)) + for idx, val := range cBuffersSlice { + values[idx] = C.GoBytes(val.buffer, C.int(val.size_buffer)) + C.free(val.buffer) + } + return values, nil +} + +// Get to get objs from data system. +func Get(objectIDs []string, timeoutMs int) ([][]byte, error) { + return GetCommon(objectIDs, timeoutMs, false) +} + +// GetRaw get objs from data system. +func GetRaw(objectIDs []string, timeoutMs int) ([][]byte, error) { + return GetCommon(objectIDs, timeoutMs, true) +} + +// Wait until result return or timeout +func Wait(objectIDs []string, waitNum uint64, timeout int) ( + readyObjectIds, unReadyObjectIds []string, errors map[string]error, +) { + cObjIDs, cObjIDsLen := CStrings(objectIDs) + defer freeCStrings(cObjIDs, cObjIDsLen) + + waitResult := C.CWaitResult{ + readyIds: nil, + size_readyIds: 0, + unreadyIds: nil, + size_unreadyIds: 0, + errorIds: nil, + size_errorIds: 0, + } + C.CWait(cObjIDs, cObjIDsLen, C.int(waitNum), C.int(timeout), &waitResult) + + if int(waitResult.size_readyIds) > 0 { + readyObjectIds = GoStrings(waitResult.size_readyIds, waitResult.readyIds) + } + + if int(waitResult.size_unreadyIds) > 0 { + unReadyObjectIds = GoStrings(waitResult.size_unreadyIds, waitResult.unreadyIds) + } + + if int(waitResult.size_errorIds) > 0 { + errors = make(map[string]error) + errorIds := unsafe.Slice(waitResult.errorIds, int(waitResult.size_errorIds)) + defer C.free(unsafe.Pointer(waitResult.errorIds)) + for _, errorId := range errorIds { + var errCode C.int + var errorMessage *C.char + var objectId *C.char + var stackTracesInfo C.CStackTracesInfo + C.CParseCErrorObjectPointer(errorId, &errCode, &errorMessage, &objectId, &stackTracesInfo) + code := int(errCode) + msg := CSafeGoString(errorMessage) + C.free(unsafe.Pointer(errorMessage)) + id := CSafeGoString(objectId) + C.free(unsafe.Pointer(objectId)) + sinfo := goStackTraces(stackTracesInfo) + errors[id] = api.NewErrInfo(code, msg, sinfo) + C.free(unsafe.Pointer(errorId)) + } + } + return +} + +// GIncreaseRefCommon increase object reference count +func GIncreaseRefCommon(objectIDs []string, isRaw bool, remoteClientID ...string) ([]string, error) { + cObjIDs, cObjIDsLen := CStrings(objectIDs) + defer freeCStrings(cObjIDs, cObjIDsLen) + + var cRemoteID *C.char = nil + if len(remoteClientID) > 0 { + cRemoteID = C.CString(remoteClientID[0]) + defer C.free(unsafe.Pointer(cRemoteID)) + } + var cFailedIDs **C.char + var cFailedIDsLen C.int + cIsRaw := C.char(0) + if isRaw { + cIsRaw = C.char(1) + } + + var cErr C.CErrorInfo + cErr = C.CIncreaseReferenceCommon(cObjIDs, cObjIDsLen, cRemoteID, &cFailedIDs, &cFailedIDsLen, cIsRaw) + + var failedIDs []string + if int(cFailedIDsLen) > 0 { + failedIDs = GoStrings(cFailedIDsLen, cFailedIDs) + } + code := int(cErr.code) + if code != 0 { + return failedIDs, codeNotZeroErr(code, cErr, "global increase ref: ") + } + return failedIDs, nil +} + +// GIncreaseRef increase object reference count +func GIncreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return GIncreaseRefCommon(objectIDs, false, remoteClientID...) +} + +// GIncreaseRefRaw increase object reference count +func GIncreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + return GIncreaseRefCommon(objectIDs, true, remoteClientID...) +} + +// GDecreaseRefCommon decrease object reference count +func GDecreaseRefCommon(objectIDs []string, isRaw bool, remoteClientID ...string) ([]string, error) { + cObjIDs, cObjIDsLen := CStrings(objectIDs) + defer freeCStrings(cObjIDs, cObjIDsLen) + + var cRemoteID *C.char = nil + if len(remoteClientID) > 0 { + cRemoteID = C.CString(remoteClientID[0]) + defer C.free(unsafe.Pointer(cRemoteID)) + } + var cFailedIDs **C.char + var cFailedIDsLen C.int + + var cErr C.CErrorInfo + + cIsRaw := C.char(0) + if isRaw { + cIsRaw = C.char(1) + } + cErr = C.CDecreaseReferenceCommon(cObjIDs, cObjIDsLen, cRemoteID, &cFailedIDs, &cFailedIDsLen, cIsRaw) + + var failedIDs []string + if int(cFailedIDsLen) > 0 { + failedIDs = GoStrings(cFailedIDsLen, cFailedIDs) + } + code := int(cErr.code) + if code != 0 { + return failedIDs, codeNotZeroErr(code, cErr, "global decrease ref: ") + } + return failedIDs, nil +} + +// GDecreaseRef decrease object reference count +func GDecreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return GDecreaseRefCommon(objectIDs, false, remoteClientID...) +} + +// GDecreaseRefRaw decrease object reference count +func GDecreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + return GDecreaseRefCommon(objectIDs, true, remoteClientID...) +} + +// ReleaseGRefs release object refs by remote client id +func ReleaseGRefs(remoteClientID string) error { + var cRemoteID *C.char = nil + cRemoteID = C.CString(remoteClientID) + defer C.free(unsafe.Pointer(cRemoteID)) + cErr := C.CReleaseGRefs(cRemoteID) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "global decrease ref: ") + } + return nil +} + +// AllocReturnObject Creates an object and applies for a memory block. +// Computing operations can be performed on the memory block. +// will return a 'Buffer' that will be used to manipulate the memory +func AllocReturnObject(do *config.DataObject, size uint, nestedIds []string, totalNativeBufferSize *uint) error { + cDataObject := C.CDataObject{ + id: C.CString(do.ID), + selfSharedPtr: do.CSharedPtr, + nestedObjIds: nil, + size_nestedObjIds: 0, + } + defer CSafeFree(cDataObject.id) + cNestedIDs, cNestedIDsLen := CStrings(nestedIds) + defer freeCStrings(cNestedIDs, cNestedIDsLen) + + var cTotalNativeBufferSize C.uint64_t = C.uint64_t(*totalNativeBufferSize) + cErr := C.CAllocReturnObject(&cDataObject, C.int(size), cNestedIDs, cNestedIDsLen, &cTotalNativeBufferSize) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "alloc return object: ") + } + do.Buffer = config.DataBuffer{ + Ptr: cDataObject.buffer.buffer, + Size: int(cDataObject.buffer.size_buffer), + SharedPtrBuffer: cDataObject.buffer.selfSharedPtrBuffer, + } + *totalNativeBufferSize = uint(cTotalNativeBufferSize) + return nil +} + +// SetReturnObject if return by message, set return object +func SetReturnObject(do *config.DataObject, size uint) { + cDataObject := C.CDataObject{ + id: C.CString(do.ID), + selfSharedPtr: do.CSharedPtr, + nestedObjIds: nil, + size_nestedObjIds: 0, + } + defer CSafeFree(cDataObject.id) + C.CSetReturnObject(&cDataObject, C.int(size)) + do.Buffer = config.DataBuffer{ + Ptr: cDataObject.buffer.buffer, + Size: int(cDataObject.buffer.size_buffer), + SharedPtrBuffer: cDataObject.buffer.selfSharedPtrBuffer, + } +} + +// WriterLatch Obtains the write lock of the buffer object. +func WriterLatch(do *config.DataObject) error { + cBuf := C.CBuffer{ + buffer: do.Buffer.Ptr, + size_buffer: C.int64_t(do.Buffer.Size), + selfSharedPtrBuffer: do.Buffer.SharedPtrBuffer, + } + cErr := C.CWriterLatch(&cBuf) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "buffer writer latch: ") + } + return nil +} + +// MemoryCopy Writes data to a buffer object. +func MemoryCopy(do *config.DataObject, src []byte) error { + cBuf := C.CBuffer{ + buffer: do.Buffer.Ptr, + size_buffer: C.int64_t(do.Buffer.Size), + selfSharedPtrBuffer: do.Buffer.SharedPtrBuffer, + } + if len(src) == 0 { + return nil + } + cSrcPtr, cSrcLen := ByteSliceToCBinaryDataNoCopy(src) + cErr := C.CMemoryCopy(&cBuf, cSrcPtr, C.uint64_t(cSrcLen)) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "buffer memory copy: ") + } + return nil +} + +// Seal Publish the object and seal it. Sealed objects cannot be modified again. +func Seal(do *config.DataObject) error { + cBuf := C.CBuffer{ + buffer: do.Buffer.Ptr, + size_buffer: C.int64_t(do.Buffer.Size), + selfSharedPtrBuffer: do.Buffer.SharedPtrBuffer, + } + cErr := C.CSeal(&cBuf) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "buffer seal: ") + } + return nil +} + +// WriterUnlatch release the write lock of the buffer object. +func WriterUnlatch(do *config.DataObject) error { + cBuf := C.CBuffer{ + buffer: do.Buffer.Ptr, + size_buffer: C.int64_t(do.Buffer.Size), + selfSharedPtrBuffer: do.Buffer.SharedPtrBuffer, + } + cErr := C.CWriterUnlatch(&cBuf) + code := int(cErr.code) + if code != 0 { + return codeNotZeroErr(code, cErr, "buffer writer unlatch: ") + } + return nil +} + +func cSetParam(param api.SetParam) C.CSetParam { + return C.CSetParam{ + writeMode: cWriteMode(param.WriteMode), + ttlSecond: C.uint32_t(param.TTLSecond), + existence: cExistenceOpt(param.Existence), + cacheType: cCacheType(param.CacheType), + } +} + +func cMSetParam(param api.MSetParam) C.CMSetParam { + return C.CMSetParam{ + writeMode: cWriteMode(param.WriteMode), + ttlSecond: C.uint32_t(param.TTLSecond), + existence: cExistenceOpt(param.Existence), + cacheType: cCacheType(param.CacheType), + } +} + +func cCreateParam(param api.PutParam) C.CCreateParam { + return C.CCreateParam{ + writeMode: cWriteMode(param.WriteMode), + consistencyType: cConsistencyType(param.ConsistencyType), + cacheType: cCacheType(param.CacheType), + } +} + +func cWriteMode(wm api.WriteModeEnum) C.CWriteMode { + switch wm { + case api.NoneL2Cache: + return C.NONE_L2_CACHE + case api.WriteThroughL2Cache: + return C.WRITE_THROUGH_L2_CACHE + case api.WriteBackL2Cache: + return C.WRITE_BACK_L2_CACHE + case api.NoneL2CacheEvict: + return C.NONE_L2_CACHE_EVICT + default: + return C.NONE_L2_CACHE + } +} + +func cExistenceOpt(o int32) C.CExistenceOpt { + if o == 0 { + return C.NONE + } + return C.NX +} + +func cCacheType(ct api.CacheTypeEnum) C.CCacheType { + switch ct { + case api.MEMORY: + return C.MEMORY + case api.DISK: + return C.DISK + default: + return C.MEMORY + } +} + +func cConsistencyType(ct api.ConsistencyTypeEnum) C.CConsistencyType { + switch ct { + case api.PRAM: + return C.PRAM + case api.CAUSAL: + return C.CAUSAL + default: + return C.PRAM + } +} + +func cStringOptional(str *string) (*C.char, C.char) { + if str == nil { + return nil, 0 + } + return C.CString(*str), 1 +} + +func cFunctionMeta(funcMeta api.FunctionMeta) *C.CFunctionMeta { + cName, hasName := cStringOptional(funcMeta.Name) + cNamespace, hasNamespace := cStringOptional(funcMeta.Namespace) + cFuncMeta := C.CFunctionMeta{ + appName: C.CString(funcMeta.AppName), + moduleName: nil, + funcName: C.CString(funcMeta.FuncName), + className: nil, + functionId: C.CString(funcMeta.FuncID), + languageType: C.int(funcMeta.Language), + signature: C.CString(funcMeta.Sig), + poolLabel: C.CString(funcMeta.PoolLabel), + apiType: C.CApiType(funcMeta.Api), + hasName: hasName, + name: cName, + hasNs: hasNamespace, + ns: cNamespace, + codeId: nil, + } + return &cFuncMeta +} + +func freeCFunctionMeta(cFuncMeta *C.CFunctionMeta) { + CSafeFree(cFuncMeta.appName) + CSafeFree(cFuncMeta.moduleName) + CSafeFree(cFuncMeta.funcName) + CSafeFree(cFuncMeta.className) + CSafeFree(cFuncMeta.functionId) + CSafeFree(cFuncMeta.signature) + CSafeFree(cFuncMeta.poolLabel) + CSafeFree(cFuncMeta.name) + CSafeFree(cFuncMeta.ns) + CSafeFree(cFuncMeta.codeId) +} + +func cArgs(args []api.Arg) (*C.CInvokeArg, C.int) { + var cArgs *C.CInvokeArg = nil + argsLen := len(args) + if argsLen != 0 { + cArgs = (*C.CInvokeArg)(C.malloc(C.size_t(argsLen) * C.sizeof_CInvokeArg)) + cArgsSlice := unsafe.Slice(cArgs, argsLen) + for idx, val := range args { + ptr, length := ByteSliceToCBinaryDataNoCopy(val.Data) + nestedPtr, nestedLen := CStrings(val.NestedObjectIDs) + cTenantId := C.CString(val.TenantID) + cArgsSlice[idx] = C.CInvokeArg{ + buf: ptr, + size_buf: C.int64_t(length), + isRef: checkIfRef(val.Type), + objId: nil, + nestedObjects: nestedPtr, + size_nestedObjects: nestedLen, + tenantId: cTenantId, + } + } + } + return cArgs, C.int(argsLen) +} + +func freeCArgs(cArgs *C.CInvokeArg, cArgsLen C.int) { + argsLen := int(cArgsLen) + if argsLen == 0 { + return + } + cArgsSlice := unsafe.Slice(cArgs, argsLen) + for idx, _ := range cArgsSlice { + freeCStrings(cArgsSlice[idx].nestedObjects, cArgsSlice[idx].size_nestedObjects) + C.free(unsafe.Pointer(cArgsSlice[idx].tenantId)) + } + C.free(unsafe.Pointer(cArgs)) +} + +func cAcquireOptions(acquireOpt api.InvokeOptions) *C.CInvokeOptions { + cSchedInstIDs, cSchedInstIDsLen := CStrings(acquireOpt.SchedulerInstanceIDs) + cAcquireOpt := C.CInvokeOptions{ + schedulerFunctionId: CSafeString(acquireOpt.SchedulerFunctionID), + schedulerInstanceIds: cSchedInstIDs, + size_schedulerInstanceIds: cSchedInstIDsLen, + traceId: CSafeString(acquireOpt.TraceID), + timeout: C.int(acquireOpt.Timeout), + acquireTimeout: C.int(acquireOpt.AcquireTimeout), + trafficLimited: C.char(btoi(acquireOpt.TrafficLimited)), + } + return &cAcquireOpt +} + +func cCInstanceAllocation(instanceAllocation api.InstanceAllocation) *C.CInstanceAllocation { + cInstanceAlloc := C.CInstanceAllocation{ + functionId: C.CString(instanceAllocation.FuncKey), + funcSig: C.CString(instanceAllocation.FuncSig), + instanceId: C.CString(instanceAllocation.InstanceID), + leaseId: C.CString(instanceAllocation.LeaseID), + tLeaseInterval: C.int(instanceAllocation.LeaseInterval), + } + return &cInstanceAlloc +} + +func freeCInstanceAllocation(cInstanceAllocation *C.CInstanceAllocation) { + CSafeFree(cInstanceAllocation.funcSig) + CSafeFree(cInstanceAllocation.functionId) + CSafeFree(cInstanceAllocation.instanceId) + CSafeFree(cInstanceAllocation.leaseId) +} + +func cInvokeOptions(invokeOpt api.InvokeOptions) *C.CInvokeOptions { + cRes, cResLen := cCustomResources(invokeOpt.CustomResources) + cExts, cExtsLen := cCustomExtensions(invokeOpt.CustomExtensions) + cCreate, cCreateLen := cCreateOpt(invokeOpt.CreateOpt) + cLabels, cLabelsLen := CStrings(invokeOpt.Labels) + cSchedAffs, cSchedAffsLen := cScheduleAffinities(invokeOpt.ScheduleAffinities) + cCodePaths, cCodePathsLen := CStrings(invokeOpt.CodePaths) + cSchedInstIDs, cSchedInstIDsLen := CStrings(invokeOpt.SchedulerInstanceIDs) + cIvkLabel, cIvkLabelLen := cInvokeLabels(invokeOpt.InvokeLabels) + cInvokeOpt := C.CInvokeOptions{ + cpu: C.int(invokeOpt.Cpu), + memory: C.int(invokeOpt.Memory), + customResources: cRes, + size_customResources: cResLen, + customExtensions: cExts, + size_customExtensions: cExtsLen, + createOpt: cCreate, + size_createOpt: cCreateLen, + labels: cLabels, + size_labels: cLabelsLen, + schedAffinities: cSchedAffs, + size_schedAffinities: cSchedAffsLen, + codePaths: cCodePaths, + size_codePaths: cCodePathsLen, + schedulerFunctionId: C.CString(invokeOpt.SchedulerFunctionID), + schedulerInstanceIds: cSchedInstIDs, + size_schedulerInstanceIds: cSchedInstIDsLen, + traceId: C.CString(invokeOpt.TraceID), + timeout: C.int(invokeOpt.Timeout), + acquireTimeout: C.int(invokeOpt.AcquireTimeout), + RetryTimes: C.int(invokeOpt.RetryTimes), + RecoverRetryTimes: C.int(invokeOpt.RecoverRetryTimes), + invokeLabels: cIvkLabel, + size_invokeLabels: cIvkLabelLen, + scheduleTimeoutMs: C.int64_t(invokeOpt.ScheduleTimeoutMs), + } + if invokeOpt.InstanceSession != nil { + cCInstanceSession := (*C.CInstanceSession)(C.malloc(C.sizeof_CInstanceSession)) + cCInstanceSession.sessionId = C.CString(invokeOpt.InstanceSession.SessionID) + cCInstanceSession.sessionTtl = C.int(invokeOpt.InstanceSession.SessionTTL) + cCInstanceSession.concurrency = C.int(invokeOpt.InstanceSession.Concurrency) + cInvokeOpt.instanceSession = cCInstanceSession + } + + return &cInvokeOpt +} + +func freeCInvokeOptions(cInvokeOpt *C.CInvokeOptions) { + freeCCustomResources(cInvokeOpt.customResources, cInvokeOpt.size_customResources) + freeCCustomExtensions(cInvokeOpt.customExtensions, cInvokeOpt.size_customExtensions) + freeCCreateOpt(cInvokeOpt.createOpt, cInvokeOpt.size_createOpt) + freeCStrings(cInvokeOpt.labels, cInvokeOpt.size_labels) + freeCScheduleAffinities(cInvokeOpt.schedAffinities, cInvokeOpt.size_schedAffinities) + freeCStrings(cInvokeOpt.codePaths, cInvokeOpt.size_codePaths) + CSafeFree(cInvokeOpt.schedulerFunctionId) + freeCStrings(cInvokeOpt.schedulerInstanceIds, cInvokeOpt.size_schedulerInstanceIds) + CSafeFree(cInvokeOpt.traceId) + freeCInvokeLabels(cInvokeOpt.invokeLabels, cInvokeOpt.size_invokeLabels) + if unsafe.Pointer(cInvokeOpt.instanceSession) != nil { + CSafeFree(cInvokeOpt.instanceSession.sessionId) + C.free(unsafe.Pointer(cInvokeOpt.instanceSession)) + } +} + +func cScheduleAffinities(schedAffinities []api.Affinity) (*C.CAffinity, C.int) { + length := len(schedAffinities) + if length == 0 { + return nil, 0 + } + cSchedAffs := (*C.CAffinity)(C.malloc(C.size_t(length) * C.sizeof_CAffinity)) + cSchedAffsSlice := unsafe.Slice(cSchedAffs, length) + for idx, val := range schedAffinities { + cSchedAffsSlice[idx].affKind = cAffinityKind(val.Kind) + cSchedAffsSlice[idx].affType = cAffinityType(val.Affinity) + cSchedAffsSlice[idx].preferredPrio = C.char(btoi(val.PreferredPriority)) + cSchedAffsSlice[idx].preferredAntiOtherLabels = C.char(btoi(val.PreferredAntiOtherLabels)) + cLabelOps, cLabelOpsLen := cLabelOperators(val.LabelOps) + cSchedAffsSlice[idx].labelOps = cLabelOps + cSchedAffsSlice[idx].size_labelOps = cLabelOpsLen + } + return cSchedAffs, (C.int)(length) +} + +func freeCScheduleAffinities(cSchedAffs *C.CAffinity, cLen C.int) { + length := int(cLen) + if length == 0 { + return + } + cSchedAffsSlice := unsafe.Slice(cSchedAffs, length) + for idx, _ := range cSchedAffsSlice { + freeCLabelOperators(cSchedAffsSlice[idx].labelOps, cSchedAffsSlice[idx].size_labelOps) + } + C.free(unsafe.Pointer(cSchedAffs)) +} + +func cLabelOperators(labelOps []api.LabelOperator) (*C.CLabelOperator, C.int) { + length := len(labelOps) + if length == 0 { + return nil, 0 + } + cLabelOps := (*C.CLabelOperator)(C.malloc(C.size_t(length) * C.sizeof_CLabelOperator)) + cLabelOpsSlice := unsafe.Slice(cLabelOps, length) + for idx, val := range labelOps { + cLabelOpsSlice[idx].opType = cLabelOpType(val.Type) + if val.LabelKey != "" { + cLabelOpsSlice[idx].labelKey = C.CString(val.LabelKey) + } else { + cLabelOpsSlice[idx].labelKey = nil + } + cLabelVals, cLabelValsLen := CStrings(val.LabelValues) + cLabelOpsSlice[idx].labelValues = cLabelVals + cLabelOpsSlice[idx].size_labelValues = cLabelValsLen + } + return cLabelOps, (C.int)(length) +} + +func freeCLabelOperators(cLabelOps *C.CLabelOperator, cLen C.int) { + length := int(cLen) + if length == 0 { + return + } + cLabelOpsSlice := unsafe.Slice(cLabelOps, length) + for idx := range length { + CSafeFree(cLabelOpsSlice[idx].labelKey) + freeCStrings(cLabelOpsSlice[idx].labelValues, cLabelOpsSlice[idx].size_labelValues) + } + C.free(unsafe.Pointer(cLabelOps)) +} + +func cLabelOpType(opType api.OperatorType) C.CLabelOpType { + switch opType { + case api.LabelOpIn: + return C.IN + case api.LabelOpNotIn: + return C.NOT_IN + case api.LabelOpExists: + return C.EXISTS + case api.LabelOpNotExists: + return C.NOT_EXISTS + default: + return C.IN + } +} + +func cAffinityKind(kind api.AffinityKindType) C.CAffinityKind { + switch kind { + case api.AffinityKindResource: + return C.RESOURCE + case api.AffinityKindInstance: + return C.INSTANCE + default: + return C.INSTANCE + } +} + +func cAffinityType(affType api.AffinityType) C.CAffinityType { + switch affType { + case api.PreferredAffinity: + return C.PREFERRED + case api.PreferredAntiAffinity: + return C.PREFERRED_ANTI + case api.RequiredAffinity: + return C.REQUIRED + case api.RequiredAntiAffinity: + return C.REQUIRED_ANTI + default: + return C.PREFERRED + } +} + +func cCustomResources(customResources map[string]float64) (*C.CCustomResource, C.int) { + length := len(customResources) + if length == 0 { + return nil, 0 + } + cCustomRscs := (*C.CCustomResource)(C.malloc(C.size_t(length) * C.sizeof_CCustomResource)) + cCustomRscsSlice := unsafe.Slice(cCustomRscs, length) + idx := 0 + for k, v := range customResources { + cCustomRscsSlice[idx] = C.CCustomResource{ + name: C.CString(k), + scalar: C.float(v), + } + idx++ + } + return cCustomRscs, C.int(length) +} + +func freeCCustomResources(cCustomRscs *C.CCustomResource, cLen C.int) { + length := int(cLen) + if length == 0 { + return + } + cCustomRscsSlice := unsafe.Slice(cCustomRscs, length) + for idx := range length { + CSafeFree(cCustomRscsSlice[idx].name) + } + C.free(unsafe.Pointer(cCustomRscs)) +} + +func cCustomExtensions(customExtensions map[string]string) (*C.CCustomExtension, C.int) { + length := len(customExtensions) + if length == 0 { + return nil, 0 + } + cCustomExts := (*C.CCustomExtension)(C.malloc(C.size_t(length) * C.sizeof_CCustomExtension)) + cCustomExtsSlice := unsafe.Slice(cCustomExts, length) + idx := 0 + for k, v := range customExtensions { + cCustomExtsSlice[idx] = C.CCustomExtension{ + key: C.CString(k), + value: C.CString(v), + } + idx++ + } + return cCustomExts, C.int(length) +} + +func freeCCustomExtensions(cCustomExts *C.CCustomExtension, cLen C.int) { + length := int(cLen) + if length == 0 { + return + } + cCustomExtsSlice := unsafe.Slice(cCustomExts, length) + for idx := range length { + CSafeFree(cCustomExtsSlice[idx].key) + CSafeFree(cCustomExtsSlice[idx].value) + } + C.free(unsafe.Pointer(cCustomExts)) +} + +func cInvokeLabels(customInvokeLabels map[string]string) (*C.CInvokeLabels, C.int) { + length := len(customInvokeLabels) + if length == 0 { + return nil, 0 + } + cCustomILs := (*C.CInvokeLabels)(C.malloc(C.size_t(length) * C.sizeof_CInvokeLabels)) + cCustomILsSlice := unsafe.Slice(cCustomILs, length) + idx := 0 + for k, v := range customInvokeLabels { + cCustomILsSlice[idx] = C.CInvokeLabels{ + key: C.CString(k), + value: C.CString(v), + } + idx++ + } + return cCustomILs, C.int(length) +} + +func freeCInvokeLabels(cCustomILs *C.CInvokeLabels, cLen C.int) { + length := int(cLen) + if length == 0 { + return + } + cCustomILsSlice := unsafe.Slice(cCustomILs, length) + for idx := range length { + CSafeFree(cCustomILsSlice[idx].key) + CSafeFree(cCustomILsSlice[idx].value) + } + C.free(unsafe.Pointer(cCustomILs)) +} + +func cCreateOpt(createOpt map[string]string) (*C.CCreateOpt, C.int) { + length := len(createOpt) + if length == 0 { + return nil, 0 + } + cCreate := (*C.CCreateOpt)(C.malloc(C.size_t(length) * C.sizeof_CCreateOpt)) + cCreateSlice := unsafe.Slice(cCreate, length) + idx := 0 + for k, v := range createOpt { + cCreateSlice[idx] = C.CCreateOpt{ + key: C.CString(k), + value: C.CString(v), + } + idx++ + } + return cCreate, C.int(length) +} + +func freeCCreateOpt(cCreate *C.CCreateOpt, cLen C.int) { + length := int(cLen) + if length == 0 { + return + } + cCreateSlice := unsafe.Slice(cCreate, length) + for idx := range length { + CSafeFree(cCreateSlice[idx].key) + CSafeFree(cCreateSlice[idx].value) + } + C.free(unsafe.Pointer(cCreate)) +} + +func cProducerConfig(producerConf api.ProducerConf) *C.CProducerConfig { + cProducerConf := C.CProducerConfig{ + delayFlushTime: C.int64_t(producerConf.DelayFlushTime), + pageSize: C.int64_t(producerConf.PageSize), + maxStreamSize: C.uint64_t(producerConf.MaxStreamSize), + traceId: C.CString(producerConf.TraceId), + } + return &cProducerConf +} + +func cSubscriptionConfig(subscriptionConf api.SubscriptionConfig) *C.CSubscriptionConfig { + cSubscriptionConf := C.CSubscriptionConfig{ + subscriptionName: C.CString(subscriptionConf.SubscriptionName), + traceId: C.CString(subscriptionConf.TraceId), + } + return &cSubscriptionConf +} + +// ByteSliceToCBinaryData Copy go byte slice to c binary data +func ByteSliceToCBinaryData(data []byte) (unsafe.Pointer, int) { + if len(data) == 0 { + return nil, 0 + } + return C.CBytes(data), len(data) +} + +// ByteSliceToCBinaryDataNoCopy convert go byte slice to c binary data with no copy +func ByteSliceToCBinaryDataNoCopy(data []byte) (unsafe.Pointer, int) { + if len(data) == 0 { + return nil, 0 + } + return unsafe.Pointer(&data[0]), len(data) +} + +// StringToCBinaryDataNoCopy - +func StringToCBinaryDataNoCopy(data string) (unsafe.Pointer, int) { + if len(data) == 0 { + return nil, 0 + } + p := (*reflect.StringHeader)(unsafe.Pointer(&data)) + return unsafe.Pointer(p.Data), len(data) +} + +// GoStrings Copy **C.char to go []string and free **C.char +func GoStrings(count C.int, values **C.char) []string { + length := int(count) + if length == 0 { + return make([]string, 0) + } + defer C.free(unsafe.Pointer(values)) + charSlice := unsafe.Slice(values, length) + strSlice := make([]string, length) + for idx, val := range charSlice { + strSlice[idx] = CSafeGoString(val) + defer C.free(unsafe.Pointer(val)) + } + return strSlice +} + +// GoStringsWithoutFree Copy **C.char to go []string without freeing **C.char +func GoStringsWithoutFree(count C.int, values **C.char) []string { + length := int(count) + if length == 0 { + return make([]string, 0) + } + charSlice := unsafe.Slice(values, length) + strSlice := make([]string, length) + for idx, val := range charSlice { + strSlice[idx] = CSafeGoString(val) + } + return strSlice +} + +// CStrings Copy go []string to **C.char +func CStrings(strs []string) (**C.char, C.int) { + strsLen := len(strs) + if strsLen == 0 { + return nil, 0 + } + cStrs := (**C.char)(C.malloc(C.size_t(strsLen) * C.size_t(unsafe.Sizeof(uintptr(0))))) + cStrsSlice := unsafe.Slice(cStrs, strsLen) + for idx, val := range strs { + cStrsSlice[idx] = C.CString(val) + } + return cStrs, C.int(strsLen) +} + +func freeCStrings(cStrs **C.char, cStrsLen C.int) { + strsLen := int(cStrsLen) + if strsLen == 0 { + return + } + cStrsSlice := unsafe.Slice(cStrs, strsLen) + for idx, _ := range cStrsSlice { + C.free(unsafe.Pointer(cStrsSlice[idx])) + } + C.free(unsafe.Pointer(cStrs)) +} + +func errtoCerr(e error) *C.CErrorInfo { + if e == nil { + ce := (*C.CErrorInfo)(C.malloc(C.sizeof_CErrorInfo)) + ce.code = 0 + ce.message = nil + ce.size_stackTracesInfo = 0 + return ce + } + return getCErrwithStackTrace(e) +} + +func goStackTraces(stackTracesInfo C.CStackTracesInfo) api.StackTracesInfo { + stackTracesLen := int(stackTracesInfo.size_stackTraces) + stacksSlice := unsafe.Slice(stackTracesInfo.stackTraces, stackTracesLen) + goStackTraces := make([]api.StackTrace, stackTracesLen) + for k, v := range stacksSlice { + goStackTraces[k].ClassName = CSafeGoString(v.className) + goStackTraces[k].MethodName = CSafeGoString(v.methodName) + goStackTraces[k].FileName = CSafeGoString(v.fileName) + goStackTraces[k].LineNumber = int64(v.lineNumber) + C.free(unsafe.Pointer(v.className)) + C.free(unsafe.Pointer(v.methodName)) + C.free(unsafe.Pointer(v.fileName)) + extensionLen := int(v.size_extensions) + if extensionLen <= 0 { + continue + } + goStackTraces[k].ExtensionInfo = make(map[string]string, extensionLen) + extensionsSlice := unsafe.Slice(v.extensions, extensionLen) + for _, val := range extensionsSlice { + key := CSafeGoString(val.key) + if key == "" { + continue + } + goStackTraces[k].ExtensionInfo[key] = CSafeGoString(val.value) + C.free(unsafe.Pointer(val.key)) + C.free(unsafe.Pointer(val.value)) + } + C.free(unsafe.Pointer(v.extensions)) + } + info := api.StackTracesInfo{ + Code: int(stackTracesInfo.code), + MCode: int(stackTracesInfo.mcode), + Message: CSafeGoString(stackTracesInfo.message), + StackTraces: goStackTraces, + } + C.free(unsafe.Pointer(stackTracesInfo.message)) + return info +} + +func getCErrwithStackTrace(e error) *C.CErrorInfo { + ce := (*C.CErrorInfo)(C.malloc(C.sizeof_CErrorInfo)) + ce.code = 2002 + ce.message = C.CString(e.Error()) + errInfo := api.TurnErrInfo(e) + if len(errInfo.StackTracesInfo.StackTraces) == 0 { + ce.size_stackTracesInfo = 0 + return ce + } + stsInfo := []api.StackTracesInfo{errInfo.StackTracesInfo} + stackTracesInfoLen := len(stsInfo) + cStackTracesInfo := (*C.CStackTracesInfo)(C.malloc(C.size_t(stackTracesInfoLen) * C.sizeof_CStackTracesInfo)) + cStackTracesInfoSlice := unsafe.Slice(cStackTracesInfo, stackTracesInfoLen) + for idx, val := range stsInfo { + length := len(val.StackTraces) + if length == 0 { + continue + } + cStackTraces := (*C.CStackTrace)(C.malloc(C.size_t(length) * C.sizeof_CStackTrace)) + cStackTracesSlice := unsafe.Slice(cStackTraces, length) + for n, stack := range val.StackTraces { + cStackTracesSlice[n].className = C.CString(stack.ClassName) + cStackTracesSlice[n].methodName = C.CString(stack.MethodName) + cStackTracesSlice[n].fileName = C.CString(stack.FileName) + cStackTracesSlice[n].lineNumber = C.int64_t(stack.LineNumber) + cExtensions, cExtensionsLen := cCustomExtensions(stack.ExtensionInfo) + cStackTracesSlice[n].extensions = cExtensions + cStackTracesSlice[n].size_extensions = cExtensionsLen + } + cStackTracesInfoSlice[idx].stackTraces = &cStackTracesSlice[0] + cStackTracesInfoSlice[idx].size_stackTraces = C.int(length) + cStackTracesInfoSlice[idx].message = C.CString(val.Message) + } + ce.stackTracesInfo = &cStackTracesInfoSlice[0] + ce.size_stackTracesInfo = C.int(stackTracesInfoLen) + return ce +} + +// GoCredential transform *C.CCredential to api.Credential +func GoCredential(cCredential C.CCredential) api.Credential { + return api.Credential{ + AccessKey: CSafeGoString(cCredential.ak), + SecretKey: CSafeGoBytes(cCredential.sk, cCredential.sizeSk), + DataKey: CSafeGoBytes(cCredential.dk, cCredential.sizeDk), + } +} + +// GoFunctionMeta transform *C.CFunctionMeta to api.FunctionMeta +func GoFunctionMeta(funcMeta *C.CFunctionMeta) api.FunctionMeta { + return api.FunctionMeta{ + AppName: CSafeGoString(funcMeta.appName), + FuncName: CSafeGoString(funcMeta.funcName), + FuncID: CSafeGoString(funcMeta.functionId), + } +} + +// GoInvokeType transform C.CInvokeType to config.InvokeType +func GoInvokeType(invokeType C.CInvokeType) config.InvokeType { + return config.InvokeType(invokeType) +} + +// GoArgs transform *C.CArg to []api.Arg +func GoArgs(args *C.CArg, argsSize C.int) []api.Arg { + length := int(argsSize) + argsSlice := unsafe.Slice(args, length) + goArgs := make([]api.Arg, length) + for idx, val := range argsSlice { + goArgs[idx].Type = api.Value + goArgs[idx].Data = unsafe.Slice((*byte)(unsafe.Pointer(val.data)), int(val.size)) + } + return goArgs +} + +// GoDataObject transform *C.CDataObject to []config.DataObject +func GoDataObject(returnObjs *C.CDataObject, returnObjsSize C.int) []config.DataObject { + length := int(returnObjsSize) + goSlice := make([]config.DataObject, length, length) + if length == 0 { + return goSlice + } + returnObjsSlice := unsafe.Slice(returnObjs, length) + for idx, val := range returnObjsSlice { + goSlice[idx].ID = CSafeGoString(val.id) + goSlice[idx].Buffer = config.DataBuffer{ + Ptr: val.buffer.buffer, + Size: int(val.buffer.size_buffer), + } + goSlice[idx].CSharedPtr = val.selfSharedPtr + } + return goSlice +} + +// GoLoadFunctions is exported as a C function for calling from C/C++ code. +// The purpose is to execute the load callback function of go. +// +//export GoLoadFunctions +func GoLoadFunctions(codePaths **C.char, codePathsSize C.int) *C.CErrorInfo { + if int(codePathsSize) == 0 { + return errtoCerr(fmt.Errorf("codePaths empty")) + } + + err := cfg.Hooks.LoadFunctionCb(GoStringsWithoutFree(codePathsSize, codePaths)) + cErr := errtoCerr(err) + if err != nil { + fmt.Println("failed to load function ", err.Error()) + } else { + fmt.Println("succeed to load function ") + } + return cErr +} + +// GoFunctionExecution is exported as a C function for calling from C/C++ code. +// The purpose is to execute the function execution callback function of go. +// +//export GoFunctionExecution +func GoFunctionExecution( + funcMeta *C.CFunctionMeta, invokeType C.CInvokeType, args *C.CArg, argsSize C.int, returnObjs *C.CDataObject, + returnObjsSize C.int, +) *C.CErrorInfo { + goFuncMeta := GoFunctionMeta(funcMeta) + goInvokeType := GoInvokeType(invokeType) + goArgs := GoArgs(args, argsSize) + goRetObjs := GoDataObject(returnObjs, returnObjsSize) + err := cfg.Hooks.FunctionExecutionCb(goFuncMeta, goInvokeType, goArgs, goRetObjs) + return errtoCerr(err) +} + +// GoCheckpoint is exported as a C function for calling from C/C++ code. +// The purpose is to execute the checkpoint callback function of go. +// +//export GoCheckpoint +func GoCheckpoint(checkpointID *C.char, buffer *C.CBuffer) *C.CErrorInfo { + goCkptId := CSafeGoString(checkpointID) + data, err := cfg.Hooks.CheckpointCb(goCkptId) + buf, length := ByteSliceToCBinaryData(data) + buffer.buffer = buf + buffer.size_buffer = C.int64_t(length) + return errtoCerr(err) +} + +// GoRecover is exported as a C function for calling from C/C++ code. +// The purpose is to execute the recover callback function of go. +// +//export GoRecover +func GoRecover(buffer *C.CBuffer) *C.CErrorInfo { + data := C.GoBytes(buffer.buffer, C.int(buffer.size_buffer)) + err := cfg.Hooks.RecoverCb(data) + return errtoCerr(err) +} + +// GoShutdown is exported as a C function for calling from C/C++ code. +// The purpose is to execute the shutdown callback function of go. +// +//export GoShutdown +func GoShutdown(gracePeriodSec C.uint64_t) *C.CErrorInfo { + err := cfg.Hooks.ShutdownCb(uint64(gracePeriodSec)) + return errtoCerr(err) +} + +// GoSignal is exported as a C function for calling from C/C++ code. +// The purpose is to execute the signal callback function of go. +// +//export GoSignal +func GoSignal(sigNo C.int, payload *C.CBuffer) *C.CErrorInfo { + payloadSlice := C.GoBytes(payload.buffer, C.int(payload.size_buffer)) + err := cfg.Hooks.SignalCb(int(sigNo), payloadSlice) + return errtoCerr(err) +} + +// GoHealthCheck is exported as a C function for calling from C/C++ code. +// The purpose is to execute the health check callback function of go. +// +//export GoHealthCheck +func GoHealthCheck() C.CHealthCheckCode { + if cfg.Hooks.HealthCheckCb != nil { + status, _ := cfg.Hooks.HealthCheckCb() + if status == api.Healthy { + return C.HEALTHY + } else if status == api.HealthCheckFailed { + return C.HEALTH_CHECK_FAILED + } else if status == api.SubHealth { + return C.SUB_HEALTH + } + } + return C.HEALTHY +} + +// GoHasHealthCheck is to used for check if health check handler is valid +// +//export GoHasHealthCheck +func GoHasHealthCheck() C.char { + if cfg.Hooks.HealthCheckCb != nil { + return C.char(1) + } + return C.char(0) +} + +// GoFunctionExecutionPoolSubmit is used to submit task to golang routines pool from libruntime +// +//export GoFunctionExecutionPoolSubmit +func GoFunctionExecutionPoolSubmit(f unsafe.Pointer) { + cfg.FunctionExectionPool.Submit( + func() { + C.CFunctionExecution(f) + }, + ) +} + +// GoEles - +func GoEles(pEles *C.CElement, num C.uint64_t) []api.Element { + length := uint64(num) + elesSlice := unsafe.Slice(pEles, length) + defer C.free(unsafe.Pointer(pEles)) + goEles := make([]api.Element, length) + for idx, val := range elesSlice { + // The content in the ptr will be released after ack. + goEles[idx].Ptr = (*uint8)(unsafe.Pointer(val.ptr)) + goEles[idx].Size = uint64(val.size) + goEles[idx].Id = uint64(val.id) + } + return goEles +} + +// IsHealth - +func IsHealth() bool { + res := int(C.CIsHealth()) + return itob(res) +} + +// IsDsHealth - +func IsDsHealth() bool { + res := int(C.CIsDsHealth()) + return itob(res) +} diff --git a/api/go/libruntime/clibruntime/clibruntime_test.go b/api/go/libruntime/clibruntime/clibruntime_test.go new file mode 100644 index 0000000..4c46c90 --- /dev/null +++ b/api/go/libruntime/clibruntime/clibruntime_test.go @@ -0,0 +1,1213 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package clibruntime +This package encapsulates all cgo invocations. +*/ +package clibruntime + +import ( + "errors" + "fmt" + "strings" + "sync" + "testing" + "unsafe" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/uuid" + "yuanrong.org/kernel/runtime/libruntime/config" +) + +func GetConnectArguments() api.ConnectArguments { + var conf api.ConnectArguments + conf.Host = "127.0.0.1" + conf.Port = 11111 + conf.TimeoutMs = 500 + conf.Token = []byte{'1', '2', '3'} + conf.ClientPublicKey = "client pub key" + conf.ClientPrivateKey = []byte{'1', '2', '3'} + conf.ServerPublicKey = "server pub key" + conf.AccessKey = "access key" + conf.SecretKey = []byte{'1', '2', '3'} + conf.AuthclientID = "auth client id" + conf.AuthclientSecret = []byte{'1', '2', '3'} + conf.AuthURL = "auth url" + conf.TenantID = "tenant id" + conf.EnableCrossNodeConnection = true + return conf +} + +func TestInit(t *testing.T) { + conf := config.Config{} + id := uuid.New() + conf.JobID = fmt.Sprintf("job-%s", strings.ReplaceAll(id.String(), "-", "")[:8]) + err := Init(conf) + if err != nil { + t.Errorf("test Init failed") + } +} + +func TestCheckNil(t *testing.T) { + convey.Convey( + "Test kvClientImplCheckNil", t, func() { + convey.Convey( + "kvClientCheckNil success", func() { + var ptr *KvClientImpl + err := kvClientCheckNil(ptr) + convey.So(err, convey.ShouldNotBeNil) + ptr = &KvClientImpl{} + err = kvClientCheckNil(ptr) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestKvClientImpl_CreateClient(t *testing.T) { + convey.Convey( + "Test CreateClient", t, func() { + convey.Convey( + "create client should success", func() { + conf := GetConnectArguments() + newClient, err := CreateClient(conf) + defer newClient.DestroyClient() + convey.So(err, convey.ShouldBeNil) + convey.So(newClient, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestKvClientImpl_SetTraceID(t *testing.T) { + convey.Convey( + "Test SetTraceID", t, func() { + conf := GetConnectArguments() + client, _ := CreateClient(conf) + defer client.DestroyClient() + convey.Convey( + "set traceID should success", func() { + convey.So(func() { client.SetTraceID("traceId") }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestKvClientImpl_GenerateKey(t *testing.T) { + convey.Convey( + "Test GenerateKey", t, func() { + client, _ := CreateClient(GetConnectArguments()) + defer client.DestroyClient() + convey.Convey( + "generate key should return a key", func() { + key := client.GenerateKey() + convey.So(key, convey.ShouldBeEmpty) + }, + ) + }, + ) +} + +func TestKvClientImpl_Set(t *testing.T) { + convey.Convey( + "Test Set", t, func() { + client, _ := CreateClient(GetConnectArguments()) + defer client.DestroyClient() + value := "value" + var param api.SetParam + convey.Convey( + "set key value", func() { + rightKey := "rightKey" + convey.Convey( + "mock a right key", func() { + status := client.KVSet(rightKey, []byte(value), param) + convey.So(status.IsOk(), convey.ShouldBeTrue) + }, + ) + }, + ) + }, + ) +} + +func TestGetCredential(t *testing.T) { + convey.Convey( + "Test GoFunctionExecutionPoolSubmit", t, func() { + credential := GetCredential() + convey.So(credential.AccessKey, convey.ShouldBeEmpty) + convey.So(credential.SecretKey, convey.ShouldBeEmpty) + convey.So(credential.DataKey, convey.ShouldBeEmpty) + }, + ) +} + +func TestKvClientImpl_SetValue(t *testing.T) { + convey.Convey( + "Test SetValue", t, func() { + client, _ := CreateClient(GetConnectArguments()) + defer client.DestroyClient() + + var param api.SetParam + convey.Convey( + "set value return key", func() { + rightValue := "value" + convey.Convey( + "set a not empty value", func() { + key, status := client.KVSetWithoutKey([]byte(rightValue), param) + convey.So(key, convey.ShouldBeEmpty) + convey.So(status.IsOk(), convey.ShouldBeTrue) + }, + ) + + emptyValue := "" + convey.Convey( + "set emptyValue return empty key", func() { + key, _ := client.KVSetWithoutKey([]byte(emptyValue), param) + convey.So(key, convey.ShouldBeEmpty) + }, + ) + }, + ) + }, + ) +} + +func TestKvClientImpl_Get(t *testing.T) { + convey.Convey( + "Test Get", t, func() { + client, _ := CreateClient(GetConnectArguments()) + defer client.DestroyClient() + + convey.Convey( + "use a key to get value", func() { + rightKey := "rightKey" + convey.Convey( + "get a rightKey", func() { + _, status := client.KVGet(rightKey, 1) + convey.So(status.IsOk(), convey.ShouldBeTrue) + }, + ) + }, + ) + }, + ) +} + +func TestKvClientImpl_GetArray(t *testing.T) { + convey.Convey( + "Test GetArray", t, func() { + client, _ := CreateClient(GetConnectArguments()) + defer client.DestroyClient() + + convey.Convey( + "use a keys to get values", func() { + + rightKeys := []string{"rightKey"} + convey.Convey( + "get values with rightKey", func() { + values, status := client.KVGetMulti(rightKeys, 1) + convey.So(status.IsOk(), convey.ShouldBeTrue) + convey.So(len(values), convey.ShouldEqual, len(rightKeys)) + }, + ) + }, + ) + }, + ) +} + +func TestKvClientImpl_QuerySize(t *testing.T) { + convey.Convey( + "Test QuerySize", t, func() { + client, _ := CreateClient(GetConnectArguments()) + defer client.DestroyClient() + + convey.Convey( + "use keys to query value sizes", func() { + + queryKeys := []string{"queryKeys"} + convey.Convey( + "get values with queryKeys", func() { + values, status := client.KVQuerySize(queryKeys) + convey.So(status.IsOk(), convey.ShouldBeTrue) + convey.So(len(values), convey.ShouldEqual, len(queryKeys)) + }, + ) + }, + ) + }, + ) +} + +func TestKvClientImpl_Del(t *testing.T) { + convey.Convey( + "Test Del", t, func() { + client, _ := CreateClient(GetConnectArguments()) + defer client.DestroyClient() + + convey.Convey( + "del key value", func() { + rightKey := "rightKey" + convey.Convey( + "del a rightKey", func() { + status := client.KVDel(rightKey) + convey.So(status.IsOk(), convey.ShouldBeTrue) + }, + ) + }, + ) + }, + ) +} + +func TestKvClientImpl_complex(t *testing.T) { + convey.Convey("concurrency", t, func() { + convey.So(func() { + client, _ := CreateClient(GetConnectArguments()) + + wg := sync.WaitGroup{} + + wg.Add(4) + go func() { + for i := 0; i < 10; i++ { + client.KVDel("1") + } + wg.Done() + }() + go func() { + for i := 0; i < 10; i++ { + client.KVGet("1", 10) + } + wg.Done() + }() + + go func() { + for i := 0; i < 10; i++ { + client.GenerateKey() + } + wg.Done() + }() + go func() { + client.DestroyClient() + client.DestroyClient() + wg.Done() + }() + wg.Wait() + }, convey.ShouldNotPanic) + }) +} + +func TestKvClientImpl_DelArray(t *testing.T) { + convey.Convey( + "Test del array", t, func() { + client, _ := CreateClient(GetConnectArguments()) + defer client.DestroyClient() + + convey.Convey( + "del keys", func() { + rightKeys := []string{"key1", "key2"} + convey.Convey( + "delete right keys", func() { + values, status := client.KVDelMulti(rightKeys) + convey.So(status.IsOk(), convey.ShouldBeTrue) + convey.So(len(values), convey.ShouldEqual, 0) + }, + ) + }, + ) + }, + ) +} + +func TestKvClientImpl_DestroyClient(t *testing.T) { + convey.Convey( + "Test destroy client", t, func() { + client, _ := CreateClient(GetConnectArguments()) + client.DestroyClient() + + convey.Convey( + "after destroy use client should be safe", func() { + status := client.KVSet("", []byte{}, api.SetParam{}) + convey.So(status.Code, convey.ShouldEqual, api.DsClientNilError) + key, status := client.KVSetWithoutKey([]byte{}, api.SetParam{}) + convey.So(status.Code, convey.ShouldEqual, api.DsClientNilError) + convey.So(key, convey.ShouldBeEmpty) + }, + ) + + convey.Convey( + "repeat destroy client should not panic", func() { + client.DestroyClient() + }, + ) + + }, + ) +} + +func TestCreateStreamProducer(t *testing.T) { + convey.Convey( + "Test create stream producer", t, func() { + convey.Convey( + "create stream producer success", func() { + producer, err := CreateStreamProducer("stream_001", api.ProducerConf{}) + convey.So(err, convey.ShouldBeNil) + convey.So(producer, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestProducerSendAndFlush(t *testing.T) { + convey.Convey( + "Test producer send", t, func() { + producer, err := CreateStreamProducer("stream_001", api.ProducerConf{}) + convey.So(err, convey.ShouldBeNil) + convey.So(producer, convey.ShouldNotBeNil) + convey.Convey( + "producer send", func() { + ele := api.Element{ + Ptr: nil, + Size: 0, + Id: 0, + } + err = producer.Send(ele) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "producer sendWithTimeout", func() { + ele := api.Element{ + Ptr: nil, + Size: 0, + Id: 0, + } + err = producer.SendWithTimeout(ele, 1) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "producer flush", func() { + err = producer.Flush() + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "producer close", func() { + err = producer.Close() + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestStreamConsumerImpl(t *testing.T) { + convey.Convey( + "Test StreamConsumerImpl", t, func() { + consumer, err := CreateStreamConsumer("stream_001", api.SubscriptionConfig{}) + convey.Convey( + "CreateStreamConsumer success", func() { + convey.So(consumer, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Receive success", func() { + eles, err := consumer.Receive(3000) + convey.So(eles, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "ReceiveExpectNum success", func() { + eles, err := consumer.ReceiveExpectNum(1, 3000) + convey.So(eles, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Ack success", func() { + err = consumer.Ack(1) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Close success", func() { + err = consumer.Close() + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestSetTenantID(t *testing.T) { + convey.Convey( + "Test SetTraceID", t, func() { + SetTenantID("tenantId") + convey.SkipSo() + }, + ) +} + +func TestKillInstance(t *testing.T) { + convey.Convey( + "Test KillInstance successfully", t, func() { + err := Kill("instanceid", 1, []byte{}) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Test KillInstance failed", t, func() { + err := Kill("instanceid", 128, []byte{}) + convey.So(err.Error(), convey.ShouldEqual, "kill instance: failed to kill") + }, + ) +} + +func TestCScheduleAffinities(t *testing.T) { + convey.Convey( + "Test cScheduleAffinities", t, func() { + convey.Convey( + "Test convert []api.Affinity to *C.CAffinity", func() { + goLength := 2 + affinities := make([]api.Affinity, goLength) + affinities[0].Affinity = api.PreferredAffinity + cAffinities, length := cScheduleAffinities(affinities) + convey.So(length, convey.ShouldEqual, 2) + cSchedAffsSlice := unsafe.Slice(cAffinities, 2) + convey.So(int(cSchedAffsSlice[0].size_labelOps), convey.ShouldEqual, 0) + freeCScheduleAffinities(cAffinities, length) + }, + ) + }, + ) +} + +func TestCheckIfRef(t *testing.T) { + convey.Convey( + "Test checkIfRef", t, func() { + convey.Convey( + "checkIfRef success when t==api.ObjectRef", func() { + cChar := checkIfRef(api.ObjectRef) + convey.So(cChar, convey.ShouldEqual, 1) + }, + ) + }, + ) +} + +func TestApiTypeToCApiType(t *testing.T) { + convey.Convey( + "Test apiTypeToCApiType", t, func() { + convey.Convey( + "apiTypeToCApiType success when apiType ==api.FaaSApi", func() { + cApiType := apiTypeToCApiType(api.FaaSApi) + convey.So(cApiType, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "apiTypeToCApiType success when apiType ==api.PosixApi", func() { + cApiType := apiTypeToCApiType(api.PosixApi) + convey.So(cApiType, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "apiTypeToCApiType success when default", func() { + cApiType := apiTypeToCApiType(api.ApiType(9)) + convey.So(cApiType, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestReceiveRequestLoop(t *testing.T) { + convey.Convey( + "Test ReceiveRequestLoop", t, func() { + ReceiveRequestLoop() + convey.So(ReceiveRequestLoop, convey.ShouldNotPanic) + }, + ) +} + +func TestCreateInstance(t *testing.T) { + convey.Convey( + "Test CreateInstance", t, func() { + str, err := CreateInstance(api.FunctionMeta{}, []api.Arg{}, api.InvokeOptions{}) + convey.So(str, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) +} + +func TestInvokeByInstanceId(t *testing.T) { + convey.Convey( + "Test InvokeByInstanceId", t, func() { + str, err := InvokeByInstanceId(api.FunctionMeta{}, "", []api.Arg{}, api.InvokeOptions{}) + convey.So(str, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestInvokeByFunctionName(t *testing.T) { + convey.Convey( + "Test InvokeByFunctionName", t, func() { + str, err := InvokeByFunctionName(api.FunctionMeta{}, []api.Arg{}, api.InvokeOptions{}) + convey.So(str, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestAcquireInstance(t *testing.T) { + convey.Convey( + "Test AcquireInstance", t, func() { + allocation, err := AcquireInstance("", api.FunctionMeta{}, api.InvokeOptions{}) + convey.So(allocation, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestReleaseInstance(t *testing.T) { + convey.Convey( + "Test ReleaseInstance", t, func() { + convey.So(func() { + ReleaseInstance(api.InstanceAllocation{}, "", false, api.InvokeOptions{}) + }, convey.ShouldNotPanic) + }, + ) +} + +func TestGetAsync(t *testing.T) { + convey.Convey( + "Test GetAsync", t, func() { + convey.So(func() { + GetAsync("", nil) + }, convey.ShouldNotPanic) + }, + ) +} + +func TestDeleteStream(t *testing.T) { + convey.Convey( + "Test DeleteStream", t, func() { + err := DeleteStream("streamName") + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestQueryGlobalProducersNum(t *testing.T) { + convey.Convey( + "Test QueryGlobalProducersNum", t, func() { + n, err := QueryGlobalProducersNum("streamName") + convey.So(n, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestQueryGlobalConsumersNum(t *testing.T) { + convey.Convey( + "Test QueryGlobalConsumersNum", t, func() { + n, err := QueryGlobalConsumersNum("streamName") + convey.So(n, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestGetGetAsyncCallback(t *testing.T) { + convey.Convey( + "Test getGetAsyncCallback", t, func() { + convey.Convey( + "getGetAsyncCallback success", func() { + cb, err := getGetAsyncCallback("objectID") + convey.So(cb, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "getGetAsyncCallback success when ok==true", func() { + cb, err := getGetAsyncCallback("") + convey.So(cb, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestGetRawCallback(t *testing.T) { + convey.Convey( + "Test getRawCallback", t, func() { + convey.Convey( + "getGetAsyncCallback success", func() { + cb, flag := getRawCallback("key") + convey.So(cb, convey.ShouldBeNil) + convey.So(flag, convey.ShouldBeFalse) + }, + ) + }, + ) +} + +func TestCreateInstanceRaw(t *testing.T) { + convey.Convey( + "Test InstanceRaw", t, func() { + convey.Convey( + "InstanceRaw success", func() { + convey.So(func() { + go CreateInstanceRaw([]byte{0}) + go InvokeByInstanceIdRaw([]byte{0}) + go KillRaw([]byte{0}) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestFinalize(t *testing.T) { + convey.Convey( + "Test Finalize", t, func() { + convey.So(Finalize, convey.ShouldNotPanic) + }, + ) +} + +func TestKVSet(t *testing.T) { + convey.Convey( + "Test KVSet", t, func() { + err := KVSet("key", []byte{0}, api.SetParam{}) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestKVGet(t *testing.T) { + convey.Convey( + "Test KVGet", t, func() { + bytes, err := KVGet("key", 3000) + convey.So(bytes, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestKVGetMulti(t *testing.T) { + convey.Convey( + "Test KVGetMulti", t, func() { + bytesArr, err := KVGetMulti([]string{"key"}, 3000) + convey.So(bytesArr, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestKVDel(t *testing.T) { + convey.Convey( + "Test KVDel", t, func() { + err := KVDel("key") + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestKVDelMulti(t *testing.T) { + convey.Convey( + "Test KVDelMulti", t, func() { + strs, err := KVDelMulti([]string{"key"}) + convey.So(strs, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestPut(t *testing.T) { + convey.Convey( + "Test Put", t, func() { + err := Put("objectID ", []byte("value"), api.PutParam{}, "nestedObjectIDs") + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestPutRaw(t *testing.T) { + convey.Convey( + "Test PutRaw", t, func() { + err := PutRaw("objectID ", []byte("value"), api.PutParam{}, "nestedObjectIDs") + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestGet(t *testing.T) { + convey.Convey( + "Test Get", t, func() { + bytesArr, err := Get([]string{"key"}, 3000) + convey.So(bytesArr, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestGetRaw(t *testing.T) { + convey.Convey( + "Test GetRaw", t, func() { + bytesArr, err := GetRaw([]string{"key"}, 3000) + convey.So(bytesArr, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestWait(t *testing.T) { + convey.Convey( + "Test Wait", t, func() { + readyObjectIds, unReadyObjectIds, errs := Wait([]string{"objectIDs"}, 1, 3000) + convey.So(readyObjectIds, convey.ShouldBeEmpty) + convey.So(unReadyObjectIds, convey.ShouldBeEmpty) + convey.So(errs, convey.ShouldBeNil) + }, + ) +} + +func TestGIncreaseRef(t *testing.T) { + convey.Convey( + "Test GIncreaseRef", t, func() { + strs, err := GIncreaseRef([]string{"objectID"}, "remoteClientID") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestGIncreaseRefRaw(t *testing.T) { + convey.Convey( + "Test GIncreaseRefRaw", t, func() { + strs, err := GIncreaseRefRaw([]string{"objectID"}, "remoteClientID") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestGDecreaseRef(t *testing.T) { + convey.Convey( + "Test GDecreaseRef", t, func() { + strs, err := GDecreaseRef([]string{"objectID"}, "remoteClientID") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestGDecreaseRefRaw(t *testing.T) { + convey.Convey( + "Test GDecreaseRefRaw", t, func() { + strs, err := GDecreaseRefRaw([]string{"objectID"}, "remoteClientID") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestReleaseGRefs(t *testing.T) { + convey.Convey( + "Test ReleaseGRefs", t, func() { + err := ReleaseGRefs("remoteClientID") + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestAllocReturnObject(t *testing.T) { + convey.Convey( + "Test AllocReturnObject", t, func() { + var size uint = 8 + err := AllocReturnObject(&config.DataObject{}, 8, []string{"nestedIds"}, &size) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestSetReturnObject(t *testing.T) { + convey.Convey( + "Test SetReturnObject", t, func() { + convey.So(func() { + SetReturnObject(&config.DataObject{}, 8) + }, convey.ShouldNotPanic) + }, + ) +} + +func TestWriterLatch(t *testing.T) { + convey.Convey( + "Test WriterLatch", t, func() { + err := WriterLatch(&config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestMemoryCopy(t *testing.T) { + convey.Convey( + "Test MemoryCopy", t, func() { + err := MemoryCopy(&config.DataObject{}, []byte{0}) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestSeal(t *testing.T) { + convey.Convey( + "Test Seal", t, func() { + err := Seal(&config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestWriterUnlatch(t *testing.T) { + convey.Convey( + "Test WriterUnlatch", t, func() { + err := WriterUnlatch(&config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }, + ) +} + +func TestCWriteMode(t *testing.T) { + convey.Convey( + "Test cWriteMode", t, func() { + convey.Convey( + "WriteThroughL2Cache success", func() { + cWriteMode := cWriteMode(api.WriteThroughL2Cache) + convey.So(cWriteMode, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "WriteBackL2Cache success", func() { + cWriteMode := cWriteMode(api.WriteBackL2Cache) + convey.So(cWriteMode, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "NoneL2CacheEvict success", func() { + cWriteMode := cWriteMode(api.NoneL2CacheEvict) + convey.So(cWriteMode, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "default success", func() { + cWriteMode := cWriteMode(api.WriteModeEnum(9)) + convey.So(cWriteMode, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestCExistenceOpt(t *testing.T) { + convey.Convey( + "Test cExistenceOpt", t, func() { + convey.Convey( + "cExistenceOpt success", func() { + cExistenceOpt := cExistenceOpt(1) + convey.So(cExistenceOpt, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestCStringOptional(t *testing.T) { + convey.Convey( + "Test cStringOptional", t, func() { + convey.Convey( + "cStringOptional success", func() { + str := "str" + charPtr, char := cStringOptional(&str) + convey.So(charPtr, convey.ShouldNotBeNil) + convey.So(char, convey.ShouldEqual, 1) + }, + ) + }, + ) +} + +func TestCArgs(t *testing.T) { + convey.Convey( + "Test cArgs", t, func() { + convey.Convey( + "cArgs success", func() { + cInvokeArg, cArgsLen := cArgs([]api.Arg{api.Arg{}, api.Arg{}}) + convey.So(cInvokeArg, convey.ShouldNotBeNil) + convey.So(cArgsLen, convey.ShouldNotBeZeroValue) + convey.So(func() { + freeCArgs(cInvokeArg, cArgsLen) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestCLabelOperators(t *testing.T) { + convey.Convey( + "Test cLabelOperators", t, func() { + convey.Convey( + "cLabelOperators success", func() { + cLabelOperator, cLen := cLabelOperators([]api.LabelOperator{api.LabelOperator{}}) + convey.So(cLabelOperator, convey.ShouldNotBeNil) + convey.So(cLen, convey.ShouldNotBeZeroValue) + convey.So(func() { + freeCLabelOperators(cLabelOperator, cLen) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestCLabelOpType(t *testing.T) { + convey.Convey( + "Test cLabelOpType", t, func() { + convey.Convey( + "LabelOpNotIn success", func() { + cLabelOpType := cLabelOpType(api.LabelOpNotIn) + convey.So(cLabelOpType, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "LabelOpExists success", func() { + cLabelOpType := cLabelOpType(api.LabelOpExists) + convey.So(cLabelOpType, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "LabelOpNotExists success", func() { + cLabelOpType := cLabelOpType(api.LabelOpNotExists) + convey.So(cLabelOpType, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "default success", func() { + cLabelOpType := cLabelOpType(api.OperatorType(9)) + convey.So(cLabelOpType, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestCAffinityKind(t *testing.T) { + convey.Convey( + "Test cAffinityKind", t, func() { + convey.Convey( + "AffinityKindInstance success", func() { + cAffinityKind := cAffinityKind(api.AffinityKindInstance) + convey.So(cAffinityKind, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "default success", func() { + cAffinityKind := cAffinityKind(api.AffinityKindType(9)) + convey.So(cAffinityKind, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestCAffinityType(t *testing.T) { + convey.Convey( + "Test cAffinityType", t, func() { + convey.Convey( + "PreferredAntiAffinity success", func() { + cAffinityType := cAffinityType(api.PreferredAntiAffinity) + convey.So(cAffinityType, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "RequiredAffinity success", func() { + cAffinityType := cAffinityType(api.RequiredAffinity) + convey.So(cAffinityType, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "RequiredAntiAffinity success", func() { + cAffinityType := cAffinityType(api.RequiredAntiAffinity) + convey.So(cAffinityType, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "default success", func() { + cAffinityType := cAffinityType(api.AffinityType(9)) + convey.So(cAffinityType, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestCCustomResources(t *testing.T) { + convey.Convey( + "Test cCustomResources", t, func() { + convey.Convey( + "cCustomResources success", func() { + cCustomResource, cLen := cCustomResources(map[string]float64{"k": 1.5}) + convey.So(cCustomResource, convey.ShouldNotBeNil) + convey.So(cLen, convey.ShouldNotBeZeroValue) + }, + ) + }, + ) +} + +func TestCCustomExtensions(t *testing.T) { + convey.Convey( + "Test cCustomExtensions", t, func() { + convey.Convey( + "cCustomExtensions success", func() { + cCustomExtension, cLen := cCustomExtensions(map[string]string{"k": "v"}) + convey.So(cCustomExtension, convey.ShouldNotBeNil) + convey.So(cLen, convey.ShouldNotBeZeroValue) + }, + ) + }, + ) +} + +func TestCCreateOpt(t *testing.T) { + convey.Convey( + "Test cCreateOpt", t, func() { + convey.Convey( + "cCreateOpt success", func() { + cCreateOpt, cLen := cCreateOpt(map[string]string{"k": "v"}) + convey.So(cCreateOpt, convey.ShouldNotBeNil) + convey.So(cLen, convey.ShouldNotBeZeroValue) + }, + ) + }, + ) +} + +func TestByteSliceToCBinaryData(t *testing.T) { + convey.Convey( + "Test ByteSliceToCBinaryData", t, func() { + convey.Convey( + "ByteSliceToCBinaryData success", func() { + ptr, len := ByteSliceToCBinaryData([]byte{0}) + convey.So(ptr, convey.ShouldNotBeNil) + convey.So(len, convey.ShouldNotBeZeroValue) + }, + ) + }, + ) +} + +func TestStringToCBinaryDataNoCopy(t *testing.T) { + convey.Convey( + "Test StringToCBinaryDataNoCopy", t, func() { + convey.Convey( + "StringToCBinaryDataNoCopy success", func() { + ptr, len := StringToCBinaryDataNoCopy("data") + convey.So(ptr, convey.ShouldNotBeNil) + convey.So(len, convey.ShouldNotBeZeroValue) + }, + ) + convey.Convey( + "StringToCBinaryDataNoCopy success when len(data)==0", func() { + ptr, len := StringToCBinaryDataNoCopy("") + convey.So(ptr, convey.ShouldEqual, unsafe.Pointer(nil)) + convey.So(len, convey.ShouldBeZeroValue) + }, + ) + }, + ) +} + +func TestErrtoCerr(t *testing.T) { + convey.Convey( + "Test errtoCerr", t, func() { + convey.Convey( + "errtoCerr success", func() { + cErrorInfo := errtoCerr(errors.New("errtoCerr")) + convey.So(cErrorInfo, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "errtoCerr success when e==nil", func() { + cErrorInfo := errtoCerr(nil) + convey.So(cErrorInfo, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestGoHealthCheck(t *testing.T) { + convey.Convey( + "Test GoHealthCheck", t, func() { + convey.Convey( + "GoHealthCheck success", func() { + cHealthCheckCode := GoHealthCheck() + convey.So(cHealthCheckCode, convey.ShouldBeZeroValue) + }, + ) + }, + ) +} + +func TestGoHasHealthCheck(t *testing.T) { + convey.Convey( + "Test GoHasHealthCheck", t, func() { + convey.Convey( + "GoHasHealthCheck success", func() { + cChar := GoHasHealthCheck() + convey.So(cChar, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestGoFunctionExecutionPoolSubmit(t *testing.T) { + convey.Convey( + "Test GoFunctionExecutionPoolSubmit", t, func() { + convey.Convey( + "GoFunctionExecutionPoolSubmit success", func() { + convey.So(func() { + GoFunctionExecutionPoolSubmit(unsafe.Pointer(nil)) + }, convey.ShouldPanic) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/config.go b/api/go/libruntime/common/config.go new file mode 100644 index 0000000..546ba6b --- /dev/null +++ b/api/go/libruntime/common/config.go @@ -0,0 +1,180 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common for tools +package common + +import ( + "encoding/json" + "flag" + "os" + "sync" + + "github.com/magiconair/properties" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong.org/kernel/runtime/libruntime/common/faas/logger" + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +const ( + defaultConfigPath = "/home/sn/config/runtime.json" + defaultMaxConcurrencyCreateNum = 100 +) + +var ( + configSingleton = struct { + sync.Once + path string + ac *AuthContext + cfg *Configuration + }{} +) + +// Configuration to save config +type Configuration struct { + RuntimeID string + InstanceID string + FunctionName string + LogLevel string + GrpcAddress string + FSAddress string + LogPath string + JobID string + DriverMode bool + EnableMTLS bool + PrivateKeyPath string + CertificateFilePath string + VerifyFilePath string + PrivateKeyPaaswd string + SystemAuthAccessKey string + SystemAuthSecretKey string + EncryptPrivateKeyPasswd string + PrimaryKeyStoreFile string + StandbyKeyStoreFile string + EnableDsEncrypt bool + RuntimePublicKeyContext string + RuntimePrivateKeyContext string + DsPublicKeyContext string + EncryptRuntimePublicKeyContext string + EncryptRuntimePrivateKeyContext string + EncryptDsPublicKeyContext string + MaxConcurrencyCreateNum int + EnableSigaction bool +} + +func initConfig() { + configSingleton.Do( + func() { + configSingleton.cfg = &Configuration{} + flag.StringVar(&configSingleton.cfg.RuntimeID, "runtimeId", "", "") + flag.StringVar(&configSingleton.cfg.InstanceID, "instanceId", "", "") + flag.StringVar(&configSingleton.cfg.FunctionName, "functionName", "", "") + flag.StringVar(&configSingleton.cfg.LogLevel, "logLevel", "", "") + flag.StringVar(&configSingleton.cfg.GrpcAddress, "grpcAddress", "", "") + flag.StringVar(&configSingleton.cfg.FSAddress, "functionSystemAddress", "", "") + flag.StringVar(&configSingleton.path, "runtimeConfigPath", defaultConfigPath, "") + flag.StringVar(&configSingleton.cfg.LogPath, "logPath", "", "") + flag.StringVar(&configSingleton.cfg.JobID, "jobId", "12345678", "") + flag.BoolVar(&configSingleton.cfg.DriverMode, "driverMode", false, "") + flag.BoolVar(&configSingleton.cfg.EnableMTLS, "enableMTLS", false, "") + flag.StringVar(&configSingleton.cfg.PrivateKeyPath, "privateKeyPath", "", "") + flag.StringVar(&configSingleton.cfg.CertificateFilePath, "certificateFilePath", "", "") + flag.StringVar(&configSingleton.cfg.VerifyFilePath, "verifyFilePath", "", "") + flag.StringVar(&configSingleton.cfg.EncryptPrivateKeyPasswd, "encryptPrivateKeyPasswd", "", "") + flag.StringVar(&configSingleton.cfg.PrimaryKeyStoreFile, "primaryKeyStoreFile", "", "") + flag.StringVar(&configSingleton.cfg.StandbyKeyStoreFile, "standbyKeyStoreFile", "", "") + flag.BoolVar(&configSingleton.cfg.EnableDsEncrypt, "enableDsEncrypt", false, "") + flag.StringVar(&configSingleton.cfg.EncryptRuntimePublicKeyContext, "encryptRuntimePublicKeyContext", "", "") + flag.StringVar(&configSingleton.cfg.EncryptRuntimePrivateKeyContext, "encryptRuntimePrivateKeyContext", "", "") + flag.StringVar(&configSingleton.cfg.EncryptDsPublicKeyContext, "encryptDsPublicKeyContext", "", "") + flag.IntVar(&configSingleton.cfg.MaxConcurrencyCreateNum, "maxConcurrencyCreateNum", + defaultMaxConcurrencyCreateNum, "") + flag.Parse() + setConfigSingletonCfg(&configSingleton.cfg.RuntimeID, "YR_RUNTIME_ID") + setConfigSingletonCfg(&configSingleton.cfg.InstanceID, "INSTANCE_ID") + setConfigSingletonCfg(&configSingleton.cfg.FunctionName, "FUNCTION_NAME") + setConfigSingletonCfg(&configSingleton.cfg.LogLevel, "YR_LOG_LEVEL") + setConfigSingletonCfg(&configSingleton.cfg.GrpcAddress, "POSIX_LISTEN_ADDR") + setConfigSingletonCfg(&configSingleton.cfg.LogPath, "GLOG_log_dir") + setConfigSingletonCfg(&configSingleton.cfg.JobID, "YR_JOB_ID") + loadSTSConfig(configSingleton.path) + configSingleton.cfg.EnableSigaction = true + }, + ) +} + +func setConfigSingletonCfg(fieldValue *string, envValue string) { + if *fieldValue == "" { + *fieldValue = os.Getenv(envValue) + } +} + +// GetConfig to parse args +func GetConfig() *Configuration { + initConfig() + config.LogLevelFromFlag = configSingleton.cfg.LogLevel + return configSingleton.cfg +} + +func loadSTSConfig(configPath string) { + if configPath == "" { + return + } + data, err := os.ReadFile(configPath) + if err != nil { + logger.GetLogger().Warnf("read config failed, err %s", err.Error()) + return + } + c := &GlobalConfig{} + err = json.Unmarshal(data, c) + if err != nil { + logger.GetLogger().Warnf("unmarshal config failed, err %s", err.Error()) + return + } + if !c.RawStsConfig.StsEnable { + return + } + stsProperties := properties.LoadMap( + map[string]string{ + "sts.server.domain": c.RawStsConfig.ServerConfig.Domain, + "sts.config.path": c.RawStsConfig.ServerConfig.Path, + }, + ) + err = stsgoapi.InitWith(*stsProperties) + if err != nil { + logger.GetLogger().Warnf("failed to init sts sdk, error %s\n", err.Error()) + return + } + enableIam, ok := c.RawStsConfig.SensitiveConfigs.Auth["enableIam"] + if !ok || enableIam != "true" { + logger.GetLogger().Warnf("enable iam is not true") + return + } + ak, err := stsgoapi.DecryptSensitiveConfig(c.RawStsConfig.SensitiveConfigs.Auth["accessKey"]) + if err != nil { + logger.GetLogger().Warnf("failed to get accessKey, error %s\n", err.Error()) + return + } + sk, err := stsgoapi.DecryptSensitiveConfig(c.RawStsConfig.SensitiveConfigs.Auth["secretKey"]) + if err != nil { + logger.GetLogger().Warnf("failed to get secretKey, error %s\n", err.Error()) + return + } + logger.GetLogger().Infof("init system auth info success, ak: %s", string(ak)) + configSingleton.cfg.SystemAuthAccessKey = string(ak) + configSingleton.cfg.SystemAuthSecretKey = string(sk) +} diff --git a/api/go/libruntime/common/config_test.go b/api/go/libruntime/common/config_test.go new file mode 100644 index 0000000..c3897a4 --- /dev/null +++ b/api/go/libruntime/common/config_test.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common for tools +package common + +import ( + "encoding/json" + "os" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/magiconair/properties" + "github.com/smartystreets/goconvey/convey" + + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" +) + +func TestGetConfig(t *testing.T) { + convey.Convey( + "Test Get config", t, func() { + convey.Convey( + "Test get config success", func() { + conf := GetConfig() + convey.So(conf.LogPath, convey.ShouldBeEmpty) + }, + ) + }, + ) +} + +func TestLoadSTSConfig(t *testing.T) { + convey.Convey( + "Test loadSTSConfig", t, func() { + convey.Convey( + "Test loadSTSConfig when configPath == \"\"", func() { + convey.So(func() { + loadSTSConfig("") + }, convey.ShouldNotPanic) + }, + ) + file, _ := os.Create("config.json") + convey.Convey( + "Test loadSTSConfig when json.Unmarshal error", func() { + convey.So(func() { + loadSTSConfig("config.json") + }, convey.ShouldNotPanic) + }, + ) + c := &GlobalConfig{ + RawStsConfig: StsConfig{ + StsEnable: false, + ServerConfig: ServerConfig{ + Domain: "244", + Path: "244", + }, + SensitiveConfigs: SensitiveConfigs{ + Auth: map[string]string{ + "enableIam": "false", + }, + }, + }, + } + bytes, _ := json.Marshal(c) + file.Write(bytes) + convey.Convey( + "Test loadSTSConfig when c.RawStsConfig.StsEnable == false", func() { + convey.So(func() { + loadSTSConfig("config.json") + }, convey.ShouldNotPanic) + }, + ) + c.RawStsConfig.StsEnable = true + bytes, _ = json.Marshal(c) + file, _ = os.OpenFile("config.json", os.O_WRONLY|os.O_TRUNC, 0644) + file.Write(bytes) + convey.Convey( + "Test loadSTSConfig when stsgoapi.InitWith error", func() { + convey.So(func() { + loadSTSConfig("config.json") + }, convey.ShouldNotPanic) + }, + ) + + convey.Convey( + "Test enableIam is false", func() { + convey.So(func() { + defer gomonkey.ApplyFunc(stsgoapi.InitWith, func(property properties.Properties) error { + return nil + }).Reset() + loadSTSConfig("config.json") + }, convey.ShouldNotPanic) + }, + ) + + file.Close() + os.Remove("config.json") + }, + ) +} diff --git a/api/go/libruntime/common/config_types.go b/api/go/libruntime/common/config_types.go new file mode 100644 index 0000000..2b255c2 --- /dev/null +++ b/api/go/libruntime/common/config_types.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common for config type +package common + +// StsConfig - +type StsConfig struct { + StsEnable bool `json:"stsEnable,omitempty"` + ServerConfig ServerConfig `json:"serverConfig,omitempty"` + SensitiveConfigs SensitiveConfigs `json:"sensitiveConfigs,omitempty"` +} + +// SensitiveConfigs - +type SensitiveConfigs struct { + Auth map[string]string `json:"auth"` +} + +// ServerConfig - +type ServerConfig struct { + Domain string `json:"domain,omitempty" validate:"max=255"` + Path string `json:"path,omitempty" validate:"max=255"` +} + +// GlobalConfig - +type GlobalConfig struct { + RawStsConfig StsConfig `json:"rawStsConfig"` +} diff --git a/api/go/libruntime/common/constants/constants.go b/api/go/libruntime/common/constants/constants.go new file mode 100644 index 0000000..f7ee012 --- /dev/null +++ b/api/go/libruntime/common/constants/constants.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constants implements vars of all +package constants + +import "time" + +const ( + // DefaultMapSize is the default map size + DefaultMapSize = 16 + // MonitorFileName monitor file name + MonitorFileName = "monitor-disk" + + // HostNameEnvKey defines the hostname env key + HostNameEnvKey = "HOSTNAME" + + // PodNameEnvKey define pod name env key + PodNameEnvKey = "POD_NAME" + + // MaxMsgSize grpc client max message size(bit) + MaxMsgSize = 1024 * 1024 * 10 + // MaxWindowSize grpc flow control window size(bit) + MaxWindowSize = 1024 * 1024 * 10 + // MaxBufferSize grpc read/write buffer size(bit) + MaxBufferSize = 1024 * 1024 * 10 + // DialBaseDelay - + DialBaseDelay = 300 * time.Millisecond + // DialMultiplier - + DialMultiplier = 1.2 + // DialJitter - + DialJitter = 0.1 + // DialMaxDelay - + DialMaxDelay = 15 * time.Second + // RuntimeDialMaxDelay - + RuntimeDialMaxDelay = 100 * time.Second +) diff --git a/api/go/libruntime/common/constants/status_code.go b/api/go/libruntime/common/constants/status_code.go new file mode 100644 index 0000000..baf6403 --- /dev/null +++ b/api/go/libruntime/common/constants/status_code.go @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constants implements vars of all +package constants + +const ( + // InsReqSuccessCode is the return code when instance request succeeds + InsReqSuccessCode = 6030 + // UnsupportedOperationErrorCode is the return code when operation is not supported + UnsupportedOperationErrorCode = 6031 + // FuncNotExistErrorCode is the return code when function does not exist + FuncNotExistErrorCode = 6032 + // InsNotExistErrorCode is the return code when instance does not exist + InsNotExistErrorCode = 6033 + // InsAcquireFailedErrorCode is the return code when acquire instance fails + InsAcquireFailedErrorCode = 6034 + // LeaseExpireOrDeletedErrorCode is the return code when lease expires or be deleted + LeaseExpireOrDeletedErrorCode = 6036 +) diff --git a/api/go/libruntime/common/faas/logger/logger.go b/api/go/libruntime/common/faas/logger/logger.go new file mode 100644 index 0000000..f50c4f9 --- /dev/null +++ b/api/go/libruntime/common/faas/logger/logger.go @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger for printing go runtime logger +package logger + +import ( + "fmt" + "sync" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/logger/log" +) + +var ( + once sync.Once + zapLogger api.FormatLogger +) + +const ( + logName = "faas-executor" +) + +// InitLogger - +func InitLogger() error { + runLog, err := log.InitRunLog(logName, true) + if err != nil { + return fmt.Errorf("failed to init faas log, err:%s", err.Error()) + } + SetupLogger(runLog) + return nil +} + +// SetupLogger to new a logger handler +func SetupLogger(formatLogger api.FormatLogger) { + if formatLogger == nil { + _ = InitLogger() + return + } + zapLogger = formatLogger +} + +// GetLogger get logger directly +func GetLogger() api.FormatLogger { + if zapLogger == nil { + once.Do(func() { + zapLogger = log.NewConsoleLogger() + }) + } + return zapLogger +} diff --git a/api/go/libruntime/common/faas/logger/logger_test.go b/api/go/libruntime/common/faas/logger/logger_test.go new file mode 100644 index 0000000..5656138 --- /dev/null +++ b/api/go/libruntime/common/faas/logger/logger_test.go @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger for printing go runtime logger +package logger + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestInitLogger(t *testing.T) { + convey.Convey( + "Test InitLogger", t, func() { + convey.Convey( + "InitLogger success", func() { + GetLogger() + err := InitLogger() + SetupLogger(nil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/faas/logger/user_logger.go b/api/go/libruntime/common/faas/logger/user_logger.go new file mode 100644 index 0000000..b320400 --- /dev/null +++ b/api/go/libruntime/common/faas/logger/user_logger.go @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger for printing user runtime logger +package logger + +import ( + "os" + "path/filepath" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +const ( + fileMode = 0750 + logFileMode = 0644 + skipLevel = 1 + loglevel = zapcore.InfoLevel + + logFile = "user-function.log" + defaultLogPath = "/home/snuser/log" +) + +var ( + userLogger *UserFunctionLogger +) + +// SetupUserLogger to new a logger handler +func SetupUserLogger(level string) error { + var l zapcore.Level + + if level == "" { + l = loglevel + } else { + _ = l.Set(level) + } + userLogPath := os.Getenv("RUNTIME_LOG_DIR") + if userLogPath == "" { + userLogPath = defaultLogPath + } + if err := os.MkdirAll(userLogPath, fileMode); err != nil && !os.IsExist(err) { + return err + } + fileName := filepath.Join(userLogPath, logFile) + writeSyncer, err := os.OpenFile(fileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, logFileMode) + if err != nil { + return err + } + Config := zapcore.EncoderConfig{ + NameKey: "userLogger", + CallerKey: "C", + TimeKey: "T", + LevelKey: "L", + MessageKey: "M", + EncodeCaller: zapcore.ShortCallerEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeLevel: zapcore.CapitalLevelEncoder, + LineEnding: zapcore.DefaultLineEnding, + } + + encoder := zapcore.NewConsoleEncoder(Config) + core := zapcore.NewCore(encoder, writeSyncer, l) + userLogger = &UserFunctionLogger{log: zap.New(core, zap.AddCaller(), zap.AddCallerSkip(skipLevel)).Sugar()} + return nil +} + +// UserFunctionLogger user log struct ,witch context use it +type UserFunctionLogger struct { + log *zap.SugaredLogger +} + +// GetUserLogger get user logger +func GetUserLogger() *UserFunctionLogger { + return userLogger +} + +// Infof to record info log +func (u *UserFunctionLogger) Infof(format string, params ...interface{}) { + u.log.Infof(format, params...) +} + +// Debugf to record info log +func (u *UserFunctionLogger) Debugf(format string, params ...interface{}) { + u.log.Debugf(format, params...) +} + +// Warnf to record info log +func (u *UserFunctionLogger) Warnf(format string, params ...interface{}) { + u.log.Warnf(format, params...) +} + +// Errorf to record info log +func (u *UserFunctionLogger) Errorf(format string, params ...interface{}) { + u.log.Errorf(format, params...) +} diff --git a/api/go/libruntime/common/faas/logger/user_logger_test.go b/api/go/libruntime/common/faas/logger/user_logger_test.go new file mode 100644 index 0000000..c407f70 --- /dev/null +++ b/api/go/libruntime/common/faas/logger/user_logger_test.go @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger for printing go runtime logger +package logger + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestSetupUserLogger(t *testing.T) { + convey.Convey( + "Test SetupUserLogger", t, func() { + convey.Convey( + "SetupUserLogger success", func() { + err := SetupUserLogger("") + convey.So(err, convey.ShouldBeNil) + err = SetupUserLogger("1") + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestUserFunctionLogger(t *testing.T) { + convey.Convey( + "Test UserFunctionLogger ", t, func() { + u := GetUserLogger() + convey.Convey( + "UserFunctionLogger success", func() { + convey.So(u, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Infof success", func() { + convey.So(func() { + u.Infof("format", "params") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Debugf success", func() { + convey.So(func() { + u.Debugf("format", "params") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Warnf success", func() { + convey.So(func() { + u.Warnf("format", "params") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Errorf success", func() { + convey.So(func() { + u.Errorf("format", "params") + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/logger/async/writer.go b/api/go/libruntime/common/logger/async/writer.go new file mode 100644 index 0000000..bc2e984 --- /dev/null +++ b/api/go/libruntime/common/logger/async/writer.go @@ -0,0 +1,176 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package async makes io.Writer write async +package async + +import ( + "bytes" + "fmt" + "io" + "sync/atomic" + "time" + + "go.uber.org/zap/buffer" +) + +const ( + diskBufferSize = 1024 * 1024 + diskFlushSize = diskBufferSize >> 1 + diskFlushTime = 500 * time.Millisecond + defaultChannelSize = 200000 + softLimitFactor = 0.8 // must be smaller than 1 +) + +var ( + linePool = buffer.NewPool() +) + +// Opt - +type Opt func(*Writer) + +// WithCachedLimit - +func WithCachedLimit(limit int) Opt { + return func(w *Writer) { + w.cachedLimit = limit + w.cachedSoftLimit = int(float64(limit) * softLimitFactor) + w.cachedLow = w.cachedSoftLimit >> 1 + } +} + +// NewAsyncWriteSyncer wrappers io.Writer to async zapcore.WriteSyncer +func NewAsyncWriteSyncer(w io.Writer, opts ...Opt) *Writer { + writer := &Writer{ + w: w, + diskBuf: bytes.NewBuffer(make([]byte, 0, diskBufferSize)), + lines: make(chan *buffer.Buffer, defaultChannelSize), + sync: make(chan struct{}), + syncDone: make(chan struct{}), + } + for _, opt := range opts { + opt(writer) + } + go writer.logConsumer() + return writer +} + +// Writer - +type Writer struct { + diskBuf *bytes.Buffer + lines chan *buffer.Buffer + w io.Writer + sync chan struct{} + syncDone chan struct{} + + cachedLimit int + cachedSoftLimit int + cachedLow int + cached int64 // atomic +} + +// Write sends data to channel non-blocking +func (w *Writer) Write(data []byte) (int, error) { + // note: data will be put back to zap's inner pool after Write, so we couldn't send it to channel directly + lp := linePool.Get() + lp.Write(data) + select { + case w.lines <- lp: + if w.cachedLimit != 0 && atomic.AddInt64(&w.cached, int64(len(data))) > int64(w.cachedLimit) { + w.doSync() + } + default: + fmt.Println("failed to push log to channel, skip") + lp.Free() + } + return len(data), nil +} + +// Sync implements zapcore.WriteSyncer. Current do nothing. +func (w *Writer) Sync() error { + w.doSync() + return nil +} + +func (w *Writer) doSync() { + w.sync <- struct{}{} + <-w.syncDone +} + +func (w *Writer) logConsumer() { + ticker := time.NewTicker(diskFlushTime) +loop: + for { + select { + case line := <-w.lines: + w.write(line) + if w.cachedLimit != 0 && atomic.LoadInt64(&w.cached) > int64(w.cachedSoftLimit) { + w.flushLines(len(w.lines), w.cachedLow) + } + case <-ticker.C: + if w.diskBuf.Len() == 0 { + continue + } + if _, err := w.w.Write(w.diskBuf.Bytes()); err != nil { + fmt.Println("failed to write", err.Error()) + } + w.diskBuf.Reset() + case _, ok := <-w.sync: + if !ok { + close(w.syncDone) + break loop + } + nLines := len(w.lines) + if nLines == 0 && w.diskBuf.Len() == 0 { + w.syncDone <- struct{}{} + continue + } + w.flushLines(nLines, -1) + if _, err := w.w.Write(w.diskBuf.Bytes()); err != nil { + fmt.Println("failed to write", err.Error()) + } + w.diskBuf.Reset() + w.syncDone <- struct{}{} + } + } + ticker.Stop() +} + +func (w *Writer) flushLines(nLines int, upTo int) { + nBytes := 0 + for i := 0; i < nLines; i++ { + line := <-w.lines + nBytes += line.Len() + w.write(line) + if upTo >= 0 && nBytes > upTo { + break + } + } +} + +func (w *Writer) write(line *buffer.Buffer) { + w.diskBuf.Write(line.Bytes()) + if w.cachedLimit != 0 { + atomic.AddInt64(&w.cached, -int64(line.Len())) + } + line.Free() + if w.diskBuf.Len() < diskFlushSize { + return + } + if _, err := w.w.Write(w.diskBuf.Bytes()); err != nil { + fmt.Println("failed to write", err.Error()) + } + w.diskBuf.Reset() +} diff --git a/api/go/libruntime/common/logger/async/writer_test.go b/api/go/libruntime/common/logger/async/writer_test.go new file mode 100644 index 0000000..a342683 --- /dev/null +++ b/api/go/libruntime/common/logger/async/writer_test.go @@ -0,0 +1,94 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package async makes io.Writer write async +package async + +import ( + "bytes" + "io" + "testing" + "time" + + "go.uber.org/zap/buffer" + + "github.com/smartystreets/goconvey/convey" +) + +func TestWriter(t *testing.T) { + convey.Convey( + "Test Writer", t, func() { + opt := WithCachedLimit(0) + var ioW io.Writer = &bytes.Buffer{} + w := NewAsyncWriteSyncer(ioW, opt) + convey.Convey( + "WithCachedLimit success", func() { + convey.So(opt, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "NewAsyncWriteSyncer success", func() { + convey.So(w, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Write success", func() { + w.cachedLimit = -1 + n, err := w.Write([]byte("data")) + convey.So(n, convey.ShouldEqual, 4) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Write success when default", func() { + w.lines = make(chan *buffer.Buffer, 0) + n, err := w.Write([]byte("data")) + convey.So(n, convey.ShouldEqual, 4) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Sync success", func() { + err := w.Sync() + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "logConsumer success", func() { + lp := linePool.Get() + lp.Write([]byte("data")) + w.lines <- lp + convey.So(func() { + go w.logConsumer() + time.Sleep(500 * time.Millisecond) + w.sync <- struct{}{} + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "write success", func() { + lp := linePool.Get() + for lp.Len() < diskFlushSize { + lp.Write([]byte("data")) + } + convey.So(func() { + w.write(lp) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/logger/config/config.go b/api/go/libruntime/common/logger/config/config.go new file mode 100644 index 0000000..a3a23e0 --- /dev/null +++ b/api/go/libruntime/common/logger/config/config.go @@ -0,0 +1,132 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config is common logger client +package config + +import ( + "encoding/json" + "errors" + "os" + "sync" + + "github.com/asaskevich/govalidator/v11" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/utils" +) + +const ( + configPath = "/home/sn/config/log.json" + fileMode = 0750 + defaultLogLevel = "INFO" + // LogConfigKey environment variable key of log config + LogConfigKey = "LOG_CONFIG" +) + +var ( + defaultCoreInfo CoreInfo + onceSetDefault sync.Once + // LogLevel - + LogLevel zapcore.Level = zapcore.InfoLevel + // LogLevelFromFlag - + LogLevelFromFlag string +) + +func initDefaultCoreInfo() { + defaultFilePath := os.Getenv("GLOG_log_dir") + if defaultFilePath == "" { + defaultFilePath = "/home/snuser/log" + } + logLevel := os.Getenv("YR_LOG_LEVEL") + if logLevel == "" { + logLevel = defaultLogLevel + } + // defaultCoreInfo default logger config + defaultCoreInfo = CoreInfo{ + FilePath: defaultFilePath, + Level: logLevel, + Tick: 0, // Unit: Second + First: 0, // Unit: Number of logs + Thereafter: 0, // Unit: Number of logs + SingleSize: 100, + Threshold: 10, + Tracing: false, // tracing log switch + Disable: false, // Disable file logger + } +} + +// CoreInfo contains the core info +type CoreInfo struct { + FilePath string `json:"filepath" valid:",required"` + Level string `json:"level" valid:",required"` + Tick int `json:"tick" valid:"range(0|86400),optional"` + First int `json:"first" valid:"range(0|20000),optional"` + Thereafter int `json:"thereafter" valid:"range(0|1000),optional"` + Tracing bool `json:"tracing" valid:",optional"` + Disable bool `json:"disable" valid:",optional"` + SingleSize int64 `json:"singlesize" valid:",optional"` + Threshold int `json:"threshold" valid:",optional"` + IsUserLog bool `json:"-"` +} + +// GetDefaultCoreInfo get defaultCoreInfo +func GetDefaultCoreInfo() CoreInfo { + onceSetDefault.Do(func() { + initDefaultCoreInfo() + }) + return defaultCoreInfo +} + +// GetCoreInfoFromEnv extracts the logger config and ensures that the log file is available +func GetCoreInfoFromEnv() (CoreInfo, error) { + coreInfo, err := ExtractCoreInfoFromEnv(LogConfigKey) + if err != nil { + return defaultCoreInfo, err + } + if err = utils.ValidateFilePath(coreInfo.FilePath); err != nil { + return defaultCoreInfo, err + } + if err = os.MkdirAll(coreInfo.FilePath, fileMode); err != nil && !os.IsExist(err) { + return defaultCoreInfo, err + } + + return coreInfo, nil +} + +// ExtractCoreInfoFromEnv extracts the logger config from ENV +func ExtractCoreInfoFromEnv(env string) (CoreInfo, error) { + var coreInfo CoreInfo + config := os.Getenv(env) + if config == "" { + return defaultCoreInfo, errors.New(env + " is empty") + } + err := json.Unmarshal([]byte(config), &coreInfo) + if err != nil { + return defaultCoreInfo, err + } + + // if the file path is empty, return error + // if the log file is not writable, zap will create a new file with the configured file path and file name + if coreInfo.FilePath == "" { + return defaultCoreInfo, errors.New("log file path is empty") + } + if _, err = govalidator.ValidateStruct(coreInfo); err != nil { + return defaultCoreInfo, err + } + + return coreInfo, nil +} diff --git a/api/go/libruntime/common/logger/config/config_test.go b/api/go/libruntime/common/logger/config/config_test.go new file mode 100644 index 0000000..f6440ed --- /dev/null +++ b/api/go/libruntime/common/logger/config/config_test.go @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config is common logger client +package config + +import ( + "encoding/json" + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestGetDefaultCoreInfo(t *testing.T) { + convey.Convey( + "Test GetDefaultCoreInfo", t, func() { + coreInfo := GetDefaultCoreInfo() + convey.Convey( + "GetDefaultCoreInfo success", func() { + convey.So(coreInfo.FilePath, convey.ShouldNotBeEmpty) + }, + ) + convey.Convey( + "GetCoreInfoFromEnv success when config==nil", func() { + coreInfo, err := GetCoreInfoFromEnv() + convey.So(coreInfo.FilePath, convey.ShouldNotBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "GetCoreInfoFromEnv success", func() { + bytes, _ := json.Marshal(coreInfo) + os.Setenv(LogConfigKey, string(bytes)) + coreInfo, err := GetCoreInfoFromEnv() + convey.So(coreInfo.FilePath, convey.ShouldNotBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GetCoreInfoFromEnv success when coreInfo.FilePath is empty", func() { + coreInfo.FilePath = "" + bytes, _ := json.Marshal(coreInfo) + os.Setenv(LogConfigKey, string(bytes)) + coreInfo, err := GetCoreInfoFromEnv() + convey.So(coreInfo.FilePath, convey.ShouldNotBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/logger/custom_encoder.go b/api/go/libruntime/common/logger/custom_encoder.go new file mode 100644 index 0000000..f9a1837 --- /dev/null +++ b/api/go/libruntime/common/logger/custom_encoder.go @@ -0,0 +1,412 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger log +package logger + +import ( + "math" + "os" + "regexp" + "strings" + "sync" + "time" + + "go.uber.org/zap/buffer" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/constants" +) + +const ( + float64bitSize = 64 + float32bitSize = 32 + headerSeparator = ' ' + elementSeparator = " " + customDefaultLineEnding = "\n" + logMsgMaxLen = 1024 * 2 + fieldSeparator = " | " +) + +var ( + _customBufferPool = buffer.NewPool() + + _customPool = sync.Pool{New: func() interface{} { + return &customEncoder{} + }} + + replComp = regexp.MustCompile(`\s+`) + + clusterName = os.Getenv("CLUSTER_ID") +) + +// customEncoder represents the encoder for zap logger +// project's interface log +type customEncoder struct { + *zapcore.EncoderConfig + buf *buffer.Buffer + podName string + traceID string +} + +func getPodName() string { + podName := os.Getenv(constants.HostNameEnvKey) + if os.Getenv(constants.PodNameEnvKey) != "" { + podName = os.Getenv(constants.PodNameEnvKey) + } + return podName +} + +// NewConsoleEncoder new custom console encoder to zap log module +func NewConsoleEncoder(cfg zapcore.EncoderConfig) (zapcore.Encoder, error) { + return &customEncoder{ + EncoderConfig: &cfg, + buf: _customBufferPool.Get(), + podName: getPodName(), + }, nil +} + +// NewCustomEncoder new custom encoder to zap log module +func NewCustomEncoder(cfg *zapcore.EncoderConfig) zapcore.Encoder { + return &customEncoder{ + EncoderConfig: cfg, + buf: _customBufferPool.Get(), + podName: getPodName(), + } +} + +// Clone return zap core Encoder +func (enc *customEncoder) Clone() zapcore.Encoder { + clone := enc.clone() + if enc.buf.Len() > 0 { + _, _ = clone.buf.Write(enc.buf.Bytes()) + } + return clone +} + +// EncodeEntry - +func (enc *customEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + final := enc.clone() + // add time + final.AppendString(ent.Time.UTC().Format("2006-01-02 15:04:05.000")) + final.buf.AppendString(fieldSeparator) + + final.EncodeLevel(ent.Level, final) + final.buf.AppendString(fieldSeparator) + + // add caller + if ent.Caller.Defined { + final.EncodeCaller(ent.Caller, final) + final.buf.AppendString(fieldSeparator) + } + // add podName + if enc.podName != "" { + final.buf.AppendString(enc.podName) + } + final.buf.AppendString(fieldSeparator) + // add clusterName + if clusterName != "" { + final.buf.AppendString(clusterName) + } + final.buf.AppendString(fieldSeparator) + if enc.traceID != "" { + final.buf.AppendString(enc.traceID) + } + final.buf.AppendString(fieldSeparator) + if enc.buf.Len() > 0 { + final.buf.Write(enc.buf.Bytes()) + } + // add msg + if len(ent.Message) > logMsgMaxLen { + final.AppendString(ent.Message[0:logMsgMaxLen]) + } else { + final.AppendString(ent.Message) + } + if ent.Stack != "" && final.StacktraceKey != "" { + final.buf.AppendString(elementSeparator) + final.AddString(final.StacktraceKey, ent.Stack) + } + for _, field := range fields { + field.AddTo(final) + } + final.buf.AppendString(customDefaultLineEnding) + ret := final.buf + putCustomEncoder(final) + return ret, nil +} + +func putCustomEncoder(enc *customEncoder) { + enc.EncoderConfig = nil + enc.buf = nil + _customPool.Put(enc) +} + +func getCustomEncoder() *customEncoder { + return _customPool.Get().(*customEncoder) +} + +func (enc *customEncoder) clone() *customEncoder { + clone := getCustomEncoder() + clone.buf = _customBufferPool.Get() + clone.EncoderConfig = enc.EncoderConfig + clone.podName = enc.podName + clone.traceID = enc.traceID + return clone +} + +func (enc *customEncoder) writeField(k string, writeVal func()) *customEncoder { + enc.buf.AppendString("(" + k + ":") + writeVal() + enc.buf.AppendString(")") + return enc +} + +// AddArray Add Array +func (enc *customEncoder) AddArray(k string, marshaler zapcore.ArrayMarshaler) error { + return nil +} + +// AddObject Add Object +func (enc *customEncoder) AddObject(k string, marshaler zapcore.ObjectMarshaler) error { + return nil +} + +// AddBinary Add Binary +func (enc *customEncoder) AddBinary(k string, v []byte) { + enc.AddString(k, string(v)) +} + +// AddByteString Add Byte String +func (enc *customEncoder) AddByteString(k string, v []byte) { + enc.AddString(k, string(v)) +} + +// AddBool Add Bool +func (enc *customEncoder) AddBool(k string, v bool) { + enc.writeField(k, func() { + enc.AppendBool(v) + }) +} + +// AddComplex128 Add Complex128 +func (enc *customEncoder) AddComplex128(k string, val complex128) {} + +// AddComplex64 Add Complex64 +func (enc *customEncoder) AddComplex64(k string, v complex64) {} + +// AddDuration Add Duration +func (enc *customEncoder) AddDuration(k string, val time.Duration) { + enc.writeField(k, func() { + enc.AppendString(val.String()) + }) +} + +// AddFloat64 Add Float64 +func (enc *customEncoder) AddFloat64(k string, val float64) { + enc.writeField(k, func() { + enc.AppendFloat64(val) + }) +} + +// AddFloat32 Add Float32 +func (enc *customEncoder) AddFloat32(k string, v float32) { + enc.writeField(k, func() { + enc.AppendFloat64(float64(v)) + }) +} + +// AddInt Add Int +func (enc *customEncoder) AddInt(k string, v int) { + enc.writeField(k, func() { + enc.AppendInt64(int64(v)) + }) +} + +// AddInt64 Add Int64 +func (enc *customEncoder) AddInt64(k string, val int64) { + enc.writeField(k, func() { + enc.AppendInt64(val) + }) +} + +// AddInt32 Add Int32 +func (enc *customEncoder) AddInt32(k string, v int32) { + enc.writeField(k, func() { + enc.AppendInt64(int64(v)) + }) +} + +// AddInt16 Add Int16 +func (enc *customEncoder) AddInt16(k string, v int16) { + enc.writeField(k, func() { + enc.AppendInt64(int64(v)) + }) +} + +// AddInt8 Add Int8 +func (enc *customEncoder) AddInt8(k string, v int8) { + enc.writeField(k, func() { + enc.AppendInt64(int64(v)) + }) +} + +// AddString Append String +func (enc *customEncoder) AddString(k, v string) { + if k == "traceID" || k == "traceId" { + enc.traceID = v + return + } + enc.writeField(k, func() { + v = replComp.ReplaceAllString(v, " ") + if strings.Contains(v, " ") { + enc.buf.AppendString("(" + v + ")") + return + } + enc.AppendString(v) + }) +} + +// AddTime Add Time +func (enc *customEncoder) AddTime(k string, v time.Time) { + enc.writeField(k, func() { + enc.AppendString(v.UTC().Format("2006-01-02 15:04:05.000")) + }) +} + +// AddUint Add Uint +func (enc *customEncoder) AddUint(k string, v uint) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddUint64 Add Uint64 +func (enc *customEncoder) AddUint64(k string, v uint64) { + enc.writeField(k, func() { + enc.AppendUint64(v) + }) +} + +// AddUint32 Add Uint32 +func (enc *customEncoder) AddUint32(k string, v uint32) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddUint16 Add Uint16 +func (enc *customEncoder) AddUint16(k string, v uint16) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddUint8 Add Uint8 +func (enc *customEncoder) AddUint8(k string, v uint8) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddUintptr Add Uint ptr +func (enc *customEncoder) AddUintptr(k string, v uintptr) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddReflected uses reflection to serialize arbitrary objects, so it's slow +// and allocation-heavy. +func (enc *customEncoder) AddReflected(k string, v interface{}) error { + return nil +} + +// OpenNamespace opens an isolated namespace where all subsequent fields will +// be added. Applications can use namespaces to prevent key collisions when +// injecting loggers into sub-components or third-party libraries. +func (enc *customEncoder) OpenNamespace(k string) {} + +// AppendBool Append Bool +func (enc *customEncoder) AppendBool(v bool) { enc.buf.AppendBool(v) } + +// AppendByteString Append Byte String +func (enc *customEncoder) AppendByteString(v []byte) { enc.AppendString(string(v)) } + +// AppendComplex128 Append Complex128 +func (enc *customEncoder) AppendComplex128(v complex128) {} + +// AppendComplex64 Append Complex64 +func (enc *customEncoder) AppendComplex64(v complex64) {} + +// AppendFloat64 Append Float64 +func (enc *customEncoder) AppendFloat64(v float64) { enc.appendFloat(v, float64bitSize) } + +// AppendFloat32 Append Float32 +func (enc *customEncoder) AppendFloat32(v float32) { enc.appendFloat(float64(v), float32bitSize) } + +func (enc *customEncoder) appendFloat(v float64, bitSize int) { + switch { + // If the condition is not met, a string is returned to prevent blankness. + // IsNaN reports whether f is an IEEE 754 ``not-a-number'' value. + case math.IsNaN(v): + enc.buf.AppendString(`"NaN"`) + // IsInf reports whether f is an infinity, according to sign + case math.IsInf(v, 1): + // IsInf reports whether f is positive infinity + enc.buf.AppendString(`"+Inf"`) + // IsInf reports whether f is negative infinity + case math.IsInf(v, -1): + enc.buf.AppendString(`"-Inf"`) + default: + enc.buf.AppendFloat(v, bitSize) + } +} + +// AppendInt Append Int +func (enc *customEncoder) AppendInt(v int) { enc.buf.AppendInt(int64(v)) } + +// AppendInt64 Append Int64 +func (enc *customEncoder) AppendInt64(v int64) { enc.buf.AppendInt(v) } + +// AppendInt32 Append Int32 +func (enc *customEncoder) AppendInt32(v int32) { enc.buf.AppendInt(int64(v)) } + +// AppendInt16 Append Int16 +func (enc *customEncoder) AppendInt16(v int16) { enc.buf.AppendInt(int64(v)) } + +// AppendInt8 Append Int8 +func (enc *customEncoder) AppendInt8(v int8) { enc.buf.AppendInt(int64(v)) } + +// AppendString Append String +func (enc *customEncoder) AppendString(val string) { enc.buf.AppendString(val) } + +// AppendUint Append Uint +func (enc *customEncoder) AppendUint(v uint) { enc.buf.AppendUint(uint64(v)) } + +// AppendUint64 Append Uint64 +func (enc *customEncoder) AppendUint64(v uint64) { enc.buf.AppendUint(v) } + +// AppendUint32 Append Uint32 +func (enc *customEncoder) AppendUint32(v uint32) { enc.buf.AppendUint(uint64(v)) } + +// AppendUint16 Append Uint16 +func (enc *customEncoder) AppendUint16(v uint16) { enc.buf.AppendUint(uint64(v)) } + +// AppendUint8 Append Uint8 +func (enc *customEncoder) AppendUint8(v uint8) { enc.buf.AppendUint(uint64(v)) } + +// AppendUintptr Append Uint ptr +func (enc *customEncoder) AppendUintptr(v uintptr) { enc.buf.AppendUint(uint64(v)) } diff --git a/api/go/libruntime/common/logger/custom_encoder_test.go b/api/go/libruntime/common/logger/custom_encoder_test.go new file mode 100644 index 0000000..86ffecd --- /dev/null +++ b/api/go/libruntime/common/logger/custom_encoder_test.go @@ -0,0 +1,447 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger for log +package logger + +import ( + "math" + "os" + "testing" + "time" + "unsafe" + + "github.com/smartystreets/goconvey/convey" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/constants" +) + +func TestNewConsoleEncoder(t *testing.T) { + convey.Convey( + "Test NewConsoleEncoder", t, func() { + convey.Convey( + "NewConsoleEncoder success", func() { + enc, err := NewConsoleEncoder(zapcore.EncoderConfig{}) + convey.So(enc, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestNewCustomEncoder(t *testing.T) { + convey.Convey( + "Test NewCustomEncoder", t, func() { + convey.Convey( + "NewCustomEncoder success", func() { + enc := NewCustomEncoder(&zapcore.EncoderConfig{}) + convey.So(enc, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestPutCustomEncoder(t *testing.T) { + convey.Convey( + "Test putCustomEncoder", t, func() { + convey.Convey( + "putCustomEncoder success", func() { + enc := &customEncoder{ + buf: _customBufferPool.Get(), + } + convey.So(func() { + putCustomEncoder(enc) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestClone(t *testing.T) { + convey.Convey( + "Test Clone", t, func() { + convey.Convey( + "Clone success", func() { + ce := &customEncoder{ + buf: _customBufferPool.Get(), + } + ce.buf.Write([]byte{0}) + clone := ce.Clone() + convey.So(clone, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestEncodeEntry(t *testing.T) { + convey.Convey( + "Test EncodeEntry", t, func() { + convey.Convey( + "EncodeEntry success", func() { + os.Setenv(constants.PodNameEnvKey, "pod_name1") + enc := &customEncoder{ + EncoderConfig: &zapcore.EncoderConfig{}, + buf: _customBufferPool.Get(), + podName: getPodName(), + traceID: "traceID1", + } + enc.buf.Write([]byte{0}) + // enc.EncodeCaller = func(caller zapcore.EntryCaller, aEnc zapcore.PrimitiveArrayEncoder){ + // return + // } + field := zapcore.Field{ + Key: "key", + Type: zapcore.StringType, + String: "value", + } + ent := zapcore.Entry{ + Level: zapcore.InfoLevel, + Time: time.Now(), + LoggerName: "loggerName1", + Message: "msg1", + Caller: zapcore.EntryCaller{Defined: false}, + Stack: "Stack1", + } + convey.So(func() { + enc.EncodeEntry(ent, []zapcore.Field{field}) + }, convey.ShouldPanic) + }, + ) + }, + ) +} + +func TestEncoderAdd(t *testing.T) { + convey.Convey( + "Test EncoderAdd", t, func() { + enc := &customEncoder{ + buf: _customBufferPool.Get(), + } + convey.Convey( + "AddArray success", func() { + err := enc.AddArray("key", nil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "AddObject success", func() { + err := enc.AddObject("key", nil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "AddBinary success", func() { + convey.So(func() { + enc.AddBinary("key", []byte("")) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddByteString success", func() { + convey.So(func() { + enc.AddByteString("traceID", []byte("")) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddBool success", func() { + convey.So(func() { + enc.AddBool("key", true) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddComplex128 success", func() { + convey.So(func() { + enc.AddComplex128("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddComplex64 success", func() { + convey.So(func() { + enc.AddComplex64("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddDuration success", func() { + convey.So(func() { + enc.AddDuration("key", time.Duration(1)) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddFloat64 success", func() { + convey.So(func() { + enc.AddFloat64("key", 1.0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddFloat32 success", func() { + convey.So(func() { + enc.AddFloat32("key", 1.0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt success", func() { + convey.So(func() { + enc.AddInt("key", 1) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt64 success", func() { + convey.So(func() { + enc.AddInt64("key", 1) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt32 success", func() { + convey.So(func() { + enc.AddInt32("key", 1) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt16 success", func() { + convey.So(func() { + enc.AddInt16("key", 1) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt8 success", func() { + convey.So(func() { + enc.AddInt8("key", 1) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddString success", func() { + convey.So(func() { + enc.AddString("key", " ") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddTime success", func() { + convey.So(func() { + enc.AddTime("key", time.Now()) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint success", func() { + convey.So(func() { + enc.AddUint("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint64 success", func() { + convey.So(func() { + enc.AddUint64("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint32 success", func() { + convey.So(func() { + enc.AddUint32("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint16 success", func() { + convey.So(func() { + enc.AddUint16("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint8 success", func() { + convey.So(func() { + enc.AddUint8("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUintptr success", func() { + convey.So(func() { + s := "hello" + addr := unsafe.Pointer(&s) + ptr := uintptr(addr) + enc.AddUintptr("key", ptr) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddReflected success", func() { + err := enc.AddReflected("key", struct{}{}) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestOpenNamespace(t *testing.T) { + convey.Convey( + "Test OpenNamespace", t, func() { + convey.Convey( + "OpenNamespace success", func() { + enc := NewCustomEncoder(&zapcore.EncoderConfig{}) + convey.So(func() { + enc.OpenNamespace("key") + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestEncoderAppend(t *testing.T) { + convey.Convey( + "Test EncoderAppend", t, func() { + enc := &customEncoder{ + buf: _customBufferPool.Get(), + } + convey.Convey( + "AppendBool success", func() { + convey.So(func() { + enc.AppendBool(true) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendByteString success", func() { + convey.So(func() { + enc.AppendByteString([]byte("value")) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendComplex128 success", func() { + convey.So(func() { + enc.AppendComplex128(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendComplex64 success", func() { + convey.So(func() { + enc.AppendComplex64(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendFloat32 success", func() { + convey.So(func() { + enc.AppendFloat32(1.0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "appendFloat success", func() { + convey.So(func() { + enc.appendFloat(math.NaN(), 8) + }, convey.ShouldNotPanic) + convey.So(func() { + enc.appendFloat(math.Inf(1), 8) + }, convey.ShouldNotPanic) + convey.So(func() { + enc.appendFloat(math.Inf(-1), 8) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendInt success", func() { + convey.So(func() { + enc.AppendInt(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendInt32 success", func() { + convey.So(func() { + enc.AppendInt32(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendInt16 success", func() { + convey.So(func() { + enc.AppendInt16(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendInt8 success", func() { + convey.So(func() { + enc.AppendInt8(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint success", func() { + convey.So(func() { + enc.AppendUint(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint32 success", func() { + convey.So(func() { + enc.AppendUint32(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint16 success", func() { + convey.So(func() { + enc.AppendUint16(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint8 success", func() { + convey.So(func() { + enc.AppendUint8(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUintptr success", func() { + convey.So(func() { + s := "hello" + addr := unsafe.Pointer(&s) + ptr := uintptr(addr) + enc.AppendUintptr(ptr) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/logger/interface_encoder.go b/api/go/libruntime/common/logger/interface_encoder.go new file mode 100644 index 0000000..6fce8b8 --- /dev/null +++ b/api/go/libruntime/common/logger/interface_encoder.go @@ -0,0 +1,349 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger log +package logger + +import ( + "errors" + "math" + "os" + "sync" + "time" + + "go.uber.org/zap/buffer" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/constants" +) + +var ( + _bufferPool = buffer.NewPool() + + _interfacePool = sync.Pool{New: interfaceFunc} +) + +// InterfaceEncoderConfig holds interface log encoder config +type InterfaceEncoderConfig struct { + ModuleName string + HTTPMethod string + ModuleFrom string + TenantID string + FuncName string + FuncVer string + EncodeCaller zapcore.CallerEncoder +} + +// interfaceEncoder represents the encoder for interface log +// project's interface log +type interfaceEncoder struct { + *InterfaceEncoderConfig + buf *buffer.Buffer + podName string + spaced bool +} + +func interfaceFunc() interface{} { + return &interfaceEncoder{} +} + +func getInterfaceEncoder() *interfaceEncoder { + return _interfacePool.Get().(*interfaceEncoder) +} + +func putInterfaceEncoder(enc *interfaceEncoder) { + enc.InterfaceEncoderConfig = nil + enc.spaced = false + enc.buf = nil + _interfacePool.Put(enc) +} + +// NewInterfaceEncoder create a new interface log encoder +func NewInterfaceEncoder(cfg InterfaceEncoderConfig, spaced bool) zapcore.Encoder { + return newInterfaceEncoder(cfg, spaced) +} + +func newInterfaceEncoder(cfg InterfaceEncoderConfig, spaced bool) *interfaceEncoder { + return &interfaceEncoder{ + InterfaceEncoderConfig: &cfg, + buf: _bufferPool.Get(), + spaced: spaced, + podName: os.Getenv(constants.HostNameEnvKey), + } +} + +// Clone return zap core Encoder +func (enc *interfaceEncoder) Clone() zapcore.Encoder { + return enc.clone() +} + +func (enc *interfaceEncoder) clone() *interfaceEncoder { + clone := getInterfaceEncoder() + clone.InterfaceEncoderConfig = enc.InterfaceEncoderConfig + clone.spaced = enc.spaced + clone.buf = _bufferPool.Get() + return clone +} + +// EncodeEntry Encode Entry +func (enc *interfaceEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + final := enc.clone() + final.buf.AppendByte('[') + // add time + final.AppendString(ent.Time.UTC().Format("2006-01-02 15:04:05.000")) + // add level + // Level of interfaceLog is eternally INFO + final.buf.AppendByte(headerSeparator) + final.AppendString("INFO") + // add caller + if ent.Caller.Defined { + final.buf.AppendByte(headerSeparator) + final.EncodeCaller(ent.Caller, final) + } + final.buf.AppendByte(']') + final.buf.AppendByte(headerSeparator) + final.buf.AppendByte('[') + // add podName + if enc.podName != "" { + final.buf.AppendString(enc.podName) + } + final.buf.AppendByte(']') + final.buf.AppendByte(headerSeparator) + if enc.buf.Len() > 0 { + final.buf.Write(enc.buf.Bytes()) + } + // add msg + final.AppendString(ent.Message) + for _, field := range fields { + field.AddTo(final) + } + final.buf.AppendString(customDefaultLineEnding) + ret := final.buf + putInterfaceEncoder(final) + return ret, nil +} + +// AddString Append String +func (enc *interfaceEncoder) AddString(key, val string) { + enc.buf.AppendString(val) +} + +// AppendString Append String +func (enc *interfaceEncoder) AppendString(val string) { + enc.buf.AppendString(val) +} + +// AddDuration Add Duration +func (enc *interfaceEncoder) AddDuration(key string, val time.Duration) { + enc.AppendDuration(val) +} + +func (enc *interfaceEncoder) addElementSeparator() { + last := enc.buf.Len() - 1 + if last < 0 { + return + } + switch enc.buf.Bytes()[last] { + case headerSeparator: + return + default: + enc.buf.AppendByte(headerSeparator) + if enc.spaced { + enc.buf.AppendByte(' ') + } + } +} + +// AppendTime Append Time +func (enc *interfaceEncoder) AppendTime(val time.Time) { + cur := enc.buf.Len() + interfaceTimeEncode(val, enc) + if cur == enc.buf.Len() { + // User-supplied EncodeTime is a no-op. Fall back to nanos since epoch to keep + // output JSON valid. + enc.AppendInt64(val.UnixNano()) + } +} + +// AddArray Add Array +func (enc *interfaceEncoder) AddArray(key string, marshaler zapcore.ArrayMarshaler) error { + return errors.New("unsupported method") +} + +// AddObject Add Object +func (enc *interfaceEncoder) AddObject(key string, marshaler zapcore.ObjectMarshaler) error { + return errors.New("unsupported method") +} + +// AddBinary Add Binary +func (enc *interfaceEncoder) AddBinary(key string, value []byte) {} + +// AddByteString Add Byte String +func (enc *interfaceEncoder) AddByteString(key string, val []byte) { + enc.AppendByteString(val) +} + +// AddBool Add Bool +func (enc *interfaceEncoder) AddBool(key string, value bool) {} + +// AddComplex64 Add Complex64 +func (enc *interfaceEncoder) AddComplex64(k string, v complex64) { enc.AddComplex128(k, complex128(v)) } + +// AddFloat32 Add Float32 +func (enc *interfaceEncoder) AddFloat32(k string, v float32) { enc.AddFloat64(k, float64(v)) } + +// AddInt Add Int +func (enc *interfaceEncoder) AddInt(k string, v int) { enc.AddInt64(k, int64(v)) } + +// AddInt32 Add Int32 +func (enc *interfaceEncoder) AddInt32(k string, v int32) { enc.AddInt64(k, int64(v)) } + +// AddInt16 Add Int16 +func (enc *interfaceEncoder) AddInt16(k string, v int16) { enc.AddInt64(k, int64(v)) } + +// AddInt8 Add Int8 +func (enc *interfaceEncoder) AddInt8(k string, v int8) { enc.AddInt64(k, int64(v)) } + +// AddUint Add Uint +func (enc *interfaceEncoder) AddUint(k string, v uint) { enc.AddUint64(k, uint64(v)) } + +// AddUint32 Add Uint32 +func (enc *interfaceEncoder) AddUint32(k string, v uint32) { enc.AddUint64(k, uint64(v)) } + +// AddUint16 Add Uint16 +func (enc *interfaceEncoder) AddUint16(k string, v uint16) { enc.AddUint64(k, uint64(v)) } + +// AddUint8 Add Uint8 +func (enc *interfaceEncoder) AddUint8(k string, v uint8) { enc.AddUint64(k, uint64(v)) } + +// AddUintptr Add Uint ptr +func (enc *interfaceEncoder) AddUintptr(k string, v uintptr) { enc.AddUint64(k, uint64(v)) } + +// AddComplex128 Add Complex128 +func (enc *interfaceEncoder) AddComplex128(key string, val complex128) { + enc.AppendComplex128(val) +} + +// AddFloat64 Add Float64 +func (enc *interfaceEncoder) AddFloat64(key string, val float64) { + enc.AppendFloat64(val) +} + +// AddInt64 Add Int64 +func (enc *interfaceEncoder) AddInt64(key string, val int64) { + enc.AppendInt64(val) +} + +// AddTime Add Time +func (enc *interfaceEncoder) AddTime(key string, value time.Time) { + enc.AppendTime(value) +} + +// AddUint64 Add Uint64 +func (enc *interfaceEncoder) AddUint64(key string, value uint64) {} + +// AddReflected uses reflection to serialize arbitrary objects, so it's slow +// and allocation-heavy. +func (enc *interfaceEncoder) AddReflected(key string, value interface{}) error { + return nil +} + +// OpenNamespace opens an isolated namespace where all subsequent fields will +// be added. Applications can use namespaces to prevent key collisions when +// injecting loggers into sub-components or third-party libraries. +func (enc *interfaceEncoder) OpenNamespace(key string) {} + +// AppendComplex128 Append Complex128 +func (enc *interfaceEncoder) AppendComplex128(val complex128) {} + +// AppendInt64 Append Int64 +func (enc *interfaceEncoder) AppendInt64(val int64) { + enc.addElementSeparator() + enc.buf.AppendInt(val) +} + +// AppendBool Append Bool +func (enc *interfaceEncoder) AppendBool(val bool) { + enc.addElementSeparator() + enc.buf.AppendBool(val) +} + +func (enc *interfaceEncoder) appendFloat(val float64, bitSize int) { + enc.addElementSeparator() + switch { + case math.IsNaN(val): + enc.buf.AppendString(`"NaN"`) + case math.IsInf(val, 1): + enc.buf.AppendString(`"+Inf"`) + case math.IsInf(val, -1): + enc.buf.AppendString(`"-Inf"`) + default: + enc.buf.AppendFloat(val, bitSize) + } +} + +// AppendUint64 Append Uint64 +func (enc *interfaceEncoder) AppendUint64(val uint64) { + enc.addElementSeparator() + enc.buf.AppendUint(val) +} + +// AppendByteString Append Byte String +func (enc *interfaceEncoder) AppendByteString(val []byte) {} + +// AppendDuration Append Duration +func (enc *interfaceEncoder) AppendDuration(val time.Duration) {} + +// AppendComplex64 Append Complex64 +func (enc *interfaceEncoder) AppendComplex64(v complex64) { enc.AppendComplex128(complex128(v)) } + +// AppendFloat64 Append Float64 +func (enc *interfaceEncoder) AppendFloat64(v float64) { enc.appendFloat(v, float64bitSize) } + +// AppendFloat32 Append Float32 +func (enc *interfaceEncoder) AppendFloat32(v float32) { enc.appendFloat(float64(v), float32bitSize) } + +// AppendInt Append Int +func (enc *interfaceEncoder) AppendInt(v int) { enc.AppendInt64(int64(v)) } + +// AppendInt32 Append Int32 +func (enc *interfaceEncoder) AppendInt32(v int32) { enc.AppendInt64(int64(v)) } + +// AppendInt16 Append Int16 +func (enc *interfaceEncoder) AppendInt16(v int16) { enc.AppendInt64(int64(v)) } + +// AppendInt8 Append Int8 +func (enc *interfaceEncoder) AppendInt8(v int8) { enc.AppendInt64(int64(v)) } + +// AppendUint Append Uint +func (enc *interfaceEncoder) AppendUint(v uint) { enc.AppendUint64(uint64(v)) } + +// AppendUint32 Append Uint32 +func (enc *interfaceEncoder) AppendUint32(v uint32) { enc.AppendUint64(uint64(v)) } + +// AppendUint16 Append Uint16 +func (enc *interfaceEncoder) AppendUint16(v uint16) { enc.AppendUint64(uint64(v)) } + +// AppendUint8 Append Uint8 +func (enc *interfaceEncoder) AppendUint8(v uint8) { enc.AppendUint64(uint64(v)) } + +// AppendUintptr Append Uint ptr +func (enc *interfaceEncoder) AppendUintptr(v uintptr) { enc.AppendUint64(uint64(v)) } + +func interfaceTimeEncode(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + t = t.UTC() + enc.AppendString(t.Format("2006-01-02 15:04:05.000")) +} diff --git a/api/go/libruntime/common/logger/interface_encoder_test.go b/api/go/libruntime/common/logger/interface_encoder_test.go new file mode 100644 index 0000000..a4e6146 --- /dev/null +++ b/api/go/libruntime/common/logger/interface_encoder_test.go @@ -0,0 +1,433 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger for log +package logger + +import ( + "math" + "testing" + "time" + "unsafe" + + "github.com/smartystreets/goconvey/convey" + "go.uber.org/zap/buffer" + "go.uber.org/zap/zapcore" +) + +func TestInterfaceFunc(t *testing.T) { + convey.Convey( + "Test interfaceFunc", t, func() { + convey.Convey( + "interfaceFunc success", func() { + enc := interfaceFunc() + convey.So(enc, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestGetPutInterfaceEncoder(t *testing.T) { + convey.Convey( + "Test GetPutInterfaceEncoder", t, func() { + enc := getInterfaceEncoder() + convey.Convey( + "getInterfaceEncoder success", func() { + convey.So(enc, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "putInterfaceEncoder success", func() { + convey.So(func() { + putInterfaceEncoder(enc) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestNewInterfaceEncoder(t *testing.T) { + convey.Convey( + "Test NewInterfaceEncoder", t, func() { + convey.Convey( + "NewInterfaceEncoder success", func() { + enc := NewInterfaceEncoder(InterfaceEncoderConfig{}, false) + convey.So(enc, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestInterfaceEncoderClone(t *testing.T) { + convey.Convey( + "Test InterfaceEncoderClone", t, func() { + convey.Convey( + "Clone success", func() { + enc := NewInterfaceEncoder(InterfaceEncoderConfig{}, false) + clone := enc.Clone() + convey.So(clone, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestInterfaceEncodeEntry(t *testing.T) { + convey.Convey( + "Test InterfaceEncodeEntry", t, func() { + convey.Convey( + "EncodeEntry success", func() { + enc := newInterfaceEncoder(InterfaceEncoderConfig{}, false) + enc.buf.Write([]byte{0}) + enc.EncodeCaller = func(caller zapcore.EntryCaller, aEnc zapcore.PrimitiveArrayEncoder) { + return + } + field := zapcore.Field{ + Key: "key", + Type: zapcore.StringType, + String: "value", + } + ent := zapcore.Entry{Caller: zapcore.EntryCaller{Defined: true}} + buf, err := enc.EncodeEntry(ent, []zapcore.Field{field}) + convey.So(buf, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestInterfaceEncoderAdd(t *testing.T) { + convey.Convey( + "Test InterfaceEncoderAdd", t, func() { + enc := NewInterfaceEncoder(InterfaceEncoderConfig{}, false) + convey.Convey( + "AddString success", func() { + convey.So(func() { + enc.AddString("key", "val") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddDuration success", func() { + convey.So(func() { + enc.AddDuration("key", time.Second) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "addElementSeparator success", func() { + ent := &interfaceEncoder{ + buf: &buffer.Buffer{}, + } + ent.buf.Write([]byte("hello")) + ent.spaced = true + convey.So(ent.addElementSeparator, convey.ShouldNotPanic) + ent.buf.Write([]byte{' '}) + convey.So(ent.addElementSeparator, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddArray success", func() { + var f zapcore.ArrayMarshalerFunc = func(zapcore.ArrayEncoder) error { + return nil + } + convey.So(func() { + enc.AddArray("key", f) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddObject success", func() { + var f zapcore.ObjectMarshalerFunc = func(zapcore.ObjectEncoder) error { + return nil + } + convey.So(func() { + enc.AddObject("key", f) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddBinary success", func() { + convey.So(func() { + enc.AddBinary("key", []byte("val")) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddByteString success", func() { + convey.So(func() { + enc.AddByteString("key", []byte("val")) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddBool success", func() { + convey.So(func() { + enc.AddBool("key", true) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddComplex64 success", func() { + convey.So(func() { + enc.AddComplex64("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddFloat32 success", func() { + convey.So(func() { + enc.AddFloat32("key", 1.0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt success", func() { + convey.So(func() { + enc.AddInt("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt32 success", func() { + convey.So(func() { + enc.AddInt32("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt16 success", func() { + convey.So(func() { + enc.AddInt16("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt8 success", func() { + convey.So(func() { + enc.AddInt8("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint success", func() { + convey.So(func() { + enc.AddUint("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint32 success", func() { + convey.So(func() { + enc.AddUint32("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint16 success", func() { + convey.So(func() { + enc.AddUint16("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint8 success", func() { + convey.So(func() { + enc.AddUint8("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUintptr success", func() { + convey.So(func() { + s := "hello" + addr := unsafe.Pointer(&s) + ptr := uintptr(addr) + enc.AddUintptr("key", ptr) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddComplex128 success", func() { + convey.So(func() { + enc.AddComplex128("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddFloat64 success", func() { + convey.So(func() { + enc.AddFloat64("key", 1.0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddInt64 success", func() { + convey.So(func() { + enc.AddInt64("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddTime success", func() { + convey.So(func() { + enc.AddTime("key", time.Now()) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddUint64 success", func() { + convey.So(func() { + enc.AddUint64("key", 0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AddReflected success", func() { + convey.So(func() { + enc.AddReflected("key", struct{}{}) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestInterfaceEncoderAppend(t *testing.T) { + convey.Convey( + "Test InterfaceEncoderAppend", t, func() { + enc := newInterfaceEncoder(InterfaceEncoderConfig{}, false) + convey.Convey( + "OpenNamespace success", func() { + convey.So(func() { + enc.OpenNamespace("key") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendBool success", func() { + convey.So(func() { + enc.AppendBool(false) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "appendFloat success", func() { + convey.So(func() { + enc.appendFloat(math.NaN(), 8) + }, convey.ShouldNotPanic) + convey.So(func() { + enc.appendFloat(math.Inf(1), 8) + }, convey.ShouldNotPanic) + convey.So(func() { + enc.appendFloat(math.Inf(-1), 8) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint64 success", func() { + convey.So(func() { + enc.AppendUint64(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendFloat32 success", func() { + convey.So(func() { + enc.AppendFloat32(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendInt success", func() { + convey.So(func() { + enc.AppendInt(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendInt32 success", func() { + convey.So(func() { + enc.AppendInt32(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendInt16 success", func() { + convey.So(func() { + enc.AppendInt16(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendInt8 success", func() { + convey.So(func() { + enc.AppendInt8(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint success", func() { + convey.So(func() { + enc.AppendUint(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint32 success", func() { + convey.So(func() { + enc.AppendUint32(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint16 success", func() { + convey.So(func() { + enc.AppendUint16(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendComplex64 success", func() { + convey.So(func() { + enc.AppendComplex64(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUint8 success", func() { + convey.So(func() { + enc.AppendUint8(0) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "AppendUintptr success", func() { + convey.So(func() { + s := "hello" + addr := unsafe.Pointer(&s) + ptr := uintptr(addr) + enc.AppendUintptr(ptr) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/logger/interfacelogger.go b/api/go/libruntime/common/logger/interfacelogger.go new file mode 100644 index 0000000..3bf36a0 --- /dev/null +++ b/api/go/libruntime/common/logger/interfacelogger.go @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger log +package logger + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +const defaultPerm = 0666 + +// NewInterfaceLogger returns a new interface logger +func NewInterfaceLogger(logPath, fileName string, cfg InterfaceEncoderConfig) (*InterfaceLogger, error) { + coreInfo, err := config.GetCoreInfoFromEnv() + if err != nil { + coreInfo = config.GetDefaultCoreInfo() + } + filePath := filepath.Join(coreInfo.FilePath, fileName+".log") + + coreInfo.FilePath = filePath + cfg.EncodeCaller = zapcore.ShortCallerEncoder + // skip level to print caller line of origin log + const skipLevel = 5 + core, err := newCore(coreInfo, cfg) + if err != nil { + return nil, err + } + logger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(skipLevel)) + + return &InterfaceLogger{log: logger}, nil +} + +// InterfaceLogger interface logger which implements by zap logger +type InterfaceLogger struct { + log *zap.Logger +} + +// Write writes message information +func (logger *InterfaceLogger) Write(msg string) { + logger.log.Debug(msg) +} + +func newCore(coreInfo config.CoreInfo, cfg InterfaceEncoderConfig) (zapcore.Core, error) { + w, err := CreateSink(coreInfo) + if err != nil { + return nil, err + } + syncer := zapcore.AddSync(w) + + encoder := NewInterfaceEncoder(cfg, false) + + priority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + var customLevel zapcore.Level + if err := customLevel.UnmarshalText([]byte(coreInfo.Level)); err != nil { + customLevel = zapcore.InfoLevel + } + return lvl >= customLevel + }) + + return zapcore.NewCore(encoder, syncer, priority), nil +} + +// CreateSink creates a new zap log sink +func CreateSink(coreInfo config.CoreInfo) (io.Writer, error) { + // create directory if not already exist + dir := filepath.Dir(coreInfo.FilePath) + err := os.MkdirAll(dir, os.ModePerm) + if err != nil { + fmt.Printf("failed to mkdir: %s", dir) + return nil, err + } + w, err := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + if err != nil { + fmt.Printf("failed to open log file: %s, err: %s\n", coreInfo.FilePath, err.Error()) + return nil, err + } + return w, nil +} diff --git a/api/go/libruntime/common/logger/interfacelogger_test.go b/api/go/libruntime/common/logger/interfacelogger_test.go new file mode 100644 index 0000000..b56f558 --- /dev/null +++ b/api/go/libruntime/common/logger/interfacelogger_test.go @@ -0,0 +1,97 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger for log +package logger + +import ( + "os" + "strings" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +func TestNewInterfaceLogger(t *testing.T) { + convey.Convey( + "Test NewInterfaceLogger", t, func() { + convey.Convey( + "NewInterfaceLogger success", func() { + logger, err := NewInterfaceLogger("logPath", "fileName", InterfaceEncoderConfig{}) + convey.So(logger, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestWrite(t *testing.T) { + convey.Convey( + "Test Write", t, func() { + convey.Convey( + "Write success", func() { + logger, _ := NewInterfaceLogger("logPath", "fileName", InterfaceEncoderConfig{}) + convey.So(func() { + logger.Write("msg") + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestNewCore(t *testing.T) { + convey.Convey( + "Test newCore", t, func() { + convey.Convey( + "newCore success", func() { + core, err := newCore(config.CoreInfo{}, InterfaceEncoderConfig{}) + convey.So(core, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestCreateSink(t *testing.T) { + convey.Convey( + "Test CreateSink", t, func() { + convey.Convey( + "CreateSink success", func() { + fp := "test" + w, err := CreateSink(config.CoreInfo{FilePath: fp}) + convey.So(w, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + cleanFile(fp) + }, + ) + }, + ) +} + +func cleanFile(fileName string) { + files, _ := os.ReadDir("./") + for _, file := range files { + flag := strings.HasPrefix(file.Name(), fileName) + if flag { + os.Remove(file.Name()) + } + } +} diff --git a/api/go/libruntime/common/logger/log/log.go b/api/go/libruntime/common/logger/log/log.go new file mode 100644 index 0000000..e73bb61 --- /dev/null +++ b/api/go/libruntime/common/logger/log/log.go @@ -0,0 +1,243 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package log is common logger client +package log + +import ( + "fmt" + "path/filepath" + "strings" + "sync" + + "github.com/asaskevich/govalidator/v11" + uberZap "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/constants" + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" + "yuanrong.org/kernel/runtime/libruntime/common/logger/zap" +) + +const snuserLogPath = "/home/snuser/log" + +var ( + once sync.Once + onceRuntimeLog sync.Once + formatLogger api.FormatLogger + runtimeLogger api.FormatLogger + defaultLogger, _ = uberZap.NewProduction() +) + +// InitRunLog init run log with log.json file +func InitRunLog(fileName string, isAsync bool) (api.FormatLogger, error) { + var err error + onceRuntimeLog.Do( + func() { + coreInfo := config.GetDefaultCoreInfo() + formatLoggerImpl, newErr := newFormatLogger(fileName, isAsync, coreInfo) + if newErr != nil { + err = newErr + } + formatLogger = formatLoggerImpl + runtimeLogger = &loggerWrapper{real: formatLogger} + }, + ) + return formatLogger, err +} + +// InitRunLogWithConfig init run log with config +func InitRunLogWithConfig(fileName string, isAsync bool, coreInfo config.CoreInfo) (api.FormatLogger, error) { + if _, err := govalidator.ValidateStruct(coreInfo); err != nil { + return nil, err + } + return newFormatLogger(fileName, isAsync, coreInfo) +} + +// zapLoggerWithFormat define logger +type zapLoggerWithFormat struct { + Logger *uberZap.Logger + SLogger *uberZap.SugaredLogger +} + +// newFormatLogger new formatLogger with log config info +func newFormatLogger(fileName string, isAsync bool, coreInfo config.CoreInfo) (api.FormatLogger, error) { + if strings.Compare(constants.MonitorFileName, fileName) == 0 { + coreInfo.FilePath = snuserLogPath + } + coreInfo.FilePath = filepath.Join(coreInfo.FilePath, fileName+"-run.log") + logger, err := zap.NewWithLevel(coreInfo, isAsync) + if err != nil { + return nil, err + } + + return &zapLoggerWithFormat{ + Logger: logger, + SLogger: logger.Sugar(), + }, nil +} + +// NewConsoleLogger returns a console logger +func NewConsoleLogger() api.FormatLogger { + logger, err := zap.NewConsoleLog() + if err != nil { + fmt.Println("new console log error", err) + logger = defaultLogger + } + return &zapLoggerWithFormat{ + Logger: logger, + SLogger: logger.Sugar(), + } +} + +// GetLogger get logger directly +func GetLogger() api.FormatLogger { + if runtimeLogger == nil { + once.Do( + func() { + formatLogger = NewConsoleLogger() + runtimeLogger = &loggerWrapper{real: formatLogger} + }, + ) + } + return runtimeLogger +} + +// With add fields to log header +func (z *zapLoggerWithFormat) With(fields ...zapcore.Field) api.FormatLogger { + logger := z.Logger.With(fields...) + return &zapLoggerWithFormat{ + Logger: logger, + SLogger: logger.Sugar(), + } +} + +// Infof stdout format and params +func (z *zapLoggerWithFormat) Infof(format string, params ...interface{}) { + z.SLogger.Infof(format, params...) +} + +// Errorf stdout format and params +func (z *zapLoggerWithFormat) Errorf(format string, params ...interface{}) { + z.SLogger.Errorf(format, params...) +} + +// Warnf stdout format and params +func (z *zapLoggerWithFormat) Warnf(format string, params ...interface{}) { + z.SLogger.Warnf(format, params...) +} + +// Debugf stdout format and params +func (z *zapLoggerWithFormat) Debugf(format string, params ...interface{}) { + if config.LogLevel > zapcore.DebugLevel { + return + } + z.SLogger.Debugf(format, params...) +} + +// Fatalf stdout format and params +func (z *zapLoggerWithFormat) Fatalf(format string, params ...interface{}) { + z.SLogger.Fatalf(format, params...) +} + +// Info stdout format and params +func (z *zapLoggerWithFormat) Info(msg string, fields ...uberZap.Field) { + z.Logger.Info(msg, fields...) +} + +// Error stdout format and params +func (z *zapLoggerWithFormat) Error(msg string, fields ...uberZap.Field) { + z.Logger.Error(msg, fields...) +} + +// Warn stdout format and params +func (z *zapLoggerWithFormat) Warn(msg string, fields ...uberZap.Field) { + z.Logger.Warn(msg, fields...) +} + +// Debug stdout format and params +func (z *zapLoggerWithFormat) Debug(msg string, fields ...uberZap.Field) { + if config.LogLevel > zapcore.DebugLevel { + return + } + z.Logger.Debug(msg, fields...) +} + +// Fatal stdout format and params +func (z *zapLoggerWithFormat) Fatal(msg string, fields ...uberZap.Field) { + z.Logger.Fatal(msg, fields...) +} + +// Sync calls the underlying Core's Sync method, flushing any buffered log +// entries. Applications should take care to call Sync before exiting. +func (z *zapLoggerWithFormat) Sync() { + z.Logger.Sync() +} + +type loggerWrapper struct { + real api.FormatLogger +} + +func (l *loggerWrapper) With(fields ...zapcore.Field) api.FormatLogger { + return &loggerWrapper{ + real: l.real.With(fields...), + } +} + +func (l *loggerWrapper) Infof(format string, params ...interface{}) { + l.real.Infof(format, params...) +} + +func (l *loggerWrapper) Errorf(format string, params ...interface{}) { + l.real.Errorf(format, params...) +} + +func (l *loggerWrapper) Warnf(format string, params ...interface{}) { + l.real.Warnf(format, params...) +} + +func (l *loggerWrapper) Debugf(format string, params ...interface{}) { + l.real.Debugf(format, params...) +} + +func (l *loggerWrapper) Fatalf(format string, params ...interface{}) { + l.real.Fatalf(format, params...) +} + +func (l *loggerWrapper) Info(msg string, fields ...uberZap.Field) { + l.real.Info(msg, fields...) +} + +func (l *loggerWrapper) Error(msg string, fields ...uberZap.Field) { + l.real.Error(msg, fields...) +} + +func (l *loggerWrapper) Warn(msg string, fields ...uberZap.Field) { + l.real.Warn(msg, fields...) +} + +func (l *loggerWrapper) Debug(msg string, fields ...uberZap.Field) { + l.real.Debug(msg, fields...) +} + +func (l *loggerWrapper) Fatal(msg string, fields ...uberZap.Field) { + l.real.Fatal(msg, fields...) +} + +func (l *loggerWrapper) Sync() { + l.real.Sync() +} diff --git a/api/go/libruntime/common/logger/log/log_test.go b/api/go/libruntime/common/logger/log/log_test.go new file mode 100644 index 0000000..5f1ca68 --- /dev/null +++ b/api/go/libruntime/common/logger/log/log_test.go @@ -0,0 +1,241 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package log is common logger client +package log + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +func TestLog(t *testing.T) { + convey.Convey( + "Test Log", t, func() { + convey.Convey( + "Test console log", func() { + consoleLog := "this is a console log" + logWrapper := GetLogger() + logWrapper.Info(consoleLog) + }, + ) + convey.Convey( + "Test GetLogger", func() { + lg, err := InitRunLog("runtime-go", false) + convey.So(lg, convey.ShouldNotBeNil) + wrapperLogger := loggerWrapper{ + real: lg, + } + convey.So(err, convey.ShouldBeNil) + convey.Convey( + "Test log warn", func() { + wrapperLogger.Info("this is a warn log 1") + convey.SkipSo() + }, + ) + convey.Convey( + "Test log error", func() { + wrapperLogger.Info("this is a error log 1") + convey.SkipSo() + }, + ) + convey.Convey( + "Test log sync", func() { + wrapperLogger.Sync() + convey.SkipSo() + }, + ) + }, + ) + }, + ) +} + +func TestInitRunLogWithConfig(t *testing.T) { + convey.Convey( + "Test InitRunLogWithConfig", t, func() { + convey.Convey( + "InitRunLogWithConfig success", func() { + coreInfo := config.GetDefaultCoreInfo() + fl, err := InitRunLogWithConfig("monitor-disk", false, coreInfo) + convey.So(fl, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "InitRunLogWithConfig when ValidateStruct error", func() { + fl, err := InitRunLogWithConfig("monitor-disk", false, config.CoreInfo{}) + convey.So(fl, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestConsoleLogger(t *testing.T) { + convey.Convey( + "Test ConsoleLogger", t, func() { + fl := NewConsoleLogger() + field := zapcore.Field{ + Key: "key", + Type: zapcore.StringType, + String: "value", + } + convey.Convey( + "NewConsoleLogger success", func() { + convey.So(fl, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "With success", func() { + newFl := fl.With(field) + convey.So(newFl, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Infof success", func() { + convey.So(func() { + fl.Infof("msg:", "ok", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Errorf success", func() { + convey.So(func() { + fl.Errorf("msg:", "ok", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Warnf success", func() { + convey.So(func() { + fl.Warnf("msg:", "ok", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Debugf success", func() { + convey.So(func() { + fl.Debugf("msg:", "ok", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Error success", func() { + convey.So(func() { + fl.Error("msg:", field) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Warn success", func() { + convey.So(func() { + fl.Warn("msg:", field) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Debugf when config.LogLevel == zapcore.DebugLevel", func() { + convey.So(func() { + fl.Debugf("msg:", field) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Debug when config.LogLevel == zapcore.DebugLevel", func() { + convey.So(func() { + fl.Debug("msg:", field) + }, convey.ShouldNotPanic) + }, + ) + config.LogLevel = zapcore.DebugLevel + + }, + ) +} + +func TestLoggerWrapper(t *testing.T) { + convey.Convey( + "Test loggerWrapper", t, func() { + lw := GetLogger() + field := zapcore.Field{ + Key: "key", + Type: zapcore.StringType, + String: "value", + } + convey.Convey( + "With success", func() { + fl := lw.With(field) + convey.So(fl, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Infof success", func() { + convey.So(func() { + lw.Infof("msg:", "ok", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Errorf success", func() { + convey.So(func() { + lw.Errorf("msg:", "ok", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Warnf success", func() { + convey.So(func() { + lw.Warnf("msg:", "ok", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Debugf success", func() { + convey.So(func() { + lw.Debugf("msg:", "ok", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Error success", func() { + convey.So(func() { + lw.Error("msg:", field) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Warn success", func() { + convey.So(func() { + lw.Warn("msg:", field) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Debug success", func() { + convey.So(func() { + lw.Debug("msg:", field) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/logger/rollinglog.go b/api/go/libruntime/common/logger/rollinglog.go new file mode 100644 index 0000000..fae58ba --- /dev/null +++ b/api/go/libruntime/common/logger/rollinglog.go @@ -0,0 +1,269 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger rollingLog +package logger + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "sync" + "time" + + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +const ( + megabyte = 1024 * 1024 + defaultFileSize = 100 + defaultBackups = 20 +) + +var logNameCache = struct { + m map[string]string + sync.Mutex +}{ + m: make(map[string]string, 1), + Mutex: sync.Mutex{}, +} + +type rollingLog struct { + file *os.File + reg *regexp.Regexp + mu sync.RWMutex + sinks []string + dir string + nameTemplate string + maxSize int64 + size int64 + maxBackups int + flag int + perm os.FileMode + isUserLog bool +} + +func initRollingLog(coreInfo config.CoreInfo, flag int, perm os.FileMode) (*rollingLog, error) { + if coreInfo.FilePath == "" { + return nil, errors.New("empty log file path") + } + log := &rollingLog{ + dir: filepath.Dir(coreInfo.FilePath), + nameTemplate: filepath.Base(coreInfo.FilePath), + flag: flag, + perm: perm, + maxSize: coreInfo.SingleSize * megabyte, + maxBackups: coreInfo.Threshold, + isUserLog: coreInfo.IsUserLog, + } + if log.maxBackups < 1 { + log.maxBackups = defaultBackups + } + if log.maxSize < megabyte { + log.maxSize = defaultFileSize * megabyte + } + if log.isUserLog { + return log, log.tidySinks() + } + extension := filepath.Ext(log.nameTemplate) + regExp := fmt.Sprintf(`^%s(?:(?:-|\.)\d*)?\%s$`, + log.nameTemplate[:len(log.nameTemplate)-len(extension)], extension) + reg, err := regexp.Compile(regExp) + if err != nil { + return nil, err + } + log.reg = reg + return log, log.tidySinks() +} + +func (r *rollingLog) tidySinks() error { + if r.isUserLog || r.file != nil { + return r.newSink() + } + // scan and reuse past log file when service restarted + r.scanLogFiles() + if len(r.sinks) > 0 { + fullName := r.sinks[len(r.sinks)-1] + info, err := os.Stat(fullName) + if err != nil || info.Size() >= r.maxSize { + return r.newSink() + } + file, err := os.OpenFile(fullName, r.flag, r.perm) + if err == nil { + r.file = file + r.size = info.Size() + return nil + } + } + return r.newSink() +} + +func (r *rollingLog) scanLogFiles() { + dirEntrys, err := os.ReadDir(r.dir) + if err != nil { + fmt.Printf("failed to read dir: %s\n", r.dir) + return + } + infos := make([]os.FileInfo, 0, r.maxBackups) + for _, entry := range dirEntrys { + if r.reg.MatchString(entry.Name()) { + info, err := entry.Info() + if err == nil { + infos = append(infos, info) + } + } + } + if len(infos) > 0 { + sort.Slice(infos, func(i, j int) bool { + return infos[i].ModTime().Before(infos[j].ModTime()) + }) + for i := range infos { + r.sinks = append(r.sinks, filepath.Join(r.dir, infos[i].Name())) + } + r.cleanRedundantSinks() + } +} + +func (r *rollingLog) cleanRedundantSinks() { + if len(r.sinks) < r.maxBackups { + return + } + curSinks := make([]string, 0, len(r.sinks)) + for _, name := range r.sinks { + if isAvailable(name) { + curSinks = append(curSinks, name) + } + + } + r.sinks = curSinks + sinkNum := len(r.sinks) + if sinkNum > r.maxBackups { + removes := r.sinks[:sinkNum-r.maxBackups] + go removeFiles(removes) + r.sinks = r.sinks[sinkNum-r.maxBackups:] + } + return +} + +func removeFiles(paths []string) { + for _, path := range paths { + err := os.Remove(path) + if err != nil && !os.IsNotExist(err) { + fmt.Printf("failed remove file %s\n", path) + } + } +} + +func isAvailable(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func (r *rollingLog) newSink() error { + fullName := filepath.Join(r.dir, r.newName()) + if isAvailable(fullName) && r.file != nil && r.file.Name() == filepath.Base(fullName) { + return errors.New("log file already opened: " + fullName) + } + file, err := os.OpenFile(fullName, r.flag, r.perm) + if err != nil { + return err + } + if r.file != nil { + err = r.file.Close() + } + if err != nil { + fmt.Printf("failed to close file: %s\n", err.Error()) + } + r.file = file + info, err := file.Stat() + if err != nil { + r.size = 0 + } else { + r.size = info.Size() + } + r.sinks = append(r.sinks, fullName) + r.cleanRedundantSinks() + if r.isUserLog { + logNameCache.Lock() + logNameCache.m[r.nameTemplate] = fullName + logNameCache.Unlock() + } + return nil +} + +func (r *rollingLog) newName() string { + if !r.isUserLog { + timeNow := time.Now().Format("2006010215040506") + ext := filepath.Ext(r.nameTemplate) + return fmt.Sprintf("%s.%s%s", r.nameTemplate[:len(r.nameTemplate)-len(ext)], timeNow, ext) + } + if r.file == nil { + return r.nameTemplate + } + timeNow := time.Now().Format("2006010215040506") + var prefix, suffix string + if index := strings.LastIndex(r.nameTemplate, "@") + 1; index <= len(r.nameTemplate) { + prefix = r.nameTemplate[:index] + } + if index := strings.Index(r.nameTemplate, "#"); index >= 0 { + suffix = r.nameTemplate[index:] + } + if prefix == "" || suffix == "" { + return "" + } + return fmt.Sprintf("%s%s%s", prefix, timeNow, suffix) +} + +// Write data to file and check whether to rotate log +func (r *rollingLog) Write(data []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r == nil || r.file == nil { + return 0, errors.New("log file is nil") + } + n, err := r.file.Write(data) + r.size += int64(n) + if r.size > r.maxSize { + r.tryRotate() + } + if syncErr := r.file.Sync(); syncErr != nil { + fmt.Printf("failed to sync log err: %s\n", syncErr.Error()) + } + return n, err +} + +func (r *rollingLog) tryRotate() { + if info, err := r.file.Stat(); err == nil && info.Size() < r.maxSize { + return + } + err := r.tidySinks() + if err != nil { + fmt.Printf("failed to rotate log err: %s\n", err.Error()) + } + return +} + +// GetLogName get current log name when refreshing user log mod time +func GetLogName(nameTemplate string) string { + logNameCache.Lock() + name := logNameCache.m[nameTemplate] + logNameCache.Unlock() + return name +} diff --git a/api/go/libruntime/common/logger/rollinglog_test.go b/api/go/libruntime/common/logger/rollinglog_test.go new file mode 100644 index 0000000..71c308f --- /dev/null +++ b/api/go/libruntime/common/logger/rollinglog_test.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger for log +package logger + +import ( + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +func TestInitRollingLog(t *testing.T) { + convey.Convey( + "Test initRollingLog", t, func() { + convey.Convey( + "initRollingLog success", func() { + coreInfo := config.CoreInfo{FilePath: "test", IsUserLog: true} + r, err := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + convey.So(r, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + cleanFile("test") + }, + ) + }, + ) +} + +func TestRollingLogHandler(t *testing.T) { + convey.Convey( + "Test RollingLogHandler", t, func() { + fp := "test" + coreInfo := config.CoreInfo{FilePath: fp, IsUserLog: true} + r, _ := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + defer cleanFile(fp) + convey.Convey( + "tidySinks success", func() { + err := r.tidySinks() + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "cleanRedundantSinks success", func() { + r.maxBackups = 0 + convey.So(func() { + r.cleanRedundantSinks() + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "newName success", func() { + str := r.newName() + convey.So(str, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestNewName(t *testing.T) { + convey.Convey( + "Test newName", t, func() { + fp := "test@#" + coreInfo := config.CoreInfo{FilePath: fp, IsUserLog: true} + r, _ := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + defer cleanFile(fp) + convey.Convey( + "newName success", func() { + str := r.newName() + convey.So(str, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestRollingLogWrite(t *testing.T) { + convey.Convey( + "Test RollingLogWrite", t, func() { + + convey.Convey( + "Write success when file==nil", func() { + r := &rollingLog{} + i, err := r.Write([]byte{}) + convey.So(i, convey.ShouldEqual, 0) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Write success", func() { + fp := "test" + coreInfo := config.CoreInfo{FilePath: fp, IsUserLog: true} + r, _ := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + r.maxSize = 0 + defer cleanFile(fp) + + i, err := r.Write([]byte{0}) + convey.So(i, convey.ShouldNotEqual, 0) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestTryRotate(t *testing.T) { + convey.Convey( + "Test tryRotate", t, func() { + fp := "test" + coreInfo := config.CoreInfo{FilePath: fp, IsUserLog: true} + r, _ := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + r.maxSize = 0 + cleanFile(fp) + convey.Convey( + "newName success", func() { + convey.So(func() { + r.tryRotate() + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestGetLogName(t *testing.T) { + convey.Convey( + "Test GetLogName", t, func() { + convey.Convey( + "GetLogName success", func() { + str := GetLogName("templateTest") + convey.So(str, convey.ShouldEqual, "") + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/logger/zap/zaplog.go b/api/go/libruntime/common/logger/zap/zaplog.go new file mode 100644 index 0000000..6f06f10 --- /dev/null +++ b/api/go/libruntime/common/logger/zap/zaplog.go @@ -0,0 +1,162 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package zap zapper log +package zap + +import ( + "fmt" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/logger" + "yuanrong.org/kernel/runtime/libruntime/common/logger/async" + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +const ( + skipLevel = 2 +) + +func init() { + zap.RegisterEncoder("custom_console", logger.NewConsoleEncoder) +} + +// NewDevelopmentLog returns a development logger based on uber zap and it output entry to stdout and stderr +func NewDevelopmentLog() (*zap.Logger, error) { + cfg := zap.NewDevelopmentConfig() + cfg.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + return cfg.Build() +} + +// NewConsoleLog returns a console logger based on uber zap +func NewConsoleLog() (*zap.Logger, error) { + outputPaths := []string{"stdout"} + cfg := zap.Config{ + Level: zap.NewAtomicLevelAt(zap.InfoLevel), + Development: false, + DisableCaller: false, + DisableStacktrace: true, + Encoding: "custom_console", + OutputPaths: outputPaths, + ErrorOutputPaths: outputPaths, + EncoderConfig: zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "Logger", + MessageKey: "M", + CallerKey: "C", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }, + } + consoleLogger, err := cfg.Build() + if err != nil { + return nil, err + } + return consoleLogger.WithOptions(zap.AddCaller(), zap.AddCallerSkip(skipLevel)), nil +} + +// NewWithLevel returns a log based on zap with Level +func NewWithLevel(coreInfo config.CoreInfo, isAsync bool) (*zap.Logger, error) { + core, err := newCore(coreInfo, isAsync) + if err != nil { + return nil, err + } + + return zap.New(core, zap.AddCaller(), zap.AddCallerSkip(skipLevel)), nil +} + +func newCore(coreInfo config.CoreInfo, isAsync bool) (zapcore.Core, error) { + w, err := logger.CreateSink(coreInfo) + if err != nil { + return nil, err + } + + var syncer zapcore.WriteSyncer + if isAsync { + syncer = async.NewAsyncWriteSyncer(w) + } else { + syncer = zapcore.AddSync(w) + } + + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "Logger", + MessageKey: "M", + CallerKey: "C", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + fileEncoder := logger.NewCustomEncoder(&encoderConfig) + + if err := config.LogLevel.UnmarshalText([]byte(coreInfo.Level)); err != nil { + config.LogLevel = zapcore.InfoLevel + } + + priority := zap.LevelEnablerFunc(priorityHandler) + + if coreInfo.Tick == 0 || coreInfo.First == 0 || coreInfo.Thereafter == 0 { + return zapcore.NewCore(fileEncoder, syncer, priority), nil + } + return zapcore.NewSamplerWithOptions(zapcore.NewCore(fileEncoder, syncer, priority), + time.Duration(coreInfo.Tick)*time.Second, coreInfo.First, coreInfo.Thereafter), nil +} + +func priorityHandler(lvl zapcore.Level) bool { + return lvl >= config.LogLevel +} + +// LoggerWithFormat zap logger +type LoggerWithFormat struct { + *zap.Logger +} + +// Infof stdout format and paras +func (z *LoggerWithFormat) Infof(format string, paras ...interface{}) { + z.Logger.Info(fmt.Sprintf(format, paras...)) +} + +// Errorf stdout format and paras +func (z *LoggerWithFormat) Errorf(format string, paras ...interface{}) { + z.Logger.Error(fmt.Sprintf(format, paras...)) +} + +// Warnf stdout format and paras +func (z *LoggerWithFormat) Warnf(format string, paras ...interface{}) { + z.Logger.Warn(fmt.Sprintf(format, paras...)) +} + +// Debugf stdout format and paras +func (z *LoggerWithFormat) Debugf(format string, paras ...interface{}) { + if config.LogLevel > zapcore.DebugLevel { + return + } + z.Logger.Debug(fmt.Sprintf(format, paras...)) +} + +// Fatalf stdout format and paras +func (z *LoggerWithFormat) Fatalf(format string, paras ...interface{}) { + z.Logger.Fatal(fmt.Sprintf(format, paras...)) +} diff --git a/api/go/libruntime/common/logger/zap/zaplog_test.go b/api/go/libruntime/common/logger/zap/zaplog_test.go new file mode 100644 index 0000000..bd5da75 --- /dev/null +++ b/api/go/libruntime/common/logger/zap/zaplog_test.go @@ -0,0 +1,159 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package zap zapper log +package zap + +import ( + "os" + "strings" + "testing" + + "github.com/smartystreets/goconvey/convey" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/common/logger/config" +) + +func TestNewDevelopmentLog(t *testing.T) { + convey.Convey( + "Test NewDevelopmentLog", t, func() { + convey.Convey( + "NewDevelopmentLog success", func() { + logger, err := NewDevelopmentLog() + convey.So(logger, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestNewConsoleLog(t *testing.T) { + convey.Convey( + "Test ConsoleLog", t, func() { + logger, err := NewConsoleLog() + z := &LoggerWithFormat{logger} + convey.Convey( + "NewConsoleLog success", func() { + convey.So(logger.Name(), convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Infof success", func() { + convey.So(func() { + z.Infof("msg:%s", "ok") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Errorf success", func() { + convey.So(func() { + z.Errorf("msg:%s", "err") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Warnf success", func() { + convey.So(func() { + z.Warnf("msg:%s", "warning") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Debugf success", func() { + convey.So(func() { + z.Debugf("msg:%s", "ok") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Debugf success when config.LogLevel==zapcore.DebugLevel", func() { + convey.So(func() { + config.LogLevel = zapcore.DebugLevel + z.Debugf("msg:%s", "ok") + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestNewWithLevel(t *testing.T) { + convey.Convey( + "Test NewWithLevel", t, func() { + convey.Convey( + "NewWithLevel success", func() { + logger, err := NewWithLevel(config.CoreInfo{FilePath: "./"}, false) + convey.So(logger, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "NewWithLevel success when isAsync==true", func() { + logger, err := NewWithLevel(config.CoreInfo{FilePath: "./"}, true) + convey.So(logger, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "NewWithLevel success when CoreInfo is empty", func() { + logger, err := NewWithLevel(config.CoreInfo{}, false) + convey.So(logger, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "NewWithLevel success when CoreInfo is empty1", func() { + coreInfo := config.CoreInfo{ + FilePath: "./", + Tick: 1, + First: 1, + Thereafter: 1, + } + logger, err := NewWithLevel(coreInfo, false) + convey.So(logger, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + cleanFile(".") + }, + ) +} + +func cleanFile(fileName string) { + files, _ := os.ReadDir("./") + for _, file := range files { + flag := strings.HasPrefix(file.Name(), fileName) + if flag { + os.Remove(file.Name()) + } + } +} + +func TestPriorityHandler(t *testing.T) { + convey.Convey( + "Test priorityHandler", t, func() { + convey.Convey( + "priorityHandler success", func() { + flag := priorityHandler(zapcore.InfoLevel) + convey.So(flag, convey.ShouldBeTrue) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/signal.go b/api/go/libruntime/common/signal.go new file mode 100644 index 0000000..6b0a345 --- /dev/null +++ b/api/go/libruntime/common/signal.go @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common for tools +package common + +import ( + "os" + "os/signal" + "syscall" + + "yuanrong.org/kernel/runtime/libruntime/common/logger/log" +) + +var ( + shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM, syscall.SIGKILL} + shutdownHandler chan os.Signal + stopCh = make(chan struct{}) +) + +const channelCount = 2 + +func init() { + // 2 is the length of shutdown Handler channel + shutdownHandler = make(chan os.Signal, channelCount) + + signal.Notify(shutdownHandler, shutdownSignals...) + + go func() { + <-shutdownHandler + close(stopCh) + <-shutdownHandler + log.GetLogger().Sync() + os.Exit(1) + }() +} + +// WaitForSignal defines signal handler process. +func WaitForSignal() <-chan struct{} { + return stopCh +} diff --git a/api/go/libruntime/common/token_mgr.go b/api/go/libruntime/common/token_mgr.go new file mode 100644 index 0000000..d41bbe0 --- /dev/null +++ b/api/go/libruntime/common/token_mgr.go @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common for token +package common + +import ( + "sync" + "time" + + "yuanrong.org/kernel/runtime/libruntime/common/logger/log" +) + +const ( + updateTokenInterval = 60 * time.Second +) + +// TokenMgr - +type TokenMgr struct { + token string + salt string + lock sync.RWMutex + + callback func(ctx *AuthContext) +} + +var tokenMgr = &TokenMgr{ + lock: sync.RWMutex{}, +} + +// SetCallback - +func (t *TokenMgr) SetCallback(f func(ctx *AuthContext)) { + t.lock.Lock() + t.callback = f + + if t.token != "" { + f(&AuthContext{ + Token: t.token, + Salt: t.salt, + }) + } + t.lock.Unlock() + log.GetLogger().Infof("set callback ok") +} + +// GetTokenMgr - +func GetTokenMgr() *TokenMgr { + return tokenMgr +} + +// GetToken - +func (t *TokenMgr) GetToken() string { + t.lock.RLock() + defer t.lock.RUnlock() + return t.token +} + +// UpdateToken - +func (t *TokenMgr) UpdateToken(auth *AuthContext) { + t.lock.Lock() + defer t.lock.Unlock() + if auth == nil || auth.Token == t.token || auth.Token == "" { + return + } + log.GetLogger().Infof("recv token") + t.token = auth.Token + t.salt = auth.Salt + if t.callback != nil { + t.callback(auth) + } +} + +// UpdateTokenLoop - +func (t *TokenMgr) UpdateTokenLoop(stopCh <-chan struct{}) { + ticker := time.NewTicker(updateTokenInterval) + defer ticker.Stop() + for { + select { + case _, ok := <-stopCh: + if !ok { + log.GetLogger().Errorf("updateToken is stopping") + return + } + case <-ticker.C: + } + } +} diff --git a/api/go/libruntime/common/token_mgr_test.go b/api/go/libruntime/common/token_mgr_test.go new file mode 100644 index 0000000..4005058 --- /dev/null +++ b/api/go/libruntime/common/token_mgr_test.go @@ -0,0 +1,79 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common for token +package common + +import ( + "testing" + "time" + + "github.com/smartystreets/goconvey/convey" +) + +func TestTokenMgr(t *testing.T) { + convey.Convey("Test token Mgr: GetToken", t, func() { + ch := make(chan struct{}) + t := GetTokenMgr() + t.callback = func(ctx *AuthContext) {} + t.UpdateToken(&AuthContext{ + ServerAuthEnable: false, + RootCertData: nil, + ModuleCertData: nil, + ModuleKeyData: nil, + Token: "fakeToken", + Salt: "134134134314134", + EnableServerMode: false, + ServerNameOverride: "", + }) + tk := GetTokenMgr().GetToken() + convey.So(tk, convey.ShouldEqual, "fakeToken") + GetTokenMgr().UpdateToken(nil) + close(ch) + }) +} + +func TestSetCallback(t *testing.T) { + convey.Convey( + "Test SetCallback", t, func() { + convey.Convey( + "SetCallback success", func() { + f := func(ctx *AuthContext) {} + convey.So(func() { + GetTokenMgr().SetCallback(f) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestUpdateTokenLoop(t *testing.T) { + convey.Convey( + "Test UpdateTokenLoop", t, func() { + convey.Convey( + "UpdateTokenLoop success", func() { + ch := WaitForSignal() + convey.So(func() { + go GetTokenMgr().UpdateTokenLoop(ch) + time.Sleep(300 * time.Millisecond) + close(stopCh) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/common/types.go b/api/go/libruntime/common/types.go new file mode 100644 index 0000000..ab98849 --- /dev/null +++ b/api/go/libruntime/common/types.go @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common for types +package common + +// AuthContext - +type AuthContext struct { + ServerAuthEnable bool `json:"serverAuthEnable" valid:"optional"` + RootCertData []byte `json:"rootCertData" valid:"optional"` + ModuleCertData []byte `json:"moduleCertData" valid:"optional"` + ModuleKeyData []byte `json:"moduleKeyData" valid:"optional"` + Token string `json:"token" valid:"optional"` + Salt string `json:"salt" valid:"optional"` + EnableServerMode bool `json:"enableServerMode" valid:"optional"` + ServerNameOverride string `json:"serverNameOverride" valid:"optional"` +} + +// TLSConfOptions defines the struct of TLS config options, which is used to initialize grpc transport credentials. +type TLSConfOptions struct { + TLSEnable bool + RootCAData []byte + CertData []byte + KeyData []byte +} diff --git a/api/go/libruntime/common/utils/utils.go b/api/go/libruntime/common/utils/utils.go new file mode 100644 index 0000000..ae2a56c --- /dev/null +++ b/api/go/libruntime/common/utils/utils.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils for common functions +package utils + +import ( + "errors" + "path/filepath" + "strings" +) + +// ValidateFilePath verify the legitimacy of the file path +func ValidateFilePath(path string) error { + absPath, err := filepath.Abs(path) + if err != nil || !strings.HasPrefix(path, absPath) { + return errors.New("invalid file path, expect to be configured as an absolute path") + } + return nil +} + +// GetFuncNameFromFuncLibPath parses function name out of FUNCTION_LIB_PATH +func GetFuncNameFromFuncLibPath(funcLibPath string) string { + if funcLibPath == "" { + return "" + } + funcLibPathSplits := strings.Split(funcLibPath, "/") + return funcLibPathSplits[len(funcLibPathSplits)-1] +} diff --git a/api/go/libruntime/common/utils/utils_test.go b/api/go/libruntime/common/utils/utils_test.go new file mode 100644 index 0000000..087271f --- /dev/null +++ b/api/go/libruntime/common/utils/utils_test.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils for common functions +package utils + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestValidateFilePath1(t *testing.T) { + convey.Convey("TestValidateFilePath", t, func() { + convey.Convey("ValidateFilePath success", func() { + err := ValidateFilePath("./") + convey.So(err.Error(), convey.ShouldContainSubstring, "invalid file path") + }) + convey.Convey("ValidateFilePath success when path==/home/test", func() { + err := ValidateFilePath("/home/test") + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestGetFuncNameFromFuncLibPath(t *testing.T) { + convey.Convey("TestGetFuncNameFromFuncLibPath", t, func() { + convey.Convey("GetFuncNameFromFuncLibPath success", func() { + str := GetFuncNameFromFuncLibPath("/home/test/function_lib") + convey.So(str, convey.ShouldEqual, "function_lib") + }) + convey.Convey("ValidateFilePath success when funcLibPath==\"\"", func() { + str := GetFuncNameFromFuncLibPath("") + convey.So(str, convey.ShouldBeEmpty) + }) + }) +} diff --git a/api/go/libruntime/common/uuid/uuid.go b/api/go/libruntime/common/uuid/uuid.go new file mode 100644 index 0000000..a49685a --- /dev/null +++ b/api/go/libruntime/common/uuid/uuid.go @@ -0,0 +1,93 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package uuid for common functions +package uuid + +import ( + "crypto/rand" + "encoding/hex" + "io" +) + +const ( + defaultByteNum = 16 + indexFour = 4 + indexSix = 6 + indexEight = 8 + indexNine = 9 + indexTen = 10 + indexThirteen = 13 + indexFourteen = 14 + indexEighteen = 18 + indexNineteen = 19 + indexTwentyThree = 23 + indexTwentyFour = 24 + indexThirtySix = 36 +) + +// RandomUUID - +type RandomUUID [defaultByteNum]byte + +var ( + rander = rand.Reader // random function +) + +// New - +func New() RandomUUID { + return mustUUID(newRandom()) +} + +func mustUUID(uuid RandomUUID, err error) RandomUUID { + if err != nil { + return RandomUUID{} + } + return uuid +} + +func newRandom() (RandomUUID, error) { + return newRandomFromReader(rander) +} + +func newRandomFromReader(r io.Reader) (RandomUUID, error) { + var randomUUID RandomUUID + _, err := io.ReadFull(r, randomUUID[:]) + if err != nil { + return RandomUUID{}, err + } + randomUUID[indexSix] = (randomUUID[indexSix] & 0x0f) | 0x40 // Version 4 + randomUUID[indexEight] = (randomUUID[indexEight] & 0x3f) | 0x80 // Variant is 10 + return randomUUID, nil +} + +// String- +func (uuid RandomUUID) String() string { + var buf [indexThirtySix]byte + encodeHex(buf[:], uuid) + return string(buf[:]) +} + +func encodeHex(dstBuf []byte, uuid RandomUUID) { + hex.Encode(dstBuf, uuid[:indexFour]) + dstBuf[indexEight] = '-' + hex.Encode(dstBuf[indexNine:indexThirteen], uuid[indexFour:indexSix]) + dstBuf[indexThirteen] = '-' + hex.Encode(dstBuf[indexFourteen:indexEighteen], uuid[indexSix:indexEight]) + dstBuf[indexEighteen] = '-' + hex.Encode(dstBuf[indexNineteen:indexTwentyThree], uuid[indexEight:indexTen]) + dstBuf[indexTwentyThree] = '-' + hex.Encode(dstBuf[indexTwentyFour:], uuid[indexTen:]) +} diff --git a/api/go/libruntime/common/uuid/uuid_test.go b/api/go/libruntime/common/uuid/uuid_test.go new file mode 100644 index 0000000..81c9d14 --- /dev/null +++ b/api/go/libruntime/common/uuid/uuid_test.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package uuid for common functions +package uuid + +import ( + "errors" + "reflect" + "strings" + "testing" +) + +func TestNew(t *testing.T) { + uid1 := New() + uid2 := New() + if reflect.DeepEqual(uid1.String(), uid2.String()) { + t.Errorf("uid1 == uid2 failed") + } +} + +func TestMustUUID(t *testing.T) { + uuid := mustUUID(RandomUUID{}, errors.New("uuid is empty")) + newUUID := RandomUUID{} + if uuid != newUUID { + t.Errorf("mustUUID when err != nil failed") + } +} + +func TestNewRandomFromReader(t *testing.T) { + r := strings.NewReader("") + uuid, err := newRandomFromReader(r) + newUUID := RandomUUID{} + if uuid != newUUID || err == nil { + t.Errorf("newRandomFromReader when r == \"\" failed") + } +} diff --git a/api/go/libruntime/config/config.go b/api/go/libruntime/config/config.go new file mode 100644 index 0000000..c67d7e0 --- /dev/null +++ b/api/go/libruntime/config/config.go @@ -0,0 +1,159 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config for init +package config + +import ( + "unsafe" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// InvokeType invoke type +type InvokeType int32 + +const ( + // CreateInstance create actor + CreateInstance InvokeType = 0 + // InvokeInstance invoke actor + InvokeInstance InvokeType = 1 + // CreateInstanceStateless create task + CreateInstanceStateless InvokeType = 2 + // InvokeInstanceStateless invoke task + InvokeInstanceStateless InvokeType = 3 +) + +const ( + // DefaultSignal default signal + DefaultSignal int = 0 + // KillInstance kill single signal + KillInstance int = 1 + // KillAllInstances kill all signal + KillAllInstances int = 2 + // Exit exit signal + Exit int = 64 + // Cancel cancel signal + Cancel int = 65 +) + +// DataBuffer save the data pointer +type DataBuffer struct { + Ptr unsafe.Pointer + Size int + SharedPtrBuffer unsafe.Pointer +} + +// DataObject data obj +type DataObject struct { + ID string + Buffer DataBuffer + NestedObjectIds []string + CSharedPtr unsafe.Pointer +} + +// LoadFunctionCallback load function callback +type LoadFunctionCallback func([]string) error + +// FunctionExecuteCallback function execute callback +type FunctionExecuteCallback func(api.FunctionMeta, InvokeType, []api.Arg, []DataObject) error + +// CheckpointCallback checkpoint callback +type CheckpointCallback func(string) ([]byte, error) + +// RecoverCallback recover callback +type RecoverCallback func([]byte) error + +// ShutdownCallback shutdown callback +type ShutdownCallback func(uint64) error + +// SignalCallback signal callback +type SignalCallback func(int, []byte) error + +// HealthCheckCallback health check callback +type HealthCheckCallback func() (api.HealthType, error) + +// HookIntfs hook struct +type HookIntfs struct { + LoadFunctionCb LoadFunctionCallback + FunctionExecutionCb FunctionExecuteCallback + CheckpointCb CheckpointCallback + RecoverCb RecoverCallback + ShutdownCb ShutdownCallback + SignalCb SignalCallback + HealthCheckCb HealthCheckCallback +} + +// Pool execution pool abstract interface +type Pool interface { + Submit(task func()) error +} + +// Config init config +type Config struct { + FunctionSystemAddress string + FunctionSystemRtServerIPAddr string + FunctionSystemRtServerPort int + GrpcAddress string + DataSystemAddress string + DataSystemIPAddr string + DataSystemPort int + IsDriver bool + JobID string + RuntimeID string + InstanceID string + FunctionName string + LogLevel string + LogDir string + LogFileSizeMax uint32 + LogFileNumMax uint32 + LogFlushInterval int + Hooks HookIntfs + FunctionExectionPool Pool + RecycleTime int + MaxTaskInstanceNum int + MaxConcurrencyCreateNum int + EnableSigaction bool + EnableMetrics bool + ThreadPoolSize uint32 + LoadPaths []string + EnableMTLS bool + PrivateKeyPath string + CertificateFilePath string + VerifyFilePath string + PrivateKeyPaaswd string + HttpIocThreadsNum uint32 + ServerName string + InCluster bool + Namespace string + Api api.ApiType + FunctionId string + SystemAuthAccessKey string + SystemAuthSecretKey string + // EnableCallStack whether to enable distribute call stack + EnableCallStack bool + CallStackLayerNum int + EncryptPrivateKeyPasswd string + PrimaryKeyStoreFile string + StandbyKeyStoreFile string + EnableDsEncrypt bool + RuntimePublicKeyContext string + RuntimePrivateKeyContext string + DsPublicKeyContext string + EncryptRuntimePublicKeyContext string + EncryptRuntimePrivateKeyContext string + EncryptDsPublicKeyContext string +} diff --git a/api/go/libruntime/cpplibruntime/BUILD.bazel b/api/go/libruntime/cpplibruntime/BUILD.bazel new file mode 100644 index 0000000..9fcfff2 --- /dev/null +++ b/api/go/libruntime/cpplibruntime/BUILD.bazel @@ -0,0 +1,53 @@ +load("//bazel:yr.bzl", "COPTS", "LOPTS") + +cc_library( + name = "cpplibruntime_lib", + deps = [ + "//:runtime_lib", + "@securec//:securec", + ], + hdrs = [ + "clibruntime.h", + ], + srcs = [ + "cpplibruntime.cpp", + ], + alwayslink = True, + visibility = ["//visibility:public"], +) + +cc_binary( + name = "libcpplibruntime.so", + dynamic_deps = ["//:grpc_dynamic"], + deps = select( + { + "test_mode": [":mockcpplibruntime_lib"], + "//conditions:default": [":cpplibruntime_lib"], + } + ), + copts = COPTS, + linkopts = LOPTS, + linkshared = True, + linkstatic = True, + visibility = ["//visibility:public"], +) + +config_setting( + name = "test_mode", + values = { "define": "mode=test" } +) + +cc_library( + name = "mockcpplibruntime_lib", + deps = [ + "//:runtime_lib", + ], + hdrs = [ + "clibruntime.h", + ], + srcs = [ + "mock/mock_cpplibruntime.cpp", + ], + alwayslink = True, + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/api/go/libruntime/cpplibruntime/clibruntime.h b/api/go/libruntime/cpplibruntime/clibruntime.h new file mode 100644 index 0000000..aa758ec --- /dev/null +++ b/api/go/libruntime/cpplibruntime/clibruntime.h @@ -0,0 +1,493 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "stdint.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef struct tagCInvokeLabels { + char *key; + char *value; +} CInvokeLabels; + +typedef struct tagCCustomExtension { + char *key; + char *value; +} CCustomExtension; + +typedef struct tagStackTrace { + char *className; + char *methodName; + char *fileName; + int64_t lineNumber; + CCustomExtension *extensions; + int size_extensions; +} CStackTrace; + +typedef struct tagStackTracesInfo { + int code; + int mcode; + char *message; + CStackTrace *stackTraces; + int size_stackTraces; +} CStackTracesInfo; + +typedef struct tagCErrorInfo { + int code; + char *message; + CStackTracesInfo *stackTracesInfo; + int size_stackTracesInfo; + int dsStatusCode; +} CErrorInfo; + +typedef enum tagCArgType { + VALUE = 0, + OBJECT_REF, +} CArgType; + +typedef enum tagCApiType { + ACTOR = 0, + FAAS = 1, + POSIX = 2, +} CApiType; + +typedef struct tagCArg { + CArgType type; + char *data; + uint64_t size; +} CArg; + +typedef struct tagCFunctionMeta { + char *appName; + char *moduleName; + char *funcName; + char *className; + int languageType; + char *codeId; + char *signature; + char *poolLabel; + CApiType apiType; + char *functionId; + char hasName; + char *name; + char hasNs; + char *ns; +} CFunctionMeta; + +typedef struct tagCInstanceAllocation { + char *functionId; + char *funcSig; + char *instanceId; + char *leaseId; + int tLeaseInterval; +} CInstanceAllocation; + +typedef struct tagCInstanceSession { + char hasSessionId; + char *sessionId; + int sessionTtl; + int concurrency; +} CInstanceSession; + +typedef int CInvokeType; + +typedef struct tagCBuffer { + void *buffer; + int64_t size_buffer; + void *selfSharedPtrBuffer; +} CBuffer; + +typedef struct tagCDataObject { + char *id; + CBuffer buffer; + char **nestedObjIds; + int size_nestedObjIds; + void *selfSharedPtr; +} CDataObject; + +typedef struct tagCLibruntimeConfig { + char *functionSystemAddress; + char *dataSystemAddress; + char *grpcAddress; + char *jobId; + char *runtimeId; + char *instanceId; + char *functionName; + char *logLevel; + char *logDir; + CApiType apiType; + char inCluster; + char isDriver; + char enableMTLS; + char *privateKeyPath; + char *certificateFilePath; + char *verifyFilePath; + char *privateKeyPaaswd; + char *functionId; + char *systemAuthAccessKey; + char *systemAuthSecretKey; + int systemAuthSecretKeySize; + char *encryptPrivateKeyPasswd; + char *primaryKeyStoreFile; + char *standbyKeyStoreFile; + char enableDsEncrypt; + char *runtimePublicKeyContext; + char *runtimePrivateKeyContext; + char *dsPublicKeyContext; + char *encryptRuntimePublicKeyContext; + char *encryptRuntimePrivateKeyContext; + char *encryptDsPublicKeyContext; + int maxConcurrencyCreateNum; + char enableSigaction; +} CLibruntimeConfig; + +typedef struct tagCInvokeArg { + // go memory, no need to free + void *buf; + int64_t size_buf; + char isRef; + char *objId; + char *tenantId; + char **nestedObjects; + int size_nestedObjects; +} CInvokeArg; + +typedef struct tagCCustomResource { + char *name; + float scalar; +} CCustomResource; + +typedef struct tagCCreateOpt { + char *key; + char *value; +} CCreateOpt; + +typedef enum tagCAffinityKind { + RESOURCE = 0, + INSTANCE = 1, +} CAffinityKind; + +typedef enum tagCAffinityType { + PREFERRED = 0, + PREFERRED_ANTI = 1, + REQUIRED = 2, + REQUIRED_ANTI = 3, +} CAffinityType; + +typedef enum tagCLabelOpType { + IN = 0, + NOT_IN = 1, + EXISTS = 2, + NOT_EXISTS = 3, +} CLabelOpType; + +typedef struct tagCLabelOperator { + CLabelOpType opType; + char *labelKey; + char **labelValues; + int size_labelValues; +} CLabelOperator; + +typedef struct tagCAffinity { + CAffinityKind affKind; + CAffinityType affType; + char preferredPrio; + char preferredAntiOtherLabels; + CLabelOperator *labelOps; + int size_labelOps; +} CAffinity; + +typedef struct TagCCredential { + char *ak; + char *sk; + int sizeSk; + char *dk; + int sizeDk; +} CCredential; + +typedef struct tagCInvokeOptions { + int cpu; + int memory; + CCustomResource *customResources; + int size_customResources; + CCustomExtension *customExtensions; + int size_customExtensions; + CCreateOpt *createOpt; + int size_createOpt; + char **labels; + int size_labels; + CAffinity *schedAffinities; + int RetryTimes; + int RecoverRetryTimes; + int size_schedAffinities; + char **codePaths; + int size_codePaths; + char *schedulerFunctionId; + char **schedulerInstanceIds; + int size_schedulerInstanceIds; + char *traceId; + int timeout; + int acquireTimeout; + char trafficLimited; + CInvokeLabels *invokeLabels; + int size_invokeLabels; + CInstanceSession *instanceSession; + int64_t scheduleTimeoutMs; +} CInvokeOptions; + +typedef struct tagCErrorObject { + char *objectId; + CErrorInfo *errorInfo; +} CErrorObject; + +typedef struct tagCWaitResult { + char **readyIds; + int size_readyIds; + char **unreadyIds; + int size_unreadyIds; + CErrorObject **errorIds; + int size_errorIds; +} CWaitResult; + +typedef enum tagCWriteMode { + NONE_L2_CACHE = 0, + WRITE_THROUGH_L2_CACHE = 1, // sync write + WRITE_BACK_L2_CACHE = 2, // async write + NONE_L2_CACHE_EVICT = 3, // evictable objects +} CWriteMode; + +typedef enum tagCExistenceOpt { + NONE = 0, + NX = 1, +} CExistenceOpt; + +typedef enum tagCCacheType { + MEMORY = 0, + DISK = 1, +} CCacheType; + +typedef struct tagCSetParam { + CWriteMode writeMode; + uint32_t ttlSecond; + CExistenceOpt existence; + CCacheType cacheType; +} CSetParam; + +typedef struct tagCMSetParam { + CWriteMode writeMode; + uint32_t ttlSecond; + CExistenceOpt existence; + CCacheType cacheType; +} CMSetParam; + +typedef enum tagCConsistencyType { + PRAM = 0, + CAUSAL = 1, +} CConsistencyType; + +typedef struct tagCCreateParam { + CWriteMode writeMode; + CConsistencyType consistencyType; + CCacheType cacheType; +} CCreateParam; + +typedef struct tagCProducerConfig { + int64_t delayFlushTime; + int64_t pageSize; + uint64_t maxStreamSize; + char *traceId; +} CProducerConfig; + +typedef enum tagCSubscriptionType { + STREAM = 0, + ROUND_ROBIN, + KEY_PARTITIONS, + UNKNOWN, +} CSubscriptionType; + +typedef struct tagCSubscriptionConfig { + char *subscriptionName; + CSubscriptionType type; + char *traceId; +} CSubscriptionConfig; + +typedef void (*CGetAsyncCallback)(char *cObjectID, CErrorInfo *cErr, void *userData); + +typedef struct tagCElement { + uint8_t *ptr; + uint64_t size; + uint64_t id; +} CElement; + +typedef struct tagConnectArguments { + char *host; + int port; + int timeoutMs; + char *token; + int tokenLen; + char *clientPublicKey; + int clientPublicKeyLen; + char *clientPrivateKey; + int clientPrivateKeyLen; + char *serverPublicKey; + int serverPublicKeyLen; + char *accessKey; + int accessKeyLen; + char *secretKey; + int secretKeyLen; + char *authClientID; + int authClientIDLen; + char *authClientSecret; + int authClientSecretLen; + char *authUrl; + int authUrlLen; + char *tenantID; + int tenantIDLen; + char enableCrossNodeConnection; +} CConnectArguments; + +typedef void *Consumer_p; +typedef void *Producer_p; +typedef void *CStateStorePtr; + +typedef enum tagCHealthCheckCode { + HEALTHY = 0, + HEALTH_CHECK_FAILED, + SUB_HEALTH, +} CHealthCheckCode; + +void CParseCErrorObjectPointer(CErrorObject *object, int *code, char **errMessage, char **objectId, + CStackTracesInfo *stackTracesInfo); + +// function +CErrorInfo CCreateInstance(CFunctionMeta *funcMeta, CInvokeArg *invokeArgs, int size_invokeArgs, CInvokeOptions *opts, + char **instanceId); +CErrorInfo CInvokeByInstanceId(CFunctionMeta *funcMeta, char *instanceId, CInvokeArg *invokeArgs, int size_invokeArgs, + CInvokeOptions *opts, char **cReturnObjectId); +CErrorInfo CInvokeByFunctionName(CFunctionMeta *funcMeta, CInvokeArg *invokeArgs, int size_invokeArgs, + CInvokeOptions *opts, char **cReturnObjectId); +CErrorInfo CAcquireInstance(char *stateId, CFunctionMeta *cFuncMeta, CInvokeOptions *cInvokeOpts, + CInstanceAllocation *cInsAlloc); + +CErrorInfo CReleaseInstance(CInstanceAllocation *insAlloc, char *cStateID, char cAbnormal, CInvokeOptions *cInvokeOpts); +void CCreateInstanceRaw(CBuffer cReqRaw, char *cContext); +void CInvokeByInstanceIdRaw(CBuffer cReqRaw, char *cContext); +void CKillRaw(CBuffer cReqRaw, char *cContext); + +extern void GoRawCallback(char *cKey, CErrorInfo cErr, CBuffer cResultRaw); + +// object +CErrorInfo CSetTenantId(const char *cTenantId, int cTenantIdLen); +void CWait(char **objIds, int size_objIds, int waitNum, int timeoutSec, CWaitResult *result); +CErrorInfo CPutCommon(char *objectId, CBuffer data, char **nestedIds, int sizeNestedIds, char isPutRaw, + CCreateParam param); +CErrorInfo CGetMultiCommon(char **cObjIds, int size_cObjIds, int timeoutMs, char allowPartial, CBuffer *cData, + char isRaw); +CErrorInfo CGet(char *objId, int timeoutSec, CBuffer *data); +extern void GoGetAsyncCallback(char *cObjectID, CBuffer cBuf, CErrorInfo *cErr, void *userData); +extern void GoWaitAsyncCallback(char *cObjectID, CErrorInfo *cErr, void *userData); +void CUpdateSchdulerInfo(char *scheduleName, char *schedulerId, char *option); +void CGetAsync(char *objectId, void *userData); +void CWaitAsync(char *objectId, void *userData); +CErrorInfo CIncreaseReferenceCommon(char **cObjIds, int size_cObjIds, char *cRemoteId, char ***cFailedIds, + int *size_cFailedIds, char isRaw); +CErrorInfo CDecreaseReferenceCommon(char **cObjIds, int size_cObjIds, char *cRemoteId, char ***cFailedIds, + int *size_cFailedIds, char isRaw); +CErrorInfo CReleaseGRefs(char *cRemoteId); +CErrorInfo CAllocReturnObject(CDataObject *object, int dataSize, char **nestedIds, int sizeNestedIds, + uint64_t *totalNativeBufferSize); +void CSetReturnObject(CDataObject *cObject, int dataSize); +void CCancel(char **objIds, int size_objIds, char force, char recursive); + +// KV +CErrorInfo CKVWrite(char *key, CBuffer data, CSetParam param); +CErrorInfo CKVMSetTx(char **key, int sizeKeys, CBuffer *data, CMSetParam param); +CErrorInfo CKVRead(char *key, int timeoutMs, CBuffer *data); +CErrorInfo CKVMultiRead(char **keys, int size_keys, int timeoutMs, char allowPartial, CBuffer *data); +CErrorInfo CKVDel(char *key); +CErrorInfo CKMultiVDel(char **keys, int size_keys, char ***cFailedKeys, int *size_cFailedKeys); + +// stream +CErrorInfo CCreateStreamProducer(char *streamName, CProducerConfig *config, Producer_p *producer); +CErrorInfo CCreateStreamConsumer(char *streamName, CSubscriptionConfig *config, Consumer_p *consumer); +CErrorInfo CDeleteStream(char *streamName); +CErrorInfo CQueryGlobalProducersNum(char *streamName, uint64_t *num); +CErrorInfo CQueryGlobalConsumersNum(char *streamName, uint64_t *num); +CErrorInfo CProducerSend(Producer_p producerPtr, uint8_t *ptr, uint64_t size, uint64_t id); +CErrorInfo CProducerSendWithTimeout(Producer_p producerPtr, uint8_t *ptr, uint64_t size, uint64_t id, + int64_t timeoutMs); +CErrorInfo CProducerFlush(Producer_p producerPtr); +CErrorInfo CProducerClose(Producer_p producerPtr); +CErrorInfo CConsumerReceive(Consumer_p consumerPtr, uint32_t timeoutMs, CElement **elements, uint64_t *count); +CErrorInfo CConsumerReceiveExpectNum(Consumer_p consumerPtr, uint32_t expectNum, uint32_t timeoutMs, + CElement **elements, uint64_t *count); +CErrorInfo CConsumerAck(Consumer_p consumerPtr, uint64_t elementId); +CErrorInfo CConsumerClose(Consumer_p consumerPtr); + +// management +void CExit(int code, char *message); +CErrorInfo CKill(char *instanceId, int sigNo, CBuffer cData); +void CFinalize(void); +CErrorInfo CInit(CLibruntimeConfig *config); +void CReceiveRequestLoop(void); +void CExecShutdownHandler(int sigNum); +char *CGetRealInstanceId(char *objectId, int timeout); +void CSaveRealInstanceId(char *objectId, char *instanceId); + +// buffer +CErrorInfo CWriterLatch(CBuffer *cBuffer); +CErrorInfo CMemoryCopy(CBuffer *cBuffer, void *cSrc, uint64_t size_cSrc); +CErrorInfo CSeal(CBuffer *cBuffer); +CErrorInfo CWriterUnlatch(CBuffer *cBuffer); + +// handlers +extern CErrorInfo *GoLoadFunctions(char **codePaths, int size_codePaths); +extern CErrorInfo *GoFunctionExecution(CFunctionMeta *, CInvokeType, CArg *, int, CDataObject *, int); +extern CErrorInfo *GoCheckpoint(char *checkpointId, CBuffer *buffer); +extern CErrorInfo *GoRecover(CBuffer *buffer); +extern CErrorInfo *GoShutdown(uint64_t gracePeriodSeconds); +extern CErrorInfo *GoSignal(int sigNo, CBuffer *payload); +extern CHealthCheckCode GoHealthCheck(void); +extern char GoHasHealthCheck(void); + +// pool +extern void GoFunctionExecutionPoolSubmit(void *ptr); +void CFunctionExecution(void *ptr); + +// kv client +extern CErrorInfo CCreateStateStore(CConnectArguments *arguments, CStateStorePtr *stateStorePtr); +extern CErrorInfo CSetTraceId(const char *cTraceId, int traceId); +extern CErrorInfo CGenerateKey(CStateStorePtr stateStorePtr, char **cKey, int *cKeyLen); +extern void CDestroyStateStore(CStateStorePtr stateStorePtr); +extern CErrorInfo CSetByStateStore(CStateStorePtr stateStorePtr, char *key, CBuffer data, CSetParam param); +extern CErrorInfo CSetValueByStateStore(CStateStorePtr stateStorePtr, CBuffer data, CSetParam param, char **cKey, + int *cKeyLen); +extern CErrorInfo CGetByStateStore(CStateStorePtr stateStorePtr, char *key, CBuffer *data, int timeoutMs); +extern CErrorInfo CGetArrayByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int cKeysLen, CBuffer *data, + int timeoutMs); +extern CErrorInfo CQuerySizeByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int sizeCKeys, uint64_t *cSize); +extern CErrorInfo CDelByStateStore(CStateStorePtr stateStorePtr, char *key); +extern CErrorInfo CDelArrayByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int cKeysLen, char ***cFailedKeys, + int *cFailedKeysLen); +extern CCredential CGetCredential(); +extern int CIsHealth(); +extern int CIsDsHealth(); +#ifdef __cplusplus +} +#endif diff --git a/api/go/libruntime/cpplibruntime/cpplibruntime.cpp b/api/go/libruntime/cpplibruntime/cpplibruntime.cpp new file mode 100644 index 0000000..026017b --- /dev/null +++ b/api/go/libruntime/cpplibruntime/cpplibruntime.cpp @@ -0,0 +1,1761 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "clibruntime.h" + +#include +#include "securec.h" +#include "src/libruntime/libruntime_manager.h" +#include "src/libruntime/utils/utils.h" +using namespace std; +using libruntime::SubscriptionType; +using YR::Libruntime::Affinity; +using namespace std::placeholders; +using YR::Libruntime::Buffer; +using YR::Libruntime::CacheType; +using YR::Libruntime::ConsistencyType; +using YR::Libruntime::CreateParam; +using YR::Libruntime::DataObject; +using YR::Libruntime::Element; +using YR::Libruntime::ErrorCode; +using YR::Libruntime::ErrorInfo; +using YR::Libruntime::ExistenceOpt; +using YR::Libruntime::FunctionMeta; +using YR::Libruntime::InstanceAllocation; +using YR::Libruntime::InvokeArg; +using YR::Libruntime::InvokeOptions; +using YR::Libruntime::LabelDoesNotExistOperator; +using YR::Libruntime::LabelExistsOperator; +using YR::Libruntime::LabelInOperator; +using YR::Libruntime::LabelNotInOperator; +using YR::Libruntime::LabelOperator; +using YR::Libruntime::LibruntimeConfig; +using YR::Libruntime::LibruntimeManager; +using YR::Libruntime::MSetParam; +using YR::Libruntime::NativeBuffer; +using YR::Libruntime::ProducerConf; +using YR::Libruntime::SetParam; +using YR::Libruntime::StackTraceElement; +using YR::Libruntime::StackTraceInfo; +using YR::Libruntime::StreamConsumer; +using YR::Libruntime::StreamProducer; +using YR::Libruntime::SubscriptionConfig; +using YR::Libruntime::WriteMode; + +#ifdef __cplusplus +extern "C" { +#endif + +#define RETURN_ERR_WHEN_CONSUMER_ISNULL(consumer) \ + if (consumer == nullptr || *consumer == nullptr) { \ + if (consumer != nullptr) { \ + delete consumer; \ + } \ + return ErrorInfoToCError(ErrorInfo(ErrorCode::ERR_PARAM_INVALID, "failed to get valid consumer.")); \ + } + +std::tuple, ErrorInfo> getLibRuntime() +{ + auto lrt = LibruntimeManager::Instance().GetLibRuntime(); + if (lrt == nullptr) { + YRLOG_ERROR("GetLibRuntime empty"); + return {lrt, ErrorInfo(YR::Libruntime::ErrorCode::ERR_FINALIZED, YR::Libruntime::ModuleCode::RUNTIME, + "libRuntime empty")}; + } + + return {lrt, ErrorInfo(YR::Libruntime::ErrorCode::ERR_OK, YR::Libruntime::ModuleCode::RUNTIME, "")}; +} + +ErrorInfo CErrorToErrorInfo(CErrorInfo *cerr) +{ + ErrorInfo err; + if (cerr == nullptr) { + return err; + } + err.SetErrorCode(static_cast(cerr->code)); + if (cerr->message != nullptr) { + err.SetErrorMsg(cerr->message); + free(cerr->message); + } + if (cerr->size_stackTracesInfo == 0) { + free(cerr); + return err; + } + std::vector stackTraces; + for (int i = 0; i < cerr->size_stackTracesInfo; i++) { + std::vector elements; + for (int j = 0; j < cerr->stackTracesInfo[i].size_stackTraces; j++) { + StackTraceElement element; + element.className = cerr->stackTracesInfo[i].stackTraces[j].className; + element.methodName = cerr->stackTracesInfo[i].stackTraces[j].methodName; + element.fileName = cerr->stackTracesInfo[i].stackTraces[j].fileName; + element.lineNumber = cerr->stackTracesInfo[i].stackTraces[j].lineNumber; + for (int k = 0; k < cerr->stackTracesInfo[i].stackTraces[j].size_extensions; k++) { + element.extensions.emplace(cerr->stackTracesInfo[i].stackTraces[j].extensions[k].key, + cerr->stackTracesInfo[i].stackTraces[j].extensions[k].value); + free(cerr->stackTracesInfo[i].stackTraces[j].extensions[k].key); + free(cerr->stackTracesInfo[i].stackTraces[j].extensions[k].value); + } + elements.emplace_back(std::move(element)); + free(cerr->stackTracesInfo[i].stackTraces[j].className); + free(cerr->stackTracesInfo[i].stackTraces[j].methodName); + free(cerr->stackTracesInfo[i].stackTraces[j].fileName); + } + StackTraceInfo stackTrace; + stackTrace.SetStackTraceElements(elements); + stackTrace.SetMsg(cerr->stackTracesInfo[i].message); + free(cerr->stackTracesInfo[i].message); + free(cerr->stackTracesInfo[i].stackTraces); + stackTraces.emplace_back(std::move(stackTrace)); + } + err.SetStackTraceInfos(stackTraces); + free(cerr->stackTracesInfo->stackTraces); + free(cerr->stackTracesInfo); + free(cerr); + return err; +} + +ErrorInfo CHealthCheckCodeToErrorInfo(CHealthCheckCode code) +{ + ErrorInfo err; + if (code == CHealthCheckCode::HEALTHY) { + err.SetErrorCode(ErrorCode::ERR_HEALTH_CHECK_HEALTHY); + } else if (code == CHealthCheckCode::HEALTH_CHECK_FAILED) { + err.SetErrorCode(ErrorCode::ERR_HEALTH_CHECK_FAILED); + } else if (code == CHealthCheckCode::SUB_HEALTH) { + err.SetErrorCode(ErrorCode::ERR_HEALTH_CHECK_SUBHEALTH); + } else { + err.SetErrorCode(ErrorCode::ERR_HEALTH_CHECK_HEALTHY); + } + return err; +} + +char *CString(const std::string &str) +{ + char *cStr = (char *)malloc(str.size() + 1); + (void)memcpy_s(cStr, str.size(), str.data(), str.size()); + cStr[str.size()] = 0; + return cStr; +} + +CErrorInfo ErrorInfoToCError(const ErrorInfo &err) +{ + CErrorInfo cErr{}; + cErr.message = nullptr; + cErr.code = static_cast(err.Code()); + cErr.dsStatusCode = err.GetDsStatusCode(); + if (!err.Msg().empty()) { + cErr.message = CString(err.Msg()); + } + std::vector stackTracesInfo = err.GetStackTraceInfos(); + if (stackTracesInfo.size() == 0) { + return cErr; + } + cErr.size_stackTracesInfo = stackTracesInfo.size(); + cErr.stackTracesInfo = new (std::nothrow) CStackTracesInfo[stackTracesInfo.size()]; + + for (size_t i = 0; i < stackTracesInfo.size(); i++) { + std::vector elements = stackTracesInfo[i].StackTraceElements(); + cErr.stackTracesInfo[i].stackTraces = new (std::nothrow) CStackTrace[elements.size()]; + cErr.stackTracesInfo[i].size_stackTraces = elements.size(); + cErr.stackTracesInfo[i].message = CString(stackTracesInfo[i].Message()); + for (size_t j = 0; j < elements.size(); j++) { + cErr.stackTracesInfo[i].stackTraces[j].className = CString(elements[j].className); + cErr.stackTracesInfo[i].stackTraces[j].fileName = CString(elements[j].fileName); + cErr.stackTracesInfo[i].stackTraces[j].methodName = CString(elements[j].methodName); + cErr.stackTracesInfo[i].stackTraces[j].lineNumber = static_cast(elements[j].lineNumber); + cErr.stackTracesInfo[i].stackTraces[j].size_extensions = elements[j].extensions.size(); + cErr.stackTracesInfo[i].stackTraces[j].extensions = + new (std::nothrow) CCustomExtension[elements[j].extensions.size()]; + int k = 0; + for (auto iter = elements[j].extensions.begin(); iter != elements[j].extensions.end(); iter++) { + cErr.stackTracesInfo[i].stackTraces[j].extensions[k].key = CString(iter->first); + cErr.stackTracesInfo[i].stackTraces[j].extensions[k].value = CString(iter->second); + k++; + } + } + } + return cErr; +} + +SetParam CSetParamToSetParam(CSetParam param) +{ + SetParam setParam; + setParam.writeMode = static_cast(param.writeMode); + setParam.ttlSecond = param.ttlSecond; + setParam.existence = static_cast(param.existence); + setParam.cacheType = static_cast(param.cacheType); + return setParam; +} + +MSetParam CMSetParamToMSetParam(CMSetParam param) +{ + MSetParam mSetParam; + mSetParam.writeMode = static_cast(param.writeMode); + mSetParam.ttlSecond = param.ttlSecond; + mSetParam.existence = static_cast(param.existence); + mSetParam.cacheType = static_cast(param.cacheType); + return mSetParam; +} + +CreateParam CCreateParamToCreateParam(CCreateParam param) +{ + CreateParam createParam; + createParam.writeMode = static_cast(param.writeMode); + createParam.consistencyType = static_cast(param.consistencyType); + createParam.cacheType = static_cast(param.cacheType); + return createParam; +} + +void CheckNullAndAssignValue(const char *str, const int len, std::string &returnValue) +{ + if (str != nullptr && len > 0) { + returnValue = std::string(str, len); + } +} + +char *StringToCString(const std::string &input) +{ + char *ret = nullptr; + if (input.empty()) { + return ret; + } + size_t destSize = input.size() + 1; + ret = static_cast(malloc(destSize)); + if (ret != nullptr) { + int err = memcpy_s(ret, destSize, input.data(), input.size()); + if (err == EOK) { + ret[destSize - 1] = '\0'; + } else { + free(ret); + ret = nullptr; + YRLOG_ERROR("StringToCString memcpy_s failed: {}", err); + return nullptr; + } + } + return ret; +} + +CCredential ConverToCCredential(YR::Libruntime::Credential &credential) +{ + CCredential ccred{}; + ccred.ak = StringToCString(credential.ak); + ccred.sk = StringToCString(credential.sk); + ccred.sizeSk = credential.sk.length(); + ccred.dk = StringToCString(credential.dk); + ccred.sizeDk = credential.dk.length(); + return ccred; +} + +CCredential CredentialToCCre(YR::Libruntime::Credential &credential) +{ + return ConverToCCredential(credential); +} + +void StringsToCStrings(const std::vector &stringVec, char ***cStrings, int *cStringsLen) +{ + if (!stringVec.empty()) { + *cStrings = (char **)malloc(stringVec.size() * sizeof(char **)); + for (size_t i = 0; i < stringVec.size(); i++) { + (*cStrings)[i] = StringToCString(stringVec[i]); + } + *cStringsLen = stringVec.size(); + } +} + +std::vector CStringsToStrings(char **cKeys, int cKeysLen) +{ + std::vector keys(cKeysLen); + for (int i = 0; i < cKeysLen; i++) { + keys[i] = cKeys[i]; + } + return keys; +} + +ErrorInfo LoadFunctionsWrapper(const std::vector &codePaths) +{ + char **cPaths = new (std::nothrow) char *[codePaths.size()]; + for (size_t i = 0; i < codePaths.size(); i++) { + cPaths[i] = const_cast(codePaths[i].c_str()); + } + CErrorInfo *cerr = GoLoadFunctions(cPaths, codePaths.size()); + delete[] cPaths; + return CErrorToErrorInfo(cerr); +} + +void InsAllocationToCInsAllocation(const InstanceAllocation &alloc, CInstanceAllocation *cInsAlloc) +{ + cInsAlloc->functionId = CString(alloc.functionId); + cInsAlloc->funcSig = CString(alloc.funcSig); + cInsAlloc->instanceId = CString(alloc.instanceId); + cInsAlloc->leaseId = CString(alloc.leaseId); + cInsAlloc->tLeaseInterval = alloc.tLeaseInterval; +} + +CFunctionMeta FunctionMetaToCFunctionMeta(const FunctionMeta &function) +{ + CFunctionMeta cFuncMeta{}; + cFuncMeta.appName = const_cast(function.appName.c_str()); + cFuncMeta.funcName = const_cast(function.funcName.c_str()); + cFuncMeta.functionId = const_cast(function.functionId.c_str()); + return cFuncMeta; +} + +ErrorInfo FunctionExecutionWrapper(const FunctionMeta &function, const libruntime::InvokeType invokeType, + const std::vector> &rawArgs, + std::vector> &returnValues) +{ + CFunctionMeta cFuncMeta = FunctionMetaToCFunctionMeta(function); + + CArg *args = nullptr; + if (!rawArgs.empty()) { + args = new (std::nothrow) CArg[rawArgs.size()]; + for (size_t i = 0; i < rawArgs.size(); i++) { + args[i].type = CArgType::VALUE; + args[i].data = static_cast(rawArgs[i]->data->MutableData()); + args[i].size = rawArgs[i]->data->GetSize(); + } + } + + CDataObject *retObjs = nullptr; + if (!returnValues.empty()) { + retObjs = new (std::nothrow) CDataObject[returnValues.size()]; + for (size_t i = 0; i < returnValues.size(); i++) { + if (returnValues[i]->id.empty()) { + retObjs[i].id = CString("empty"); + } else { + retObjs[i].id = CString(returnValues[i]->id); + } + retObjs[i].selfSharedPtr = static_cast(returnValues[i].get()); + retObjs[i].nestedObjIds = nullptr; + retObjs[i].size_nestedObjIds = 0; + retObjs[i].buffer.buffer = nullptr; + retObjs[i].buffer.selfSharedPtrBuffer = nullptr; + retObjs[i].buffer.size_buffer = 0; + } + } + + CErrorInfo *cerr = + GoFunctionExecution(&cFuncMeta, CInvokeType(invokeType), args, rawArgs.size(), retObjs, returnValues.size()); + for (size_t i = 0; i < returnValues.size(); i++) { + free(retObjs[i].id); + } + delete[] args; + delete[] retObjs; + return CErrorToErrorInfo(cerr); +} + +ErrorInfo CheckpointWrapper(const std::string &checkpointId, std::shared_ptr &data) +{ + char *cChkptId = const_cast(checkpointId.c_str()); + CBuffer buf = {0}; + CErrorInfo *cerr = GoCheckpoint(cChkptId, &buf); + data = std::make_shared(buf.buffer, buf.size_buffer, true); + return CErrorToErrorInfo(cerr); +} + +ErrorInfo RecoverWrapper(std::shared_ptr data) +{ + CBuffer buf = {0}; + buf.buffer = data->MutableData(); + buf.size_buffer = data->GetSize(); + CErrorInfo *cerr = GoRecover(&buf); + return CErrorToErrorInfo(cerr); +} + +ErrorInfo ShutdownWrapper(uint64_t gracePeriodSeconds) +{ + YRLOG_INFO("start execute go shutdown handler"); + CErrorInfo *cerr = GoShutdown(gracePeriodSeconds); + YRLOG_INFO("end to execute shutdown handler"); + return CErrorToErrorInfo(cerr); +} + +ErrorInfo SignalWrapper(int sigNo, std::shared_ptr payload) +{ + CBuffer buf = {0}; + buf.buffer = payload->MutableData(); + buf.size_buffer = payload->GetSize(); + CErrorInfo *cerr = GoSignal(sigNo, &buf); + return CErrorToErrorInfo(cerr); +} + +ErrorInfo HealthCheckWrapper(void) +{ + auto code = GoHealthCheck(); + return CHealthCheckCodeToErrorInfo(code); +} + +void FuncExecSubmitHook(std::function &&f) +{ + auto funcPtr = new (std::nothrow) std::shared_ptr>(); + *funcPtr = std::make_shared>([f = std::move(f), funcPtr]() { + f(); + delete funcPtr; + }); + GoFunctionExecutionPoolSubmit(funcPtr); +} + +CErrorInfo CInit(CLibruntimeConfig *config) +{ + LibruntimeConfig librtCfg{}; + YR::ParseIpAddr(config->functionSystemAddress, librtCfg.functionSystemIpAddr, librtCfg.functionSystemPort); + YR::ParseIpAddr(config->grpcAddress, librtCfg.functionSystemRtServerIpAddr, librtCfg.functionSystemRtServerPort); + YR::ParseIpAddr(config->dataSystemAddress, librtCfg.dataSystemIpAddr, librtCfg.dataSystemPort); + librtCfg.runtimeId = config->runtimeId; + librtCfg.instanceId = config->instanceId; + librtCfg.functionName = config->functionName; + librtCfg.logDir = config->logDir; + librtCfg.logLevel = config->logLevel; + librtCfg.selfApiType = static_cast(config->apiType); + librtCfg.inCluster = config->inCluster != 0; + librtCfg.isDriver = config->isDriver != 0; + + librtCfg.enableMTLS = config->enableMTLS != 0; + librtCfg.privateKeyPath = config->privateKeyPath; + librtCfg.certificateFilePath = config->certificateFilePath; + librtCfg.verifyFilePath = config->verifyFilePath; + librtCfg.encryptPrivateKeyPasswd = config->encryptPrivateKeyPasswd; + librtCfg.primaryKeyStoreFile = config->primaryKeyStoreFile; + librtCfg.standbyKeyStoreFile = config->standbyKeyStoreFile; + librtCfg.encryptEnable = config->enableDsEncrypt != 0; + librtCfg.runtimePublicKey = config->runtimePublicKeyContext; + librtCfg.runtimePrivateKey = config->runtimePrivateKeyContext; + librtCfg.dsPublicKey = config->dsPublicKeyContext; + librtCfg.encryptRuntimePublicKeyContext = config->encryptRuntimePublicKeyContext; + librtCfg.encryptRuntimePrivateKeyContext = config->encryptRuntimePrivateKeyContext; + librtCfg.encryptDsPublicKeyContext = config->encryptDsPublicKeyContext; + librtCfg.ak_ = config->systemAuthAccessKey; + librtCfg.sk_ = datasystem::SensitiveValue(config->systemAuthSecretKey, config->systemAuthSecretKeySize); + auto len = sizeof(config->privateKeyPaaswd); + (void)memcpy_s(librtCfg.privateKeyPaaswd, len, config->privateKeyPaaswd, len); + auto decryptErr = librtCfg.Decrypt(); + if (!decryptErr.OK()) { + return ErrorInfoToCError(decryptErr); + } + + librtCfg.functionIds[libruntime::LanguageType::Golang] = config->functionId; + librtCfg.selfLanguage = libruntime::LanguageType::Golang; + librtCfg.libruntimeOptions.loadFunctionCallback = LoadFunctionsWrapper; + librtCfg.libruntimeOptions.functionExecuteCallback = FunctionExecutionWrapper; + librtCfg.libruntimeOptions.checkpointCallback = CheckpointWrapper; + librtCfg.libruntimeOptions.recoverCallback = RecoverWrapper; + librtCfg.libruntimeOptions.shutdownCallback = ShutdownWrapper; + librtCfg.libruntimeOptions.signalCallback = SignalWrapper; + if (GoHasHealthCheck() != 0) { + librtCfg.libruntimeOptions.healthCheckCallback = HealthCheckWrapper; + } else { + librtCfg.libruntimeOptions.healthCheckCallback = nullptr; + } + librtCfg.jobId = config->jobId; + librtCfg.funcExecSubmitHook = FuncExecSubmitHook; + librtCfg.maxConcurrencyCreateNum = config->maxConcurrencyCreateNum; + librtCfg.enableSigaction = config->enableSigaction; + auto err = LibruntimeManager::Instance().Init(librtCfg); + return ErrorInfoToCError(err); +} + +void CReceiveRequestLoop(void) +{ + LibruntimeManager::Instance().ReceiveRequestLoop(); +} + +void CExecShutdownHandler(int sigNum) +{ + LibruntimeManager::Instance().ExecShutdownCallback(sigNum, false); +} + +static FunctionMeta BuildFunctionMeta(CFunctionMeta *cFuncMeta) +{ + FunctionMeta funcMeta; + funcMeta.apiType = static_cast(cFuncMeta->apiType); + funcMeta.funcName = cFuncMeta->funcName; + funcMeta.functionId = cFuncMeta->functionId; + funcMeta.languageType = static_cast(cFuncMeta->languageType); + funcMeta.signature = cFuncMeta->signature; + funcMeta.poolLabel = cFuncMeta->poolLabel; + if (cFuncMeta->hasName) { + funcMeta.name = cFuncMeta->name; + } + if (cFuncMeta->hasNs) { + funcMeta.ns = cFuncMeta->ns; + } + return funcMeta; +} + +static std::vector BuildInvokeArgs(CInvokeArg *cInvokeArgs, int size_invokeArgs) +{ + std::vector invokeArgs; + for (int i = 0; i < size_invokeArgs; i++) { + CInvokeArg &cArg = cInvokeArgs[i]; + InvokeArg arg; + if (cArg.isRef) { + // need to process + } else { + arg.isRef = false; + arg.dataObj = std::make_shared(0, cArg.size_buf); + // copy arg to DataObject->data + arg.dataObj->data->MemoryCopy(cArg.buf, cArg.size_buf); + } + for (int i = 0; i < cArg.size_nestedObjects; i++) { + arg.nestedObjects.emplace(cArg.nestedObjects[i]); + } + arg.tenantId = cArg.tenantId; + invokeArgs.emplace_back(std::move(arg)); + } + return invokeArgs; +} + +static std::list> BuildLabelOperators(CLabelOperator *cLabelOps, int cLabelOpsLen) +{ + std::list> labelOps; + for (int i = 0; i < cLabelOpsLen; i++) { + std::shared_ptr labelOp; + CLabelOperator &cLabelOp = cLabelOps[i]; + switch (cLabelOp.opType) { + case CLabelOpType::IN: + labelOp = std::make_shared(); + break; + case CLabelOpType::NOT_IN: + labelOp = std::make_shared(); + break; + case CLabelOpType::EXISTS: + labelOp = std::make_shared(); + break; + case CLabelOpType::NOT_EXISTS: + labelOp = std::make_shared(); + break; + default: + labelOp = std::make_shared(); + break; + } + labelOp->SetKey(cLabelOp.labelKey); + std::list labelValues; + for (int j = 0; j < cLabelOp.size_labelValues; j++) { + labelValues.push_back(cLabelOp.labelValues[j]); + } + labelOp->SetValues(labelValues); + + labelOps.push_back(labelOp); + } + return labelOps; +} + +static std::shared_ptr BuildScheduleAffinity(CAffinity &cAffinity) +{ + std::string kind; + switch (cAffinity.affKind) { + case CAffinityKind::RESOURCE: + kind = YR::Libruntime::RESOURCE; + break; + case CAffinityKind::INSTANCE: + kind = YR::Libruntime::INSTANCE; + break; + default: + kind = YR::Libruntime::RESOURCE; + break; + } + + std::string type; + switch (cAffinity.affType) { + case CAffinityType::PREFERRED: + type = YR::Libruntime::PREFERRED; + break; + case CAffinityType::PREFERRED_ANTI: + type = YR::Libruntime::PREFERRED_ANTI; + break; + case CAffinityType::REQUIRED: + type = YR::Libruntime::REQUIRED; + break; + case CAffinityType::REQUIRED_ANTI: + type = YR::Libruntime::REQUIRED_ANTI; + break; + default: + type = YR::Libruntime::PREFERRED; + break; + } + using AffnitiyCreator = std::function()>; + const std::unordered_map affinityMap{ + {"ResourcePreferredAffinity", [] { return std::make_shared(); }}, + {"ResourcePreferredAntiAffinity", + [] { return std::make_shared(); }}, + {"ResourceRequiredAffinity", [] { return std::make_shared(); }}, + {"ResourceRequiredAntiAffinity", + [] { return std::make_shared(); }}, + {"InstancePreferredAffinity", [] { return std::make_shared(); }}, + {"InstancePreferredAntiAffinity", + [] { return std::make_shared(); }}, + {"InstanceRequiredAffinity", [] { return std::make_shared(); }}, + {"InstanceRequiredAntiAffinity", + [] { return std::make_shared(); }}}; + std::string key = kind + type; + std::shared_ptr aff; + auto it = affinityMap.find(key); + if (it != affinityMap.end()) { + aff = it->second(); + } else { + aff = std::make_shared(kind, type); + } + aff->SetLabelOperators(BuildLabelOperators(cAffinity.labelOps, cAffinity.size_labelOps)); + aff->SetPreferredPriority(cAffinity.preferredPrio == 0 ? false : true); + aff->SetPreferredAntiOtherLabels(cAffinity.preferredAntiOtherLabels == 0 ? false : true); + return aff; +} + +static InvokeOptions BuildInvokeOptions(CInvokeOptions *cInvokeOpts) +{ + InvokeOptions invokeOpts; + invokeOpts.cpu = cInvokeOpts->cpu; + invokeOpts.memory = cInvokeOpts->memory; + invokeOpts.retryTimes = cInvokeOpts->RetryTimes; + invokeOpts.recoverRetryTimes = cInvokeOpts->RecoverRetryTimes; + invokeOpts.timeout = cInvokeOpts->timeout; + invokeOpts.acquireTimeout = cInvokeOpts->acquireTimeout; + invokeOpts.scheduleTimeoutMs = cInvokeOpts->scheduleTimeoutMs; + for (int i = 0; i < cInvokeOpts->size_customResources; i++) { + invokeOpts.customResources.emplace(cInvokeOpts->customResources[i].name, + cInvokeOpts->customResources[i].scalar); + } + for (int i = 0; i < cInvokeOpts->size_customExtensions; i++) { + invokeOpts.customExtensions.emplace(cInvokeOpts->customExtensions[i].key, + cInvokeOpts->customExtensions[i].value); + } + for (int i = 0; i < cInvokeOpts->size_createOpt; i++) { + invokeOpts.createOptions.emplace(cInvokeOpts->createOpt[i].key, cInvokeOpts->createOpt[i].value); + } + for (int i = 0; i < cInvokeOpts->size_schedulerInstanceIds; i++) { + invokeOpts.schedulerInstanceIds.emplace_back(cInvokeOpts->schedulerInstanceIds[i]); + } + for (int i = 0; i < cInvokeOpts->size_labels; i++) { + invokeOpts.labels.emplace_back(cInvokeOpts->labels[i]); + } + for (int i = 0; i < cInvokeOpts->size_schedAffinities; i++) { + invokeOpts.scheduleAffinities.push_back(BuildScheduleAffinity(cInvokeOpts->schedAffinities[i])); + } + for (int i = 0; i < cInvokeOpts->size_codePaths; i++) { + invokeOpts.codePaths.emplace_back(cInvokeOpts->codePaths[i]); + } + if (cInvokeOpts->schedulerFunctionId != nullptr) { + invokeOpts.schedulerFunctionId = cInvokeOpts->schedulerFunctionId; + } + if (cInvokeOpts->traceId != nullptr) { + invokeOpts.traceId = cInvokeOpts->traceId; + } + if (cInvokeOpts->trafficLimited != 0) { + invokeOpts.trafficLimited = true; + } + for (int i = 0; i < cInvokeOpts->size_invokeLabels; i++) { + invokeOpts.invokeLabels.emplace(cInvokeOpts->invokeLabels[i].key, cInvokeOpts->invokeLabels[i].value); + } + if (cInvokeOpts->instanceSession != nullptr) { + invokeOpts.instanceSession = std::make_shared(); + invokeOpts.instanceSession->sessionID = cInvokeOpts->instanceSession->sessionId; + invokeOpts.instanceSession->sessionTTL = cInvokeOpts->instanceSession->sessionTtl; + invokeOpts.instanceSession->concurrency = cInvokeOpts->instanceSession->concurrency; + } + return invokeOpts; +} + +static ProducerConf BuildProducerConfig(CProducerConfig *config) +{ + ProducerConf producerConfig; + producerConfig.delayFlushTime = config->delayFlushTime; + producerConfig.pageSize = config->pageSize; + producerConfig.maxStreamSize = config->maxStreamSize; + std::string traceId(config->traceId); + producerConfig.traceId = traceId; + return producerConfig; +} + +static SubscriptionConfig BuildSubscriptionConfig(CSubscriptionConfig *config) +{ + SubscriptionConfig subscriptionConfig; + std::string subName(config->subscriptionName); + subscriptionConfig.subscriptionName = subName; + subscriptionConfig.subscriptionType = SubscriptionType::STREAM; + std::string traceId(config->traceId); + subscriptionConfig.traceId = traceId; + return subscriptionConfig; +} + +static Element BuildElement(uint8_t *ptr, uint64_t size, uint64_t id) +{ + Element ele; + ele.ptr = ptr; + ele.size = size; + ele.id = id; + return ele; +} + +void FreeBuffers(CBuffer *cData, size_t size) +{ + for (size_t j = 0; j < size; j++) { + cData[j].size_buffer = 0; + if (cData[j].buffer != nullptr) { + free(cData[j].buffer); + cData[j].buffer = nullptr; + } + } +} + +CErrorInfo CCreateInstance(CFunctionMeta *cFuncMeta, CInvokeArg *cInvokeArgs, int size_invokeArgs, + CInvokeOptions *cInvokeOpts, char **instanceId) +{ + auto funcMeta = BuildFunctionMeta(cFuncMeta); + auto invokeArgs = BuildInvokeArgs(cInvokeArgs, size_invokeArgs); + auto invokeOpts = BuildInvokeOptions(cInvokeOpts); + + auto [lrt, errorInfo] = getLibRuntime(); + if (!errorInfo.OK()) { + return ErrorInfoToCError(errorInfo); + } + auto [err, instId] = lrt->CreateInstance(funcMeta, invokeArgs, invokeOpts); + if (err.OK()) { + *instanceId = CString(instId); + } + return ErrorInfoToCError(err); +} +CErrorInfo CInvokeByInstanceId(CFunctionMeta *cFuncMeta, char *cInstanceId, CInvokeArg *cInvokeArgs, + int size_invokeArgs, CInvokeOptions *cInvokeOpts, char **cReturnObjectId) +{ + auto funcMeta = BuildFunctionMeta(cFuncMeta); + std::string instanceId(cInstanceId); + auto invokeArgs = BuildInvokeArgs(cInvokeArgs, size_invokeArgs); + auto invokeOpts = BuildInvokeOptions(cInvokeOpts); + std::vector returnObjs{{""}}; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->InvokeByInstanceId(funcMeta, instanceId, invokeArgs, invokeOpts, returnObjs); + if (err.OK()) { + *cReturnObjectId = CString(returnObjs[0].id); + } + return ErrorInfoToCError(err); +} + +CErrorInfo CInvokeByFunctionName(CFunctionMeta *cFuncMeta, CInvokeArg *cInvokeArgs, int size_invokeArgs, + CInvokeOptions *cInvokeOpts, char **cReturnObjectId) +{ + auto funcMeta = BuildFunctionMeta(cFuncMeta); + auto invokeArgs = BuildInvokeArgs(cInvokeArgs, size_invokeArgs); + auto invokeOpts = BuildInvokeOptions(cInvokeOpts); + std::vector returnObjs{{""}}; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->InvokeByFunctionName(funcMeta, invokeArgs, invokeOpts, returnObjs); + if (err.OK()) { + *cReturnObjectId = CString(returnObjs[0].id); + } + return ErrorInfoToCError(err); +} + +CErrorInfo CAcquireInstance(char *stateId, CFunctionMeta *cFuncMeta, CInvokeOptions *cInvokeOpts, + CInstanceAllocation *cInsAlloc) +{ + auto funcMeta = BuildFunctionMeta(cFuncMeta); + auto invokeOpts = BuildInvokeOptions(cInvokeOpts); + std::string state(stateId); + auto [lrt, err1] = getLibRuntime(); + if (!err1.OK()) { + return ErrorInfoToCError(err1); + } + auto [instanceAllocation, err2] = lrt->AcquireInstance(state, funcMeta, invokeOpts); + InsAllocationToCInsAllocation(instanceAllocation, cInsAlloc); + return ErrorInfoToCError(err2); +} + +CErrorInfo CReleaseInstance(CInstanceAllocation *insAlloc, char *cStateID, char cAbnormal, CInvokeOptions *cInvokeOpts) +{ + std::string leaseID(insAlloc->leaseId); + std::string stateID(cStateID); + auto invokeOpts = BuildInvokeOptions(cInvokeOpts); + bool abnormal = cAbnormal == 0 ? false : true; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->ReleaseInstance(leaseID, stateID, abnormal, invokeOpts); + return ErrorInfoToCError(err); +} + +void RawCallbackWrapper(const std::string context, const ErrorInfo &err, std::shared_ptr resultRaw) +{ + CBuffer cResult = {0}; + cResult.buffer = resultRaw->MutableData(); + cResult.size_buffer = resultRaw->GetSize(); + GoRawCallback(const_cast(context.c_str()), ErrorInfoToCError(err), cResult); +} + +void CCreateInstanceRaw(CBuffer cReqRaw, char *cContext) +{ + auto reqRaw = std::make_shared(cReqRaw.buffer, cReqRaw.size_buffer); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return; // 以后把报错抛出去 + } + lrt->CreateInstanceRaw(reqRaw, std::bind(RawCallbackWrapper, std::string(cContext), _1, _2)); +} + +void CInvokeByInstanceIdRaw(CBuffer cReqRaw, char *cContext) +{ + auto reqRaw = std::make_shared(cReqRaw.buffer, cReqRaw.size_buffer); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return; // 以后把报错抛出去 + } + lrt->InvokeByInstanceIdRaw(reqRaw, std::bind(RawCallbackWrapper, std::string(cContext), _1, _2)); +} + +void CKillRaw(CBuffer cReqRaw, char *cContext) +{ + auto reqRaw = std::make_shared(cReqRaw.buffer, cReqRaw.size_buffer); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return; // 以后把报错抛出去 + } + lrt->KillRaw(reqRaw, std::bind(RawCallbackWrapper, std::string(cContext), _1, _2)); +} + +ErrorInfo ToCBuffer(std::shared_ptr buf, CBuffer *data) +{ + data->size_buffer = buf->GetSize(); + if (data->size_buffer == 0) { + data->buffer = nullptr; + return ErrorInfo(); + } + data->buffer = malloc(data->size_buffer); + if (data->buffer != nullptr) { + int err = memcpy_s(data->buffer, data->size_buffer, buf->ImmutableData(), buf->GetSize()); + if (err != EOK) { + free(data->buffer); + data->buffer = nullptr; + YRLOG_ERROR("CGet memcpy_s failed: {}", err); + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "CGet memcpy_s failed"); + } + } else { + YRLOG_ERROR("CGet malloc failed"); + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "CGet malloc failed"); + } + return ErrorInfo(); +} + +CErrorInfo CGet(char *objId, int timeoutSec, CBuffer *data) +{ + auto [lrt, err1] = getLibRuntime(); + if (!err1.OK()) { + return ErrorInfoToCError(err1); + } + auto [err2, res] = lrt->Get({objId}, timeoutSec, false); + if (err2.OK()) { + return ErrorInfoToCError(ToCBuffer(res[0]->data, data)); + } + return ErrorInfoToCError(err2); +} + +void CUpdateSchdulerInfo(char *scheduleName, char *schedulerId, char *option) +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return; // 以后把报错抛出去 + } + lrt->UpdateSchdulerInfo(scheduleName, schedulerId, option); +} + +void CGetAsync(char *objectId, void *userData) +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return; // 以后把报错抛出去 + } + lrt->GetAsync( + objectId, + [](std::shared_ptr data, const ErrorInfo &err, void *userData) { + auto cErr = ErrorInfoToCError(err); + CBuffer cBuf = {0}; + if (err.OK()) { + cErr = ErrorInfoToCError(ToCBuffer(data->data, &cBuf)); + } + auto cObjectId = const_cast(data->id.c_str()); + GoGetAsyncCallback(cObjectId, cBuf, &cErr, userData); + }, + userData); +} + +void CWaitAsync(char *objectId, void *userData) +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return; // 以后把报错抛出去 + } + lrt->WaitAsync( + objectId, + [](const std::string &objId, const ErrorInfo &err, void *userData) { + auto cErr = ErrorInfoToCError(err); + auto cObjectId = const_cast(objId.c_str()); + GoWaitAsyncCallback(cObjectId, &cErr, userData); + }, + userData); +} + +CErrorInfo CKill(char *instanceId, int sigNo, CBuffer cData) +{ + std::shared_ptr data; + if (cData.buffer != nullptr) { + data = std::make_shared(cData.buffer, cData.size_buffer); + } + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + if (data) { + err = lrt->Kill(instanceId, sigNo, data); + } else { + err = lrt->Kill(instanceId, sigNo); + } + return ErrorInfoToCError(err); +} + +char *CGetRealInstanceId(char *objectId, int timeout) +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return CString(""); + } + auto instId = lrt->GetRealInstanceId(objectId, timeout); + return CString(instId); +} + +void CExit(int code, char *message) +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return; // 以后把报错抛出去 + } + lrt->Exit(code, message); +} + +void CFinalize(void) +{ + LibruntimeManager::Instance().Finalize(); +} + +CErrorInfo CSetTenantId(const char *cTenantId, int cTenantIdLen) +{ + std::string tenantId = ""; + CheckNullAndAssignValue(cTenantId, cTenantIdLen, tenantId); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->SetTenantId(tenantId, true); + return ErrorInfoToCError(err); +} + +CErrorInfo CPutCommon(char *cObjectId, CBuffer cData, char **cNestedIds, int size_cNestedIds, char isPutRaw, + CCreateParam param) +{ + std::unordered_set nestedIds; + for (int i = 0; i < size_cNestedIds; i++) { + nestedIds.emplace(cNestedIds[i]); + } + CreateParam createParam = CCreateParamToCreateParam(param); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + if (isPutRaw) { + auto data = std::make_shared(cData.size_buffer); + data->MemoryCopy(cData.buffer, cData.size_buffer); + err = lrt->PutRaw(cObjectId, data, nestedIds, createParam); + return ErrorInfoToCError(err); + } + auto dataObj = std::make_shared(0, cData.size_buffer); + dataObj->data->MemoryCopy(cData.buffer, cData.size_buffer); + err = lrt->Put(cObjectId, dataObj, nestedIds, createParam); + return ErrorInfoToCError(err); +} + +CErrorInfo CGetMultiCommon(char **cObjIds, int size_cObjIds, int timeoutMs, char allowPartial, CBuffer *cData, + char isRaw) +{ + std::vector objectIds(size_cObjIds); + for (size_t i = 0; i < objectIds.size(); i++) { + objectIds[i] = cObjIds[i]; + } + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + + bool allowPart = allowPartial == 0 ? false : true; + std::pair>> getRet; + if (isRaw) { + getRet = lrt->GetRaw(objectIds, timeoutMs, allowPart); + } else { + std::pair>> getDataObjRet; + getDataObjRet = lrt->Get(objectIds, timeoutMs, allowPart); + getRet.second.resize(getDataObjRet.second.size()); + for (size_t i = 0; i < getDataObjRet.second.size(); i++) { + if (getDataObjRet.second[i]) { + getRet.second[i] = getDataObjRet.second[i]->data; + } + } + } + auto [err1, data] = getRet; + + for (size_t i = 0; i < data.size(); i++) { + if (data[i] == nullptr) { + continue; + } + auto errInfo = ToCBuffer(data[i], &cData[i]); + if (!errInfo.OK()) { + FreeBuffers(cData, i); + return ErrorInfoToCError(errInfo); + } + } + return ErrorInfoToCError(err1); +} + +void CWait(char **objIds, int size_objIds, int waitNum, int timeoutSec, CWaitResult *result) +{ + std::vector objectIds(size_objIds); + for (size_t i = 0; i < objectIds.size(); i++) { + objectIds[i] = objIds[i]; + } + + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return; // 以后把报错抛出去 + } + auto waitResults = lrt->Wait(objectIds, waitNum, timeoutSec); + if (!waitResults->readyIds.empty()) { + StringsToCStrings(waitResults->readyIds, &(result->readyIds), &(result->size_readyIds)); + } + + if (!waitResults->unreadyIds.empty()) { + StringsToCStrings(waitResults->unreadyIds, &(result->unreadyIds), &(result->size_unreadyIds)); + } + if (!waitResults->exceptionIds.empty()) { + result->errorIds = new (std::nothrow) CErrorObject *[waitResults->exceptionIds.size()]; + int index = 0; + for (auto iter = waitResults->exceptionIds.begin(); iter != waitResults->exceptionIds.end(); iter++) { + result->errorIds[index] = new (std::nothrow) CErrorObject(); + + result->errorIds[index]->objectId = new (std::nothrow) char[iter->first.size()]; + memcpy_s(result->errorIds[index]->objectId, iter->first.size(), iter->first.c_str(), iter->first.size()); + + result->errorIds[index]->errorInfo = new (std::nothrow) CErrorInfo(); + auto cError = ErrorInfoToCError(iter->second); + result->errorIds[index]->errorInfo->code = cError.code; + result->errorIds[index]->errorInfo->message = cError.message; + if (cError.size_stackTracesInfo == 0) { + return; + } + result->errorIds[index]->errorInfo->stackTracesInfo = cError.stackTracesInfo; + result->errorIds[index]->errorInfo->size_stackTracesInfo = cError.size_stackTracesInfo; + result->size_errorIds++; + index++; + } + } +} + +CErrorInfo CIncreaseReferenceCommon(char **cObjIds, int size_cObjIds, char *cRemoteId, char ***cFailedIds, + int *size_cFailedIds, char isRaw) +{ + std::vector objectIds(size_cObjIds); + for (size_t i = 0; i < objectIds.size(); i++) { + objectIds[i] = cObjIds[i]; + } + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + if (cRemoteId == nullptr) { + if (isRaw) { + err = lrt->IncreaseReferenceRaw(objectIds); + return ErrorInfoToCError(err); + } else { + err = lrt->IncreaseReference(objectIds); + return ErrorInfoToCError(err); + } + } + + std::pair> increRet; + if (isRaw) { + increRet = lrt->IncreaseReferenceRaw(objectIds, cRemoteId); + } else { + increRet = lrt->IncreaseReference(objectIds, std::string(cRemoteId)); + } + auto [err1, failedIds] = increRet; + StringsToCStrings(failedIds, cFailedIds, size_cFailedIds); + return ErrorInfoToCError(err1); +} + +CErrorInfo CDecreaseReferenceCommon(char **cObjIds, int size_cObjIds, char *cRemoteId, char ***cFailedIds, + int *size_cFailedIds, char isRaw) +{ + std::vector objectIds(size_cObjIds); + for (size_t i = 0; i < objectIds.size(); i++) { + objectIds[i] = cObjIds[i]; + } + + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + std::pair> decreRet; + if (cRemoteId) { + if (isRaw) { + decreRet = lrt->DecreaseReferenceRaw(objectIds, cRemoteId); + } else { + decreRet = lrt->DecreaseReference(objectIds, cRemoteId); + } + } else { + if (isRaw) { + lrt->DecreaseReferenceRaw(objectIds); + } else { + lrt->DecreaseReference(objectIds); + } + } + auto [err1, failedIds] = decreRet; + StringsToCStrings(failedIds, cFailedIds, size_cFailedIds); + return ErrorInfoToCError(err1); +} + +CErrorInfo CReleaseGRefs(char *cRemoteId) +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->ReleaseGRefs(cRemoteId); + return ErrorInfoToCError(err); +} + +CErrorInfo CKVWrite(char *key, CBuffer data, CSetParam param) +{ + auto mData = std::make_shared(data.size_buffer); + mData->MemoryCopy(data.buffer, data.size_buffer); + SetParam setParam = CSetParamToSetParam(param); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->KVWrite(key, mData, setParam); + return ErrorInfoToCError(err); +} + +CErrorInfo CKVMSetTx(char **cKeys, int sizeCKeys, CBuffer *data, CMSetParam param) +{ + std::vector keys(sizeCKeys); + for (size_t i = 0; i < keys.size(); i++) { + keys[i] = cKeys[i]; + } + std::vector> vals(sizeCKeys); + for (size_t i = 0; i < keys.size(); i++) { + auto mData = std::make_shared(data[i].size_buffer); + mData->MemoryCopy(data[i].buffer, data[i].size_buffer); + vals[i] = mData; + } + MSetParam mSetParam = CMSetParamToMSetParam(param); + + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->KVMSetTx(keys, vals, mSetParam); + return ErrorInfoToCError(err); +} + +CErrorInfo CKVRead(char *key, int timeoutMs, CBuffer *cData) +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + auto [value, err1] = lrt->KVRead(key, timeoutMs); + if (err1.OK()) { + return ErrorInfoToCError(ToCBuffer(value, cData)); + } + return ErrorInfoToCError(err1); +} + +CErrorInfo CKVMultiRead(char **cKeys, int size_cKeys, int timeoutMs, char allowPartial, CBuffer *cData) +{ + std::vector keys(size_cKeys); + for (size_t i = 0; i < keys.size(); i++) { + keys[i] = cKeys[i]; + } + bool allowPart = allowPartial == 0 ? false : true; + auto [lrt, err0] = getLibRuntime(); + if (!err0.OK()) { + return ErrorInfoToCError(err0); + } + auto [values, err] = lrt->KVRead(keys, timeoutMs, allowPart); + for (int i = 0; i < size_cKeys; i++) { + if (values[i] == nullptr) { + continue; + } + auto errInfo = ToCBuffer(values[i], &cData[i]); + if (!errInfo.OK()) { + FreeBuffers(cData, i); + return ErrorInfoToCError(errInfo); + } + } + return ErrorInfoToCError(err); +} + +CErrorInfo CKVDel(char *key) +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->KVDel(key); + return ErrorInfoToCError(err); +} + +CErrorInfo CKMultiVDel(char **cKeys, int size_cKeys, char ***cFailedKeys, int *size_cFailedKeys) +{ + std::vector keys(size_cKeys); + for (size_t i = 0; i < keys.size(); i++) { + keys[i] = cKeys[i]; + } + auto [lrt, err0] = getLibRuntime(); + if (!err0.OK()) { + return ErrorInfoToCError(err0); + } + auto [failedKeys, err] = lrt->KVDel(keys); + StringsToCStrings(failedKeys, cFailedKeys, size_cFailedKeys); + return ErrorInfoToCError(err); +} + +CErrorInfo CCreateStreamProducer(char *cStreamName, CProducerConfig *config, Producer_p *producer) +{ + std::string streamName(cStreamName); + auto producerConfig = BuildProducerConfig(config); + std::shared_ptr outProducer; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->CreateStreamProducer(streamName, producerConfig, outProducer); + if (err.OK()) { + std::unique_ptr> pOutProducer = + std::make_unique>(std::move(outProducer)); + *producer = reinterpret_cast(pOutProducer.release()); + } + return ErrorInfoToCError(err); +} + +CErrorInfo CCreateStreamConsumer(char *cStreamName, CSubscriptionConfig *config, Consumer_p *consumer) +{ + std::string streamName(cStreamName); + auto subscriptionConfig = BuildSubscriptionConfig(config); + std::shared_ptr outConsumer; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->CreateStreamConsumer(streamName, subscriptionConfig, outConsumer); + if (err.OK()) { + *consumer = reinterpret_cast(new shared_ptr(std::move(outConsumer))); + } + return ErrorInfoToCError(err); +} + +CErrorInfo CDeleteStream(char *cStreamName) +{ + std::string streamName(cStreamName); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->DeleteStream(streamName); + return ErrorInfoToCError(err); +} + +CErrorInfo CQueryGlobalProducersNum(char *cStreamName, uint64_t *num) +{ + std::string streamName(cStreamName); + uint64_t gProducerNum = 0; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->QueryGlobalProducersNum(streamName, gProducerNum); + *num = gProducerNum; + return ErrorInfoToCError(err); +} + +CErrorInfo CQueryGlobalConsumersNum(char *cStreamName, uint64_t *num) +{ + std::string streamName(cStreamName); + uint64_t gConsumerNum = 0; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->QueryGlobalConsumersNum(streamName, gConsumerNum); + *num = gConsumerNum; + return ErrorInfoToCError(err); +} + +CErrorInfo CProducerSend(Producer_p producerPtr, uint8_t *ptr, uint64_t size, uint64_t id) +{ + auto element = BuildElement(ptr, size, id); + auto producer = *reinterpret_cast *>(producerPtr); + auto err = producer->Send(element); + return ErrorInfoToCError(err); +} + +CErrorInfo CProducerSendWithTimeout(Producer_p producerPtr, uint8_t *ptr, uint64_t size, uint64_t id, int64_t timeoutMs) +{ + auto element = BuildElement(ptr, size, id); + auto producer = *reinterpret_cast *>(producerPtr); + auto err = producer->Send(element, timeoutMs); + return ErrorInfoToCError(err); +} + +CErrorInfo CProducerFlush(Producer_p producerPtr) +{ + auto producer = *reinterpret_cast *>(producerPtr); + auto err = producer->Flush(); + return ErrorInfoToCError(err); +} + +CErrorInfo CProducerClose(Producer_p producerPtr) +{ + auto producer = reinterpret_cast *>(producerPtr); + RETURN_ERR_WHEN_CONSUMER_ISNULL(producer); + auto err = (*producer)->Close(); + delete producer; + return ErrorInfoToCError(err); +} + +CErrorInfo CConsumerReceive(Consumer_p consumerPtr, uint32_t timeoutMs, CElement **elements, uint64_t *count) +{ + std::vector eles; + auto consumer = reinterpret_cast *>(consumerPtr); + RETURN_ERR_WHEN_CONSUMER_ISNULL(consumer); + auto err = (*consumer)->Receive(timeoutMs, eles); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + *count = eles.size(); + std::unique_ptr pEles = std::make_unique(eles.size()); + for (size_t i = 0; i < eles.size(); i++) { + pEles[i].id = eles[i].id; + pEles[i].ptr = eles[i].ptr; + pEles[i].size = eles[i].size; + } + *elements = pEles.release(); + return ErrorInfoToCError(err); +} + +CErrorInfo CConsumerReceiveExpectNum(Consumer_p consumerPtr, uint32_t expectNum, uint32_t timeoutMs, + CElement **elements, uint64_t *count) +{ + std::vector eles; + auto consumer = reinterpret_cast *>(consumerPtr); + RETURN_ERR_WHEN_CONSUMER_ISNULL(consumer); + auto err = (*consumer)->Receive(expectNum, timeoutMs, eles); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + *count = eles.size(); + std::unique_ptr pEles = std::make_unique(eles.size()); + for (size_t i = 0; i < eles.size(); i++) { + pEles[i].id = eles[i].id; + pEles[i].ptr = eles[i].ptr; + pEles[i].size = eles[i].size; + } + *elements = pEles.release(); + return ErrorInfoToCError(err); +} + +CErrorInfo CConsumerAck(Consumer_p consumerPtr, uint64_t elementId) +{ + auto consumer = reinterpret_cast *>(consumerPtr); + RETURN_ERR_WHEN_CONSUMER_ISNULL(consumer); + auto err = (*consumer)->Ack(elementId); + return ErrorInfoToCError(err); +} + +CErrorInfo CConsumerClose(Consumer_p consumerPtr) +{ + auto consumer = reinterpret_cast *>(consumerPtr); + RETURN_ERR_WHEN_CONSUMER_ISNULL(consumer); + auto err = (*consumer)->Close(); + delete consumer; + return ErrorInfoToCError(err); +} + +CErrorInfo CAllocReturnObject(CDataObject *cObject, int dataSize, char **cNestedIds, int size_cNestedIds, + uint64_t *totalNativeBufferSize) +{ + auto dataObject = static_cast(cObject->selfSharedPtr); + dataObject->alwaysNative = true; + std::vector nestedIds(size_cNestedIds); + for (size_t i = 0; i < nestedIds.size(); i++) { + nestedIds[i] = cNestedIds[i]; + } + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->AllocReturnObject(dataObject, 0, (size_t)dataSize, nestedIds, *totalNativeBufferSize); + if (err.OK()) { + cObject->buffer.buffer = dataObject->data->MutableData(); + cObject->buffer.size_buffer = dataObject->data->GetSize(); + cObject->buffer.selfSharedPtrBuffer = static_cast(dataObject->data.get()); + } + return ErrorInfoToCError(err); +} + +void CSetReturnObject(CDataObject *cObject, int dataSize) +{ + auto dataObject = static_cast(cObject->selfSharedPtr); + std::shared_ptr dataBuf = std::make_shared(dataSize); + dataObject->SetDataBuf(dataBuf); + cObject->buffer.buffer = dataObject->data->MutableData(); + cObject->buffer.size_buffer = dataObject->data->GetSize(); + cObject->buffer.selfSharedPtrBuffer = static_cast(dataObject->data.get()); +} + +CErrorInfo CWriterLatch(CBuffer *cBuffer) +{ + Buffer *buffer = static_cast(cBuffer->selfSharedPtrBuffer); + auto err = buffer->WriterLatch(); + return ErrorInfoToCError(err); +} + +CErrorInfo CMemoryCopy(CBuffer *cBuffer, void *cSrc, uint64_t size_cSrc) +{ + Buffer *buffer = static_cast(cBuffer->selfSharedPtrBuffer); + auto err = buffer->MemoryCopy(cSrc, size_cSrc); + return ErrorInfoToCError(err); +} + +CErrorInfo CSeal(CBuffer *cBuffer) +{ + Buffer *buffer = static_cast(cBuffer->selfSharedPtrBuffer); + std::unordered_set nestedIds; // support Seal with nestedIDs + auto err = buffer->Seal(nestedIds); + return ErrorInfoToCError(err); +} + +CErrorInfo CWriterUnlatch(CBuffer *cBuffer) +{ + Buffer *buffer = static_cast(cBuffer->selfSharedPtrBuffer); + auto err = buffer->WriterUnlatch(); + return ErrorInfoToCError(err); +} + +void CParseCErrorObjectPointer(CErrorObject *object, int *code, char **errMessage, char **objectId, + CStackTracesInfo *stackTracesInfo) +{ + if (object == nullptr) { + return; + } + if (object->errorInfo == nullptr) { + return; + } + *code = object->errorInfo->code; + *errMessage = object->errorInfo->message; + stackTracesInfo->size_stackTraces = object->errorInfo->stackTracesInfo[0].size_stackTraces; + stackTracesInfo->stackTraces = object->errorInfo->stackTracesInfo[0].stackTraces; + stackTracesInfo->message = object->errorInfo->stackTracesInfo->message; + *objectId = object->objectId; +} + +void CFunctionExecution(void *ptr) +{ + auto f = static_cast> *>(ptr); + if (f && *f && **f) { + (**f)(); + } +} + +CErrorInfo CCreateStateStore(CConnectArguments *arguments, CStateStorePtr *stateStorePtr) +{ + YR::Libruntime::DsConnectOptions opts; + size_t hostLen = 0; + if (arguments->host != nullptr) { + hostLen = strlen(arguments->host); + } + CheckNullAndAssignValue(arguments->host, hostLen, opts.host); + CheckNullAndAssignValue(arguments->token, arguments->tokenLen, opts.token); + CheckNullAndAssignValue(arguments->clientPublicKey, arguments->clientPublicKeyLen, opts.clientPublicKey); + CheckNullAndAssignValue(arguments->clientPrivateKey, arguments->clientPrivateKeyLen, opts.clientPrivateKey); + CheckNullAndAssignValue(arguments->serverPublicKey, arguments->serverPublicKeyLen, opts.serverPublicKey); + CheckNullAndAssignValue(arguments->accessKey, arguments->accessKeyLen, opts.accessKey); + CheckNullAndAssignValue(arguments->secretKey, arguments->secretKeyLen, opts.secretKey); + CheckNullAndAssignValue(arguments->authClientID, arguments->authClientIDLen, opts.oAuthClientId); + CheckNullAndAssignValue(arguments->authClientSecret, arguments->authClientSecretLen, opts.oAuthClientSecret); + CheckNullAndAssignValue(arguments->authUrl, arguments->authUrlLen, opts.oAuthUrl); + CheckNullAndAssignValue(arguments->tenantID, arguments->tenantIDLen, opts.tenantId); + opts.port = arguments->port; + opts.connectTimeoutMs = arguments->timeoutMs; + opts.enableCrossNodeConnection = arguments->enableCrossNodeConnection != 0; + std::shared_ptr sharedStore = nullptr; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->CreateStateStore(opts, sharedStore); + if (err.OK()) { + std::unique_ptr> uniqueStorePtr( + new std::shared_ptr(std::move(sharedStore))); + *stateStorePtr = reinterpret_cast(uniqueStorePtr.release()); + } + return ErrorInfoToCError(err); +} + +CErrorInfo CSetTraceId(const char *cTraceId, int cTraceIdLen) +{ + std::string traceId = ""; + CheckNullAndAssignValue(cTraceId, cTraceIdLen, traceId); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->SetTraceId(traceId); + return ErrorInfoToCError(err); +} + +inline std::shared_ptr *GetStateStoreClient(CStateStorePtr stateStorePtr) +{ + if (stateStorePtr == nullptr) { + return nullptr; + } + auto stateStore = reinterpret_cast *>(stateStorePtr); + if (stateStore == nullptr || *stateStore == nullptr) { + return nullptr; + } + return stateStore; +} + +CErrorInfo CGenerateKey(CStateStorePtr stateStorePtr, char **cKey, int *cKeyLen) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore == nullptr) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "the state store is empty")); + } + std::string key = ""; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->GenerateKeyByStateStore(*stateStore, key); + *cKey = StringToCString(key); + *cKeyLen = key.size(); + return ErrorInfoToCError(err); +} + +void CDestroyStateStore(CStateStorePtr stateStorePtr) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore != nullptr) { + delete stateStore; + stateStore = nullptr; + } +} + +CErrorInfo CSetByStateStore(CStateStorePtr stateStorePtr, char *key, CBuffer data, CSetParam param) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore == nullptr) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "the state store is empty")); + } + SetParam setParam = CSetParamToSetParam(param); + auto nativeBuffer = std::make_shared(data.buffer, data.size_buffer); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->SetByStateStore(*stateStore, key, nativeBuffer, setParam); + return ErrorInfoToCError(err); +} + +CErrorInfo CSetValueByStateStore(CStateStorePtr stateStorePtr, CBuffer data, CSetParam param, char **cKey, int *cKeyLen) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore == nullptr) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "the state store is empty")); + } + SetParam setParam = CSetParamToSetParam(param); + std::string key = ""; + auto nativeBuffer = std::make_shared(data.buffer, data.size_buffer); + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->SetValueByStateStore(*stateStore, nativeBuffer, setParam, key); + *cKey = StringToCString(key); + *cKeyLen = key.size(); + return ErrorInfoToCError(err); +} + +CErrorInfo CGetByStateStore(CStateStorePtr stateStorePtr, char *key, CBuffer *data, int timeoutMs) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore == nullptr) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "the state store is empty")); + } + auto [lrt, err0] = getLibRuntime(); + if (!err0.OK()) { + return ErrorInfoToCError(err0); + } + auto [value, err] = lrt->GetByStateStore(*stateStore, key, timeoutMs); + if (err.OK()) { + return ErrorInfoToCError(ToCBuffer(value, data)); + } + return ErrorInfoToCError(err); +} + +CErrorInfo CGetArrayByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int cKeysLen, CBuffer *data, int timeoutMs) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore == nullptr) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "the state store is empty")); + } + auto keys = CStringsToStrings(cKeys, cKeysLen); + auto [lrt, err0] = getLibRuntime(); + if (!err0.OK()) { + return ErrorInfoToCError(err0); + } + auto [values, err] = lrt->GetArrayByStateStore(*stateStore, keys, timeoutMs); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + for (int i = 0; i < cKeysLen; i++) { + if (values[i] == nullptr) { + continue; + } + auto errInfo = ToCBuffer(values[i], &data[i]); + if (!errInfo.OK()) { + FreeBuffers(data, i); + return ErrorInfoToCError(errInfo); + } + } + return ErrorInfoToCError(err); +} + +CErrorInfo CQuerySizeByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int sizeCKeys, uint64_t *cSize) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore == nullptr) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "the state store is empty")); + } + std::vector keys(sizeCKeys); + for (size_t i = 0; i < keys.size(); i++) { + keys[i] = cKeys[i]; + } + std::vector results{}; + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->QuerySizeByStateStore(*stateStore, keys, results); + for (size_t i = 0; i < results.size(); ++i) { + cSize[i] = results[i]; + } + return ErrorInfoToCError(err); +} + +CErrorInfo CDelByStateStore(CStateStorePtr stateStorePtr, char *key) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore == nullptr) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "the state store is empty")); + } + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return ErrorInfoToCError(err); + } + err = lrt->DelByStateStore(*stateStore, key); + return ErrorInfoToCError(err); +} + +CErrorInfo CDelArrayByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int cKeysLen, char ***cFailedKeys, + int *cFailedKeysLen) +{ + auto stateStore = GetStateStoreClient(stateStorePtr); + if (stateStore == nullptr) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "the state store is empty")); + } + auto keys = CStringsToStrings(cKeys, cKeysLen); + auto [lrt, err0] = getLibRuntime(); + if (!err0.OK()) { + return ErrorInfoToCError(err0); + } + auto [failedKeys, err] = lrt->DelArrayByStateStore(*stateStore, keys); + StringsToCStrings(failedKeys, cFailedKeys, cFailedKeysLen); + return ErrorInfoToCError(err); +} + +CCredential CGetCredential() +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + auto credential = YR::Libruntime::Credential{}; + return CredentialToCCre(credential); + } + auto credential = lrt->GetCredential(); + return CredentialToCCre(credential); +} + +int CIsHealth() +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return 0; + } + if (lrt->IsHealth()) { + return 1; + } + return 0; +} + +int CIsDsHealth() +{ + auto [lrt, err] = getLibRuntime(); + if (!err.OK()) { + return 0; + } + if (lrt->IsDsHealth()) { + return 1; + } + return 0; +} + +#ifdef __cplusplus +} +#endif diff --git a/api/go/libruntime/cpplibruntime/mock/mock_cpplibruntime.cpp b/api/go/libruntime/cpplibruntime/mock/mock_cpplibruntime.cpp new file mode 100644 index 0000000..a478ae0 --- /dev/null +++ b/api/go/libruntime/cpplibruntime/mock/mock_cpplibruntime.cpp @@ -0,0 +1,446 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../clibruntime.h" +#include "src/libruntime/libruntime_manager.h" +#include "src/libruntime/utils/utils.h" +using namespace std; +using namespace std::placeholders; +using YR::Libruntime::ErrorCode; +using YR::Libruntime::ErrorInfo; +using YR::Libruntime::StreamConsumer; +using YR::Libruntime::StreamProducer; +#ifdef __cplusplus +extern "C" { +#endif + +char *CString(const std::string &str) +{ + char *cStr = (char *)malloc(str.size() + 1); + memcpy_s(cStr, str.size(), str.data(), str.size()); + cStr[str.size()] = 0; + return cStr; +} + +CErrorInfo *GoLoadFunctions(char **codePaths, int size_codePaths) +{ + return nullptr; +} +CErrorInfo *GoFunctionExecution(CFunctionMeta *, CInvokeType, CArg *, int, CDataObject *, int) +{ + return nullptr; +} +CErrorInfo *GoCheckpoint(char *checkpointId, CBuffer *buffer) +{ + return nullptr; +} +CErrorInfo *GoRecover(CBuffer *buffer) +{ + return nullptr; +} +CErrorInfo *GoShutdown(uint64_t gracePeriodSeconds) +{ + return nullptr; +} +CErrorInfo *GoSignal(int sigNo, CBuffer *payload) +{ + return nullptr; +} +CHealthCheckCode GoHealthCheck(void) +{ + return CHealthCheckCode::HEALTHY; +} +char GoHasHealthCheck(void) +{ + return 0; +} + +void GoFunctionExecutionPoolSubmit(void *ptr) {} + +void GoRawCallback(char *cKey, CErrorInfo cErr, CBuffer cResultRaw) {} + +void GoGetAsyncCallback(char *cObjectID, CBuffer cBuf, CErrorInfo *cErr, void *userData) {} + +void GoWaitAsyncCallback(char *cObjectID, CErrorInfo *cErr, void *userData) {} + +CErrorInfo ErrorInfoToCError(const ErrorInfo &err) +{ + CErrorInfo cErr{}; + cErr.message = nullptr; + if (!err.Msg().empty()) { + cErr.message = CString(err.Msg()); + } + cErr.code = static_cast(err.Code()); + cErr.dsStatusCode = err.GetDsStatusCode(); + return cErr; +} + +CErrorInfo CInit(CLibruntimeConfig *config) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +void CReceiveRequestLoop(void) +{ + return; +} + +void CExecShutdownHandler(int sigNum) +{ + return; +} + +CErrorInfo CCreateInstance(CFunctionMeta *cFuncMeta, CInvokeArg *cInvokeArgs, int size_invokeArgs, + CInvokeOptions *cInvokeOpts, char **instanceId) +{ + return ErrorInfoToCError(ErrorInfo()); +} +CErrorInfo CInvokeByInstanceId(CFunctionMeta *cFuncMeta, char *cInstanceId, CInvokeArg *cInvokeArgs, + int size_invokeArgs, CInvokeOptions *cInvokeOpts, char **cReturnObjectId) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CAcquireInstance(char *stateId, CFunctionMeta *cFuncMeta, CInvokeOptions *cInvokeOpts, + CInstanceAllocation *cInsAlloc) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CReleaseInstance(CInstanceAllocation *insAlloc, char *cStateID, char cAbnormal, CInvokeOptions *cInvokeOpts) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CInvokeByFunctionName(CFunctionMeta *cFuncMeta, CInvokeArg *cInvokeArgs, int size_invokeArgs, + CInvokeOptions *cInvokeOpts, char **cReturnObjectId) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +void CCreateInstanceRaw(CBuffer cReqRaw, char *cContext) +{ + return; +} + +void CInvokeByInstanceIdRaw(CBuffer cReqRaw, char *cContext) +{ + return; +} + +void CKillRaw(CBuffer cReqRaw, char *cContext) +{ + return; +} + +CErrorInfo CGet(char *objId, int timeoutSec, CBuffer *data) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +void CUpdateSchdulerInfo(char *schedulerName, char *schedulerId, char *option) +{ + return; +} + +void CGetAsync(char *objectId, void *userData) +{ + return; +} + +void CWaitAsync(char *objectId, void *userData) +{ + return; +} + +CErrorInfo CKill(char *instanceId, int sigNo, CBuffer cData) +{ + if (sigNo == 128) { + return ErrorInfoToCError(ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, + YR::Libruntime::ModuleCode::RUNTIME, "failed to kill")); + } + return ErrorInfoToCError(ErrorInfo()); +} + +char *CGetRealInstanceId(char *objectId, int timeout) +{ + std::string str = "InstanceID"; + char *cStr = (char *)malloc(str.size() + 1); + memcpy_s(cStr, str.size(), str.data(), str.size()); + cStr[str.size()] = 0; + return cStr; +} + +void CExit(int code, char *message) +{ + return; +} + +void CFinalize(void) +{ + return; +} + +CErrorInfo CPutCommon(char *objectId, CBuffer data, char **nestedIds, int size_nestedIds, char isPutRaw, + CCreateParam param) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CGetMultiCommon(char **cObjIds, int size_cObjIds, int timeoutMs, char allowPartial, CBuffer *cData, + char isRaw) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +void CWait(char **objIds, int size_objIds, int waitNum, int timeoutSec, CWaitResult *result) +{ + return; +} + +CErrorInfo CIncreaseReferenceCommon(char **cObjIds, int size_cObjIds, char *cRemoteId, char ***cFailedIds, + int *size_cFailedIds, char isRaw) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CDecreaseReferenceCommon(char **cObjIds, int size_cObjIds, char *cRemoteId, char ***cFailedIds, + int *size_cFailedIds, char isRaw) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CReleaseGRefs(char *cRemoteId) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CKVWrite(char *key, CBuffer data, CSetParam param) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CKVMSetTx(char **key, int sizeKeys, CBuffer *data, CMSetParam param) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CKVRead(char *key, int timeoutMs, CBuffer *cData) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CKVMultiRead(char **cKeys, int size_cKeys, int timeoutMs, char allowPartial, CBuffer *cData) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CKVDel(char *key) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CKMultiVDel(char **cKeys, int size_cKeys, char ***cFailedKeys, int *size_cFailedKeys) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CCreateStreamProducer(char *cStreamName, CProducerConfig *config, Producer_p *producer) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CCreateStreamConsumer(char *cStreamName, CSubscriptionConfig *config, Consumer_p *consumer) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CDeleteStream(char *cStreamName) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CQueryGlobalProducersNum(char *cStreamName, uint64_t *num) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CQueryGlobalConsumersNum(char *cStreamName, uint64_t *num) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CProducerSend(Producer_p producerPtr, uint8_t *ptr, uint64_t size, uint64_t id) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CProducerSendWithTimeout(Producer_p producerPtr, uint8_t *ptr, uint64_t size, uint64_t id, int64_t timeoutMs) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CProducerFlush(Producer_p producerPtr) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CProducerClose(Producer_p producerPtr) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CConsumerReceive(Consumer_p consumerPtr, uint32_t timeoutMs, CElement **elements, uint64_t *count) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CConsumerReceiveExpectNum(Consumer_p consumerPtr, uint32_t expectNum, uint32_t timeoutMs, + CElement **elements, uint64_t *count) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CConsumerAck(Consumer_p consumerPtr, uint64_t elementId) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CConsumerClose(Consumer_p consumerPtr) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CAllocReturnObject(CDataObject *cObject, int dataSize, char **cNestedIds, int size_cNestedIds, + uint64_t *totalNativeBufferSize) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +void CSetReturnObject(CDataObject *cObject, int dataSize) +{ + return; +} + +CErrorInfo CWriterLatch(CBuffer *cBuffer) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CMemoryCopy(CBuffer *cBuffer, void *cSrc, uint64_t size_cSrc) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CSeal(CBuffer *cBuffer) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CWriterUnlatch(CBuffer *cBuffer) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +void CParseCErrorObjectPointer(CErrorObject *object, int *code, char **errMessage, char **objectId, + CStackTracesInfo *stackTracesInfo) +{ + return; +} + +void CFunctionExecution(void *ptr) {} + +CErrorInfo CCreateStateStore(CConnectArguments *arguments, CStateStorePtr *stateStorePtr) +{ + *stateStorePtr = static_cast(malloc(sizeof(CStateStorePtr))); + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CSetTraceId(const char *cTraceId, int cTraceIdLen) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CSetTenantId(const char *cTenantId, int cTenantIdLen) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CGenerateKey(CStateStorePtr stateStorePtr, char **cKey, int *cKeyLen) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +void CDestroyStateStore(CStateStorePtr stateStorePtr) +{ + free(stateStorePtr); + return; +} + +CErrorInfo CSetByStateStore(CStateStorePtr stateStorePtr, char *key, CBuffer data, CSetParam param) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CSetValueByStateStore(CStateStorePtr stateStorePtr, CBuffer data, CSetParam param, char **cKey, int *cKeyLen) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CGetByStateStore(CStateStorePtr stateStorePtr, char *key, CBuffer *data, int timeoutMs) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CGetArrayByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int cKeysLen, CBuffer *data, int timeoutMs) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CQuerySizeByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int sizeCKeys, uint64_t *cSize) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CDelByStateStore(CStateStorePtr stateStorePtr, char *key) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CErrorInfo CDelArrayByStateStore(CStateStorePtr stateStorePtr, char **cKeys, int cKeysLen, char ***cFailedKeys, + int *cFailedKeysLen) +{ + return ErrorInfoToCError(ErrorInfo()); +} + +CCredential CGetCredential() +{ + return CCredential(); +} + +int CIsHealth() +{ + return 1; +} + +int CIsDsHealth() +{ + return 1; +} + +#ifdef __cplusplus +} +#endif diff --git a/api/go/libruntime/execution/execution.go b/api/go/libruntime/execution/execution.go new file mode 100644 index 0000000..ded45a5 --- /dev/null +++ b/api/go/libruntime/execution/execution.go @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package execution +This package provides methods to obtain the execution interface. +*/ +package execution + +import ( + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/config" +) + +// FunctionExecutionIntfs This interface defines several callback functions, for example, FunctionExecute. +type FunctionExecutionIntfs interface { + GetExecType() ExecutionType + LoadFunction(codePaths []string) error + FunctionExecute( + funcMeta api.FunctionMeta, invokeType config.InvokeType, args []api.Arg, returnobjs []config.DataObject, + ) error + Checkpoint(checkpointID string) ([]byte, error) + Recover(state []byte) error + Shutdown(gracePeriod uint64) error + Signal(sig int, data []byte) error + HealthCheck() (api.HealthType, error) +} + +// ExecutionType execution type, for example, posix +type ExecutionType int32 + +const ( + // ExecutionTypeInvalid invalid type + ExecutionTypeInvalid ExecutionType = 0 + // ExecutionTypeActor actor type + ExecutionTypeActor ExecutionType = 1 + // ExecutionTypeFaaS faas type + ExecutionTypeFaaS ExecutionType = 2 + // ExecutionTypePosix posix type + ExecutionTypePosix ExecutionType = 3 +) diff --git a/api/go/libruntime/libruntime.go b/api/go/libruntime/libruntime.go new file mode 100644 index 0000000..33605e1 --- /dev/null +++ b/api/go/libruntime/libruntime.go @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package libruntime +This package provides a set of functions to interface with the clibruntime. +*/ +package libruntime + +import ( + "yuanrong.org/kernel/runtime/libruntime/clibruntime" + "yuanrong.org/kernel/runtime/libruntime/config" +) + +// Init Initialization entry, which is used to initialize the data system and function system. +func Init(conf config.Config) error { + return clibruntime.Init(conf) +} + +// ReceiveRequestLoop begins loop processing the received request. +func ReceiveRequestLoop() { + clibruntime.ReceiveRequestLoop() +} + +// ExecShutdownHandler exec shutdown handler. +func ExecShutdownHandler(signum int) { + clibruntime.ExecShutdownHandler(signum) +} + +// AllocReturnObject Creates an object and applies for a memory block. +// Computing operations can be performed on the memory block. +// will return a 'Buffer' that will be used to manipulate the memory +func AllocReturnObject(do *config.DataObject, size uint, nestedIds []string, totalNativeBufferSize *uint) error { + return clibruntime.AllocReturnObject(do, size, nestedIds, totalNativeBufferSize) +} + +// SetReturnObject if return by message, set return object +func SetReturnObject(do *config.DataObject, size uint) { + clibruntime.SetReturnObject(do, size) +} + +// WriterLatch Obtains the write lock of the buffer object. +func WriterLatch(do *config.DataObject) error { + return clibruntime.WriterLatch(do) +} + +// MemoryCopy Writes data to a buffer object. +func MemoryCopy(do *config.DataObject, src []byte) error { + return clibruntime.MemoryCopy(do, src) +} + +// Seal Publish the object and seal it. Sealed objects cannot be modified again. +func Seal(do *config.DataObject) error { + return clibruntime.Seal(do) +} + +// WriterUnlatch release the write lock of the buffer object. +func WriterUnlatch(do *config.DataObject) error { + return clibruntime.WriterUnlatch(do) +} + +// IsHealth - +func IsHealth() bool { + return clibruntime.IsHealth() +} diff --git a/api/go/libruntime/libruntime_test.go b/api/go/libruntime/libruntime_test.go new file mode 100644 index 0000000..5adc29c --- /dev/null +++ b/api/go/libruntime/libruntime_test.go @@ -0,0 +1,148 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package libruntime +This package provides a set of functions to interface with the clibruntime. +*/ +package libruntime + +import ( + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/config" +) + +func TestInitAndReceiveRequestLoop(t *testing.T) { + convey.Convey( + "Test Init and ReceiveRequestLoop", t, func() { + conf := config.Config{ + GrpcAddress: "127.0.0.1", + DataSystemAddress: os.Getenv("DATASYSTEM_ADDR"), + JobID: "jobId", + RuntimeID: "runtimeId", + LogDir: "logDir", + LogLevel: "DEBUG", + InCluster: true, + Api: api.PosixApi, + Hooks: config.HookIntfs{ + LoadFunctionCb: nil, + FunctionExecutionCb: nil, + CheckpointCb: nil, + RecoverCb: nil, + ShutdownCb: nil, + SignalCb: nil, + HealthCheckCb: nil, + }, + } + convey.Convey( + "Test Init Success", func() { + err := Init(conf) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Test ReceiveRequestLoop should not be panic", func() { + convey.So(ReceiveRequestLoop, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestAllocReturnObject(t *testing.T) { + convey.Convey( + "Test AllocReturnObject", t, func() { + convey.Convey( + "AllocReturnObject Success", func() { + var n uint = 8 + err := AllocReturnObject(&config.DataObject{}, 8, []string{""}, &n) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestSetReturnObject(t *testing.T) { + convey.Convey( + "Test SetReturnObject ", t, func() { + convey.Convey( + "SetReturnObject Success", func() { + convey.So(func() { + SetReturnObject(&config.DataObject{}, 8) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestWriterLatch(t *testing.T) { + convey.Convey( + "Test WriterLatch", t, func() { + convey.Convey( + "WriterLatch Success", func() { + err := WriterLatch(&config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestMemoryCopy(t *testing.T) { + convey.Convey( + "Test MemoryCopy", t, func() { + convey.Convey( + "MemoryCopy Success", func() { + err := MemoryCopy(&config.DataObject{}, []byte{0}) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestSeal(t *testing.T) { + convey.Convey( + "Test Seal", t, func() { + convey.Convey( + "Seal Success", func() { + err := Seal(&config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestWriterUnlatch(t *testing.T) { + convey.Convey( + "Test WriterUnlatch", t, func() { + convey.Convey( + "WriterUnlatch Success", func() { + err := WriterUnlatch(&config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/libruntimesdkimpl/libruntimesdkimpl.go b/api/go/libruntime/libruntimesdkimpl/libruntimesdkimpl.go new file mode 100644 index 0000000..41c9569 --- /dev/null +++ b/api/go/libruntime/libruntimesdkimpl/libruntimesdkimpl.go @@ -0,0 +1,252 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// this is Package libruntimesdkimpl implements +package libruntimesdkimpl + +import ( + "fmt" + "os" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/clibruntime" + "yuanrong.org/kernel/runtime/libruntime/common/logger/log" + "yuanrong.org/kernel/runtime/libruntime/common/utils" +) + +type libruntimeSDKImpl struct{} + +// NewLibruntimeSDKImpl creates and returns a new libruntimeSDK instance +func NewLibruntimeSDKImpl() api.LibruntimeAPI { + return &libruntimeSDKImpl{} +} + +func (l *libruntimeSDKImpl) CreateInstance( + funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + return clibruntime.CreateInstance(funcMeta, args, invokeOpt) +} + +func (l *libruntimeSDKImpl) InvokeByInstanceId( + funcMeta api.FunctionMeta, instanceID string, args []api.Arg, invokeOpt api.InvokeOptions, +) (string, error) { + return clibruntime.InvokeByInstanceId(funcMeta, instanceID, args, invokeOpt) +} + +func (l *libruntimeSDKImpl) InvokeByFunctionName( + funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions, +) (string, error) { + return clibruntime.InvokeByFunctionName(funcMeta, args, invokeOpt) +} + +func (l libruntimeSDKImpl) AcquireInstance(state string, funcMeta api.FunctionMeta, + acquireOpt api.InvokeOptions) (api.InstanceAllocation, error) { + return clibruntime.AcquireInstance(state, funcMeta, acquireOpt) +} + +func (l libruntimeSDKImpl) ReleaseInstance(allocation api.InstanceAllocation, stateID string, + abnormal bool, option api.InvokeOptions) { + clibruntime.ReleaseInstance(allocation, stateID, abnormal, option) +} + +func (l *libruntimeSDKImpl) Kill(instanceID string, signal int, payload []byte) error { + return clibruntime.Kill(instanceID, signal, payload) +} + +func (l *libruntimeSDKImpl) CreateInstanceRaw(createReqRaw []byte) ([]byte, error) { + return clibruntime.CreateInstanceRaw(createReqRaw) +} + +func (l *libruntimeSDKImpl) InvokeByInstanceIdRaw(invokeReqRaw []byte) ([]byte, error) { + return clibruntime.InvokeByInstanceIdRaw(invokeReqRaw) +} + +func (l *libruntimeSDKImpl) KillRaw(killReqRaw []byte) ([]byte, error) { + return clibruntime.KillRaw(killReqRaw) +} + +func (l *libruntimeSDKImpl) SaveState(state []byte) (string, error) { + return "", fmt.Errorf("not support") +} + +func (l *libruntimeSDKImpl) LoadState(checkpointID string) ([]byte, error) { + return []byte{}, fmt.Errorf("not support") +} + +func (l *libruntimeSDKImpl) Exit(code int, message string) { + clibruntime.Exit(code, message) +} + +func (l *libruntimeSDKImpl) Finalize() { + clibruntime.Finalize() +} + +func (l *libruntimeSDKImpl) KVSet(key string, value []byte, param api.SetParam) error { + return clibruntime.KVSet(key, value, param) +} + +func (l *libruntimeSDKImpl) KVSetWithoutKey(value []byte, param api.SetParam) (string, error) { + return "", fmt.Errorf("not support") +} + +func (l *libruntimeSDKImpl) KVMSetTx(keys []string, values [][]byte, param api.MSetParam) error { + return clibruntime.KVMSetTx(keys, values, param) +} + +func (l *libruntimeSDKImpl) KVGet(key string, timeoutms uint) ([]byte, error) { + return clibruntime.KVGet(key, timeoutms) +} + +func (l *libruntimeSDKImpl) KVGetMulti(keys []string, timeoutms uint) ([][]byte, error) { + return clibruntime.KVGetMulti(keys, timeoutms) +} + +func (l *libruntimeSDKImpl) KVDel(key string) error { + return clibruntime.KVDel(key) +} + +func (l *libruntimeSDKImpl) KVDelMulti(keys []string) ([]string, error) { + return clibruntime.KVDelMulti(keys) +} + +func (l *libruntimeSDKImpl) CreateProducer( + streamName string, producerConf api.ProducerConf) (api.StreamProducer, error) { + return clibruntime.CreateStreamProducer(streamName, producerConf) +} + +func (l *libruntimeSDKImpl) Subscribe( + streamName string, config api.SubscriptionConfig) (api.StreamConsumer, error) { + return clibruntime.CreateStreamConsumer(streamName, config) +} + +func (l *libruntimeSDKImpl) DeleteStream(streamName string) error { + return clibruntime.DeleteStream(streamName) +} + +func (l *libruntimeSDKImpl) QueryGlobalProducersNum(streamName string) (uint64, error) { + return clibruntime.QueryGlobalProducersNum(streamName) +} + +func (l *libruntimeSDKImpl) QueryGlobalConsumersNum(streamName string) (uint64, error) { + return clibruntime.QueryGlobalConsumersNum(streamName) +} + +func (l *libruntimeSDKImpl) SetTraceID(traceID string) { + return +} + +func (l *libruntimeSDKImpl) SetTenantID(tenantID string) error { + return clibruntime.SetTenantID(tenantID) +} + +func (l *libruntimeSDKImpl) Put( + objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + return clibruntime.Put(objectID, value, param, nestedObjectIDs...) +} + +func (l *libruntimeSDKImpl) PutRaw( + objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string, +) error { + return clibruntime.PutRaw(objectID, value, param, nestedObjectIDs...) +} + +func (l *libruntimeSDKImpl) Get(objectIDs []string, timeoutMs int) ([][]byte, error) { + return clibruntime.Get(objectIDs, timeoutMs) +} + +func (l *libruntimeSDKImpl) GetRaw(objectIDs []string, timeoutMs int) ([][]byte, error) { + return clibruntime.GetRaw(objectIDs, timeoutMs) +} + +func (l *libruntimeSDKImpl) Wait( + objectIDs []string, waitNum uint64, timeoutMs int, +) ([]string, []string, map[string]error) { + return clibruntime.Wait(objectIDs, waitNum, timeoutMs) +} + +func (l *libruntimeSDKImpl) GIncreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return clibruntime.GIncreaseRef(objectIDs, remoteClientID...) +} + +func (l *libruntimeSDKImpl) GIncreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + return clibruntime.GIncreaseRefRaw(objectIDs, remoteClientID...) +} + +func (l *libruntimeSDKImpl) GDecreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return clibruntime.GDecreaseRef(objectIDs, remoteClientID...) +} + +func (l *libruntimeSDKImpl) GDecreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + return clibruntime.GDecreaseRefRaw(objectIDs, remoteClientID...) +} + +// ReleaseGRefs release object refs by remote client id +func (l *libruntimeSDKImpl) ReleaseGRefs(remoteClientID string) error { + err := clibruntime.ReleaseGRefs(remoteClientID) + return err +} + +func (l *libruntimeSDKImpl) GetAsync(objectID string, cb api.GetAsyncCallback) { + clibruntime.GetAsync(objectID, cb) +} + +// UpdateSchdulerInfo - +func (l *libruntimeSDKImpl) UpdateSchdulerInfo(schedulerName string, schedulerId string, option string) { + clibruntime.UpdateSchdulerInfo(schedulerName, schedulerId, option) +} + +func getLogName() string { + logName := "runtime-go" + funcLibPath := os.Getenv("FUNCTION_LIB_PATH") + funcName := utils.GetFuncNameFromFuncLibPath(funcLibPath) + if len(funcName) == 0 { + return logName + } + logName = funcName + if logName == "go1.x" { + logName = "faas-executor" + } + return logName +} + +func (l *libruntimeSDKImpl) GetFormatLogger() api.FormatLogger { + logName := getLogName() + log.GetLogger().Infof("logName is: %s", logName) + logImpl, err := log.InitRunLog(logName, true) + if err != nil { + panic("InitRunLog failed") + } + return logImpl +} + +// CreateClient - +func (l *libruntimeSDKImpl) CreateClient(config api.ConnectArguments) (api.KvClient, error) { + return clibruntime.CreateClient(config) +} + +// GetCredential +func (l *libruntimeSDKImpl) GetCredential() api.Credential { + return clibruntime.GetCredential() +} + +// IsHealth - +func (l *libruntimeSDKImpl) IsHealth() bool { + return clibruntime.IsHealth() +} + +// IsDsHealth - +func (l *libruntimeSDKImpl) IsDsHealth() bool { + return clibruntime.IsDsHealth() +} diff --git a/api/go/libruntime/libruntimesdkimpl/libruntimesdkimpl_test.go b/api/go/libruntime/libruntimesdkimpl/libruntimesdkimpl_test.go new file mode 100644 index 0000000..8ea8604 --- /dev/null +++ b/api/go/libruntime/libruntimesdkimpl/libruntimesdkimpl_test.go @@ -0,0 +1,337 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package libruntimesdkimpl implements +package libruntimesdkimpl + +import ( + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +func TestCreateClient(t *testing.T) { + convey.Convey( + "Test CreateClient", t, func() { + libruntimeAPI := NewLibruntimeSDKImpl() + convey.Convey( + "create client should success", func() { + var conf api.ConnectArguments + conf.Host = "127.0.0.1" + conf.Port = 11111 + conf.TimeoutMs = 500 + conf.Token = []byte{'1', '2', '3'} + conf.ClientPublicKey = "client pub key" + conf.ClientPrivateKey = []byte{'1', '2', '3'} + conf.ServerPublicKey = "server pub key" + conf.AccessKey = "access key" + conf.SecretKey = []byte{'1', '2', '3'} + conf.AuthclientID = "auth client id" + conf.AuthclientSecret = []byte{'1', '2', '3'} + conf.AuthURL = "auth url" + conf.TenantID = "tenant id" + conf.EnableCrossNodeConnection = true + newClient, err := libruntimeAPI.CreateClient(conf) + defer newClient.DestroyClient() + convey.So(err, convey.ShouldBeNil) + convey.So(newClient, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestLibruntimeSDKImpl(t *testing.T) { + convey.Convey( + "Test libruntimeSDKImpl", t, func() { + libruntimeAPI := NewLibruntimeSDKImpl() + convey.Convey( + "CreateInstance success", func() { + str, err := libruntimeAPI.CreateInstance(api.FunctionMeta{}, []api.Arg{}, api.InvokeOptions{}) + convey.So(str, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "InvokeByInstanceId success", func() { + str, err := libruntimeAPI.InvokeByInstanceId(api.FunctionMeta{}, "instanceID", + []api.Arg{}, api.InvokeOptions{}) + convey.So(str, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "InvokeByFunctionName success", func() { + str, err := libruntimeAPI.InvokeByFunctionName(api.FunctionMeta{}, []api.Arg{}, api.InvokeOptions{}) + convey.So(str, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "AcquireInstance success", func() { + allocation, err := libruntimeAPI.AcquireInstance("state", api.FunctionMeta{}, api.InvokeOptions{}) + convey.So(allocation, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "ReleaseInstance success", func() { + convey.So(func() { + libruntimeAPI.ReleaseInstance(api.InstanceAllocation{}, "state", false, api.InvokeOptions{}) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Kill success", func() { + err := libruntimeAPI.Kill("instanceID", 0, []byte{}) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "CreateInstanceRaw success", func() { + convey.So(func() { + go libruntimeAPI.CreateInstanceRaw([]byte{}) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "InvokeByInstanceIdRaw success", func() { + convey.So(func() { + go libruntimeAPI.InvokeByInstanceIdRaw([]byte{}) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "KillRaw success", func() { + convey.So(func() { + go libruntimeAPI.KillRaw([]byte{}) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "SaveState success", func() { + str, err := libruntimeAPI.SaveState([]byte{}) + convey.So(str, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "LoadState success", func() { + str, err := libruntimeAPI.LoadState("checkpointID") + convey.So(str, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Finalize success", func() { + convey.So(libruntimeAPI.Finalize, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "KVSet success", func() { + err := libruntimeAPI.KVSet("key", []byte("value"), api.SetParam{}) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "KVSetWithoutKey success", func() { + str, err := libruntimeAPI.KVSetWithoutKey([]byte("value"), api.SetParam{}) + convey.So(str, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "KVGet success", func() { + bytes, err := libruntimeAPI.KVGet("key", 300) + convey.So(bytes, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "KVGetMulti success", func() { + bytesArr, err := libruntimeAPI.KVGetMulti([]string{"key"}, 300) + convey.So(bytesArr, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "KVDel success", func() { + err := libruntimeAPI.KVDel("key") + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "KVDelMulti success", func() { + strs, err := libruntimeAPI.KVDelMulti([]string{"key"}) + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "create producer", func() { + producer, err := libruntimeAPI.CreateProducer("stream_001", api.ProducerConf{}) + convey.ShouldNotBeNil(producer) + convey.ShouldBeNil(err) + }, + ) + convey.Convey( + "create consumer", func() { + consumer, err := libruntimeAPI.Subscribe("stream_001", api.SubscriptionConfig{}) + convey.ShouldNotBeNil(consumer) + convey.ShouldBeNil(err) + }, + ) + convey.Convey( + "DeleteStream success", func() { + err := libruntimeAPI.DeleteStream("streamName") + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "QueryGlobalProducersNum success", func() { + n, err := libruntimeAPI.QueryGlobalProducersNum("streamName") + convey.So(n, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "QueryGlobalConsumersNum success", func() { + n, err := libruntimeAPI.QueryGlobalConsumersNum("streamName") + convey.So(n, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "SetTraceID success", func() { + convey.So(func() { + libruntimeAPI.SetTraceID("traceID") + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "SetTenantID success", func() { + err := libruntimeAPI.SetTenantID("tenantID") + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Put success", func() { + err := libruntimeAPI.Put("objectID", []byte("data"), api.PutParam{}, "nestedObjectID") + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "PutRaw success", func() { + err := libruntimeAPI.PutRaw("objectID", []byte("data"), api.PutParam{}, "nestedObjectID") + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Get success", func() { + bytesArr, err := libruntimeAPI.Get([]string{"objectID"}, 300) + convey.So(bytesArr, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GetRaw success", func() { + bytesArr, err := libruntimeAPI.GetRaw([]string{"objectID"}, 300) + convey.So(bytesArr, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Wait success", func() { + strs1, strs2, m := libruntimeAPI.Wait([]string{"objectID"}, 1, 300) + convey.So(strs1, convey.ShouldBeEmpty) + convey.So(strs2, convey.ShouldBeEmpty) + convey.So(m, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GIncreaseRef success", func() { + strs, err := libruntimeAPI.GIncreaseRef([]string{"objectID"}, "remoteClientID") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GIncreaseRefRaw success", func() { + strs, err := libruntimeAPI.GIncreaseRefRaw([]string{"objectID"}, "remoteClientID") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GDecreaseRef success", func() { + strs, err := libruntimeAPI.GDecreaseRef([]string{"objectID"}, "remoteClientID") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GDecreaseRefRaw success", func() { + strs, err := libruntimeAPI.GDecreaseRefRaw([]string{"objectID"}, "remoteClientID") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "ReleaseGRefs success", func() { + err := libruntimeAPI.ReleaseGRefs("objectID") + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "GetAsync success", func() { + convey.So(func() { + f := func(result []byte, err error) {} + libruntimeAPI.GetAsync("objectID", f) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "GetFormatLogger success", func() { + fl := libruntimeAPI.GetFormatLogger() + convey.So(fl, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestGetLogName(t *testing.T) { + convey.Convey( + "Test getLogName", t, func() { + convey.Convey( + "CreateInstance success when path is empty", func() { + str := getLogName() + convey.So(str, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "CreateInstance success", func() { + os.Setenv("FUNCTION_LIB_PATH", "./go1.x") + str := getLogName() + convey.So(str, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/libruntime/pool/pool.go b/api/go/libruntime/pool/pool.go new file mode 100644 index 0000000..8bdfd35 --- /dev/null +++ b/api/go/libruntime/pool/pool.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package pool for goroutine pool +package pool + +import "github.com/panjf2000/ants/v2" + +// DefaultFuncExecPoolSize - +const DefaultFuncExecPoolSize = 300000 + +// Pool implements a goroutine pool for execution of function calls +type Pool struct { + p *ants.Pool +} + +// NewPool implements a constructor for `Pool` +func NewPool(size int) *Pool { + tmp := Pool{} + p, _ := ants.NewPool(size) + tmp.p = p + return &tmp +} + +// Submit implements a submitting to pool for function call task +func (p *Pool) Submit(task func()) error { + return p.p.Submit(task) +} diff --git a/api/go/libruntime/pool/pool_test.go b/api/go/libruntime/pool/pool_test.go new file mode 100644 index 0000000..7357c95 --- /dev/null +++ b/api/go/libruntime/pool/pool_test.go @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package pool for goroutine pool +package pool + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestPoolSubmint(t *testing.T) { + convey.Convey("test new pool and submit a task", t, func() { + p := NewPool(1) + convey.So(p, convey.ShouldNotBeNil) + + c := make(chan int, 1) + p.Submit(func() { + c <- 1 + }) + i := <-c + convey.So(i, convey.ShouldEqual, 1) + }) +} diff --git a/api/go/posixsdk/posixhandler.go b/api/go/posixsdk/posixhandler.go new file mode 100644 index 0000000..ff38777 --- /dev/null +++ b/api/go/posixsdk/posixhandler.go @@ -0,0 +1,345 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package posixsdk +This package provides methods to obtain the execution interface. +*/ +package posixsdk + +import ( + "fmt" + "os" + "path/filepath" + "plugin" + "strings" + "sync" + + "yuanrong.org/kernel/runtime/libruntime" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/logger/log" + "yuanrong.org/kernel/runtime/libruntime/config" + "yuanrong.org/kernel/runtime/libruntime/execution" + "yuanrong.org/kernel/runtime/libruntime/libruntimesdkimpl" +) + +type initHandlerType = func([]api.Arg, api.LibruntimeAPI) ([]byte, error) +type callHandlerType = func([]api.Arg) ([]byte, error) +type checkpointHandlerType = func(string) ([]byte, error) +type recoverHandlerType = func([]byte, api.LibruntimeAPI) error +type shutdownHandlerType = func(gracePeriodSecond uint64) error +type signalHandlerType = func(signal int, payload []byte) error +type healthCheckHandlerType = func() (api.HealthType, error) + +// RegisterHandler - +type RegisterHandler struct { + InitHandler initHandlerType + CallHandler callHandlerType + CheckPointHandler checkpointHandlerType + RecoverHandler recoverHandlerType + ShutDownHandler shutdownHandlerType + SignalHandler signalHandlerType + HealthCheckHandler healthCheckHandlerType +} + +type posixHandlers struct { + sync.Once + InitHandler initHandlerType + CallHandler callHandlerType + CheckPointHandler checkpointHandlerType + RecoverHandler recoverHandlerType + ShutDownHandler shutdownHandlerType + SignalHandler signalHandlerType + HealthCheckHandler healthCheckHandlerType + execType execution.ExecutionType + onceType sync.Once + libruntimeAPI api.LibruntimeAPI +} + +const ( + functionLibraryPath string = "YR_FUNCTION_LIB_PATH" + initHandler string = "INIT_HANDLER" + callHandler string = "CALL_HANDLER" + checkpointHandler string = "CHECKPOINT_HANDLER" + recoverHandler string = "RECOVER_HANDLER" + shutdownHandler string = "SHUTDOWN_HANDLER" + signalHandler string = "SIGNAL_HANDLER" + healthCheckHandler string = "HEALTH_CHECK_HANDLER" + actorHandlersString string = "yrlib-executor.so" + faaSHandlersString string = "faas-executor.so" + minPathPartLength = 2 +) + +var ( + posixHandler = &posixHandlers{ + libruntimeAPI: libruntimesdkimpl.NewLibruntimeSDKImpl(), + } + posixHandlersEnvs []string = []string{ + initHandler, callHandler, checkpointHandler, recoverHandler, shutdownHandler, signalHandler, healthCheckHandler, + } + libMap map[string]*plugin.Plugin +) + +// Load getEnvs to load code and setHandler +func Load() (execution.ExecutionType, error) { + functionLibPath := os.Getenv(functionLibraryPath) + if functionLibPath == "" { + return execution.ExecutionTypeInvalid, fmt.Errorf("YR_FUNCTION_LIB_PATH not found") + } + + var err error + posixHandler.Do( + func() { + libMap = make(map[string]*plugin.Plugin) + for _, envKey := range posixHandlersEnvs { + if err = setHandler(functionLibPath, envKey); err != nil { + fmt.Printf("failed to open lib %v\n", err) + } + } + }, + ) + return posixHandler.execType, err +} + +func setHandler(functionLibPath, envKey string) error { + handler := os.Getenv(envKey) + if handler == "" { + return fmt.Errorf("%s not found", envKey) + } + + symbol, err := getLib(functionLibPath, handler) + if err != nil { + return err + } + ok := true + switch envKey { + case initHandler: + posixHandler.InitHandler, ok = symbol.(initHandlerType) + case callHandler: + posixHandler.CallHandler, ok = symbol.(callHandlerType) + case checkpointHandler: + posixHandler.CheckPointHandler, ok = symbol.(checkpointHandlerType) + case recoverHandler: + posixHandler.RecoverHandler, ok = symbol.(recoverHandlerType) + case shutdownHandler: + posixHandler.ShutDownHandler, ok = symbol.(shutdownHandlerType) + case signalHandler: + posixHandler.SignalHandler, ok = symbol.(signalHandlerType) + case healthCheckHandler: + posixHandler.HealthCheckHandler, ok = symbol.(healthCheckHandlerType) + default: + return fmt.Errorf("not support %s", handler) + } + if !ok { + return fmt.Errorf("%s type error", handler) + } + return nil +} + +func getLib(functionLibPath, handler string) (plugin.Symbol, error) { + path, name := getLibInfo(functionLibPath, handler) + if path == "" { + return nil, fmt.Errorf("invalid handler name :%s", handler) + } + posixHandler.onceType.Do( + func() { + posixHandler.execType = parseHandlersType(path) + }, + ) + + lib, ok := libMap[path] + if !ok { + log.GetLogger().Infof("start to open lib %s", path) + plug, err := plugin.Open(path) + if err != nil { + log.GetLogger().Errorf("failed to open lib %v", err) + return nil, fmt.Errorf("failed to open %s", handler) + } + lib = plug + libMap[path] = plug + } + symbol, err := lib.Lookup(name) + if err != nil { + log.GetLogger().Errorf("failed to look up %v", err) + return nil, fmt.Errorf("failed to look up %s", name) + } + return symbol, nil +} + +// getLibInfo will parse handler info, example: +// libPath="/tmp" libName="example.init" --> fileName="/tmp/example.so" handlerName="init" +// libPath="/tmp" libName="test.example.init" --> fileName="/tmp/test/example.so" handlerName="init" +func getLibInfo(libPath, libName string) (string, string) { + path := libPath + parts := strings.Split(libName, ".") + length := len(parts) + handlerName := parts[length-1] + if length < minPathPartLength { + return "", "" + } else if length > minPathPartLength { + tmpPath := filepath.Join(parts[:length-minPathPartLength]...) + path = filepath.Join(path, tmpPath) + } + fileName := filepath.Join(path, parts[length-minPathPartLength]+".so") + + // hack for handlers for runtime-go and libruntime both present + if strings.HasSuffix(fileName, "faasfrontend.so") || + strings.HasSuffix(fileName, "faascontroller.so") || + strings.HasSuffix(fileName, "faasscheduler.so") || + strings.HasSuffix(fileName, "faasmanager.so") { + handlerName = handlerName + "Libruntime" // _Libruntime is not good format in golang + } + return fileName, handlerName +} + +func parseHandlersType(libPath string) execution.ExecutionType { + if strings.HasSuffix(libPath, actorHandlersString) { + return execution.ExecutionTypeActor + } else if strings.HasSuffix(libPath, faaSHandlersString) { + return execution.ExecutionTypeFaaS + } + return execution.ExecutionTypePosix +} + +// GetExecType get execution type +func (h *posixHandlers) GetExecType() execution.ExecutionType { + return h.execType +} + +// LoadFunction load function hook +func (h *posixHandlers) LoadFunction(codePaths []string) error { + return nil +} + +// FunctionExecute function execute hook +func (h *posixHandlers) FunctionExecute( + funcMeta api.FunctionMeta, invokeType config.InvokeType, args []api.Arg, returnobjs []config.DataObject, +) error { + var ret []byte + var err error + switch invokeType { + case config.CreateInstance, config.CreateInstanceStateless: + ret, err = h.InitHandler(args, h.libruntimeAPI) + case config.InvokeInstance, config.InvokeInstanceStateless: + ret, err = h.CallHandler(args) + default: + err = fmt.Errorf("no such invokeType %d", invokeType) + } + + if err != nil { + return err + } + + var totalNativeBufferSize uint = 0 + var do *config.DataObject + if len(returnobjs) > 0 { + do = &returnobjs[0] + } else { + return nil + } + if do.ID == "returnByMsg" { + libruntime.SetReturnObject(do, uint(len(ret))) + } else { + if err = libruntime.AllocReturnObject(do, uint(len(ret)), []string{}, &totalNativeBufferSize); err != nil { + return err + } + } + + if err = libruntime.WriterLatch(do); err != nil { + return err + } + defer func() { + if err = libruntime.WriterUnlatch(do); err != nil { + log.GetLogger().Errorf("%v", err) + } + }() + + if err = libruntime.MemoryCopy(do, ret); err != nil { + return err + } + + if err = libruntime.Seal(do); err != nil { + return err + } + return nil +} + +// Checkpoint check point hook +func (h *posixHandlers) Checkpoint(checkpointID string) ([]byte, error) { + if h.CheckPointHandler == nil { + return nil, nil + } + return h.CheckPointHandler(checkpointID) +} + +// Recover recover hook +func (h *posixHandlers) Recover(state []byte) error { + if h.RecoverHandler == nil { + return nil + } + return h.RecoverHandler(state, h.libruntimeAPI) +} + +// Shutdown graceful shutdown hook +func (h *posixHandlers) Shutdown(gracePeriod uint64) error { + if h.ShutDownHandler == nil { + return nil + } + return h.ShutDownHandler(gracePeriod) +} + +// Signal receive signal hook +func (h *posixHandlers) Signal(sig int, data []byte) error { + if h.SignalHandler == nil { + return nil + } + return h.SignalHandler(sig, data) +} + +// HealthCheck - +func (h *posixHandlers) HealthCheck() (api.HealthType, error) { + if h.HealthCheckHandler == nil { + return api.Healthy, nil + } + return h.HealthCheckHandler() +} + +// NewPosixFuncExecutionIntfs - +func NewPosixFuncExecutionIntfs() (execution.FunctionExecutionIntfs, error) { + if _, err := Load(); err != nil { + fmt.Printf("Load function execution Intfs error:%s\n", err.Error()) + } + if posixHandler.execType != execution.ExecutionTypePosix { + return nil, fmt.Errorf("executionType is not posix") + } + return posixHandler, nil +} + +// NewSDKPosixFuncExecutionWithHandler - +func NewSDKPosixFuncExecutionWithHandler(registerHandler RegisterHandler) execution.FunctionExecutionIntfs { + SDKPosixHandler := &posixHandlers{ + libruntimeAPI: libruntimesdkimpl.NewLibruntimeSDKImpl(), + execType: execution.ExecutionTypePosix, + InitHandler: registerHandler.InitHandler, + CallHandler: registerHandler.CallHandler, + CheckPointHandler: registerHandler.CheckPointHandler, + RecoverHandler: registerHandler.RecoverHandler, + ShutDownHandler: registerHandler.ShutDownHandler, + SignalHandler: registerHandler.SignalHandler, + HealthCheckHandler: registerHandler.HealthCheckHandler, + } + return SDKPosixHandler +} diff --git a/api/go/posixsdk/posixhandler_test.go b/api/go/posixsdk/posixhandler_test.go new file mode 100644 index 0000000..34b4947 --- /dev/null +++ b/api/go/posixsdk/posixhandler_test.go @@ -0,0 +1,170 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package posixsdk +This package provides methods to obtain the execution interface. +*/ +package posixsdk + +import ( + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/config" + "yuanrong.org/kernel/runtime/libruntime/execution" +) + +func TestLoad(t *testing.T) { + convey.Convey( + "Test Load", t, func() { + convey.Convey("Test invalid handler name", func() { + os.Setenv(functionLibraryPath, "/tmp") + _, err := Load() + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "not found") + }) + }) +} + +func TestPosixHandler_FunctionExecute(t *testing.T) { + convey.Convey( + "Test PosixHandler FunctionExecute", t, func() { + convey.Convey("Test no such invokeType", func() { + err := posixHandler.FunctionExecute(api.FunctionMeta{}, 4, []api.Arg{}, []config.DataObject{}) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "no such invokeType") + }) + posixHandler.InitHandler = func([]api.Arg, api.LibruntimeAPI) ([]byte, error) { return nil, nil } + invokeType := config.CreateInstance + convey.Convey("FunctionExecute when case config.CreateInstance", func() { + err := posixHandler.FunctionExecute(api.FunctionMeta{}, invokeType, []api.Arg{}, []config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }) + posixHandler.CallHandler = func([]api.Arg) ([]byte, error) { return nil, nil } + invokeType = config.InvokeInstance + convey.Convey("FunctionExecute when case config.InvokeInstance", func() { + err := posixHandler.FunctionExecute(api.FunctionMeta{}, invokeType, []api.Arg{}, []config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("FunctionExecute when len(returnobjs) > 0", func() { + err := posixHandler.FunctionExecute(api.FunctionMeta{}, invokeType, []api.Arg{}, []config.DataObject{{}}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("FunctionExecute when do.ID == \"returnByMsg\"", func() { + returnObjs := []config.DataObject{{ID: "returnByMsg"}} + err := posixHandler.FunctionExecute(api.FunctionMeta{}, invokeType, []api.Arg{}, returnObjs) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestSetHandler(t *testing.T) { + convey.Convey("Test setHandler", t, func() { + functionLibPath := "/tmp" + envKey := "initHandler" + convey.Convey("setHandler success", func() { + os.Setenv(envKey, "example.init") + convey.So(func() { + setHandler(functionLibPath, envKey) + }, convey.ShouldNotPanic) + }) + }) +} + +func TestGetLib(t *testing.T) { + convey.Convey("Test getLib", t, func() { + convey.Convey("getLib when path == \"\"", func() { + symbol, err := getLib("", "handler") + convey.So(symbol, convey.ShouldBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "invalid handler name") + }) + }) +} + +func TestGetLibInfo(t *testing.T) { + convey.Convey("Test getLibInfo", t, func() { + convey.Convey("getLibInfo when length > minPathPartLength", func() { + fileName, handlerName := getLibInfo("/tmp", "test.faasfrontend.init") + convey.So(fileName, convey.ShouldEqual, "/tmp/test/faasfrontend.so") + convey.So(handlerName, convey.ShouldEqual, "initLibruntime") + }) + }) +} + +func TestParseHandlersType(t *testing.T) { + convey.Convey("Test parseHandlersType", t, func() { + convey.Convey("parseHandlersType when hasSuffix actorHandlersString", func() { + executionType := parseHandlersType("lib" + actorHandlersString) + convey.So(executionType, convey.ShouldEqual, execution.ExecutionTypeActor) + }) + convey.Convey("parseHandlersType when hasSuffix faaSHandlersString", func() { + executionType := parseHandlersType("lib" + faaSHandlersString) + convey.So(executionType, convey.ShouldEqual, execution.ExecutionTypeFaaS) + }) + }) +} + +func TestPosixHandlers(t *testing.T) { + convey.Convey("Test posixHandlers", t, func() { + convey.Convey("LoadFunction success", func() { + err := posixHandler.LoadFunction([]string{}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("Checkpoint success", func() { + bytes, err := posixHandler.Checkpoint("checkpointID1") + convey.So(bytes, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + posixHandler.CheckPointHandler = func(string) ([]byte, error) { return nil, nil } + bytes, err = posixHandler.Checkpoint("checkpointID1") + convey.So(bytes, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("Recover success", func() { + err := posixHandler.Recover([]byte{}) + convey.So(err, convey.ShouldBeNil) + posixHandler.RecoverHandler = func([]byte, api.LibruntimeAPI) error { return nil } + err = posixHandler.Recover([]byte{}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("Shutdown success", func() { + err := posixHandler.Shutdown(0) + convey.So(err, convey.ShouldBeNil) + posixHandler.ShutDownHandler = func(gracePeriodSecond uint64) error { return nil } + err = posixHandler.Shutdown(0) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("Signal success", func() { + err := posixHandler.Signal(0, []byte{}) + convey.So(err, convey.ShouldBeNil) + posixHandler.SignalHandler = func(signal int, payload []byte) error { return nil } + err = posixHandler.Signal(0, []byte{}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("HealthCheck success", func() { + healthType, err := posixHandler.HealthCheck() + convey.So(healthType, convey.ShouldEqual, api.Healthy) + convey.So(err, convey.ShouldBeNil) + posixHandler.HealthCheckHandler = func() (api.HealthType, error) { return api.Healthy, nil } + healthType, err = posixHandler.HealthCheck() + convey.So(healthType, convey.ShouldEqual, api.Healthy) + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/api/go/posixsdk/runtime.go b/api/go/posixsdk/runtime.go new file mode 100644 index 0000000..d2575dc --- /dev/null +++ b/api/go/posixsdk/runtime.go @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package posixsdk for init and start +package posixsdk + +import ( + "fmt" + "os" + + "yuanrong.org/kernel/runtime/libruntime" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common" + "yuanrong.org/kernel/runtime/libruntime/common/utils" + "yuanrong.org/kernel/runtime/libruntime/config" + "yuanrong.org/kernel/runtime/libruntime/execution" + "yuanrong.org/kernel/runtime/libruntime/pool" +) + +const ( + initArgsFilePathEnvKey = "INIT_ARGS_FILE_PATH" +) + +// Run begins loop processing the received request. +func Run() { + libruntime.ReceiveRequestLoop() +} + +// InitRuntime init runtime +func InitRuntime(conf *common.Configuration, intfs execution.FunctionExecutionIntfs) error { + runtimeConf := config.Config{ + GrpcAddress: conf.GrpcAddress, + FunctionSystemAddress: conf.FSAddress, + DataSystemAddress: os.Getenv("DATASYSTEM_ADDR"), + JobID: conf.JobID, + RuntimeID: conf.RuntimeID, + InstanceID: conf.InstanceID, + FunctionName: conf.FunctionName, + LogDir: conf.LogPath, + LogLevel: conf.LogLevel, + InCluster: true, + IsDriver: conf.DriverMode, + EnableMTLS: conf.EnableMTLS, + PrivateKeyPath: conf.PrivateKeyPath, + CertificateFilePath: conf.CertificateFilePath, + VerifyFilePath: conf.VerifyFilePath, + PrivateKeyPaaswd: conf.PrivateKeyPaaswd, + Api: api.PosixApi, + Hooks: config.HookIntfs{ + LoadFunctionCb: intfs.LoadFunction, + FunctionExecutionCb: intfs.FunctionExecute, + CheckpointCb: intfs.Checkpoint, + RecoverCb: intfs.Recover, + ShutdownCb: intfs.Shutdown, + SignalCb: intfs.Signal, + HealthCheckCb: intfs.HealthCheck, + }, + FunctionExectionPool: pool.NewPool(pool.DefaultFuncExecPoolSize), + SystemAuthAccessKey: conf.SystemAuthAccessKey, + SystemAuthSecretKey: conf.SystemAuthSecretKey, + EncryptPrivateKeyPasswd: conf.EncryptPrivateKeyPasswd, + PrimaryKeyStoreFile: conf.PrimaryKeyStoreFile, + StandbyKeyStoreFile: conf.StandbyKeyStoreFile, + EnableDsEncrypt: conf.EnableDsEncrypt, + RuntimePublicKeyContext: conf.RuntimePublicKeyContext, + RuntimePrivateKeyContext: conf.RuntimePrivateKeyContext, + DsPublicKeyContext: conf.DsPublicKeyContext, + EncryptRuntimePublicKeyContext: conf.EncryptRuntimePublicKeyContext, + EncryptRuntimePrivateKeyContext: conf.EncryptRuntimePrivateKeyContext, + EncryptDsPublicKeyContext: conf.EncryptDsPublicKeyContext, + MaxConcurrencyCreateNum: conf.MaxConcurrencyCreateNum, + EnableSigaction: conf.EnableSigaction, + } + if err := libruntime.Init(runtimeConf); err != nil { + fmt.Printf("failed to init libruntime, error %s\n", err.Error()) + return err + } + initArgsFilePath := os.Getenv(initArgsFilePathEnvKey) + if conf.DriverMode && len(initArgsFilePath) != 0 { + initArgsData, err := os.ReadFile(initArgsFilePath) + if err != nil { + fmt.Printf("failed to read init args file %s, error %s\n", initArgsFilePath, err.Error()) + return err + } + return bootstrapSystemFunction(intfs, initArgsData) + } + return nil +} + +func bootstrapSystemFunction(intfs execution.FunctionExecutionIntfs, initArgsData []byte) error { + funcLibPath := os.Getenv("FUNCTION_LIB_PATH") + funcName := utils.GetFuncNameFromFuncLibPath(funcLibPath) + funcMeta := api.FunctionMeta{ + AppName: funcName, + FuncName: funcName, + FuncID: funcName, + } + args := []api.Arg{ + { + Type: api.Value, + Data: initArgsData, + }, + } + if err := intfs.FunctionExecute(funcMeta, config.CreateInstance, args, []config.DataObject{}); err != nil { + fmt.Printf("failed to call init for function %s, error %s\n", funcName, err.Error()) + return err + } + return nil +} + +// ExecShutdownHandler exec shutdown handler. +func ExecShutdownHandler(signum int) { + libruntime.ExecShutdownHandler(signum) +} diff --git a/api/go/posixsdk/runtime_test.go b/api/go/posixsdk/runtime_test.go new file mode 100644 index 0000000..1a33a9e --- /dev/null +++ b/api/go/posixsdk/runtime_test.go @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package posixsdk for init and start +package posixsdk + +import ( + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/common" +) + +func TestInitRuntimeAndRun(t *testing.T) { + convey.Convey( + "Test InitRuntimeAndRun", t, func() { + cfg := common.GetConfig() + cfg.DriverMode = true + os.Create("./test.json") + defer os.Remove("./test.json") + os.Setenv(initArgsFilePathEnvKey, "./test.json") + defer os.Setenv(initArgsFilePathEnvKey, "") + os.Setenv("FUNCTION_LIB_PATH", "/tmp") + intfs, err := NewPosixFuncExecutionIntfs() + convey.So(err, convey.ShouldBeNil) + convey.Convey("Test InitRuntime Failed", func() { + err = InitRuntime(cfg, intfs) + Run() + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestExecShutdownHandler(t *testing.T) { + convey.Convey( + "Test ExecShutdownHandler", t, func() { + ExecShutdownHandler(15) + convey.So(nil, convey.ShouldBeNil) + }) +} diff --git a/api/go/runtime/system_function_bootstrap b/api/go/runtime/system_function_bootstrap new file mode 100644 index 0000000..5913cb1 --- /dev/null +++ b/api/go/runtime/system_function_bootstrap @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +export ENABLE_SERVER_MODE=${ENABLE_SERVER_MODE:-"true"} +export YR_FUNCTION_LIB_PATH=${YR_FUNCTION_LIB_PATH:-"./"} +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-"./"} +./goruntime -jobId ${JOB_ID} -runtimeId=${RUNTIME_ID} -instanceId=${INSTANCE_ID} -logLevel=${LOG_LEVEL} \ + -logPath=${LOG_PATH} -functionSystemAddress=${FS_ADDRESS} -driverMode true diff --git a/api/go/runtime/yr_runtime_main.go b/api/go/runtime/yr_runtime_main.go new file mode 100644 index 0000000..1c8023f --- /dev/null +++ b/api/go/runtime/yr_runtime_main.go @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package main for start the program. +package main + +import ( + "fmt" + + "yuanrong.org/kernel/runtime/faassdk" + "yuanrong.org/kernel/runtime/libruntime/common" + "yuanrong.org/kernel/runtime/libruntime/execution" + "yuanrong.org/kernel/runtime/posixsdk" + "yuanrong.org/kernel/runtime/yr" +) + +func start() { + execType, err := posixsdk.Load() + if err != nil { + fmt.Printf("Load function execution Intfs error:%s\n", err.Error()) + } + + switch execType { + case execution.ExecutionTypeActor: + err = yr.InitRuntime() + if err != nil { + fmt.Print("init runtime failed: " + err.Error()) + return + } + yr.Run() + case execution.ExecutionTypeFaaS: + err = faassdk.InitRuntime() + if err != nil { + fmt.Print("init runtime failed: " + err.Error()) + return + } + faassdk.Run() + case execution.ExecutionTypePosix: + conf := common.GetConfig() + intfs, err := posixsdk.NewPosixFuncExecutionIntfs() + if err != nil { + fmt.Printf("failed to new posix intfs, error %s\n", err.Error()) + return + } + err = posixsdk.InitRuntime(conf, intfs) + if err != nil { + fmt.Print("init runtime failed: " + err.Error()) + return + } + posixsdk.Run() + default: + err = yr.InitRuntime() + if err != nil { + fmt.Print("init runtime failed: " + err.Error()) + return + } + yr.Run() + } +} + +func main() { + start() +} diff --git a/api/go/runtime/yr_runtime_main_test.go b/api/go/runtime/yr_runtime_main_test.go new file mode 100644 index 0000000..4fb8c04 --- /dev/null +++ b/api/go/runtime/yr_runtime_main_test.go @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package main +This package start the program. +*/ +package main + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestStart(t *testing.T) { + convey.Convey( + "Test start", t, func() { + convey.So(main, convey.ShouldNotPanic) + }, + ) +} diff --git a/api/go/yr/actorhandler.go b/api/go/yr/actorhandler.go new file mode 100644 index 0000000..f17029c --- /dev/null +++ b/api/go/yr/actorhandler.go @@ -0,0 +1,327 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package yr +This package provides methods to obtain the execution interface. +*/ +package yr + +import ( + "fmt" + "os" + "path/filepath" + "plugin" + "reflect" + + "github.com/vmihailenco/msgpack" + + "yuanrong.org/kernel/runtime/libruntime" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/logger/log" + "yuanrong.org/kernel/runtime/libruntime/config" + "yuanrong.org/kernel/runtime/libruntime/execution" +) + +type actorHandlers struct { + execType execution.ExecutionType + plugins map[string]*plugin.Plugin + instance *reflect.Value +} + +var ( + actorHdlrs actorHandlers = actorHandlers{ + execType: execution.ExecutionTypeActor, + plugins: make(map[string]*plugin.Plugin), + instance: nil, + } +) + +// parseArgsValue obtain the parameter type through reflection and deserialize the parameter. +func parseArgsValue(method reflect.Value, args []api.Arg) ([]reflect.Value, error) { + if len(args) != method.Type().NumIn() { + return make([]reflect.Value, 0), fmt.Errorf("args number is not valid") + } + + values := make([]reflect.Value, 0, len(args)) + for i := 0; i < method.Type().NumIn(); i++ { + value, err := parseArgValue(args[i], method.Type().In(i)) + if err != nil { + return values, err + } + + if method.Type().In(i).Kind() != reflect.Ptr { + values = append(values, value.Elem()) + continue + } + values = append(values, value) + } + return values, nil +} + +func parseArgValue(arg api.Arg, dataType reflect.Type) (reflect.Value, error) { + if arg.Type == api.ObjectRef { + // todo is ref + } + + value := reflect.New(dataType).Interface() + if unmarshalError := msgpack.Unmarshal(arg.Data, value); unmarshalError != nil { + return reflect.Value{}, unmarshalError + } + + return reflect.ValueOf(value), nil +} + +// GetExecType get execution type +func (h *actorHandlers) GetExecType() execution.ExecutionType { + return h.execType +} + +// LoadFunction load user function +func (h *actorHandlers) LoadFunction(codePaths []string) error { + for _, path := range codePaths { + files, err := os.ReadDir(path) + if err != nil { + return err + } + + for _, file := range files { + // not support load recursively + if file.IsDir() { + continue + } + if filepath.Ext(file.Name()) != ".so" { + continue + } + pluginPath := filepath.Join(path, file.Name()) + p, err := plugin.Open(pluginPath) + if err != nil { + return err + } else { + h.plugins[pluginPath] = p + } + } + } + return nil +} + +// invokeFunction invoke user function +func (h *actorHandlers) invokeFunction(funcMeta api.FunctionMeta, args []api.Arg) ([]reflect.Value, error) { + for _, p := range h.plugins { + symbol, err := p.Lookup(funcMeta.FuncName) + if err != nil { + continue + } + method := reflect.ValueOf(symbol) + methodArgs, err := parseArgsValue(method, args) + if err != nil { + return make([]reflect.Value, 0), err + } + returnValues := method.Call(methodArgs) + err = catchUserErr(returnValues) + if err != nil { + return make([]reflect.Value, 0), err + } + return returnValues, nil + } + + return make([]reflect.Value, 0), fmt.Errorf("could not find function({})", funcMeta.FuncName) +} + +func (h *actorHandlers) invokeMemberFunction(funcMeta api.FunctionMeta, args []api.Arg) ([]reflect.Value, error) { + if h.instance == nil { + return make([]reflect.Value, 0), fmt.Errorf("did not create instance firstly") + } + method := h.instance.MethodByName(funcMeta.FuncName) + if !method.IsValid() { + return make([]reflect.Value, 0), fmt.Errorf("can not find method: ", funcMeta.FuncName) + } + + methodArgs, err := parseArgsValue(method, args) + if err != nil { + return make([]reflect.Value, 0), err + } + returnValues := method.Call(methodArgs) + err = catchUserErr(returnValues) + if err != nil { + return make([]reflect.Value, 0), err + } + return returnValues, err +} + +func catchUserErr(returnValues []reflect.Value) error { + for i := 0; i < len(returnValues); i++ { + if returnValues[i].Interface() == nil { + continue + } + errorInterface := reflect.TypeOf((*error)(nil)).Elem() + if returnValues[i].Type().Implements(errorInterface) { + e, ok := returnValues[i].Interface().(error) + if !ok { + return nil + } + return e + } + } + return nil +} + +func (h *actorHandlers) createInstance( + funcMeta api.FunctionMeta, args []api.Arg, returnobjs []config.DataObject, +) error { + fmt.Println("stateful create") + var err error + var returnValues []reflect.Value + returnValues, err = h.invokeFunction(funcMeta, args) + if err != nil { + fmt.Println("Create stateful instance failed, error: ", err) + return api.AddStack(err, GetStackTrace(h.instance, funcMeta.FuncName)) + } + + if len(returnValues) != 1 { + fmt.Println("return value number is valid") + err = fmt.Errorf("return value number is valid") + return err + } + + h.instance = &returnValues[0] + return nil +} + +func processInvokeRes(returnValues []reflect.Value, returnobjs []config.DataObject) error { + if len(returnValues) == 0 { + fmt.Println("return number is zero") + return nil + } + + packedReturnValues := make([][]byte, 0, len(returnValues)) + for _, returnValue := range returnValues { + packedReturnValue, err := msgpack.Marshal(returnValue.Interface()) + if err != nil { + fmt.Println("marshel failed ", err) + return nil + } + packedReturnValues = append(packedReturnValues, packedReturnValue) + } + + ret := packedReturnValues[0] + var totalNativeBufferSize uint = 0 + var do *config.DataObject = &returnobjs[0] + if err := libruntime.AllocReturnObject(do, uint(len(ret)), []string{}, &totalNativeBufferSize); err != nil { + return err + } + + if err := libruntime.WriterLatch(do); err != nil { + return err + } + defer func() { + if err := libruntime.WriterUnlatch(do); err != nil { + log.GetLogger().Errorf("%v", err) + } + }() + + if err := libruntime.MemoryCopy(do, ret); err != nil { + return err + } + + if err := libruntime.Seal(do); err != nil { + return err + } + + return nil +} + +func (h *actorHandlers) invokeInstanceStateless( + funcMeta api.FunctionMeta, args []api.Arg, returnobjs []config.DataObject, +) error { + fmt.Println("stateless invoke") + returnValues, err := h.invokeFunction(funcMeta, args) + if err != nil { + fmt.Println("Invoke Instance Stateless failed, error: ", err) + return api.AddStack(err, GetStackTrace(h.instance, funcMeta.FuncName)) + } + err = processInvokeRes(returnValues, returnobjs) + if err != nil { + fmt.Println("failed to process invokeInstanceStateless result by runtime: ", err) + } + return err +} + +func (h *actorHandlers) invokeInstance( + funcMeta api.FunctionMeta, args []api.Arg, returnobjs []config.DataObject, +) error { + fmt.Println("stateful invoke") + returnValues, err := h.invokeMemberFunction(funcMeta, args) + if err != nil { + fmt.Println("Invoke Instance Stateless failed, error: ", err) + return api.AddStack(err, GetStackTrace(h.instance, funcMeta.FuncName)) + } + err = processInvokeRes(returnValues, returnobjs) + if err != nil { + fmt.Println("failed to process invokeInstance result by runtime: ", err) + } + return err +} + +// FunctionExecute function execute hook +func (h *actorHandlers) FunctionExecute( + funcMeta api.FunctionMeta, invokeType config.InvokeType, args []api.Arg, returnobjs []config.DataObject, +) error { + switch invokeType { + case config.CreateInstance: + return h.createInstance(funcMeta, args, returnobjs) + case config.CreateInstanceStateless: + fmt.Println("stateless create") + return nil + case config.InvokeInstance: + return h.invokeInstance(funcMeta, args, returnobjs) + case config.InvokeInstanceStateless: + return h.invokeInstanceStateless(funcMeta, args, returnobjs) + default: + fmt.Println("invalid invoke type: ", invokeType) + return fmt.Errorf("invalid invoke type: %s", invokeType) + } +} + +// Checkpoint check point +func (h *actorHandlers) Checkpoint(checkpointID string) ([]byte, error) { + return []byte{}, nil +} + +// Recover recover hook +func (h *actorHandlers) Recover(state []byte) error { + return nil +} + +// Shutdown hook +func (h *actorHandlers) Shutdown(gracePeriod uint64) error { + return nil +} + +// Signal hook +func (h *actorHandlers) Signal(sig int, data []byte) error { + return nil +} + +func (h *actorHandlers) HealthCheck() (api.HealthType, error) { + return api.Healthy, nil +} + +func newActorFuncExecutionIntfs() execution.FunctionExecutionIntfs { + runtime := &ClusterModeRuntime{} + GetRuntimeHolder().Init(runtime) + return &actorHdlrs +} diff --git a/api/go/yr/actorhandler_test.go b/api/go/yr/actorhandler_test.go new file mode 100644 index 0000000..29c036c --- /dev/null +++ b/api/go/yr/actorhandler_test.go @@ -0,0 +1,235 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package yr +This package provides methods to obtain the execution interface. +*/ +package yr + +import ( + "errors" + "os" + "plugin" + "reflect" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/config" + "yuanrong.org/kernel/runtime/libruntime/execution" +) + +func TestLoadFunction(t *testing.T) { + convey.Convey( + "Test ActorHandler LoadFunction", t, func() { + intfs := newActorFuncExecutionIntfs() + convey.So(intfs.GetExecType(), convey.ShouldEqual, execution.ExecutionTypeActor) + convey.Convey("LoadFunction success when path error", func() { + err := intfs.LoadFunction([]string{"path"}) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("LoadFunction success when plugin error", func() { + os.Create("lib.so") + err := intfs.LoadFunction([]string{"./"}) + convey.So(err, convey.ShouldNotBeNil) + os.Remove("lib.so") + }) + convey.Convey("LoadFunction success when codePaths is empty", func() { + err := intfs.LoadFunction([]string{}) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestFunctionExecute(t *testing.T) { + convey.Convey( + "Test ActorHandler FunctionExecute", t, func() { + intfs := newActorFuncExecutionIntfs() + convey.Convey("LoadFunction success when case config.CreateInstance:", func() { + convey.So(func() { + intfs.FunctionExecute(api.FunctionMeta{}, config.CreateInstance, []api.Arg{}, []config.DataObject{}) + }, convey.ShouldNotPanic) + }) + convey.Convey("LoadFunction success when case config.CreateInstanceStateless:", func() { + convey.So(func() { + intfs.FunctionExecute(api.FunctionMeta{}, config.CreateInstanceStateless, []api.Arg{}, []config.DataObject{}) + }, convey.ShouldNotPanic) + }) + convey.Convey("LoadFunction success when case config.InvokeInstance:", func() { + convey.So(func() { + intfs.FunctionExecute(api.FunctionMeta{}, config.InvokeInstance, []api.Arg{}, []config.DataObject{}) + }, convey.ShouldNotPanic) + }) + convey.Convey("LoadFunction success when case config.InvokeInstanceStateless:", func() { + convey.So(func() { + intfs.FunctionExecute(api.FunctionMeta{}, config.InvokeInstanceStateless, []api.Arg{}, []config.DataObject{}) + }, convey.ShouldNotPanic) + }) + convey.Convey("LoadFunction success when case default:", func() { + convey.So(func() { + intfs.FunctionExecute(api.FunctionMeta{}, 5, []api.Arg{}, []config.DataObject{}) + }, convey.ShouldNotPanic) + }) + }, + ) +} + +func TestCheckpoint(t *testing.T) { + convey.Convey( + "Test Checkpoint", t, func() { + intfs := newActorFuncExecutionIntfs() + convey.Convey("Checkpoint success", func() { + bytes, err := intfs.Checkpoint("checkpointID1") + convey.So(bytes, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }) + }, + ) +} + +func TestRecover(t *testing.T) { + convey.Convey( + "Test Recover", t, func() { + intfs := newActorFuncExecutionIntfs() + convey.Convey("Recover success", func() { + err := intfs.Recover([]byte{}) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestShutdown(t *testing.T) { + convey.Convey( + "Test Shutdown", t, func() { + intfs := newActorFuncExecutionIntfs() + convey.Convey("Shutdown success", func() { + err := intfs.Shutdown(1) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestSignal(t *testing.T) { + convey.Convey( + "Test Signal", t, func() { + intfs := newActorFuncExecutionIntfs() + convey.Convey("Signal success", func() { + err := intfs.Signal(1, []byte{}) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestHealthCheck(t *testing.T) { + convey.Convey( + "Test HealthCheck", t, func() { + intfs := newActorFuncExecutionIntfs() + convey.Convey("HealthCheck success", func() { + healthType, err := intfs.HealthCheck() + convey.So(healthType, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestInvokeFunction(t *testing.T) { + convey.Convey( + "Test invokeFunction", t, func() { + intfs := newActorFuncExecutionIntfs() + h := intfs.(*actorHandlers) + h.plugins = map[string]*plugin.Plugin{"p1": {}} + convey.Convey("invokeFunction success", func() { + values, err := h.invokeFunction(api.FunctionMeta{}, []api.Arg{}) + convey.So(values, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("invokeMemberFunction when h.instance == nil", func() { + values, err := h.invokeMemberFunction(api.FunctionMeta{}, []api.Arg{}) + convey.So(values, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + + }) + v := reflect.ValueOf("v") + h.instance = &v + convey.Convey("invokeMemberFunction success", func() { + values, err := h.invokeMemberFunction(api.FunctionMeta{}, []api.Arg{}) + convey.So(values, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestCatchUserErr(t *testing.T) { + convey.Convey("Test catchUserErr", t, func() { + values := []reflect.Value{reflect.ValueOf(&api.FunctionMeta{}), reflect.ValueOf("add")} + convey.Convey("catchUserErr success", func() { + err := catchUserErr(values) + convey.So(err, convey.ShouldBeNil) + }) + values = []reflect.Value{reflect.ValueOf(errors.New("err1"))} + convey.Convey("catchUserErr when error", func() { + err := catchUserErr(values) + convey.So(err.Error(), convey.ShouldEqual, "err1") + }) + }) +} + +func TestProcessInvokeRes(t *testing.T) { + convey.Convey("Test processInvokeRes", t, func() { + values := []reflect.Value{reflect.ValueOf(&api.FunctionMeta{}), reflect.ValueOf("add")} + convey.Convey("processInvokeRes when len(returnValues) == 0", func() { + err := processInvokeRes([]reflect.Value{}, []config.DataObject{}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("processInvokeRes success", func() { + err := processInvokeRes(values, []config.DataObject{config.DataObject{}}) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestParseArgsValue(t *testing.T) { + convey.Convey("Test parseArgsValue", t, func() { + convey.Convey("parseArgsValue when len(args) != method.Type().NumIn()", func() { + values, err := parseArgsValue(reflect.ValueOf(func(foo int) {}), []api.Arg{}) + convey.So(len(values), convey.ShouldBeZeroValue) + convey.So(err.Error(), convey.ShouldEqual, "args number is not valid") + }) + convey.Convey("parseArgsValue when parseArgValue error", func() { + args := []api.Arg{{Data: []byte("a")}} + values, err := parseArgsValue(reflect.ValueOf(func(foo int) {}), args) + convey.So(len(values), convey.ShouldEqual, 1) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("parseArgsValue success", func() { + values, err := parseArgsValue(reflect.ValueOf(func(foo int) {}), []api.Arg{{Data: make([]byte, 1)}}) + convey.So(len(values), convey.ShouldEqual, 1) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestParseArgValue(t *testing.T) { + convey.Convey("Test parseArgValue", t, func() { + convey.Convey("parseArgsValue when arg.Type == api.ObjectRef ", func() { + value, err := parseArgValue(api.Arg{Type: api.ObjectRef}, reflect.TypeOf("v")) + convey.So(value, convey.ShouldEqual, reflect.Value{}) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} diff --git a/api/go/yr/cluster_mode_runtime.go b/api/go/yr/cluster_mode_runtime.go new file mode 100644 index 0000000..d65c665 --- /dev/null +++ b/api/go/yr/cluster_mode_runtime.go @@ -0,0 +1,257 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "fmt" + + "yuanrong.org/kernel/runtime/libruntime" + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/clibruntime" +) + +// ClusterModeRuntime implements the libRuntimeAPI interface +type ClusterModeRuntime struct { +} + +// ReleaseInstance - +func (r *ClusterModeRuntime) ReleaseInstance(allocation api.InstanceAllocation, stateID string, + abnormal bool, option api.InvokeOptions) { + return +} + +// AcquireInstance - +func (r *ClusterModeRuntime) AcquireInstance(state string, funcMeta api.FunctionMeta, + invokeOpt api.InvokeOptions) (api.InstanceAllocation, error) { + return api.InstanceAllocation{}, fmt.Errorf("not implement") +} + +// Init initializes the ClusterModeRuntime +func (r *ClusterModeRuntime) Init() error { + return libruntime.Init(GetConfigManager().Config) +} + +// InvokeByFunctionName stateless function invoke +func (r *ClusterModeRuntime) InvokeByFunctionName( + funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions, +) (string, error) { + return clibruntime.InvokeByFunctionName(funcMeta, args, invokeOpt) +} + +// CreateInstance create instance +func (r *ClusterModeRuntime) CreateInstance( + funcMeta api.FunctionMeta, args []api.Arg, invokeOpt api.InvokeOptions, +) (string, error) { + return clibruntime.CreateInstance(funcMeta, args, invokeOpt) +} + +// InvokeByInstanceId invoke by instanceId +func (r *ClusterModeRuntime) InvokeByInstanceId( + funcMeta api.FunctionMeta, instanceID string, args []api.Arg, invokeOpt api.InvokeOptions, +) (string, error) { + return clibruntime.InvokeByInstanceId(funcMeta, instanceID, args, api.InvokeOptions{}) +} + +// Kill send kill instance request +func (r *ClusterModeRuntime) Kill(instanceID string, signal int, payload []byte) error { + return clibruntime.Kill(instanceID, signal, payload) +} + +// CreateInstanceRaw not support raw interface +func (r *ClusterModeRuntime) CreateInstanceRaw(createReqRaw []byte) ([]byte, error) { + return nil, nil +} + +// InvokeByInstanceIdRaw not support raw interface +func (r *ClusterModeRuntime) InvokeByInstanceIdRaw(invokeReqRaw []byte) ([]byte, error) { + return nil, nil +} + +// KillRaw not support raw interface +func (r *ClusterModeRuntime) KillRaw(killReqRaw []byte) ([]byte, error) { return nil, nil } + +// SaveState no implement +func (r *ClusterModeRuntime) SaveState(state []byte) (string, error) { return "", nil } + +// LoadState no implement +func (r *ClusterModeRuntime) LoadState(checkpointID string) ([]byte, error) { return nil, nil } + +// Exit no implement +func (r *ClusterModeRuntime) Exit(code int, message string) {} + +// Finalize release resources +func (r *ClusterModeRuntime) Finalize() { + clibruntime.Finalize() +} + +// KVSet save binary data to the data system. +func (r *ClusterModeRuntime) KVSet(key string, value []byte, param api.SetParam) error { + return clibruntime.KVSet(key, value, param) +} + +// KVSetWithoutKey no implement +func (r *ClusterModeRuntime) KVSetWithoutKey(value []byte, param api.SetParam) (string, error) { + return "", nil +} + +// KVMSetTx no implement +func (r *ClusterModeRuntime) KVMSetTx(keys []string, values [][]byte, param api.MSetParam) error { + return clibruntime.KVMSetTx(keys, values, param) +} + +// KVGet get binary data from data system. +func (r *ClusterModeRuntime) KVGet(key string, timeoutms uint) ([]byte, error) { + return clibruntime.KVGet(key, timeoutms) +} + +// KVGetMulti get multi binary data from data system. +func (r *ClusterModeRuntime) KVGetMulti(keys []string, timeoutms uint) ([][]byte, error) { + return clibruntime.KVGetMulti(keys, timeoutms) +} + +// KVDel del data from data system. +func (r *ClusterModeRuntime) KVDel(key string) error { + return clibruntime.KVDel(key) +} + +// KVDelMulti del multi data from data system. +func (r *ClusterModeRuntime) KVDelMulti(keys []string) ([]string, error) { + return clibruntime.KVDelMulti(keys) +} + +// CreateProducer creates and returns api.StreamProducer +func (r *ClusterModeRuntime) CreateProducer( + streamName string, producerConf api.ProducerConf, +) (api.StreamProducer, error) { + return clibruntime.CreateStreamProducer(streamName, producerConf) +} + +// Subscribe creates and returns api.StreamConsumer +func (r *ClusterModeRuntime) Subscribe(streamName string, config api.SubscriptionConfig) (api.StreamConsumer, error) { + return clibruntime.CreateStreamConsumer(streamName, config) +} + +// DeleteStream Delete a data flow. +// When the number of global producers and consumers is 0, +// the data flow is no longer used and the metadata related to the data flow is deleted from each worker and the +// master. This function can be invoked on any host node. +func (r *ClusterModeRuntime) DeleteStream(streamName string) error { + return clibruntime.DeleteStream(streamName) +} + +// QueryGlobalProducersNum Specifies the flow name to query the number of all producers of the flow. +func (r *ClusterModeRuntime) QueryGlobalProducersNum(streamName string) (uint64, error) { + return clibruntime.QueryGlobalProducersNum(streamName) +} + +// QueryGlobalConsumersNum Specifies the flow name to query the number of all consumers of the flow. +func (r *ClusterModeRuntime) QueryGlobalConsumersNum(streamName string) (uint64, error) { + return clibruntime.QueryGlobalConsumersNum(streamName) +} + +// SetTraceID no implement +func (r *ClusterModeRuntime) SetTraceID(traceID string) {} + +// SetTenantID - +func (r *ClusterModeRuntime) SetTenantID(tenantID string) error { + return clibruntime.SetTenantID(tenantID) +} + +// Put put obj data to data system. +func (r *ClusterModeRuntime) Put(objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string) error { + return clibruntime.Put(objectID, value, param, nestedObjectIDs...) +} + +// PutRaw - +func (r *ClusterModeRuntime) PutRaw( + objectID string, value []byte, param api.PutParam, nestedObjectIDs ...string, +) error { + return clibruntime.PutRaw(objectID, value, param, nestedObjectIDs...) +} + +// Get to get objs from data system. +func (r *ClusterModeRuntime) Get(objectIDs []string, timeoutMs int) ([][]byte, error) { + return clibruntime.Get(objectIDs, timeoutMs) +} + +// GetRaw - +func (r *ClusterModeRuntime) GetRaw(objectIDs []string, timeoutMs int) ([][]byte, error) { + return clibruntime.GetRaw(objectIDs, timeoutMs) +} + +// Wait until result return or timeout +func (r *ClusterModeRuntime) Wait( + objectIDs []string, waitNum uint64, timeoutMs int, +) ([]string, []string, map[string]error) { + return clibruntime.Wait(objectIDs, waitNum, timeoutMs) +} + +// GIncreaseRef increase object reference count +func (r *ClusterModeRuntime) GIncreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return clibruntime.GIncreaseRef(objectIDs, remoteClientID...) +} + +// GIncreaseRefRaw - +func (r *ClusterModeRuntime) GIncreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + return clibruntime.GIncreaseRefRaw(objectIDs, remoteClientID...) +} + +// GDecreaseRef no implement +func (r *ClusterModeRuntime) GDecreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return nil, nil +} + +// GDecreaseRefRaw - +func (r *ClusterModeRuntime) GDecreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + return clibruntime.GDecreaseRefRaw(objectIDs, remoteClientID...) +} + +// GetAsync no implement +func (r *ClusterModeRuntime) GetAsync(objectID string, cb api.GetAsyncCallback) {} + +// GetFormatLogger no implement +func (r *ClusterModeRuntime) GetFormatLogger() api.FormatLogger { return nil } + +// CreateClient - +func (r *ClusterModeRuntime) CreateClient(config api.ConnectArguments) (api.KvClient, error) { + return nil, nil +} + +// ReleaseGRefs release object refs by remote client id +func (r *ClusterModeRuntime) ReleaseGRefs(remoteClientID string) error { + return fmt.Errorf("not support") +} + +// GetCredential - +func (r *ClusterModeRuntime) GetCredential() api.Credential { + return api.Credential{} +} + +// UpdateSchdulerInfo - +func (r *ClusterModeRuntime) UpdateSchdulerInfo(schedulerName string, schedulerId string, option string) { +} + +// IsHealth - +func (r *ClusterModeRuntime) IsHealth() bool { + return true +} + +// IsDsHealth - +func (r *ClusterModeRuntime) IsDsHealth() bool { + return true +} diff --git a/api/go/yr/cluster_mode_runtime_test.go b/api/go/yr/cluster_mode_runtime_test.go new file mode 100644 index 0000000..3f04e84 --- /dev/null +++ b/api/go/yr/cluster_mode_runtime_test.go @@ -0,0 +1,303 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +var clusterRt = &ClusterModeRuntime{} + +func TestClusterModeRuntimeInit(t *testing.T) { + convey.Convey( + "Test ClusterModeRuntimeInit", t, func() { + convey.Convey( + "Init success", func() { + err := clusterRt.Init() + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestReleaseInstance(t *testing.T) { + convey.Convey( + "Test ReleaseInstance", t, func() { + convey.Convey( + "ReleaseInstance success", func() { + convey.So(func() { + clusterRt.ReleaseInstance(api.InstanceAllocation{}, "", false, api.InvokeOptions{}) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestAcquireInstance(t *testing.T) { + convey.Convey( + "Test AcquireInstance", t, func() { + convey.Convey( + "AcquireInstance success", func() { + instanceAllocation, err := clusterRt.AcquireInstance("", api.FunctionMeta{}, api.InvokeOptions{}) + convey.So(instanceAllocation, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestCreateInstanceRaw(t *testing.T) { + convey.Convey( + "Test CreateInstanceRaw", t, func() { + convey.Convey( + "CreateInstanceRaw success", func() { + bytes, err := clusterRt.CreateInstanceRaw([]byte{0}) + convey.So(bytes, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestInvokeByInstanceIdRaw(t *testing.T) { + convey.Convey( + "Test InvokeByInstanceIdRaw", t, func() { + convey.Convey( + "InvokeByInstanceIdRaw success", func() { + bytes, err := clusterRt.InvokeByInstanceIdRaw([]byte{0}) + convey.So(bytes, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestKillRaw(t *testing.T) { + convey.Convey( + "Test KillRaw", t, func() { + convey.Convey( + "KillRaw success", func() { + bytes, err := clusterRt.KillRaw([]byte{0}) + convey.So(bytes, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestSaveState(t *testing.T) { + convey.Convey( + "Test SaveState", t, func() { + convey.Convey( + "SaveState success", func() { + str, err := clusterRt.SaveState([]byte{0}) + convey.So(str, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestLoadState(t *testing.T) { + convey.Convey( + "Test LoadState", t, func() { + convey.Convey( + "LoadState success", func() { + bytes, err := clusterRt.LoadState("") + convey.So(bytes, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestKVSetWithoutKey(t *testing.T) { + convey.Convey( + "Test KVSetWithoutKey", t, func() { + convey.Convey( + "KVSetWithoutKey success", func() { + str, err := clusterRt.KVSetWithoutKey([]byte{0}, api.SetParam{}) + convey.So(str, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestSetTraceID(t *testing.T) { + convey.Convey( + "Test SetTraceID", t, func() { + convey.Convey( + "SetTraceID success", func() { + convey.So(func() { + clusterRt.SetTraceID("") + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestSetTenantID(t *testing.T) { + convey.Convey( + "Test SetTenantID", t, func() { + convey.Convey( + "SetTenantID success", func() { + err := clusterRt.SetTenantID("") + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestPutRaw(t *testing.T) { + convey.Convey( + "Test PutRaw", t, func() { + convey.Convey( + "PutRaw success", func() { + err := clusterRt.PutRaw("", []byte{0}, api.PutParam{}, "") + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestGetRaw(t *testing.T) { + convey.Convey( + "Test GetRaw", t, func() { + convey.Convey( + "GetRaw success", func() { + bytesMatrix, err := clusterRt.GetRaw([]string{""}, 3000) + convey.So(bytesMatrix, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestGIncreaseRefRaw(t *testing.T) { + convey.Convey( + "Test GIncreaseRefRaw", t, func() { + convey.Convey( + "GIncreaseRefRaw success", func() { + strs, err := clusterRt.GIncreaseRefRaw([]string{""}, "") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestGDecreaseRef(t *testing.T) { + convey.Convey( + "Test GDecreaseRef", t, func() { + convey.Convey( + "GDecreaseRef success", func() { + strs, err := clusterRt.GDecreaseRef([]string{""}, "") + convey.So(strs, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestGDecreaseRefRaw(t *testing.T) { + convey.Convey( + "Test GDecreaseRefRaw", t, func() { + convey.Convey( + "GDecreaseRefRaw success", func() { + strs, err := clusterRt.GDecreaseRefRaw([]string{""}, "") + convey.So(strs, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestGetAsync(t *testing.T) { + convey.Convey( + "Test GetAsync", t, func() { + convey.Convey( + "GetAsync success", func() { + convey.So(func() { + clusterRt.GetAsync("", nil) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestGetFormatLogger(t *testing.T) { + convey.Convey( + "Test GetFormatLogger", t, func() { + convey.Convey( + "GetFormatLogger success", func() { + logger := clusterRt.GetFormatLogger() + convey.So(logger, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestCreateClient(t *testing.T) { + convey.Convey( + "Test CreateClient", t, func() { + convey.Convey( + "CreateClient success", func() { + kvClient, err := clusterRt.CreateClient(api.ConnectArguments{}) + convey.So(kvClient, convey.ShouldBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestReleaseGRefs(t *testing.T) { + convey.Convey( + "Test ReleaseGRefs", t, func() { + convey.Convey( + "ReleaseGRefs success", func() { + err := clusterRt.ReleaseGRefs("") + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/yr/config.go b/api/go/yr/config.go new file mode 100644 index 0000000..21cf11d --- /dev/null +++ b/api/go/yr/config.go @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/config" +) + +// Mode defines the function running mode +type Mode int + +const ( + // LocalMode Single-machine multithreading + LocalMode = iota + // ClusterMode Multi-machine multiprocessing + ClusterMode + // Invalid invalid + Invalid +) + +// Config Used for initializing the YR system +type Config struct { + Mode Mode + FunctionUrn string + ServerAddr string + DataSystemAddr string + DataSystemAgentAddr string + InCluster bool + LogLevel string + AutoStart bool +} + +// FlagsConfig from the command line +type FlagsConfig struct { + RuntimeID string + InstanceID string + LogLevel string + Address string + LogPath string + JobID string + Hooks config.HookIntfs + FunctionExectionPool config.Pool +} + +// InvokeOptions function invoke option +type InvokeOptions struct { + Cpu int + Memory int + InvokeLabels map[string]string + CustomResources map[string]float64 + CustomExtensions map[string]string + RetryTime uint + RecoverRetryTimes int + CreateNotifyTimeout int + Labels []string + Affinity map[string]string + ScheduleAffinities []api.Affinity +} + +// NewInvokeOptions return *InvokeOptions +func NewInvokeOptions() *InvokeOptions { + return &InvokeOptions{ + Cpu: 500, + Memory: 500, + InvokeLabels: make(map[string]string), + CustomResources: make(map[string]float64), + CustomExtensions: make(map[string]string), + RetryTime: 0, + CreateNotifyTimeout: -1, + Labels: make([]string, 0), + Affinity: make(map[string]string), + ScheduleAffinities: make([]api.Affinity, 0), + } +} diff --git a/api/go/yr/config_manager.go b/api/go/yr/config_manager.go new file mode 100644 index 0000000..524318c --- /dev/null +++ b/api/go/yr/config_manager.go @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "fmt" + "strings" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong.org/kernel/runtime/libruntime/common/uuid" + "yuanrong.org/kernel/runtime/libruntime/config" +) + +const defaultStackTracesNum = 10 + +// ConfigManager config manager +type ConfigManager struct { + config.Config + Mode Mode +} + +var manager *ConfigManager = nil + +// ClientInfo client info +type ClientInfo struct { + jobId string + version string +} + +// GetClientInfo return ClientInfo +func (m *ConfigManager) GetClientInfo() *ClientInfo { + return &ClientInfo{m.JobID, ""} +} + +// GetConfigManager return the ConfigManager +func GetConfigManager() *ConfigManager { + if manager != nil { + return manager + } + + manager = &ConfigManager{} + manager.JobID = "" + manager.RuntimeID = "driver" + manager.IsDriver = true + manager.LoadPaths = make([]string, 0) + manager.InCluster = true + manager.FunctionSystemAddress = "" + manager.DataSystemIPAddr = "" + manager.DataSystemPort = 0 + manager.LogDir = "" + manager.LogLevel = "DEBUG" + manager.ThreadPoolSize = 10 + manager.EnableMTLS = false + manager.ServerName = "" + manager.Namespace = "" + manager.EnableMetrics = false + manager.MaxConcurrencyCreateNum = 100 + manager.HttpIocThreadsNum = 200 + manager.RecycleTime = 2 + manager.MaxTaskInstanceNum = -1 + manager.LogFileNumMax = 20 + manager.LogFileSizeMax = 400 + return manager +} + +// Init initializes the ConfigManager +func (m *ConfigManager) Init(yrConfig *Config, yrFlagsConfig *FlagsConfig) error { + m.IsDriver = !yrConfig.AutoStart + + m.FunctionSystemAddress = yrConfig.ServerAddr + m.DataSystemAddress = yrConfig.DataSystemAddr + m.LogLevel = yrConfig.LogLevel + m.GrpcAddress = yrFlagsConfig.Address + m.InCluster = yrConfig.InCluster + m.LogDir = yrFlagsConfig.LogPath + m.Api = api.ActorApi + m.Hooks = config.HookIntfs{ + LoadFunctionCb: yrFlagsConfig.Hooks.LoadFunctionCb, + FunctionExecutionCb: yrFlagsConfig.Hooks.FunctionExecutionCb, + CheckpointCb: yrFlagsConfig.Hooks.CheckpointCb, + RecoverCb: yrFlagsConfig.Hooks.RecoverCb, + ShutdownCb: yrFlagsConfig.Hooks.ShutdownCb, + SignalCb: yrFlagsConfig.Hooks.SignalCb, + } + m.FunctionExectionPool = yrFlagsConfig.FunctionExectionPool + if yrConfig.Mode != Invalid { + m.Mode = yrConfig.Mode + } + + if yrConfig.Mode == ClusterMode && (m.IsDriver || len(yrConfig.FunctionUrn) != 0) { + var err error + m.FunctionId, err = ConvertFunctionUrn2Id(yrConfig.FunctionUrn) + if err != nil { + return err + } + } + + if !m.IsDriver && len(yrFlagsConfig.RuntimeID) != 0 { + m.RuntimeID = yrFlagsConfig.RuntimeID + } + + if m.IsDriver { + id := uuid.New() + m.JobID = fmt.Sprintf("job-%s", strings.ReplaceAll(id.String(), "-", "")[:8]) + } else { + m.JobID = yrFlagsConfig.JobID + } + callStackLayerNum, enableCallStack := getCallStackTracesMgr() + m.EnableCallStack = enableCallStack + m.CallStackLayerNum = callStackLayerNum + return nil +} + +// ConvertFunctionUrn2Id convert a functionUrn to id +func ConvertFunctionUrn2Id(functionUrn string) (string, error) { + functionUrns := strings.Split(functionUrn, ":") + if len(functionUrns) != 7 { + return "", fmt.Errorf("functionUrn is not valid") + } + + return fmt.Sprintf("%s/%s/%s", functionUrns[3], functionUrns[5], functionUrns[6]), nil +} diff --git a/api/go/yr/config_manager_test.go b/api/go/yr/config_manager_test.go new file mode 100644 index 0000000..751239e --- /dev/null +++ b/api/go/yr/config_manager_test.go @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "strings" + "testing" +) + +func TestConvertFunctionUrn2Id(t *testing.T) { + // convert correct value + functionUrn := "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest" + if functionId, err := ConvertFunctionUrn2Id(functionUrn); functionId != "12345678901234561234567890123456/0-opc-opc/$latest" || err != nil { + t.Errorf("Convert %s failed.", functionUrn) + } + + // convert invalid value + functionUrn = "sn:cn:yrk:12345678901234561234567890123456^function:0-opc-opc" + if functionId, err := ConvertFunctionUrn2Id(functionUrn); len(functionId) > 0 || err == nil { + t.Errorf("Convert %s failed.", functionUrn) + } +} + +func TestConfigManagerInit(t *testing.T) { + yrConfig := &Config{ + Mode: ClusterMode, + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + ServerAddr: "10.90.41.99:31220", + DataSystemAddr: "10.90.41.99:31501", + DataSystemAgentAddr: "", + InCluster: true, + LogLevel: "INFO", + } + configManager := GetConfigManager() + res := configManager.Init(yrConfig, &FlagsConfig{}) + if res != nil { + t.Errorf("configManager init failed") + } + + if configManager.Config.FunctionSystemAddress != "10.90.41.99:31220" { + t.Errorf("FunctionSystemAddress init failed") + } + + if configManager.Config.DataSystemAddress != "10.90.41.99:31501" { + t.Errorf("DataSystemIPAddr init failed") + } + + if !configManager.Config.IsDriver { + t.Errorf("IsDriver init failed") + } + + if !strings.Contains(configManager.Config.JobID, "job-") || len(configManager.Config.JobID) != 12 { + t.Errorf("JobID init failed") + } + + if configManager.Config.RuntimeID != "driver" { + t.Errorf("RuntimeID init failed") + } + + if configManager.Config.LogLevel != "INFO" { + t.Errorf("LogLevel init failed") + } + + if !configManager.Config.InCluster { + t.Errorf("InCluster init failed") + } + + if configManager.FunctionId != "12345678901234561234567890123456/0-opc-opc/$latest" { + t.Errorf("FunctionId init failed") + } + + if configManager.Mode != ClusterMode { + t.Errorf("Mode init failed") + } +} diff --git a/api/go/yr/config_test.go b/api/go/yr/config_test.go new file mode 100644 index 0000000..7e12c7c --- /dev/null +++ b/api/go/yr/config_test.go @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestNewInvokeOptions(t *testing.T) { + convey.Convey( + "Test NewInvokeOptions", t, func() { + convey.Convey( + "NewInvokeOptions success", func() { + option := NewInvokeOptions() + convey.So(option, convey.ShouldNotBeNil) + convey.So(len(option.ScheduleAffinities), convey.ShouldEqual, 0) + }, + ) + }, + ) +} diff --git a/api/go/yr/function_handler.go b/api/go/yr/function_handler.go new file mode 100644 index 0000000..1aaeaf0 --- /dev/null +++ b/api/go/yr/function_handler.go @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "fmt" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// FunctionHandler provides method to execute the function +type FunctionHandler struct { + funcMeta api.FunctionMeta + invokeOptions api.InvokeOptions +} + +// Function return a *FunctionHandler +func Function(fn any) *FunctionHandler { + if !IsFunction(fn) { + panic(fmt.Errorf("fn parament is not a function")) + } + functionName, _ := GetFunctionName(fn) + funcMeta := api.FunctionMeta{ + FuncName: functionName, + FuncID: GetConfigManager().FunctionId, + Language: api.Golang, + } + + invokeOptions := api.InvokeOptions{ + Cpu: 500, + Memory: 500, + InvokeLabels: make(map[string]string), + CustomResources: make(map[string]float64), + CustomExtensions: make(map[string]string), + Labels: make([]string, 0), + Affinities: make(map[string]string), + ScheduleAffinities: make([]api.Affinity, 0), + RetryTimes: 0, + } + return &FunctionHandler{ + funcMeta: funcMeta, + invokeOptions: invokeOptions, + } +} + +// Options set function option +func (handler *FunctionHandler) Options(options *InvokeOptions) *FunctionHandler { + invokeOptions := api.InvokeOptions{ + Cpu: options.Cpu, + Memory: options.Memory, + InvokeLabels: options.InvokeLabels, + CustomResources: options.CustomResources, + CustomExtensions: options.CustomExtensions, + Labels: options.Labels, + Affinities: options.Affinity, + ScheduleAffinities: options.ScheduleAffinities, + RetryTimes: int(options.RetryTime), + } + handler.invokeOptions = invokeOptions + return handler +} + +// Invoke function invoke +func (handler *FunctionHandler) Invoke(args ...any) (refs []*ObjectRef) { + packedArgs, err := PackInvokeArgs(args...) + if err != nil { + panic(fmt.Errorf(err.Error())) + } + + objId, err := GetRuntimeHolder().GetRuntime().InvokeByFunctionName( + handler.funcMeta, packedArgs, handler.invokeOptions, + ) + if err != nil { + panic(fmt.Sprintf("invoke failed, err: %v", err)) + } + + objRef := ObjectRef{objId: objId} + refs = append(refs, &objRef) + return +} diff --git a/api/go/yr/function_handler_test.go b/api/go/yr/function_handler_test.go new file mode 100644 index 0000000..50060ec --- /dev/null +++ b/api/go/yr/function_handler_test.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestFunctionHandler(t *testing.T) { + convey.Convey( + "Test FunctionHandler", t, func() { + convey.Convey( + "Function success when !IsFunction", func() { + convey.So(func() { + Function("") + }, convey.ShouldPanic) + }, + ) + handler := Function(PlusOne) + convey.Convey( + "Function success", func() { + convey.So(handler, convey.ShouldNotBeNil) + }, + ) + opts := new(InvokeOptions) + newHandler := handler.Options(opts) + convey.Convey( + "Options success", func() { + convey.So(newHandler, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Invoke success", func() { + refs := newHandler.Invoke(2, 3) + convey.So(refs, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/yr/instance_creator.go b/api/go/yr/instance_creator.go new file mode 100644 index 0000000..fba91b9 --- /dev/null +++ b/api/go/yr/instance_creator.go @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "fmt" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// InstanceCreator provides method to create instance +type InstanceCreator struct { + meta api.FunctionMeta + options api.InvokeOptions +} + +// Options set function option +func (instanceCreator *InstanceCreator) Options(options *InvokeOptions) *InstanceCreator { + invokeOptions := api.InvokeOptions{ + Cpu: options.Cpu, + Memory: options.Memory, + InvokeLabels: options.InvokeLabels, + CustomResources: options.CustomResources, + CustomExtensions: options.CustomExtensions, + Labels: options.Labels, + ScheduleAffinities: options.ScheduleAffinities, + Affinities: options.Affinity, + RetryTimes: int(options.RetryTime), + RecoverRetryTimes: options.RecoverRetryTimes, + } + instanceCreator.options = invokeOptions + return instanceCreator +} + +// Invoke create instance +func (instanceCreator *InstanceCreator) Invoke(args ...any) *NamedInstance { + packedArgs, err := PackInvokeArgs(args...) + if err != nil { + panic(fmt.Errorf(err.Error())) + } + instanceId, err := GetRuntimeHolder().GetRuntime().CreateInstance( + instanceCreator.meta, packedArgs, instanceCreator.options, + ) + if err != nil { + panic(err) + } + handler := NewNamedInstance(instanceId) + return handler +} + +// Instance return a *InstanceCreator +func Instance(fn any) *InstanceCreator { + if !IsFunction(fn) { + panic(fmt.Errorf("paramenter is not a function")) + } + functionName, _ := GetFunctionName(fn) + funcMeta := api.FunctionMeta{ + FuncName: functionName, + Language: api.Golang, + FuncID: GetConfigManager().FunctionId, + } + invokeOptions := api.InvokeOptions{ + Cpu: 500, + Memory: 500, + InvokeLabels: make(map[string]string), + CustomResources: make(map[string]float64), + CustomExtensions: make(map[string]string), + Labels: make([]string, 0), + ScheduleAffinities: make([]api.Affinity, 0), + Affinities: make(map[string]string), + RetryTimes: 0, + Timeout: 60, + } + return &InstanceCreator{meta: funcMeta, options: invokeOptions} +} diff --git a/api/go/yr/instance_creator_test.go b/api/go/yr/instance_creator_test.go new file mode 100644 index 0000000..5d420aa --- /dev/null +++ b/api/go/yr/instance_creator_test.go @@ -0,0 +1,71 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +type CounterTest struct { + Count int +} + +func (c *CounterTest) Add(x int) int { + c.Count += x + return c.Count +} + +func NewCounter(init int) *CounterTest { + return &CounterTest{Count: init} +} + +func TestInstanceCreator(t *testing.T) { + convey.Convey( + "Test InstanceCreator", t, func() { + convey.Convey( + "InstanceCreator success when !IsFunction", func() { + convey.So(func() { + Instance("") + }, convey.ShouldPanic) + }, + ) + creator := Instance(NewCounter(1).Add) + convey.Convey( + "InstanceCreator success", func() { + convey.So(creator, convey.ShouldNotBeNil) + }, + ) + opts := new(InvokeOptions) + newCreator := creator.Options(opts) + convey.Convey( + "Options success", func() { + convey.So(newCreator, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Invoke success", func() { + convey.So(func() { + newCreator.Invoke(2, 3) + }, convey.ShouldPanic) + }, + ) + }, + ) +} diff --git a/api/go/yr/instance_function_handler.go b/api/go/yr/instance_function_handler.go new file mode 100644 index 0000000..dc913ab --- /dev/null +++ b/api/go/yr/instance_function_handler.go @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "fmt" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// InstanceFunctionHandler provides method to execute the function +type InstanceFunctionHandler struct { + instanceId string + funcMeta api.FunctionMeta +} + +// NewInstanceFunctionHandler return a *InstanceFunctionHandler +func NewInstanceFunctionHandler(fn any, instanceId string) *InstanceFunctionHandler { + if !IsFunction(fn) { + panic(fmt.Errorf("fn parameter is not a function")) + } + functionName, _ := GetFunctionName(fn) + meta := api.FunctionMeta{ + FuncName: functionName, + FuncID: GetConfigManager().FunctionId, + Language: api.Golang, + } + return &InstanceFunctionHandler{funcMeta: meta, instanceId: instanceId} +} + +// Invoke function invoke +func (handler *InstanceFunctionHandler) Invoke(args ...any) (refs []*ObjectRef) { + packedArgs, err := PackInvokeArgs(args...) + if err != nil { + panic(fmt.Errorf(err.Error())) + } + objId, err := GetRuntimeHolder().GetRuntime().InvokeByInstanceId(handler.funcMeta, handler.instanceId, + packedArgs, api.InvokeOptions{}) + if err != nil { + panic(fmt.Sprintf("invoke failed, err: %v", err)) + } + + objRef := ObjectRef{objId: objId} + refs = append(refs, &objRef) + return +} diff --git a/api/go/yr/instance_function_handler_test.go b/api/go/yr/instance_function_handler_test.go new file mode 100644 index 0000000..0d81bfe --- /dev/null +++ b/api/go/yr/instance_function_handler_test.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestInstanceFunctionHandler(t *testing.T) { + convey.Convey( + "Test InstanceFunctionHandler", t, func() { + convey.Convey( + "InstanceFunctionHandler success when !isFunction", func() { + convey.So(func() { + NewInstanceFunctionHandler("", "") + }, convey.ShouldPanic) + }, + ) + funcHandler := NewInstanceFunctionHandler((*CounterTest).Add, "") + convey.Convey( + "InstanceFunctionHandler success ", func() { + convey.So(funcHandler, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Invoke success ", func() { + refs := funcHandler.Invoke(1, 2) + convey.So(refs, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/yr/named_instance.go b/api/go/yr/named_instance.go new file mode 100644 index 0000000..6f53b76 --- /dev/null +++ b/api/go/yr/named_instance.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "github.com/vmihailenco/msgpack" + + "yuanrong.org/kernel/runtime/libruntime/config" +) + +// NamedInstance provides function to construct *InstanceFunctionHandler +type NamedInstance struct { + instanceId string +} + +// NewNamedInstance return *NamedInstance +func NewNamedInstance(instanceId string) *NamedInstance { + instance := &NamedInstance{ + instanceId: instanceId, + } + return instance +} + +// EncodeMsgpack serialize +func (instance NamedInstance) EncodeMsgpack(enc *msgpack.Encoder) error { + return enc.EncodeMulti(instance.instanceId) +} + +// DecodeMsgpack deserialize +func (instance *NamedInstance) DecodeMsgpack(dec *msgpack.Decoder) error { + return dec.DecodeMulti(&instance.instanceId) +} + +// Function return *InstanceFunctionHandler +func (instance *NamedInstance) Function(fn any) *InstanceFunctionHandler { + return NewInstanceFunctionHandler(fn, instance.instanceId) +} + +// Terminate kill instance +func (instance *NamedInstance) Terminate() error { + return GetRuntimeHolder().GetRuntime().Kill(instance.instanceId, config.KillInstance, make([]byte, 0)) +} diff --git a/api/go/yr/named_instance_test.go b/api/go/yr/named_instance_test.go new file mode 100644 index 0000000..46feca3 --- /dev/null +++ b/api/go/yr/named_instance_test.go @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "bytes" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/vmihailenco/msgpack" +) + +func TestNewNamedInstance(t *testing.T) { + convey.Convey( + "Test NamedInstance", t, func() { + instance := NewNamedInstance("") + convey.Convey( + "NewNamedInstance success", func() { + convey.So(instance, convey.ShouldNotBeNil) + }, + ) + buf := new(bytes.Buffer) + convey.Convey( + "EncodeMsgpack success", func() { + enc := msgpack.NewEncoder(buf) + err := instance.EncodeMsgpack(enc) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "DecodeMsgpack success", func() { + dec := msgpack.NewDecoder(buf) + err := instance.DecodeMsgpack(dec) + convey.So(err.Error(), convey.ShouldEqual, "EOF") + }, + ) + convey.Convey( + "Function success", func() { + ifh := instance.Function((*CounterTest).Add) + convey.So(ifh, convey.ShouldNotBeNil) + }, + ) + convey.Convey( + "Terminate success", func() { + err := instance.Terminate() + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} diff --git a/api/go/yr/object_ref.go b/api/go/yr/object_ref.go new file mode 100644 index 0000000..0c11cb5 --- /dev/null +++ b/api/go/yr/object_ref.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "reflect" + + "github.com/vmihailenco/msgpack" + + "yuanrong.org/kernel/runtime/libruntime/common/uuid" +) + +// ObjectRef Reference to a data object +type ObjectRef struct { + objId string + refType reflect.Type +} + +// EncodeMsgpack serialize +func (ref ObjectRef) EncodeMsgpack(enc *msgpack.Encoder) error { + return enc.EncodeMulti(ref.objId) +} + +// DecodeMsgpack deserialize +func (ref *ObjectRef) DecodeMsgpack(dec *msgpack.Decoder) error { + return dec.DecodeMulti(&ref.objId) +} + +// NewObjectRef return a *ObjectRef +func NewObjectRef() *ObjectRef { + objId := "yr-api-obj-" + uuid.New().String() + o := &ObjectRef{objId: objId} + return o +} diff --git a/api/go/yr/object_ref_test.go b/api/go/yr/object_ref_test.go new file mode 100644 index 0000000..09fb328 --- /dev/null +++ b/api/go/yr/object_ref_test.go @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "bytes" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/vmihailenco/msgpack" +) + +func TestNewObjectRef(t *testing.T) { + convey.Convey( + "Test NewObjectRef", t, func() { + convey.Convey( + "NewObjectRef success", func() { + objRef := NewObjectRef() + convey.So(objRef, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestDecodeMsgpack(t *testing.T) { + convey.Convey( + "Test DecodeMsgpack", t, func() { + convey.Convey( + "DecodeMsgpack success", func() { + objRef := NewObjectRef() + r := bytes.NewReader([]byte{}) + err := objRef.DecodeMsgpack(msgpack.NewDecoder(r)) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/yr/runtime.go b/api/go/yr/runtime.go new file mode 100644 index 0000000..ee9a353 --- /dev/null +++ b/api/go/yr/runtime.go @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for init and start +package yr + +import ( + "os" + + "yuanrong.org/kernel/runtime/libruntime" + "yuanrong.org/kernel/runtime/libruntime/common" + "yuanrong.org/kernel/runtime/libruntime/config" +) + +// InitRuntime init runtime +func InitRuntime() error { + conf := common.GetConfig() + intfs := newActorFuncExecutionIntfs() + hooks := config.HookIntfs{ + LoadFunctionCb: intfs.LoadFunction, + FunctionExecutionCb: intfs.FunctionExecute, + CheckpointCb: intfs.Checkpoint, + RecoverCb: intfs.Recover, + ShutdownCb: intfs.Shutdown, + SignalCb: intfs.Signal, + } + + yrConf := &Config{ + InCluster: true, + Mode: ClusterMode, + DataSystemAddr: os.Getenv("DATASYSTEM_ADDR"), + LogLevel: conf.LogLevel, + AutoStart: true, + } + + yrFlagsConfig := &FlagsConfig{ + RuntimeID: conf.RuntimeID, + InstanceID: conf.InstanceID, + LogPath: conf.LogPath, + JobID: conf.JobID, + Hooks: hooks, + } + + _, err := InitWithFlags(yrConf, yrFlagsConfig) + return err +} + +// Run begins loop processing the received request. +func Run() { + libruntime.ReceiveRequestLoop() +} diff --git a/api/go/yr/runtime_holder.go b/api/go/yr/runtime_holder.go new file mode 100644 index 0000000..debcf7b --- /dev/null +++ b/api/go/yr/runtime_holder.go @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import "yuanrong.org/kernel/runtime/libruntime/api" + +var holder *RuntimeHolder = nil + +// RuntimeHolder holds a libRuntimeAPI +type RuntimeHolder struct { + runtime api.LibruntimeAPI +} + +// GetRuntimeHolder return RuntimeHolder +func GetRuntimeHolder() *RuntimeHolder { + if holder != nil { + return holder + } + + holder = &RuntimeHolder{} + return holder +} + +// Init initializes the RuntimeHolder +func (holder *RuntimeHolder) Init(runtime api.LibruntimeAPI) { + if holder != nil { + holder.runtime = runtime + } +} + +// GetRuntime return libRuntimeAPI +func (holder *RuntimeHolder) GetRuntime() api.LibruntimeAPI { + if holder != nil { + return holder.runtime + } + return nil +} diff --git a/api/go/yr/runtime_holder_test.go b/api/go/yr/runtime_holder_test.go new file mode 100644 index 0000000..95c4f79 --- /dev/null +++ b/api/go/yr/runtime_holder_test.go @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestGetRuntimeHolder(t *testing.T) { + convey.Convey( + "Test GetRuntimeHolder", t, func() { + convey.Convey( + "GetRuntimeHolder success", func() { + rtHolder := GetRuntimeHolder() + convey.So(rtHolder, convey.ShouldNotBeNil) + }, + ) + }, + ) +} diff --git a/api/go/yr/runtime_test.go b/api/go/yr/runtime_test.go new file mode 100644 index 0000000..c8ec538 --- /dev/null +++ b/api/go/yr/runtime_test.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for init and start +package yr + +import ( + "os" + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/common" +) + +func TestInitRuntimeAndRun(t *testing.T) { + convey.Convey( + "Test InitRuntimeAndRun", t, func() { + cfg := common.GetConfig() + cfg.DriverMode = true + os.Setenv("FUNCTION_LIB_PATH", "/tmp") + convey.Convey("Test InitRuntime Failed", func() { + err := InitRuntime() + Run() + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/api/go/yr/stacktrace.go b/api/go/yr/stacktrace.go new file mode 100644 index 0000000..1d58555 --- /dev/null +++ b/api/go/yr/stacktrace.go @@ -0,0 +1,170 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "fmt" + "os" + "reflect" + "regexp" + "runtime" + "strconv" + "strings" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +const defaultClassName = "main" + +func getFuncInfos(s string) (string, string, string, error) { + var methodName, parameters string + items := strings.Split(s, ".") + if len(items) > 0 { + s = strings.TrimPrefix(s, items[0]+".") + } + if s == "" { + return "", "", "", fmt.Errorf("failed to parse stack") + } + sByte := []byte(s) + l := len(sByte) + if sByte[l-1] != ')' { + // no param + return defaultClassName, s, "", nil + } + for j := l - 1; j >= 0; j-- { + if sByte[j] == '(' { + parameters = string(sByte[j:l]) + break + } + } + methodName = strings.TrimSuffix(s, parameters) + return defaultClassName, methodName, parameters, nil +} + +func getFileInfos(s string) (string, string, int64, error) { + var fileName, offset string + var linNumber int64 + fileInfoSep := "[: ]" + fileInfos := regexp.MustCompile(fileInfoSep).Split(strings.TrimSpace(s), -1) + // for example: [a.go 22 0x03], lengh should larger then 2 + if len(fileInfos) >= 2 { + fileName = fileInfos[0] + line, err := strconv.Atoi(fileInfos[1]) + if err != nil { + return "", "", 0, err + } else { + linNumber = int64(line) + } + if len(fileInfos) > 2 { + offset = fileInfos[2] + } + } + return fileName, offset, linNumber, nil +} + +func parseStack(funcInfo, fileInfo string) (api.StackTrace, error) { + className, methodName, parameters, err := getFuncInfos(funcInfo) + if err != nil { + fmt.Println(err.Error(), funcInfo) + return api.StackTrace{}, err + } + fineName, offset, linNumber, err := getFileInfos(fileInfo) + if err != nil { + fmt.Println(err.Error(), funcInfo) + return api.StackTrace{}, err + } + stack := api.StackTrace{ + ClassName: className, + MethodName: methodName, + FileName: fineName, + LineNumber: linNumber, + ExtensionInfo: make(map[string]string, 10), + } + if parameters != "" { + stack.ExtensionInfo["parameters"] = parameters + } + if offset != "" { + stack.ExtensionInfo["offset"] = offset + } + return stack, nil +} + +func getStackTraceInfos() []api.StackTrace { + if !GetConfigManager().EnableCallStack { + return nil + } + stackNum := GetConfigManager().CallStackLayerNum + stack := make([]byte, 1024*1024) + n := runtime.Stack(stack, false) + info := string(stack[:n]) + items := strings.Split(info, "\n") + var stacks []api.StackTrace + for i := 0; i < len(items); i++ { + if (strings.HasPrefix(items[i], "plugin") || strings.HasPrefix(items[i], "main")) && i+1 < len(items) { + stack, err := parseStack(items[i], items[i+1]) + if err != nil { + continue + } + stacks = append(stacks, stack) + if len(stacks) == stackNum { + break + } + i++ + } + } + return stacks +} + +// GetStackTrace return api.StackTrace +func GetStackTrace(instance *reflect.Value, functionName string) api.StackTrace { + if !GetConfigManager().EnableCallStack { + return api.StackTrace{} + } + return api.StackTrace{ + ClassName: defaultClassName, + MethodName: getMethodName(instance, functionName), + } +} + +func getMethodName(instance *reflect.Value, funcName string) string { + if instance == nil { + return funcName + } + // <*main.Item Value> + s := instance.String() + infos := strings.Split(s, ".") + // for example: [<*main Item Value>], lengh should larger then 2 + if len(infos) < 2 { + return funcName + } + items := strings.Split(infos[1], " ") + return fmt.Sprintf("(*%s).%s", items[0], funcName) +} + +func getCallStackTracesMgr() (int, bool) { + isEnable := os.Getenv("ENABLE_DIS_CONV_CALL_STACK") != "false" + if !isEnable { + return 0, false + } + callStackLayer := os.Getenv("MAX_CALL_STACK_LAYER_NUM") + layerNum, err := strconv.Atoi(callStackLayer) + if err != nil { + layerNum = defaultStackTracesNum + } + return layerNum, true +} diff --git a/api/go/yr/stacktrace_test.go b/api/go/yr/stacktrace_test.go new file mode 100644 index 0000000..2f97c86 --- /dev/null +++ b/api/go/yr/stacktrace_test.go @@ -0,0 +1,128 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "reflect" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestGetFuncInfos(t *testing.T) { + convey.Convey( + "Test getFuncInfos", t, func() { + var str, str1, str2, str3 string + var err error + + convey.Convey( + "when s == \"\" success", func() { + str = "" + str1, str2, str3, err = getFuncInfos(str) + convey.So(str1, convey.ShouldBeEmpty) + convey.So(str2, convey.ShouldBeEmpty) + convey.So(str3, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + + convey.Convey( + "when sByte[l-1] != ')' success", func() { + str = "1.1" + str1, str2, str3, err = getFuncInfos(str) + convey.So(str1, convey.ShouldEqual, defaultClassName) + convey.So(str2, convey.ShouldEqual, "1") + convey.So(str3, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestGetStackTraceInfos(t *testing.T) { + stacks := getStackTraceInfos() + if len(stacks) != 0 { + t.Errorf("failed to test getStackTraceInfos") + } +} + +func TestParseStack(t *testing.T) { + GetConfigManager().EnableCallStack = true + funcInfo := "plugin/unnamed-6748bd019e0aa0de0891ed0eb3826f47e13fb900.Functiond(0x45,0x57)" + fileInfo := "./test_exception.go:34 +0x265" + stack, err := parseStack(funcInfo, fileInfo) + if err != nil { + t.Errorf("parse stack shoud success %s", err.Error()) + return + } + if stack.MethodName != "Functiond" { + t.Errorf("methodName should be Functiond, but get %s", stack.MethodName) + return + } + if stack.LineNumber != 34 { + t.Errorf("methodName should be 34, but get %d", stack.LineNumber) + return + } + fileInfo = "./test_exception.go:aa +0x265" + stack, err = parseStack(funcInfo, fileInfo) + if err == nil { + t.Errorf("parse stack %s shoud be failed", fileInfo) + return + } +} + +func TestGetCallStackTracesMgr(t *testing.T) { + callNum, enbaleCall := getCallStackTracesMgr() + if callNum != defaultStackTracesNum || enbaleCall == false { + t.Error("failed call stack configure") + return + } +} + +func TestGetStackTrace(t *testing.T) { + GetConfigManager().EnableCallStack = true + + type mockInterface struct{} + m := &mockInterface{} + r := reflect.ValueOf(m) + stack := GetStackTrace(&r, "mockFunction") + if stack.ClassName != defaultClassName || stack.MethodName != "(*mockInterface).mockFunction" { + t.Errorf("failed to get stack, %v", stack) + } + stack = GetStackTrace(nil, "mockFunction") + if stack.ClassName != defaultClassName || stack.MethodName != "mockFunction" { + t.Errorf("failed to get stack, %v", stack) + } +} + +func TestGetMethodName(t *testing.T) { + convey.Convey( + "Test getMethodName", t, func() { + convey.Convey( + "when len(infos) < 2 success", func() { + str := "main" + instance := reflect.ValueOf(str) + + res := getMethodName(&instance, "funcName") + convey.So(res, convey.ShouldEqual, "funcName") + }, + ) + }, + ) +} diff --git a/api/go/yr/stream.go b/api/go/yr/stream.go new file mode 100644 index 0000000..d3b37ce --- /dev/null +++ b/api/go/yr/stream.go @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package yr +This package encapsulates all cgo invocations. +*/ +package yr + +import "C" +import ( + "sync" + "unsafe" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// SubscriptionType subscribe type +type SubscriptionType int + +// Element data to be sent or receive +type Element struct { + Data []byte + Size int + Id int +} + +// ProducerConf represents configuration information for a producer. +type ProducerConf struct { + DelayFlushTime int + PageSize int + MaxStreamSize int +} + +// Producer struct represents a producer of streaming data. +type Producer struct { + producer api.StreamProducer + mutex *sync.RWMutex +} + +// Consumer struct represents a consumer of streaming data. +type Consumer struct { + consumer api.StreamConsumer + mutex *sync.RWMutex +} + +// Send sends an element. +// This method can be used to send data to consumers. +func (producer *Producer) Send(data []byte) error { + element := api.Element{ + Size: uint64(len(data)), + Ptr: (*uint8)(unsafe.Pointer(&data[0])), + } + producer.mutex.Lock() + defer producer.mutex.Unlock() + return producer.producer.Send(element) +} + +// Flush ensure flush buffered data so that it is visible to the consumer. +func (producer *Producer) Flush() error { + producer.mutex.Lock() + defer producer.mutex.Unlock() + return producer.producer.Flush() +} + +// Close signals the producer to stop accepting new data and automatically flushes +// any pending data in the buffer. Once closed, the producer is no longer available. +func (producer *Producer) Close() error { + producer.mutex.Lock() + defer producer.mutex.Unlock() + return producer.producer.Close() +} + +// Receive retrieves data from the consumer with an optional timeout. +// Parameters: +// - expectNum: The expected number of elements to receive. +// - timeoutMs: Maximum time in milliseconds to wait for data before timing out. +// Returns: +// - []api.Element: The received data. +// - error: nil if data was received within the timeout, error otherwise. +func (consumer *Consumer) Receive(expectNum, timeoutMs uint32) ([]*Element, error) { + datas, err := consumer.consumer.ReceiveExpectNum(expectNum, timeoutMs) + if err != nil { + return make([]*Element, 0), err + } + return consumer.receiveArr(datas) +} + +func (consumer *Consumer) receiveArr(datas []api.Element) ([]*Element, error) { + res := make([]*Element, 0, len(datas)) + for _, data := range datas { + element := Element{ + Data: C.GoBytes(unsafe.Pointer(data.Ptr), C.int(data.Size)), + Size: int(data.Size), + Id: int(data.Id)} + res = append(res, &element) + } + return res, nil +} + +// Ack confirms that the consumer has completed processing the element identified by elementID. +// This function signals to other workers whether the consumer has finished processing the element. +// If all consumers have acknowledged processing the element, it triggers internal memory reclamation +// for the corresponding page. +// Parameters: +// - elementID: The identifier of the element that has been consumed. +func (consumer *Consumer) Ack(elementId uint64) error { + return consumer.consumer.Ack(elementId) +} + +// Close closes the consumer, unsubscribing it from further data consumption. +// This method also acknowledges any unacknowledged elements on the consumer, +// ensuring that they are marked as processed before shutting down. +func (consumer *Consumer) Close() error { + return consumer.consumer.Close() +} diff --git a/api/go/yr/stream_test.go b/api/go/yr/stream_test.go new file mode 100644 index 0000000..8765f15 --- /dev/null +++ b/api/go/yr/stream_test.go @@ -0,0 +1,132 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +func TestProducer(t *testing.T) { + convey.Convey( + "Test Producer", t, func() { + convey.Convey( + "Init success", func() { + yrConfig := &Config{ + Mode: ClusterMode, + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + ServerAddr: "127.0.0.1:12345", + DataSystemAddr: "127.0.0.1:12346", + InCluster: true, + AutoStart: false, + } + clientInfo, err := Init(yrConfig) + convey.So(clientInfo.jobId, convey.ShouldNotBeEmpty) + convey.So(err, convey.ShouldBeNil) + convey.Convey( + "CreateProducer success", func() { + producerConf := ProducerConf{ + DelayFlushTime: 5, + PageSize: 1024 * 1024, + MaxStreamSize: 1024 * 1024 * 1024, + } + producer, err := CreateProducer("teststream", producerConf) + convey.So(producer, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) + }, + ) +} + +func TestStream(t *testing.T) { + producer, err := CreateProducer("streamName", ProducerConf{}) + if err != nil { + t.Errorf("create producer failed, err: %s", err) + return + } + consumer, err := Subscribe("streamName", "subStreamName", 0) + if err != nil { + t.Errorf("create consumer failed, err: %s", err) + return + } + + convey.Convey( + "Test producer", t, func() { + convey.Convey( + "Send success", func() { + err = producer.Send([]byte("value1")) + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Flush success", func() { + err = producer.Flush() + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) + + convey.Convey("Test consumer", t, func() { + convey.Convey( + "Receive success", func() { + elements, err1 := consumer.Receive(1, 3000) + convey.So(len(elements), convey.ShouldBeZeroValue) + convey.So(err1, convey.ShouldBeNil) + }, + ) + convey.Convey( + "receiveArr success", func() { + var value uint8 = 97 + data1 := api.Element{Ptr: &value, Size: 1, Id: 1} + elements, err1 := consumer.receiveArr([]api.Element{data1}) + convey.So(len(elements), convey.ShouldEqual, 1) + convey.So(err1, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Ack success", func() { + err = consumer.Ack(0) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) + + convey.Convey("Test Close", t, func() { + convey.Convey( + "Producer Close success", func() { + err = producer.Close() + convey.So(err, convey.ShouldBeNil) + }, + ) + convey.Convey( + "Consumer Close success", func() { + err = consumer.Close() + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} diff --git a/api/go/yr/utils.go b/api/go/yr/utils.go new file mode 100644 index 0000000..2175642 --- /dev/null +++ b/api/go/yr/utils.go @@ -0,0 +1,82 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "fmt" + "reflect" + "runtime" + "strings" + + "github.com/vmihailenco/msgpack" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// IsFunction return if type is Func +func IsFunction(fn any) bool { + return reflect.TypeOf(fn).Kind() == reflect.Func +} + +// GetFunctionName return function name +func GetFunctionName(fn any) (string, error) { + functionName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + functionNameSlice := strings.Split(functionName, ".") + if len(functionNameSlice) >= 1 { + return functionNameSlice[len(functionNameSlice)-1], nil + } + + return "", fmt.Errorf("get function name %s failed", functionName) +} + +// PackInvokeArgs pack func args and return []Arg +func PackInvokeArgs(args ...any) ([]api.Arg, error) { + res := make([]api.Arg, 0, len(args)) + for _, arg := range args { + packedArg, err := PackInvokeArg(arg) + if err != nil { + return res, err + } + res = append(res, packedArg) + } + return res, nil +} + +// PackInvokeArg pack func arg and return Arg +func PackInvokeArg(arg any) (api.Arg, error) { + res := api.Arg{} + switch reflect.TypeOf(arg).String() { + case "yr.ObjectRef": + return res, nil + case "*yr.ObjectRef": + return res, nil + case "[]yr.ObjectRef": + return res, nil + case "[]*yr.ObjectRef": + return res, nil + default: + { + data, err := msgpack.Marshal(arg) + if err != nil { + return res, err + } + res.Data = data + return res, nil + } + } +} diff --git a/api/go/yr/utils_test.go b/api/go/yr/utils_test.go new file mode 100644 index 0000000..34afdea --- /dev/null +++ b/api/go/yr/utils_test.go @@ -0,0 +1,133 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "math" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/vmihailenco/msgpack/v5" +) + +func PlusOne() int { + return 1 +} + +type T struct{} + +func (t *T) A() {} +func (t T) B() {} + +func TestIsFunction(t *testing.T) { + if !IsFunction(PlusOne) { + t.Errorf("PlusOne is a function, judge failed.") + } + + if !IsFunction((*T).A) { + t.Errorf("(*T).A is a function, judge failed.") + } + + if !IsFunction((*T).B) { + t.Errorf("(*T).B is a function, judge failed.") + } + + if !IsFunction(T.B) { + t.Errorf("T.B is a function, judge failed.") + } + + a := func() {} + if !IsFunction(a) { + t.Errorf("a is a function, judge failed.") + } + + b := 1 + if IsFunction(b) { + t.Errorf("b is not a function, judge failed.") + } +} + +func TestPackInvokeArg(t *testing.T) { + convey.Convey( + "Test DeleteStream", t, func() { + convey.Convey( + "arg type: yr.ObjectRef success", func() { + res, err := PackInvokeArg(ObjectRef{}) + convey.So(res, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + + convey.Convey( + "arg type: *yr.ObjectRef success", func() { + objRef := ObjectRef{} + res, err := PackInvokeArg(&objRef) + convey.So(res, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + + convey.Convey( + "arg type: []yr.ObjectRef success", func() { + res, err := PackInvokeArg([]ObjectRef{}) + convey.So(res, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + + convey.Convey( + "arg type: []*yr.ObjectRef success", func() { + res, err := PackInvokeArg([]*ObjectRef{}) + convey.So(res, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + + convey.Convey( + "arg type: else success", func() { + d := 122.0 + res, err := PackInvokeArg(d) + convey.So(err, convey.ShouldBeNil) + var f float64 + err = msgpack.Unmarshal(res.Data, &f) + if err != nil { + t.Errorf("msgpack.Unmarshal err: %s", err) + return + } + convey.So(math.Abs(d-f), convey.ShouldBeLessThanOrEqualTo, 1e-9) + }, + ) + }, + ) +} + +func TestPackInvokeArgs(t *testing.T) { + convey.Convey( + "Test PackInvokeArgs", t, func() { + convey.Convey( + "PackInvokeArgs success", func() { + objRef := ObjectRef{} + args := []any{objRef, &objRef, []ObjectRef{}, []*ObjectRef{}} + res, err := PackInvokeArgs(args) + convey.So(len(res), convey.ShouldNotBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} diff --git a/api/go/yr/yr.go b/api/go/yr/yr.go new file mode 100644 index 0000000..487e3a6 --- /dev/null +++ b/api/go/yr/yr.go @@ -0,0 +1,266 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package yr for actor +package yr + +import ( + "fmt" + "sync" + + "github.com/vmihailenco/msgpack" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// IsInit Used to determine whether initialization has been completed +var IsInit = false + +// Init Initialization entry, which is used to initialize the data system and function system. +func Init(yrConfig *Config) (*ClientInfo, error) { + return InitWithFlags(yrConfig, &FlagsConfig{}) +} + +// InitWithFlags Init with flags args. +func InitWithFlags(yrConfig *Config, yrFlagsConfig *FlagsConfig) (*ClientInfo, error) { + if IsInit { + return GetConfigManager().GetClientInfo(), nil + } + GetConfigManager().Init(yrConfig, yrFlagsConfig) + fmt.Println(GetConfigManager().Config) + runtime := &ClusterModeRuntime{} + if err := runtime.Init(); err != nil { + return nil, err + } + GetRuntimeHolder().Init(runtime) + IsInit = true + return GetConfigManager().GetClientInfo(), nil +} + +// Finalize release resources +func Finalize(all bool) { + GetRuntimeHolder().GetRuntime().Finalize() +} + +// Get to get obj from data system. +func Get[T any](objRef *ObjectRef, timeoutMs int) (res T, err error) { + err = Wait(objRef, timeoutMs/1000) + if err != nil { + return + } + return get[T](objRef, timeoutMs) +} + +func get[T any](objRef *ObjectRef, timeoutMs int) (res T, err error) { + objIds := []string{objRef.objId} + data, err := GetRuntimeHolder().GetRuntime().Get(objIds, timeoutMs) + if err != nil { + stacks := getStackTraceInfos() + err = api.NewErrorInfoWithStackInfo(err, stacks) + return + } + + result := new(T) + err = msgpack.Unmarshal(data[0], result) + if err != nil { + return + } + + return *result, nil +} + +// BatchGet to get objs from data system. +func BatchGet[T any](objRefs []*ObjectRef, timeoutMs int, allowPartial bool) (res []T, err error) { + if len(objRefs) == 0 { + return + } + + objIds := make([]string, 0, len(objRefs)) + for _, ref := range objRefs { + objIds = append(objIds, ref.objId) + } + datas, err := GetRuntimeHolder().GetRuntime().Get(objIds, timeoutMs) + if err != nil { + return + } + + res = make([]T, 0, len(objRefs)) + for _, data := range datas { + result := new(T) + err = msgpack.Unmarshal(data, result) + if err != nil { + return + } + res = append(res, *result) + } + return +} + +// Put put obj data to data system. +func Put(val any) (*ObjectRef, error) { + data, err := msgpack.Marshal(val) + if err != nil { + return nil, err + } + + ref := NewObjectRef() + + objectIds := []string{ref.objId} + _, err = GetRuntimeHolder().GetRuntime().GIncreaseRef(objectIds, make([]string, 0)...) + if err != nil { + return nil, err + } + + err = GetRuntimeHolder().GetRuntime().Put(ref.objId, data, api.PutParam{}, make([]string, 0)...) + if err != nil { + return nil, err + } + return ref, nil +} + +// Wait until result return or timeout +func Wait(objRef *ObjectRef, timeout int) error { + objIds := []string{objRef.objId} + readyIds, _, errs := GetRuntimeHolder().GetRuntime().Wait(objIds, 1, timeout) + if errs != nil { + return waitErr(errs) + } + + if len(readyIds) != 1 { + return fmt.Errorf("wait failed") + } + return nil +} + +func waitErr(errs map[string]error) error { + for _, v := range errs { + stacks := getStackTraceInfos() + return api.NewErrorInfoWithStackInfo(v, stacks) + } + return fmt.Errorf("wait failed") +} + +// WaitNum specify a specific count to wait +func WaitNum(objRefs []*ObjectRef, waitNum uint64, timeout int) (readyObjRefs, unreadyObjRefs []*ObjectRef, err error) { + objIds := make([]string, 0, len(objRefs)) + objIdMap := make(map[string]*ObjectRef) + for _, ref := range objRefs { + objIds = append(objIds, ref.objId) + objIdMap[ref.objId] = ref + } + readyIds, unReadyIds, errors := GetRuntimeHolder().GetRuntime().Wait(objIds, waitNum, timeout) + if errors != nil { + err = fmt.Errorf("wait num failed") + return + } + + readyObjRefs = make([]*ObjectRef, 0, len(readyIds)) + unreadyObjRefs = make([]*ObjectRef, 0, len(unReadyIds)) + for _, id := range readyIds { + readyObjRefs = append(readyObjRefs, objIdMap[id]) + } + + for _, id := range unReadyIds { + unreadyObjRefs = append(unreadyObjRefs, objIdMap[id]) + } + return +} + +// SetKV save binary data to the data system. +func SetKV(key, value string) error { + return GetRuntimeHolder().GetRuntime().KVSet(key, []byte(value), api.SetParam{}) +} + +// GetKV get binary data from data system. +func GetKV(key string, timeoutMs uint) (string, error) { + data, err := GetRuntimeHolder().GetRuntime().KVGet(key, timeoutMs) + if err != nil { + return "", err + } + + return string(data), nil +} + +// GetKVs get multi binary data from data system. +func GetKVs(keys []string, timeoutMs uint, allowPartial bool) ([]string, error) { + datas, err := GetRuntimeHolder().GetRuntime().KVGetMulti(keys, timeoutMs) + if err != nil { + return make([]string, 0), err + } + + res := make([]string, 0, len(keys)) + for _, data := range datas { + res = append(res, string(data)) + } + return res, nil +} + +// DelKV del data from data system. +func DelKV(key string) error { + return GetRuntimeHolder().GetRuntime().KVDel(key) +} + +// DelKVs del multi data from data system. +func DelKVs(keys []string) ([]string, error) { + return GetRuntimeHolder().GetRuntime().KVDelMulti(keys) +} + +// CreateProducer creates and return Producer +func CreateProducer(streamName string, opt ProducerConf) (*Producer, error) { + conf := api.ProducerConf{ + DelayFlushTime: int64(opt.DelayFlushTime), + PageSize: int64(opt.PageSize), + MaxStreamSize: uint64(opt.MaxStreamSize), + } + producer, err := GetRuntimeHolder().GetRuntime().CreateProducer(streamName, conf) + if err != nil { + return nil, err + } + + return &Producer{producer: producer, mutex: new(sync.RWMutex)}, nil +} + +// Subscribe creates and return Consumer +func Subscribe(streamName, subscriptionName string, subscriptionType SubscriptionType) (*Consumer, error) { + conf := api.SubscriptionConfig{ + SubscriptionName: subscriptionName, + SubscriptionType: api.SubscriptionType(subscriptionType), + } + consumer, err := GetRuntimeHolder().GetRuntime().Subscribe(streamName, conf) + if err != nil { + return nil, err + } + + return &Consumer{consumer: consumer, mutex: new(sync.RWMutex)}, nil +} + +// DeleteStream Delete a data flow. +// When the number of global producers and consumers is 0, +// the data flow is no longer used and the metadata related to the data flow is deleted from each worker and the +// master. This function can be invoked on any host node. +func DeleteStream(streamName string) error { + return GetRuntimeHolder().GetRuntime().DeleteStream(streamName) +} + +// QueryGlobalProducersNum Specifies the flow name to query the number of all producers of the flow. +func QueryGlobalProducersNum(streamName string) (uint64, error) { + return GetRuntimeHolder().GetRuntime().QueryGlobalProducersNum(streamName) +} + +// QueryGlobalConsumersNum Specifies the flow name to query the number of all consumers of the flow. +func QueryGlobalConsumersNum(streamName string) (uint64, error) { + return GetRuntimeHolder().GetRuntime().QueryGlobalConsumersNum(streamName) +} diff --git a/api/go/yr/yr_test.go b/api/go/yr/yr_test.go new file mode 100644 index 0000000..931438f --- /dev/null +++ b/api/go/yr/yr_test.go @@ -0,0 +1,354 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package execution +This package provides methods to obtain the execution interface. +*/ +package yr + +import ( + "errors" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +var clientInfo *ClientInfo +var initErr error + +func init() { + yrConfig := &Config{ + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + Mode: ClusterMode, + AutoStart: false, + ServerAddr: "127.0.0.1:12345", + DataSystemAddr: "127.0.0.1:12346", + InCluster: true, + } + clientInfo, initErr = Init(yrConfig) +} + +func TestInit(t *testing.T) { + convey.Convey( + "Test Init", t, func() { + convey.Convey( + "Init success", func() { + convey.So(clientInfo.jobId, convey.ShouldNotBeEmpty) + convey.So(initErr, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestInitWithFlags(t *testing.T) { + convey.Convey( + "Test InitWithFlags", t, func() { + yrConfig := &Config{ + FunctionUrn: "sn:cn:yrk:12345678901234561234567890123456:function:0-opc-opc:$latest", + Mode: ClusterMode, + AutoStart: false, + ServerAddr: "127.0.0.1:12345", + DataSystemAddr: "127.0.0.1:12346", + InCluster: true, + } + yrFlagsConfig := &FlagsConfig{} + convey.Convey( + "InitWithFlags success", func() { + clientInfo, err := InitWithFlags(yrConfig, yrFlagsConfig) + convey.So(clientInfo.jobId, convey.ShouldNotBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestFinalize(t *testing.T) { + convey.Convey( + "Test Finalize", t, func() { + convey.Convey( + "Finalize killAllInstance success", func() { + convey.So(func() { + Finalize(true) + }, convey.ShouldNotPanic) + }, + ) + convey.Convey( + "Finalize notKillAllInstance success", func() { + convey.So(func() { + Finalize(false) + }, convey.ShouldNotPanic) + }, + ) + }, + ) +} + +func TestGet(t *testing.T) { + convey.Convey( + "Test Get", t, func() { + convey.Convey( + "Get success", func() { + objRef, err := Put(250) + if err != nil { + t.Errorf("Put failed, error: %s", err) + return + } + res, err := Get[int](objRef, 3000) + convey.So(res, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestSubGet(t *testing.T) { + convey.Convey( + "Test SubGet", t, func() { + convey.Convey( + "SubGet success", func() { + objRef, err := Put(250) + if err != nil { + t.Errorf("Put failed, error: %s", err) + return + } + res, err := get[int](objRef, 3000) + convey.So(res, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestBatchGet(t *testing.T) { + convey.Convey( + "Test BatchGet", t, func() { + convey.Convey( + "BatchGet success", func() { + obj, err := Put(250) + if err != nil { + t.Errorf("Put failed, error: %s", err) + return + } + obj1, err := Put(2560) + if err != nil { + t.Errorf("Put failed, error: %s", err) + return + } + objs := []*ObjectRef{obj, obj1} + + res, err := BatchGet[int](objs, 3000, false) + convey.So(len(res), convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestPut(t *testing.T) { + convey.Convey( + "Test Put", t, func() { + convey.Convey( + "Put success", func() { + objRef, err := Put(250) + if err != nil { + t.Errorf("Put failed, error: %s", err) + return + } + convey.So(objRef.objId, convey.ShouldNotBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestWaitErr(t *testing.T) { + convey.Convey( + "Test waitErr", t, func() { + convey.Convey( + "waitErr success", func() { + errs := map[string]error{"1": errors.New("wait failed")} + err := waitErr(errs) + convey.So(err, convey.ShouldNotBeNil) + }, + ) + }, + ) +} + +func TestWaitNum(t *testing.T) { + convey.Convey( + "Test WaitNum", t, func() { + convey.Convey( + "WaitNum success", func() { + objRef, err := Put(250) + if err != nil { + t.Errorf("Put failed, error: %s", err) + return + } + + readyObjRefs, unreadyObjRefs, err := WaitNum([]*ObjectRef{objRef}, 1, 3000) + convey.So(len(readyObjRefs), convey.ShouldBeZeroValue) + convey.So(len(unreadyObjRefs), convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestSetKV(t *testing.T) { + convey.Convey( + "Test SetKV", t, func() { + convey.Convey( + "SetKV success", func() { + err := SetKV("key1", "value1") + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestGetKV(t *testing.T) { + convey.Convey( + "Test GetKV", t, func() { + convey.Convey( + "GetKV success", func() { + str, err := GetKV("key1", 3000) + convey.So(str, convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestGetKVs(t *testing.T) { + convey.Convey( + "Test GetKVs", t, func() { + convey.Convey( + "GetKVs success", func() { + values, err := GetKVs([]string{"key1"}, 3000, true) + convey.So(values[0], convey.ShouldBeEmpty) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestDelKV(t *testing.T) { + convey.Convey( + "Test DelKV", t, func() { + convey.Convey( + "DelKV success", func() { + err := DelKV("key1") + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestDelKVs(t *testing.T) { + convey.Convey( + "Test DelKVs", t, func() { + convey.Convey( + "DelKVs success", func() { + failedKeys, err := DelKVs([]string{"key1", "key2"}) + convey.So(len(failedKeys), convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestCreateProducer(t *testing.T) { + convey.Convey( + "Test CreateProducer", t, func() { + convey.Convey( + "CreateProducer success", func() { + producer, err := CreateProducer("streamName", ProducerConf{}) + convey.So(producer.producer, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestSubscribe(t *testing.T) { + convey.Convey( + "Test Subscribe", t, func() { + convey.Convey( + "Subscribe success", func() { + consumer, err := Subscribe("streamName", "subscriptionName", 0) + convey.So(consumer.consumer, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestDeleteStream(t *testing.T) { + convey.Convey( + "Test DeleteStream", t, func() { + convey.Convey( + "DeleteStream success", func() { + err := DeleteStream("streamName") + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestQueryGlobalProducersNum(t *testing.T) { + convey.Convey( + "Test QueryGlobalProducersNum", t, func() { + convey.Convey( + "QueryGlobalProducersNum success", func() { + producersNum, err := QueryGlobalProducersNum("streamName") + convey.So(producersNum, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} + +func TestQueryGlobalConsumersNum(t *testing.T) { + convey.Convey( + "Test QueryGlobalConsumersNum", t, func() { + convey.Convey( + "QueryGlobalConsumersNum success", func() { + consumersNum, err := QueryGlobalConsumersNum("streamName") + convey.So(consumersNum, convey.ShouldBeZeroValue) + convey.So(err, convey.ShouldBeNil) + }, + ) + }, + ) +} diff --git a/api/java/example/GoInstanceExample.java b/api/java/example/GoInstanceExample.java new file mode 100644 index 0000000..ef73db8 --- /dev/null +++ b/api/java/example/GoInstanceExample.java @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package example; + +import com.yuanrong.Config; + +public class GoInstanceExample { + public static void main(String[] args) throws Exception { + //! [GoFunctionHandle options 样例代码] + InvokeOptions invokeOptions = new InvokeOptions(); + invokeOptions.setCpu(1500); + invokeOptions.setMemory(1500); + GoFunction goFunction = GoFunction.of("PlusOne", int.class, 1); + GoFunctionHandler goFuncHandler = YR.function(goFunction).options(invokeOptions); + ObjectRef ref = goFuncHandler.invoke(); + int result = YR.get(ref, 15); + //! [GoFunctionHandle options 样例代码] + + //! [GoInstanceCreator options 样例代码] + InvokeOptions invokeOptions = new InvokeOptions(); + invokeOptions.setCpu(1500); + invokeOptions.setMemory(1500); + GoInstanceCreator goInstanceCreator = YR.instance(GoInstanceClass.of("Counter")).options(invokeOptions); + GoInstanceHandler goInstanceHandler = goInstanceCreator.invoke(1); + ObjectRef ref = goInstanceHandler.function(GoInstanceMethod.of("Add", int.class)).invoke(5); + int res = (int)YR.get(ref, 100); + //! [GoInstanceCreator options 样例代码] + + //! [GoInstanceHandler function example] + GoInstanceCreator goInstanceCreator = YR.instance(GoInstanceClass.of("Counter")); + GoInstanceHandler goInstanceHandler = goInstanceCreator.invoke(1); + GoInstanceFunctionHandler goInstFuncHandler = goInstanceHandler.function(GoInstanceMethod.of("Add", int.class)); + ObjectRef ref = goInstFuncHandler.invoke(5); + int res = (int)YR.get(ref, 100); + //! [GoInstanceHandler function example] + + //! [GoInstanceFunctionHandler options 样例代码] + InvokeOptions invokeOptions = new InvokeOptions(); + invokeOptions.addCustomExtension("app_name", "myApp"); + GoInstanceHandler goInstanceHandler = YR.instance(GoInstanceClass.of("Counter")).invoke(1); + GoInstanceFunctionHandler goInstFuncHandler = goInstanceHandler.function(GoInstanceMethod.of("Add", int.class)).options(invokeOptions); + ObjectRef ref = goInstFuncHandler.invoke(5); + String res = (String)YR.get(ref, 100); + //! [GoInstanceFunctionHandler options 样例代码] + + //! [GoInstanceHandler terminate example] + GoInstanceCreator goInstanceCreator = YR.instance(GoInstanceClass.of("Counter")); + GoInstanceHandler goInstanceHandler = goInstanceCreator.invoke(1); + goInstanceHandler.terminate(); + //! [GoInstanceHandler terminate example] + + //! [GoInstanceHandler terminate sync example] + GoInstanceCreator goInstanceCreator = YR.instance(GoInstanceClass.of("Counter")); + GoInstanceHandler goInstanceHandler = goInstanceCreator.invoke(1); + goInstanceHandler.terminate(true); + //! [GoInstanceHandler terminate sync example] + } +} \ No newline at end of file diff --git a/api/java/example/InstanceExample.java b/api/java/example/InstanceExample.java index 2049c8f..952d6ab 100644 --- a/api/java/example/InstanceExample.java +++ b/api/java/example/InstanceExample.java @@ -56,7 +56,7 @@ public class InstanceExample { } } public static void main(String[] args) throws Exception { - Config conf = new Config("FunctionURN", "ip", "ip", ""); + Config conf = new Config("FunctionURN", "ip", "ip", "", false); YR.init(conf); MyYRApp myapp = new MyYRApp(); InstanceCreator myYRapp = YR.instance(MyYRApp::new); @@ -76,7 +76,7 @@ public class InstanceExample { } } public static void main(String[] args) throws Exception { - Config conf = new Config("FunctionURN", "ip", "ip", ""); + Config conf = new Config("FunctionURN", "ip", "ip", "", false); YR.init(conf); MyYRApp myapp = new MyYRApp(); // The instance name of this named instance is funcB @@ -104,7 +104,7 @@ public class InstanceExample { } } public static void main(String[] args) throws Exception { - Config conf = new Config("FunctionURN", "ip", "ip", ""); + Config conf = new Config("FunctionURN", "ip", "ip", "", false); YR.init(conf); MyYRApp myapp = new MyYRApp(); // The instance name of this named instance is nsA-funcB diff --git a/api/java/example/OptionsExample.java b/api/java/example/OptionsExample.java index aac5d89..be83e6c 100644 --- a/api/java/example/OptionsExample.java +++ b/api/java/example/OptionsExample.java @@ -30,7 +30,7 @@ public class OptionsExample { public static void main(String[] args) throws YRException { //! [function options 样例代码] - Config conf = new Config("FunctionURN", "ip", "ip", ""); + Config conf = new Config("FunctionURN", "ip", "ip", "", false); YR.init(conf); InvokeOptions opts = new InvokeOptions(); opts.setCpu(300); diff --git a/api/java/example/VoidFunctionExample.java b/api/java/example/VoidFunctionExample.java index fb4a953..825c835 100644 --- a/api/java/example/VoidFunctionExample.java +++ b/api/java/example/VoidFunctionExample.java @@ -26,7 +26,7 @@ import com.yuanrong.call.VoidInstanceFunctionHandler; public class VoidFunctionExample { public static void main(String[] args) throws Exception { //! [VoidFunctionHandler options 样例代码] - Config conf = new Config("FunctionURN", "ip", "ip", ""); + Config conf = new Config("FunctionURN", "ip", "ip", "", false); YR.init(conf); InvokeOptions invokeOptions = new InvokeOptions(); @@ -38,7 +38,7 @@ public class VoidFunctionExample { //! [VoidFunctionHandler options 样例代码] //! [VoidInstanceFunctionHandler options 样例代码] - Config conf = new Config("FunctionURN", "ip", "ip", ""); + Config conf = new Config("FunctionURN", "ip", "ip", "", false); YR.init(conf); InvokeOptions invokeOptions = new InvokeOptions(); diff --git a/api/java/faas-function-sdk/pom.xml b/api/java/faas-function-sdk/pom.xml new file mode 100644 index 0000000..63dda01 --- /dev/null +++ b/api/java/faas-function-sdk/pom.xml @@ -0,0 +1,91 @@ + + + + + + yr-api-java + com.yuanrong + 1.0.0 + ../pom.xml + + + 4.0.0 + + faas-function-sdk + + + UTF-8 + 8 + 8 + 1.0.0 + + + + + com.yuanrong + function-common + + + com.google.code.gson + gson + + + junit + junit + test + + + org.powermock + powermock-module-junit4 + test + + + objenesis + org.objenesis + + + + + org.jacoco + org.jacoco.agent + runtime + test + + + org.powermock + powermock-api-mockito2 + test + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.2 + + 1.8 + 1.8 + + + + + \ No newline at end of file diff --git a/api/java/faas-function-sdk/src/main/java/com/function/CreateOptions.java b/api/java/faas-function-sdk/src/main/java/com/function/CreateOptions.java new file mode 100644 index 0000000..aa72a5d --- /dev/null +++ b/api/java/faas-function-sdk/src/main/java/com/function/CreateOptions.java @@ -0,0 +1,96 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function; + +import com.yuanrong.InvokeOptions; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +import java.util.Map; +import java.util.HashMap; + +/** + * Create options for new function instance + * + * @since 2024-09-06 + */ +@Setter +@Getter +@NoArgsConstructor +public class CreateOptions { + private int cpu = 0; + + private int memory = 0; + private Map aliasParams = new HashMap<>(); + + /** + * construct method + * + * @param memory function instance's memory + */ + public CreateOptions(int memory) { + this(0, memory, new HashMap<>()); + } + + /** + * construct method + * + * @param aliasParams alias params + */ + public CreateOptions(Map aliasParams) { + this(0, 0, aliasParams); + } + + /** + * construct method + * + * @param cpu function instance's cpu + * @param memory function instance's memory + */ + public CreateOptions(int cpu, int memory) { + this.cpu = cpu; + this.memory = memory; + } + + /** + * construct method + * + * @param cpu function instance's cpu + * @param memory function instance's memory + * @param aliasParams function alias params + */ + public CreateOptions(int cpu, int memory, Map aliasParams) { + this.cpu = cpu; + this.memory = memory; + this.aliasParams = aliasParams; + } + + /** + * convertInvokeOptions convert to invoke options + * + * @return InvokeOptions + */ + public InvokeOptions convertInvokeOptions() { + InvokeOptions yrOptions = new InvokeOptions(); + yrOptions.setCpu(this.cpu); + yrOptions.setMemory(this.memory); + yrOptions.setAliasParams(this.aliasParams); + return yrOptions; + } +} diff --git a/api/java/faas-function-sdk/src/main/java/com/function/Function.java b/api/java/faas-function-sdk/src/main/java/com/function/Function.java new file mode 100644 index 0000000..d7a3226 --- /dev/null +++ b/api/java/faas-function-sdk/src/main/java/com/function/Function.java @@ -0,0 +1,174 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function; + +import com.function.common.Util; +import com.function.runtime.exception.InvokeException; +import com.services.model.CallRequest; +import com.services.runtime.Context; +import com.yuanrong.api.InvokeArg; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.exception.LibRuntimeException; +import com.yuanrong.InvokeOptions; +import com.yuanrong.jni.LibRuntime; +import com.yuanrong.libruntime.generated.Libruntime.ApiType; +import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; +import com.yuanrong.libruntime.generated.Libruntime.LanguageType; +import com.yuanrong.runtime.util.Utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Locale; + +/** + * The type Function. + * + * @since 2024/8/15 + */ +public class Function { + private static Logger LOG = LoggerFactory.getLogger(Function.class); + + private final Context context; + + private String functionNameWithVersion; + + private String functionId; + + private CreateOptions createOptions = new CreateOptions(); + + private String instanceId; + + /** + * construct method + * + * @param context context + */ + public Function(Context context) { + this(context, null); + } + + /** + * construct method + * + * @param functionNameWithVersion functionNameWithVersion, eg: javafunc:latest + */ + public Function(String functionNameWithVersion) { + this(null, functionNameWithVersion); + } + + /** + * construct method + * + * @param context context + * @param functionNameWithVersion functionNameWithVersion, eg: javafunc:latest + */ + public Function(Context context, String functionNameWithVersion) { + this.context = context; + this.functionNameWithVersion = functionNameWithVersion; + } + + /** + * options + * + * @param opts the CreateOptions value + * @return Function + */ + public Function options(CreateOptions opts) { + Util.checkDynamicResource(opts.getCpu(), opts.getMemory()); + this.createOptions = opts; + return this; + } + + /** + * SDK invoke api + * + * @param payload passed argument + * @return ObjectRef + */ + public ObjectRef invoke(String payload) { + LOG.info("SDK invoke api beginning"); + String functionService = com.services.runtime.utils.Util.getServiceNameFromEnv(context); + String tenantId = com.services.runtime.utils.Util.getTenantIdFromEnv(context); + String[] nameAndVersion = Util.checkFuncName(this.functionNameWithVersion); + this.functionId = String.format(Locale.ROOT, "%s/0@%s@%s/%s", tenantId, functionService, nameAndVersion[0], + nameAndVersion[1]); + + FunctionMeta funcMeta = FunctionMeta.newBuilder() + .setApiType(ApiType.Faas) + .setFunctionID(this.functionId) + .setLanguage(LanguageType.Java) + .build(); + Util.checkPayload(payload); + CallRequest arg1 = new CallRequest.Builder().build(); + CallRequest arg2 = new CallRequest.Builder().withBody(payload).build(); + List argList; + argList = Utils.packFaasInvokeArgs(arg1, arg2); + Pair res; + try { + InvokeOptions invokeOptions = createOptions.convertInvokeOptions(); + if (context != null) { + invokeOptions.setTraceId(context.getTraceID()); + } + res = LibRuntime.InvokeInstance(funcMeta, "", argList, invokeOptions); + } catch (LibRuntimeException e) { + throw new InvokeException(e.getErrorCode().getValue(), e.getMessage()); + } + Util.checkErrorAndThrow(res.getFirst(), "faas invoke"); + return new ObjectRef<>(res.getSecond()); + } + + /** + * SDK terminate api + * + * @return ObjectRef + * @throws InvokeException when internal error + */ + public ObjectRef terminate() { + LOG.debug("SDK terminate api beginning"); + return null; + } + + /** + * SDK saveState api + * + * @throws InvokeException when internal error + */ + public void saveState() { + LOG.debug("saveState api beginning"); + } + + /** + * get context + * + * @return context + */ + public Context getContext() { + return this.context; + } + + /** + * SDK getInstanceID api + * + * @return InstanceID + */ + public String getInstanceID() { + return this.instanceId; + } +} diff --git a/api/java/faas-function-sdk/src/main/java/com/function/ObjectRef.java b/api/java/faas-function-sdk/src/main/java/com/function/ObjectRef.java new file mode 100644 index 0000000..7b299cc --- /dev/null +++ b/api/java/faas-function-sdk/src/main/java/com/function/ObjectRef.java @@ -0,0 +1,222 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function; + +import com.function.common.RspErrorCode; +import com.function.common.Util; +import com.function.runtime.exception.InvokeException; +import com.services.enums.FaasErrorCode; +import com.services.model.CallResponse; +import com.services.model.CallResponseJsonObject; +import com.services.model.Response; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.exception.LibRuntimeException; +import com.yuanrong.jni.LibRuntime; +import com.yuanrong.runtime.util.Constants; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.google.gson.JsonParseException; + +import lombok.extern.slf4j.Slf4j; + +import java.lang.reflect.Type; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +/** + * ObjectRef class + * + * @since 2024-7-10 + */ +@Slf4j +public class ObjectRef { + /** + * GSON + */ + protected Gson gson = new Gson(); + + /** + * flag represents whether the result has been got. + */ + protected boolean hasFlag; + + /** + * result is the result of the ObjectRef. + */ + protected T result; + + private final String objectID; + + /** + * Decrease reference flag of ObjectRef. + */ + private boolean isReleased = false; + + protected ObjectRef(String objectID) { + this.objectID = objectID; + this.hasFlag = false; + this.result = null; + } + + private int checkAndGetTimeoutMs(int timeoutSec) throws InvokeException { + if (timeoutSec < Constants.NO_TIMEOUT) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "get config timeout (" + timeoutSec + ") is invalid"); + } + int timeoutMs = Constants.NO_TIMEOUT; + if (timeoutSec != Constants.NO_TIMEOUT) { + timeoutMs = timeoutSec * Constants.SEC_TO_MS; + } + return timeoutMs; + } + + /** + * get result + * + * @param timeoutSec the timeoutSec + * @return result of ObjectRef + */ + public T get(int timeoutSec) { + if (this.hasFlag) { + return this.result; + } + int timeoutMs = checkAndGetTimeoutMs(timeoutSec); + List refIds = Arrays.asList(this.objectID); + Pair> getResult; + try { + // This method may throw 'LibruntimeException' when timeout occurs. + getResult = LibRuntime.Get(refIds, timeoutMs, true); + } catch (LibRuntimeException e) { + String message = String.format(Locale.ROOT, "failed to get result %s, error: %s", refIds, e.getMessage()); + throw new InvokeException(FaasErrorCode.FUNCTION_RUN_ERROR.getCode(), message); + } + Util.checkErrorAndThrow(getResult.getFirst(), "faas get"); + Response response; + String responseStr = new String(getResult.getSecond().get(0), StandardCharsets.UTF_8); + try { + response = gson.fromJson(responseStr, CallResponseJsonObject.class); + } catch (JsonParseException e) { + try { + response = gson.fromJson(responseStr, CallResponse.class); + } catch (JsonParseException exception) { + throw new InvokeException(FaasErrorCode.FUNCTION_RUN_ERROR.getCode(), exception.getMessage()); + } + } + String innerCode = response.getInnerCode(); + if (!ErrorCode.ERR_OK.toString().equals(innerCode)) { + throw new InvokeException(Integer.parseInt(innerCode), String.valueOf(response.getBody())); + } + this.result = (T) response.getBody(); + this.hasFlag = true; + return this.result; + } + + /** + * get result + * + * @return result of ObjectRef + */ + public T get() { + return this.get(Constants.NO_TIMEOUT); + } + + /** + * get result + * + * @param classType the classType + * @param timeoutSec the timeoutSec + * @return result of ObjectRef + */ + public T get(Class classType, int timeoutSec) { + if (this.hasFlag) { + return this.result; + } + int timeoutMs = checkAndGetTimeoutMs(timeoutSec); + List refIds = Arrays.asList(this.objectID); + Pair> getRes; + try { + // This method may throw 'LibruntimeException' when timeout occurs. + getRes = LibRuntime.Get(refIds, timeoutMs, true); + } catch (LibRuntimeException e) { + String message = String.format(Locale.ROOT, "failed to get result %s, error: %s", refIds, e.getMessage()); + throw new InvokeException(FaasErrorCode.FUNCTION_RUN_ERROR.getCode(), message); + } + Util.checkErrorAndThrow(getRes.getFirst(), "faas get"); + Response response; + String responseString = new String(getRes.getSecond().get(0), StandardCharsets.UTF_8); + try { + if (classType.equals(JsonObject.class)) { + response = gson.fromJson(responseString, CallResponseJsonObject.class); + } else { + response = gson.fromJson(responseString, CallResponse.class); + } + } catch (JsonParseException e) { + try { + response = gson.fromJson(responseString, CallResponse.class); + } catch (JsonParseException exception) { + throw new InvokeException(FaasErrorCode.FUNCTION_RUN_ERROR.getCode(), exception.getMessage()); + } + } + Object responseBody = response.getBody(); + String innerCode = response.getInnerCode(); + if (!ErrorCode.ERR_OK.toString().equals(innerCode)) { + throw new InvokeException(Integer.parseInt(innerCode), String.valueOf(responseBody)); + } + if (classType.isInstance(responseBody)) { + this.result = (T) responseBody; + } else { + try { + this.result = gson.fromJson(String.valueOf(responseBody), (Type) classType); + } catch (JsonParseException e) { + throw new InvokeException(RspErrorCode.INTERNAL_ERROR.getErrorCode(), e.getMessage()); + } + } + this.hasFlag = true; + return this.result; + } + + /** + * get result + * + * @param classType the classType + * @return result of ObjectRef + */ + public T get(Class classType) { + return this.get(classType, Constants.NO_TIMEOUT); + } + + @Override + protected void finalize() throws Throwable { + release(); + } + + /** + * Release the ObjectRef, decrease reference. + */ + public void release() { + if (!isReleased && LibRuntime.IsInitialized()) { + LibRuntime.DecreaseReference(Collections.singletonList(this.objectID)); + isReleased = true; + } + } +} diff --git a/api/java/faas-function-sdk/src/main/java/com/function/common/RspErrorCode.java b/api/java/faas-function-sdk/src/main/java/com/function/common/RspErrorCode.java new file mode 100644 index 0000000..9464d3b --- /dev/null +++ b/api/java/faas-function-sdk/src/main/java/com/function/common/RspErrorCode.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function.common; + +/** + * The response error code + * + * @since 2024/07/02 + */ +public enum RspErrorCode { + GET_TIMEOUT_ERROR(133, "Get timeout failed"), + USER_STATE_LARGE_ERROR(4003, "state content is too large"), + USER_STATE_UNDEFINED_ERROR(4005, "state is undefined"), + INVALID_PARAMETER(4040, "invalid input parameter"), + INTERNAL_ERROR(110500, "internal system error"); + + private static final String TIMEOUT_ERROR_CODE = "211408"; + private static final String RUNTIME_INVOKE_TIMEOUT_CODE = "4010"; + private static final String SLA_ERROR_CODE = "216001"; + private static final String QUEUE_TIMEOUT_CODE = "150430"; + private static final String RUNTIME_ERROR_PREFIX = "4"; + private int errorCode; + private String desc; + + private RspErrorCode(int errorCode, String desc) { + this.errorCode = errorCode; + this.desc = desc; + } + + /** + * getErrorCode + * + * @return errorCode + */ + public int getErrorCode() { + return this.errorCode; + } + + /** + * getDesc + * + * @return desc + */ + public String getDesc() { + return this.desc; + } +} \ No newline at end of file diff --git a/api/java/faas-function-sdk/src/main/java/com/function/common/Util.java b/api/java/faas-function-sdk/src/main/java/com/function/common/Util.java new file mode 100644 index 0000000..f3c28b8 --- /dev/null +++ b/api/java/faas-function-sdk/src/main/java/com/function/common/Util.java @@ -0,0 +1,179 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function.common; + +import com.function.common.RspErrorCode; +import com.function.runtime.exception.InvokeException; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.google.gson.JsonSyntaxException; + +import lombok.extern.slf4j.Slf4j; + +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Util API + * + * @since 2024-07-10 + */ +@Slf4j +public class Util { + private static Logger LOG = LoggerFactory.getLogger(Util.class); + + private static final String FUNC_NAME_PATTERN_STRING = "^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$"; + private static final Pattern FUNC_NAME_PATTERN = Pattern.compile(FUNC_NAME_PATTERN_STRING); + private static final int FUNC_NAME_LENGTH_LIMIT = 60; + private static final String VERSION_PATTERN_STRING = + "^[a-zA-Z0-9]([a-zA-Z0-9_-]*\\\\.)*[a-zA-Z0-9_-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$"; + private static final Pattern VERSION_PATTERN = Pattern.compile(VERSION_PATTERN_STRING); + private static final int VERSION_LENGTH_LIMIT = 32; + + private static final String ALIAS_PATTERN_STRING = "^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$"; + private static final Pattern ALIAS_PATTERN = Pattern.compile(ALIAS_PATTERN_STRING); + private static final int ALIAS_LENGTH_LIMIT = 32; + private static final String ALIAS_PREFIX = "!"; + + private static final Gson GSON = new Gson(); + + /** + * check funcName is valid + * + * @param funcName function name + * @return res[0]: funcNameBase, res[1]: version + * @throws InvokeException when invalid funcName + */ + public static String[] checkFuncName(String funcName) { + if (funcName == null) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "invalid funcName, expect not null"); + } + String name = ""; + String version = "latest"; + if (StringUtils.contains(funcName, ":")) { + String[] nameAndVersion = StringUtils.split(funcName, ":"); + if (nameAndVersion.length != 2) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "invalid funcName, not match regular expression"); + } + name = nameAndVersion[0]; + if (!checkFunctionName(nameAndVersion[0])) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "invalid funcName, not match regular expression"); + } + version = nameAndVersion[1]; + if (StringUtils.startsWith(version, ALIAS_PREFIX)) { + if (!checkAlias(version)) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "invalid funcName, not match regular expression"); + } + } else { + if (!checkVersion(version)) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "invalid funcName, not match regular expression"); + } + } + } else { + name = funcName; + if (!checkFunctionName(funcName)) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "invalid funcName, not match regular expression"); + } + } + + return new String[]{name, version}; + } + + /** + * check payload is valid json string + * + * @param payload json string + * @throws InvokeException when invalid payload + */ + public static void checkPayload(String payload) { + if (StringUtils.isBlank(payload)) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "invalid payload, expect not null"); + } + try { + GSON.fromJson(payload, JsonObject.class); + } catch (JsonSyntaxException jsonSyntaxException) { + log.error("throw JsonSyntaxException", jsonSyntaxException); + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "invalid payload, invalid json string"); + } + } + + /** + * check error and throw + * + * @param errorInfo errorInfo + * @param msg msg + * @throws InvokeException when not ok errorInfo + */ + public static void checkErrorAndThrow(ErrorInfo errorInfo, String msg) throws InvokeException { + if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { + log.error("{} occurs exception {}", msg, errorInfo); + throw new InvokeException(errorInfo.getErrorCode().getValue(), errorInfo.getErrorMessage()); + } + } + + /** + * Check cpu and memory non-negative + * + * @param cpu cpu to check + * @param memory memory to check + */ + public static void checkDynamicResource(int cpu, int memory) { + if (cpu < 0 || memory < 0) { + throw new InvokeException(RspErrorCode.INVALID_PARAMETER.getErrorCode(), + "Invalid dynamic resource options, not allow negative number, cpu is " + cpu + ".memory is " + memory); + } + } + + private static boolean checkFunctionName(String funcName) { + Matcher matcher = FUNC_NAME_PATTERN.matcher(funcName); + if (!matcher.matches()) { + return false; + } + return StringUtils.length(funcName) <= FUNC_NAME_LENGTH_LIMIT; + } + + private static boolean checkVersion(String version) { + Matcher matcher = VERSION_PATTERN.matcher(version); + if (!matcher.matches()) { + return false; + } + return StringUtils.length(version) <= VERSION_LENGTH_LIMIT; + } + + private static boolean checkAlias(String alias) { + Matcher matcher = ALIAS_PATTERN.matcher(alias); + if (!matcher.matches()) { + return false; + } + return StringUtils.length(alias) <= ALIAS_LENGTH_LIMIT; + } +} diff --git a/api/java/faas-function-sdk/src/main/java/com/function/runtime/exception/InvokeException.java b/api/java/faas-function-sdk/src/main/java/com/function/runtime/exception/InvokeException.java new file mode 100644 index 0000000..14a4f23 --- /dev/null +++ b/api/java/faas-function-sdk/src/main/java/com/function/runtime/exception/InvokeException.java @@ -0,0 +1,106 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function.runtime.exception; + +/** + * Description: + * + * @since 2024-07-20 + */ +public class InvokeException extends RuntimeException { + private int errorCode; + private String message; + + /** + * Instantiates a new InvokeException exception. + * + * @param errorCode the error code + */ + public InvokeException(int errorCode) { + this(errorCode, null, null); + } + + /** + * Instantiates a new InvokeException exception. + * + * @param errorCode the error code + * @param message the message + */ + public InvokeException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Instantiates a new InvokeException exception. + * + * @param errorCode the error code + * @param message the message + * @param cause the cause + */ + public InvokeException(int errorCode, String message, Throwable cause) { + super(cause); + this.errorCode = errorCode; + this.message = message; + } + + /** + * Gets error code. + * + * @return the error code + */ + public int getErrorCode() { + return errorCode; + } + + /** + * Sets error code. + * + * @param errorCode the error code + */ + public void setErrorCode(int errorCode) { + this.errorCode = errorCode; + } + + /** + * Gets message. + * + * @return the message + */ + @Override + public String getMessage() { + return message; + } + + /** + * Sets message. + * + * @param message the message + */ + public void setMessage(String message) { + this.message = message; + } + + /** + * To string string. + * + * @return the string + */ + @Override + public String toString() { + return "{" + "\"code\":\"" + errorCode + "\", \"message\":\"" + message + "\"}"; + } +} diff --git a/api/java/faas-function-sdk/src/test/java/com/function/TestCreateOptions.java b/api/java/faas-function-sdk/src/test/java/com/function/TestCreateOptions.java new file mode 100644 index 0000000..f383571 --- /dev/null +++ b/api/java/faas-function-sdk/src/test/java/com/function/TestCreateOptions.java @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function; + +import com.yuanrong.InvokeOptions; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class TestCreateOptions { + @Test + public void testInitCreateOptions() { + CreateOptions createOptions = new CreateOptions(10, 10); + InvokeOptions invokeOptions = createOptions.convertInvokeOptions(); + createOptions.setCpu(10); + createOptions.setMemory(10); + Assert.assertNotNull(createOptions); + Assert.assertNotNull(invokeOptions); + + HashMap m = new HashMap<>(); + m.put("a", "b"); + createOptions.setAliasParams(m); + Assert.assertNotNull(createOptions.getAliasParams().get("a")); + + } +} diff --git a/api/java/faas-function-sdk/src/test/java/com/function/TestFunction.java b/api/java/faas-function-sdk/src/test/java/com/function/TestFunction.java new file mode 100644 index 0000000..b8046b3 --- /dev/null +++ b/api/java/faas-function-sdk/src/test/java/com/function/TestFunction.java @@ -0,0 +1,164 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.when; + +import com.function.common.ContextMock; +import com.function.runtime.exception.InvokeException; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.exception.LibRuntimeException; +import com.yuanrong.jni.LibRuntime; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.lang.reflect.Field; + +@RunWith(PowerMockRunner.class) +@PrepareForTest( {LibRuntime.class}) +@SuppressStaticInitializationFor( {"com.yuanrong.jni.LibRuntime"}) +@PowerMockIgnore("javax.management.*") +public class TestFunction { + private ContextMock context = null; + + @Before + public void setup() { + context = new ContextMock(); + context.setInvokeProperty("{'1':'aa'}"); + } + + @Test + public void testFunctionSecond() { + new Thread(() -> { + try { + Function function = new Function(context, "ss"); + Assert.assertNotNull(function); + } catch (InvokeException e) { + Assert.fail(); + } + }).start(); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Assert.fail(); + } + } + + @Test + public void testFunctionThird() { + Thread p =new Thread(() -> { + try { + Function function = new Function(context, "ss:alias"); + Assert.assertNotNull(function); + } catch (InvokeException e) { + Assert.fail(); + } + }); + p.start(); + try { + p.join(1000); + } catch (InterruptedException e) { + Assert.fail(); + } + } + + @Test + public void testSaveState() { + new Thread(() -> { + Function function = new Function(context); + try { + function.saveState(); + } catch (InvokeException e) { + Assert.fail(); + } + }).start(); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Assert.fail(); + } + } + + @Test + public void testTerminate() { + new Thread(() -> { + try { + Function function = new Function(context, "demo:latest"); + Class cl = function.getClass(); + Field field = cl.getDeclaredField("instanceID"); + field.setAccessible(true); + field.set(function, "30"); + ObjectRef terminate = function.terminate(); + Assert.assertNull(terminate); + } catch (InvokeException | NoSuchFieldException | IllegalAccessException ignored) { + Assert.assertNull(ignored); + } + }).start(); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Assert.fail(); + } + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Assert.fail(); + } + } + + @Test + public void testFunctionWithInvalidResource() { + CreateOptions createOptions = new CreateOptions(-1); + boolean isException = false; + try { + Function f = new Function(context).options(createOptions); + } catch (InvokeException e) { + isException = true; + Assert.assertTrue(e.getMessage().contains("Invalid dynamic resource options, not allow negative number")); + } + Assert.assertTrue(isException); + } + + @Test + public void testInitFunction() throws Exception { + Function testFunc = new Function("testFunc"); + testFunc.options(new CreateOptions(10, 10)); + testFunc.getContext(); + testFunc.getInstanceID(); + testFunc.terminate(); + + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, ""); + Pair getRes = new Pair<>(errorInfo, "ok"); + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.InvokeInstance(any(),anyString(), anyList(), any())).thenReturn(getRes); + testFunc.invoke("{\n" + " \"name\": \"test\"\n" + "}"); + } +} diff --git a/api/java/faas-function-sdk/src/test/java/com/function/TestObjectRef.java b/api/java/faas-function-sdk/src/test/java/com/function/TestObjectRef.java new file mode 100644 index 0000000..2795d6b --- /dev/null +++ b/api/java/faas-function-sdk/src/test/java/com/function/TestObjectRef.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function; + +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.Mockito.when; + +import com.function.runtime.exception.InvokeException; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.exception.LibRuntimeException; +import com.yuanrong.jni.LibRuntime; + +import com.google.gson.JsonObject; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +@RunWith(PowerMockRunner.class) +@PrepareForTest( {LibRuntime.class}) +@SuppressStaticInitializationFor( {"com.yuanrong.jni.LibRuntime"}) +@PowerMockIgnore("javax.management.*") +public class TestObjectRef { + @Test + public void testCheckAndGetTimeoutMs() { + ObjectRef objectRef = new ObjectRef("testID"); + boolean isException = false; + try { + objectRef.get(-2); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testGet() throws LibRuntimeException { + ObjectRef objectRef = new ObjectRef("testID"); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, ""); + ArrayList bytes = new ArrayList<>(); + String jsonString = "{\"innerCode\": \"0\", \"body\": {}}"; + bytes.add(jsonString.getBytes(StandardCharsets.UTF_8)); + Pair> getRes = new Pair<>(errorInfo, bytes); + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.Get(anyList(), anyInt(), anyBoolean())).thenReturn(getRes); + boolean isException = false; + try { + objectRef.get(10); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + isException = false; + try { + objectRef.get(); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + ObjectRef objectRef2 = new ObjectRef<>("1"); + isException = false; + try { + objectRef2.get(JsonObject.class); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + isException = false; + try { + objectRef2.get(String.class,1); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testGetInvokeException() throws LibRuntimeException { + ObjectRef objectRef = new ObjectRef("testID"); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, ""); + ArrayList bytes = new ArrayList<>(); + String jsonString = "{\"innerCode\": \"4004\", \"body\": \"response body size 6400000 exceeds the limit of 6291456\"}"; + bytes.add(jsonString.getBytes(StandardCharsets.UTF_8)); + Pair> getRes = new Pair<>(errorInfo, bytes); + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.Get(anyList(), anyInt(), anyBoolean())).thenReturn(getRes); + boolean isException = false; + try { + objectRef.get(10); + } catch (InvokeException e) { + isException = true; + Assert.assertEquals(4004, e.getErrorCode()); + Assert.assertTrue(e.getMessage().contains("exceeds the limit of 6291456")); + } + Assert.assertTrue(isException); + + ObjectRef objectRef2 = new ObjectRef<>("1"); + isException = false; + try { + objectRef2.get(JsonObject.class); + } catch (InvokeException e) { + isException = true; + Assert.assertEquals(4004, e.getErrorCode()); + Assert.assertTrue(e.getMessage().contains("exceeds the limit of 6291456")); + } + Assert.assertTrue(isException); + } +} diff --git a/api/java/faas-function-sdk/src/test/java/com/function/common/ContextMock.java b/api/java/faas-function-sdk/src/test/java/com/function/common/ContextMock.java new file mode 100644 index 0000000..2e60f41 --- /dev/null +++ b/api/java/faas-function-sdk/src/test/java/com/function/common/ContextMock.java @@ -0,0 +1,235 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function.common; + +import com.services.runtime.Context; +import com.services.runtime.RuntimeLogger; +import com.services.runtime.action.ExtendedMetaData; + +import java.util.Map; + +/* + * when we run unit test of Mock, it would be triggered + */ +public class ContextMock implements Context { + private static final String INVOKE_ID = "invokeID"; + private static final String INVOKE_STATE = "invokeState"; + private static String TRACE_ID = "traceID"; + private static String INVOKE_PROPERTY = "invokeProperty"; + private Object state; + + @Override + public String getRequestID() { + return null; + } + + @Override + public int getRemainingTimeInMilliSeconds() { + return 0; + } + + @Override + public String getAccessKey() { + return null; + } + + @Override + public String getSecretKey() { + return null; + } + + @Override + public String getSecurityAccessKey() { + return null; + } + + @Override + public String getSecuritySecretKey() { + return null; + } + + @Override + public String getUserData(String s) { + return null; + } + + @Override + public String getFunctionName() { + return null; + } + + @Override + public int getRunningTimeInSeconds() { + return 0; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public int getMemorySize() { + return 0; + } + + @Override + public int getCPUNumber() { + return 0; + } + + @Override + public String getProjectID() { + return null; + } + + @Override + public String getPackage() { + return null; + } + + @Override + public String getToken() { + return null; + } + + @Override + public String getAlias() { + return null; + } + + @Override + public String getSecurityToken() { + return null; + } + + @Override + public RuntimeLogger getLogger() { + return null; + } + + @Override + public Object getState() { + return state; + } + + @Override + public void setState(Object state) { + this.state = state; + } + + @Override + public String getInstanceID() { + return "001"; + } + + @Override + public String getInstanceLabel() { + return ""; + } + + public void setInstanceLabel(String instanceLabel) { + } + + @Override + public String getInvokeProperty() { + return INVOKE_PROPERTY; + } + + public void setInvokeProperty(String invokeProperty) { + INVOKE_PROPERTY = invokeProperty; + } + + @Override + public String getTraceID() { + return TRACE_ID; + } + + public void setTraceID(String traceID) { + TRACE_ID = traceID; + } + + @Override + public String getInvokeID() { + return INVOKE_ID; + } + + @Override + public String getWorkflowID() { + return null; + } + + @Override + public String getWorkflowRunID() { + return null; + } + + @Override + public String getWorkflowStateID() { + return null; + } + + @Override + public String getReqStreamName() { + return null; + } + + @Override + public String getRespStreamName() { + return null; + } + + @Override + public String getFrontendResponseStreamName() { + return null; + } + + @Override + public String getIAMToken() { + return ""; + } + + @Override + public Map getExtraMap() { + return null; + } + + @Override + public void setExtraMap(Map extraMap) { + } + + public static class State { + + private String key = "invokeProperty"; + + public State(String key) { + this.key = key; + } + + public State() { + + } + + public String getKey() { + return key; + } + + public void setKey(String key) { + this.key = key; + } + } +} \ No newline at end of file diff --git a/api/java/faas-function-sdk/src/test/java/com/function/common/TestUtil.java b/api/java/faas-function-sdk/src/test/java/com/function/common/TestUtil.java new file mode 100644 index 0000000..0405a84 --- /dev/null +++ b/api/java/faas-function-sdk/src/test/java/com/function/common/TestUtil.java @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function.common; + +import com.function.runtime.exception.InvokeException; +import com.services.runtime.action.ContextImpl; +import com.services.runtime.action.FunctionMetaData; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; + +import com.google.gson.Gson; + +import org.junit.Assert; +import org.junit.Test; + +public class TestUtil { + private static final Gson GSON = new Gson(); + + @Test + public void testCheckFuncName() { + // func name is empty + try { + Util.checkFuncName(null); + } catch (InvokeException e) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), e.getErrorCode()); + Assert.assertEquals("invalid funcName, expect not null", e.getMessage()); + } + + // funcName contains multiple ':' + try { + Util.checkFuncName("name1:name2:name3"); + } catch (InvokeException e) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), e.getErrorCode()); + Assert.assertEquals("invalid funcName, not match regular expression", e.getMessage()); + } + + // funcName not match regular + try { + Util.checkFuncName("+++:name2"); + } catch (InvokeException e) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), e.getErrorCode()); + Assert.assertEquals("invalid funcName, not match regular expression", e.getMessage()); + } + + // funcName contains ':' and version not regular + try { + Util.checkFuncName("funcName:!v1"); + } catch (InvokeException e) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), e.getErrorCode()); + Assert.assertEquals("invalid funcName, not match regular expression", e.getMessage()); + } + + // funcName contains ':' and version not regular + try { + Util.checkFuncName("funcName:v1@2"); + } catch (InvokeException ie) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), ie.getErrorCode()); + Assert.assertEquals("invalid funcName, not match regular expression", ie.getMessage()); + } + + // funcName contains ':' and version not regular + try { + Util.checkFuncName("funcName"); + } catch (InvokeException e) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), e.getErrorCode()); + Assert.assertEquals("invalid funcName, not match regular expression", e.getMessage()); + } + + // funcName not contains ':' and version not regular + try { + Util.checkFuncName("funcName@"); + } catch (InvokeException e) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), e.getErrorCode()); + Assert.assertEquals("invalid funcName, not match regular expression", e.getMessage()); + } + + // funcName check right + String funcName = "funcName"; + String version = "v1"; + String[] strings = Util.checkFuncName(funcName + ":" + version); + Assert.assertEquals(funcName, strings[0]); + Assert.assertEquals(version, strings[1]); + } + + @Test + public void testCheckPayload() { + try { + Util.checkPayload(""); + } catch (InvokeException e) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), e.getErrorCode()); + Assert.assertEquals("invalid payload, expect not null", e.getMessage()); + } + + try { + Util.checkPayload("abc"); + } catch (InvokeException e) { + Assert.assertEquals(RspErrorCode.INVALID_PARAMETER.getErrorCode(), e.getErrorCode()); + Assert.assertEquals("invalid payload, invalid json string", e.getMessage()); + } + } + + @Test + public void testInitUtil() { + Util util = new Util(); + ContextImpl context = new ContextImpl(); + FunctionMetaData functionMetaData = new FunctionMetaData(); + context.setFuncMetaData(functionMetaData); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_ETCD_OPERATION_ERROR, ModuleCode.CORE, "test"); + boolean isException = false; + try { + Util.checkErrorAndThrow(errorInfo, "test"); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testRspErrorCode() { + Assert.assertNotNull(RspErrorCode.INVALID_PARAMETER); + } +} diff --git a/api/java/faas-function-sdk/src/test/java/com/function/runtime/exception/TestInvokeException.java b/api/java/faas-function-sdk/src/test/java/com/function/runtime/exception/TestInvokeException.java new file mode 100644 index 0000000..b35776a --- /dev/null +++ b/api/java/faas-function-sdk/src/test/java/com/function/runtime/exception/TestInvokeException.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.function.runtime.exception; + +import org.junit.Assert; +import org.junit.Test; + +public class TestInvokeException { + @Test + public void testInitInvokeException() { + InvokeException exception = new InvokeException(1); + exception.setErrorCode(2); + exception.setMessage("test"); + exception.toString(); + Assert.assertEquals("test", exception.getMessage()); + } +} diff --git a/api/java/function-common/src/main/cpp/com_yuanrong_jni_Consumer.cpp b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Consumer.cpp new file mode 100644 index 0000000..9f6d0e4 --- /dev/null +++ b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Consumer.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "com_yuanrong_jni_Consumer.h" + +#include "jni_types.h" + +#include "src/dto/stream_conf.h" +#include "src/libruntime/err_type.h" +#include "src/libruntime/streamstore/stream_producer_consumer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +using YR::Libruntime::ErrorInfo; + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniConsumer_receive(JNIEnv *env, jclass, jlong consumerPtr, + jlong expectNum, jint timeoutMs, + jboolean hasExpectedNum) +{ + auto consumer = reinterpret_cast *>(consumerPtr); + std::vector elements; + ErrorInfo err; + if (hasExpectedNum == JNI_TRUE) { + err = (*consumer)->Receive(expectNum, timeoutMs, elements); + } else { + err = (*consumer)->Receive(timeoutMs, elements); + } + + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Consumer_receive, get null"); + return nullptr; + } + jclass arrayListClass = env->FindClass("java/util/ArrayList"); + jobject elementList = env->NewObject(arrayListClass, env->GetMethodID(arrayListClass, "", "()V")); + jmethodID elementAdd = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z"); + + jclass elementClass = env->FindClass("com/yuanrong/stream/Element"); + jmethodID elementInit = env->GetMethodID(elementClass, "", "(JLjava/nio/ByteBuffer;)V"); + jobject elementBuffer = NULL; + jobject elementObject = NULL; + for (const auto &element : elements) { + elementBuffer = env->NewDirectByteBuffer(element.ptr, element.size); + elementObject = env->NewObject(elementClass, elementInit, element.id, elementBuffer); + env->CallBooleanMethod(elementList, elementAdd, elementObject); + } + // Clean up + env->DeleteLocalRef(arrayListClass); + env->DeleteLocalRef(elementClass); + env->DeleteLocalRef(elementBuffer); + env->DeleteLocalRef(elementObject); + return YR::jni::JNIPair::CreateJPair(env, jerr, elementList); +} + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniConsumer_ack(JNIEnv *env, jclass, jlong consumerPtr, + jlong elementId) +{ + auto consumer = reinterpret_cast *>(consumerPtr); + auto err = (*consumer)->Ack(static_cast(elementId)); + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Consumer_ack, get null"); + return nullptr; + } + return jerr; +} + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniConsumer_close(JNIEnv *env, jclass, jlong consumerPtr) +{ + auto consumer = reinterpret_cast *>(consumerPtr); + auto err = (*consumer)->Close(); + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Consumer_close, get null"); + return nullptr; + } + delete consumer; + return jerr; +} + +JNIEXPORT void JNICALL Java_com_datasystem_streamcache_ConsumerImpl_freeJNIPtrNative(JNIEnv *, jclass, jlong handle) +{ + auto consumer = reinterpret_cast *>(handle); + delete consumer; +} + +#ifdef __cplusplus +} +#endif diff --git a/api/java/function-common/src/main/cpp/com_yuanrong_jni_Consumer.h b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Consumer.h new file mode 100644 index 0000000..fbd7975 --- /dev/null +++ b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Consumer.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#pragma once +#ifdef __cplusplus +extern "C" { +#endif + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniConsumer_receive(JNIEnv *, jclass, jlong, jlong, jint, + jboolean); + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniConsumer_ack(JNIEnv *, jclass, jlong, jlong); + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniConsumer_close(JNIEnv *, jclass, jlong); + +JNIEXPORT void JNICALL Java_com_yuanrong_jni_JniConsumer_freeJNIPtrNative(JNIEnv *, jclass, jlong); + +#ifdef __cplusplus +} +#endif diff --git a/api/java/function-common/src/main/cpp/com_yuanrong_jni_LibRuntime.cpp b/api/java/function-common/src/main/cpp/com_yuanrong_jni_LibRuntime.cpp index 367061c..c601359 100644 --- a/api/java/function-common/src/main/cpp/com_yuanrong_jni_LibRuntime.cpp +++ b/api/java/function-common/src/main/cpp/com_yuanrong_jni_LibRuntime.cpp @@ -35,6 +35,7 @@ #include "src/dto/invoke_arg.h" #include "src/dto/invoke_options.h" #include "src/dto/status.h" +#include "src/dto/stream_conf.h" #include "src/libruntime/auto_init.h" #include "src/libruntime/err_type.h" #include "src/libruntime/invokeadaptor/request_manager.h" @@ -42,6 +43,7 @@ #include "src/libruntime/libruntime_manager.h" #include "src/libruntime/libruntime_options.h" +#include "src/libruntime/streamstore/stream_producer_consumer.h" #include "src/proto/libruntime.pb.h" /** @@ -81,11 +83,13 @@ extern "C" { } \ } -static jobject get_runtime_context_callback(JNIEnv *env, jclass c) +static std::string get_runtime_context_callback(JNIEnv *env, jclass c) { jmethodID callbackMethodID = env->GetStaticMethodID(c, "GetRuntimeContext", "()Ljava/lang/String;"); jobject result = env->CallStaticObjectMethod(c, callbackMethodID); - return result; + std::string resultStr = YR::jni::JNIString::FromJava(env, (jstring)result); + env->DeleteLocalRef(result); + return resultStr; } JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Init(JNIEnv *env, jclass c, jobject jconfig) @@ -144,8 +148,7 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Init(JNIEnv *env, jcl config.libruntimeOptions.checkpointCallback = checkpointCb; config.libruntimeOptions.recoverCallback = recoverCb; config.libruntimeOptions.shutdownCallback = functionShutdownCb; - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); auto err = YR::Libruntime::LibruntimeManager::Instance().Init(config, rtCtx); jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); if (jerr == nullptr) { @@ -165,25 +168,24 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_CreateInstance(JNIEnv jobject functionMeta, jobject args, jobject opt) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); auto funcMeta = YR::jni::JNIFunctionMeta::FromJava(env, functionMeta); CHECK_JAVA_EXCEPTION_AND_THROW_NEW_AND_RETURN_IF_OCCUR(env, nullptr, "exception occurred when convert funcMeta from java to cc"); - auto invokeArgs = YR::jni::JNIInvokeArg::FromJavaList( - env, args, YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->GetTenantId()); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto invokeArgs = YR::jni::JNIInvokeArg::FromJavaList(env, args, libRuntime->GetTenantId()); CHECK_JAVA_EXCEPTION_AND_THROW_NEW_AND_RETURN_IF_OCCUR( env, nullptr, "exception occurred when convert invokeArgs from java to cc"); auto invokeOptions = YR::jni::JNIInvokeOptions::FromJava(env, opt); CHECK_JAVA_EXCEPTION_AND_THROW_NEW_AND_RETURN_IF_OCCUR( env, nullptr, "exception occurred when convert invokeOptions from java to cc"); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); - auto [err, objectID] = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->CreateInstance( - funcMeta, invokeArgs, invokeOptions); + libRuntime->SetTenantIdWithPriority(); + auto [err, objectID] = libRuntime->CreateInstance(funcMeta, invokeArgs, invokeOptions); if (!err.OK()) { - YRLOG_WARN("failed to CreateInstance, err({}), msg({})", err.Code(), err.Msg()); + YRLOG_WARN("failed to CreateInstance, err({}), msg({})", fmt::underlying(err.Code()), err.Msg()); YR::jni::JNILibruntimeException::Throw( env, err.Code(), err.MCode(), "failed to CreateInstance, err " + std::to_string(err.Code()) + ", msg: " + err.Msg()); @@ -204,7 +206,10 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_CreateInstance(JNIEnv return nullptr; } - return YR::jni::JNIPair::CreateJPair(env, jerr, jobjectID); + jobject jpair = YR::jni::JNIPair::CreateJPair(env, jerr, jobjectID); + env->DeleteLocalRef(jerr); + env->DeleteLocalRef(jobjectID); + return jpair; } /* @@ -218,16 +223,16 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_InvokeInstance(JNIEnv jstring instanceId, jobject args, jobject opt) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); auto funcMeta = YR::jni::JNIFunctionMeta::FromJava(env, functionMeta); CHECK_JAVA_EXCEPTION_AND_THROW_NEW_AND_RETURN_IF_OCCUR(env, nullptr, "exception occurred when convert funcMeta from java to cc"); auto instanceIdStr = YR::jni::JNIString::FromJava(env, instanceId); CHECK_JAVA_EXCEPTION_AND_THROW_NEW_AND_RETURN_IF_OCCUR( env, nullptr, "exception occurred when convert instanceIdStr from java to cc"); - auto invokeArgs = YR::jni::JNIInvokeArg::FromJavaList( - env, args, YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->GetTenantId()); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto invokeArgs = YR::jni::JNIInvokeArg::FromJavaList(env, args, libRuntime->GetTenantId()); CHECK_JAVA_EXCEPTION_AND_THROW_NEW_AND_RETURN_IF_OCCUR( env, nullptr, "exception occurred when convert invokeArgs from java to cc"); auto invokeOptions = YR::jni::JNIInvokeOptions::FromJava(env, opt); @@ -237,13 +242,11 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_InvokeInstance(JNIEnv env, nullptr, "exception occurred when convert returnDataObjs from java to cc"); YR::Libruntime::ErrorInfo err; std::vector returnDataObjs{{""}}; - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); + libRuntime->SetTenantIdWithPriority(); if (instanceIdStr.size() > 0) { - err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->InvokeByInstanceId( - funcMeta, instanceIdStr, invokeArgs, invokeOptions, returnDataObjs); + err = libRuntime->InvokeByInstanceId(funcMeta, instanceIdStr, invokeArgs, invokeOptions, returnDataObjs); } else { - err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->InvokeByFunctionName( - funcMeta, invokeArgs, invokeOptions, returnDataObjs); + err = libRuntime->InvokeByFunctionName(funcMeta, invokeArgs, invokeOptions, returnDataObjs); } if (!err.OK()) { YR::jni::JNILibruntimeException::Throw( @@ -263,15 +266,18 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_InvokeInstance(JNIEnv env, "failed to convert ReturnDataObjId when invokeByInstanceID, get null"); return nullptr; } - return YR::jni::JNIPair::CreateJPair(env, jerr, jreturnDataObjId); + + jobject jpair = YR::jni::JNIPair::CreateJPair(env, jerr, jreturnDataObjId); + env->DeleteLocalRef(jerr); + env->DeleteLocalRef(jreturnDataObjId); + return jpair; } JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Put(JNIEnv *env, jclass c, jbyteArray byteArray, jobject objectIds) { auto nestedObjectIds = YR::jni::JNIString::FromJArrayToUnorderedSet(env, objectIds); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); std::shared_ptr dataObj; auto errInfo = YR::jni::JNIDataObject::WriteDataObject(env, dataObj, byteArray); if (errInfo.Code() != YR::Libruntime::ErrorCode::ERR_OK) { @@ -280,9 +286,10 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Put(JNIEnv *env, jcla "put finished, return code is " + std::to_string(errInfo.Code()) + ", msg: " + errInfo.Msg()); return nullptr; } - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); - auto [err, objId] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->Put(dataObj, nestedObjectIds); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto [err, objId] = libRuntime->Put(dataObj, nestedObjectIds); jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); jstring jobjId = YR::jni::JNIString::FromCc(env, objId); @@ -295,8 +302,7 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_PutWithParam(JNIEnv * { auto nestedObjectIds = YR::jni::JNIString::FromJArrayToUnorderedSet(env, objectIds); auto ccreateParam = YR::jni::JNICreateParam::FromJava(env, createParam); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); std::shared_ptr dataObj; auto errInfo = YR::jni::JNIDataObject::WriteDataObject(env, dataObj, byteArray); if (errInfo.Code() != YR::Libruntime::ErrorCode::ERR_OK) { @@ -305,9 +311,10 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_PutWithParam(JNIEnv * "put finished, return code is " + std::to_string(errInfo.Code()) + ", msg: " + errInfo.Msg()); return nullptr; } - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); - auto [err, objId] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->Put(dataObj, nestedObjectIds, ccreateParam); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto [err, objId] = libRuntime->Put(dataObj, nestedObjectIds, ccreateParam); jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); jstring jobjId = YR::jni::JNIString::FromCc(env, objId); @@ -325,16 +332,16 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Get(JNIEnv *env, jcla auto objIds = YR::jni::JNIList::FromJava(env, listOfIds, [](JNIEnv *env, jobject obj) -> std::string { return YR::jni::JNIString::FromJava(env, static_cast(obj)); }); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); CHECK_JAVA_EXCEPTION_AND_THROW_NEW_AND_RETURN_IF_OCCUR( env, nullptr, "LibRuntime_Get called, but exception occurred when convert args from java to cc"); int timeoutMsInt = static_cast(timeoutMs); bool allowPartialBool = static_cast(allowPartial); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); // A special constraint: call wait first before call get - auto [err, res] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->Get(objIds, timeoutMsInt, allowPartialBool); + auto [err, res] = libRuntime->Get(objIds, timeoutMsInt, allowPartialBool); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { YR::jni::JNILibruntimeException::Throw( env, err.Code(), err.MCode(), @@ -354,7 +361,10 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Get(JNIEnv *env, jcla return nullptr; } - return YR::jni::JNIPair::CreateJPair(env, jerr, listResult); + jobject jpair = YR::jni::JNIPair::CreateJPair(env, jerr, listResult); + env->DeleteLocalRef(jerr); + env->DeleteLocalRef(listResult); + return jpair; } JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Wait(JNIEnv *env, jclass c, jobject objList, @@ -363,15 +373,15 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Wait(JNIEnv *env, jcl auto objIds = YR::jni::JNIList::FromJava(env, objList, [](JNIEnv *env, jobject obj) -> std::string { return YR::jni::JNIString::FromJava(env, static_cast(obj)); }); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); CHECK_JAVA_EXCEPTION_AND_THROW_NEW_AND_RETURN_IF_OCCUR( env, nullptr, "LibRuntime_Wait called, but exception occurred when convert args from java to cc"); int timeoutSecInt = static_cast(timeoutSec); int waitNumInt = static_cast(waitNum); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); - auto internalWaitResult = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->Wait(objIds, waitNumInt, timeoutSecInt); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto internalWaitResult = libRuntime->Wait(objIds, waitNumInt, timeoutSecInt); jobject res = YR::jni::JNIInternalWaitResult::FromCc(env, internalWaitResult); if (res == nullptr) { YR::jni::JNILibruntimeException::ThrowNew(env, "get null when transform wait result from cpp to java"); @@ -391,8 +401,7 @@ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_DecreaseReference(JNIEnv auto objIds = YR::jni::JNIList::FromJava(env, objList, [](JNIEnv *env, jobject obj) -> std::string { return YR::jni::JNIString::FromJava(env, static_cast(obj)); }); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); if (libRuntime) { libRuntime->SetTenantIdWithPriority(); @@ -407,8 +416,7 @@ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_DecreaseReference(JNIEnv */ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_ReceiveRequestLoop(JNIEnv *env, jclass c) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); YR::Libruntime::LibruntimeManager::Instance().ReceiveRequestLoop(rtCtx); } @@ -431,8 +439,7 @@ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_FinalizeWithCtx(JNIEnv * */ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_Finalize(JNIEnv *env, jclass c) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); YR::Libruntime::LibruntimeManager::Instance().Finalize(rtCtx); } @@ -443,9 +450,11 @@ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_Finalize(JNIEnv *env, jc */ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_Exit(JNIEnv *env, jclass c) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->Exit(); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + if (libRuntime) { + libRuntime->Exit(); + } } /* @@ -463,9 +472,10 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_AutoInitYR(JNIEnv *en JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Kill(JNIEnv *env, jclass c, jstring instanceID) { auto instID = YR::jni::JNIString::FromJava(env, instanceID); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->Kill(instID); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto err = libRuntime->Kill(instID); jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); if (jerr == nullptr) { YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Libruntime_Kill, get null"); @@ -481,8 +491,7 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_Kill(JNIEnv *env, jcl */ JNIEXPORT jboolean JNICALL Java_com_yuanrong_jni_LibRuntime_IsInitialized(JNIEnv *env, jclass c) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); return static_cast(YR::Libruntime::LibruntimeManager::Instance().IsInitialized(rtCtx)); } @@ -506,9 +515,10 @@ JNIEXPORT jstring JNICALL Java_com_yuanrong_jni_LibRuntime_GetRealInstanceId(JNI jstring objectID) { std::string cobjectID = YR::jni::JNIString::FromJava(env, objectID); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - auto instanceID = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->GetRealInstanceId(cobjectID); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto instanceID = libRuntime->GetRealInstanceId(cobjectID); jstring jinstanceID = YR::jni::JNIString::FromCc(env, instanceID); ASSERT_NOT_NULL(jinstanceID); @@ -529,10 +539,10 @@ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_SaveRealInstanceId(JNIEn auto opts = YR::jni::JNIInvokeOptions::FromJava(env, opt); YR::Libruntime::InstanceOptions instOpts; instOpts.needOrder = opts.needOrder; - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SaveRealInstanceId(cobjectID, cinstanceID, - instOpts); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN_VOID(env, libRuntime, "exception occurred because LibRuntime is null"); + libRuntime->SaveRealInstanceId(cobjectID, cinstanceID, instOpts); } /* @@ -547,12 +557,12 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KVWrite(JNIEnv *env, auto ckey = YR::jni::JNIString::FromJava(env, key); auto cvalue = YR::jni::JNIByteBuffer::FromJava(env, value); auto csetParam = YR::jni::JNISetParam::FromJava(env, setParam); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); // result from cpp to java - auto cerrorInfo = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->KVWrite(ckey, cvalue, csetParam); + auto cerrorInfo = libRuntime->KVWrite(ckey, cvalue, csetParam); jobject jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, cerrorInfo); ASSERT_NOT_NULL(jerrorInfo); @@ -578,12 +588,12 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KVMSetTx(JNIEnv *env, }); auto cmSetParam = YR::jni::JNIMSetParam::FromJava(env, mSetParam); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); // result from cpp to java - auto cerrorInfo = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->KVMSetTx(ckeys, cvalues, cmSetParam); + auto cerrorInfo = libRuntime->KVMSetTx(ckeys, cvalues, cmSetParam); jobject jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, cerrorInfo); ASSERT_NOT_NULL(jerrorInfo); @@ -601,13 +611,13 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KVRead__Ljava_lang_St { // parameters from java to cpp auto ckey = YR::jni::JNIString::FromJava(env, key); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); int ctimeoutMS = static_cast(timeoutMS); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); // result from cpp to java (std::pair, ErrorInfo>) - auto [sbuf_ptr, cerrorInfo] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->KVRead(ckey, ctimeoutMS); + auto [sbuf_ptr, cerrorInfo] = libRuntime->KVRead(ckey, ctimeoutMS); jbyteArray byteArray = nullptr; if (sbuf_ptr != nullptr) { byteArray = env->NewByteArray(sbuf_ptr->GetSize()); @@ -634,14 +644,14 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KVRead__Ljava_util_Li auto ckeys = YR::jni::JNIList::FromJava(env, keys, [](JNIEnv *env, jobject obj) -> std::string { return YR::jni::JNIString::FromJava(env, static_cast(obj)); }); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); int ctimeoutMS = static_cast(timeoutMS); bool callowPartial = static_cast(allowPartial); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); // result from cpp to java (std::pair>, ErrorInfo>) - auto [sbuf_ptr_vector, cerrorInfo] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->KVRead(ckeys, ctimeoutMS, callowPartial); + auto [sbuf_ptr_vector, cerrorInfo] = libRuntime->KVRead(ckeys, ctimeoutMS, callowPartial); jobject listByteArray = YR::jni::JNIByteBuffer::FromCcPrtVectorToList(env, sbuf_ptr_vector); ASSERT_NOT_NULL(listByteArray); @@ -663,14 +673,13 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KVGetWithParam(JNIEnv auto ckeys = YR::jni::JNIList::FromJava(env, keys, [](JNIEnv *env, jobject obj) -> std::string { return YR::jni::JNIString::FromJava(env, static_cast(obj)); }); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); int ctimeoutMS = static_cast(timeoutMS); auto cgetParams = YR::jni::JNIGetParams::FromJava(env, getParams); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); - auto [sbuf_ptr_vector, cerrorInfo] = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->KVGetWithParam(ckeys, cgetParams, - ctimeoutMS); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto [sbuf_ptr_vector, cerrorInfo] = libRuntime->KVGetWithParam(ckeys, cgetParams, ctimeoutMS); jobject listByteArray = YR::jni::JNIByteBuffer::FromCcPrtVectorToList(env, sbuf_ptr_vector); ASSERT_NOT_NULL(listByteArray); @@ -690,11 +699,12 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KVDel__Ljava_lang_Str { // parameters from java to cpp auto ckey = YR::jni::JNIString::FromJava(env, key); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); // result from cpp to java (ErrorInfo) - auto cerrorInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->KVDel(ckey); + auto cerrorInfo = libRuntime->KVDel(ckey); jobject jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, cerrorInfo); ASSERT_NOT_NULL(jerrorInfo); return jerrorInfo; @@ -712,11 +722,12 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KVDel__Ljava_util_Lis auto ckeys = YR::jni::JNIList::FromJava(env, keys, [](JNIEnv *env, jobject obj) -> std::string { return YR::jni::JNIString::FromJava(env, static_cast(obj)); }); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SetTenantIdWithPriority(); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); // result from cpp to java (std::pair>, ErrorInfo>) - auto [cvector_keys, cerrorInfo] = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->KVDel(ckeys); + auto [cvector_keys, cerrorInfo] = libRuntime->KVDel(ckeys); jobject keysDeleted = YR::jni::JNIString::FromCcVectorToList(env, cvector_keys); ASSERT_NOT_NULL(keysDeleted); jobject jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, cerrorInfo); @@ -732,8 +743,7 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KVDel__Ljava_util_Lis */ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_SaveState(JNIEnv *env, jclass c, jint timeoutMs) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); std::shared_ptr data; auto errInfo = YR::jni::JNICodeExecutor::DumpInstance(env, "", data); @@ -743,7 +753,9 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_SaveState(JNIEnv *env return jerrorInfo; } - errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SaveState(data, timeoutMs); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + errInfo = libRuntime->SaveState(data, timeoutMs); jobject jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, errInfo); ASSERT_NOT_NULL(jerrorInfo); return jerrorInfo; @@ -756,12 +768,13 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_SaveState(JNIEnv *env */ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_LoadState(JNIEnv *env, jclass c, jint timeoutMs) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); std::shared_ptr data; jobject jerrorInfo; - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->LoadState(data, timeoutMs); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto errInfo = libRuntime->LoadState(data, timeoutMs); if (!errInfo.OK()) { jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, errInfo); ASSERT_NOT_NULL(jerrorInfo); @@ -779,9 +792,10 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_GroupCreate(JNIEnv *e { auto groupOpts = YR::jni::JNIGroupOptions::FromJava(env, opt); auto cStr = YR::jni::JNIString::FromJava(env, str); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->GroupCreate(cStr, groupOpts); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto errInfo = libRuntime->GroupCreate(cStr, groupOpts); jobject jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, errInfo); ASSERT_NOT_NULL(jerrorInfo); return jerrorInfo; @@ -790,17 +804,20 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_GroupCreate(JNIEnv *e JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_GroupTerminate(JNIEnv *env, jclass c, jstring str) { auto cStr = YR::jni::JNIString::FromJava(env, str); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->GroupTerminate(cStr); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + if (libRuntime) { + libRuntime->GroupTerminate(cStr); + } } JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_GroupWait(JNIEnv *env, jclass c, jstring str) { auto cStr = YR::jni::JNIString::FromJava(env, str); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->GroupWait(cStr); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto errInfo = libRuntime->GroupWait(cStr); jobject jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, errInfo); ASSERT_NOT_NULL(jerrorInfo); return jerrorInfo; @@ -808,8 +825,7 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_GroupWait(JNIEnv *env JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_processLog(JNIEnv *env, jclass c, jobject functionLog) { - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); + auto rtCtx = get_runtime_context_callback(env, c); auto cFunctionLog = YR::jni::JNIFunctionLog::FromJava(env, functionLog); auto errInfo = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->ProcessLog(cFunctionLog); jobject jerrorInfo = YR::jni::JNIErrorInfo::FromCc(env, errInfo); @@ -817,6 +833,125 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_processLog(JNIEnv *en return jerrorInfo; } +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_CreateStreamProducer( + JNIEnv *env, jclass c, jstring streamName, jlong delay, jlong pageSize, jlong maxStreamSize, jboolean autoCleanup, + jboolean encryptStream, jlong retainForNumConsumers, jlong reserveSize) +{ + auto rtCtx = get_runtime_context_callback(env, c); + auto cStreamName = YR::jni::JNIString::FromJava(env, streamName); + YR::Libruntime::ProducerConf producerConf = {.delayFlushTime = delay, + .pageSize = pageSize, + .maxStreamSize = static_cast(maxStreamSize), + .autoCleanup = static_cast(autoCleanup), + .encryptStream = static_cast(encryptStream), + .retainForNumConsumers = static_cast(retainForNumConsumers), + .reserveSize = static_cast(reserveSize)}; + std::shared_ptr streamProducer; + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, -1, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto err = libRuntime->CreateStreamProducer(cStreamName, producerConf, streamProducer); + CHECK_ERROR_AND_THROW( + env, err, -1, + "create stream producer finished, return code is " + std::to_string(err.Code()) + ", msg: " + err.Msg()) + std::unique_ptr> pOutProducer = + std::make_unique>(std::move(streamProducer)); + return reinterpret_cast(pOutProducer.release()); +} + +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_CreateStreamConsumer( + JNIEnv *env, jclass c, jstring streamName, jstring subName, jobject subscription, jboolean shouldAutoAck) +{ + auto rtCtx = get_runtime_context_callback(env, c); + auto cStreamName = YR::jni::JNIString::FromJava(env, streamName); + auto cSubName = YR::jni::JNIString::FromJava(env, subName); + jclass jniSubscription = env->GetObjectClass(subscription); + jmethodID jniMethoId = env->GetMethodID(jniSubscription, "ordinal", "()I"); + int subscriptionTypeInt = env->CallIntMethod(jniSubscription, jniMethoId); + libruntime::SubscriptionType subType = (libruntime::SubscriptionType)subscriptionTypeInt; + const struct YR::Libruntime::SubscriptionConfig config(cSubName, subType); + std::shared_ptr streamConsumer; + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, -1, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto err = libRuntime->CreateStreamConsumer(cStreamName, config, streamConsumer, static_cast(shouldAutoAck)); + CHECK_ERROR_AND_THROW( + env, err, -1, + "create stream consumer finished, return code is " + std::to_string(err.Code()) + ", msg: " + err.Msg()) + std::unique_ptr> pOutConsumer = + std::make_unique>(std::move(streamConsumer)); + return reinterpret_cast(pOutConsumer.release()); +} + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_DeleteStream(JNIEnv *env, jclass c, + jstring streamName) +{ + auto rtCtx = get_runtime_context_callback(env, c); + auto cStreamName = YR::jni::JNIString::FromJava(env, streamName); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto err = libRuntime->DeleteStream(cStreamName); + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Libruntime_DeleteStream, get null"); + return nullptr; + } + return jerr; +} + +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_QueryGlobalProducersNum(JNIEnv *env, jclass c, + jstring streamName) +{ + auto rtCtx = get_runtime_context_callback(env, c); + auto cStreamName = YR::jni::JNIString::FromJava(env, streamName); + uint64_t producerNum; + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, -1, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto err = libRuntime->QueryGlobalProducersNum(cStreamName, producerNum); + CHECK_ERROR_AND_THROW( + env, err, -1, + "query global producers num finished, return code is " + std::to_string(err.Code()) + ", msg: " + err.Msg()) + return static_cast(producerNum); +} + +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_QueryGlobalConsumersNum(JNIEnv *env, jclass c, + jstring streamName) +{ + auto rtCtx = get_runtime_context_callback(env, c); + auto cStreamName = YR::jni::JNIString::FromJava(env, streamName); + uint64_t consumerNum; + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, -1, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto err = libRuntime->QueryGlobalConsumersNum(cStreamName, consumerNum); + CHECK_ERROR_AND_THROW( + env, err, -1, + "query global consumers num finished, return code is " + std::to_string(err.Code()) + ", msg: " + err.Msg()) + return static_cast(consumerNum); +} + +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_CreateStreamProducerWithConfig(JNIEnv *env, jclass c, + jstring streamName, + jobject producerConfig) +{ + auto rtCtx = get_runtime_context_callback(env, c); + auto cStreamName = YR::jni::JNIString::FromJava(env, streamName); + auto cProducerConf = YR::jni::JNIProducerConfig::FromJava(env, producerConfig); + std::shared_ptr streamProducer; + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, -1, "exception occurred because LibRuntime is null"); + libRuntime->SetTenantIdWithPriority(); + auto err = libRuntime->CreateStreamProducer(cStreamName, cProducerConf, streamProducer); + CHECK_ERROR_AND_THROW( + env, err, -1, + "create stream producer finished, return code is " + std::to_string(err.Code()) + ", msg: " + err.Msg()) + std::unique_ptr> pOutProducer = + std::make_unique>(std::move(streamProducer)); + return reinterpret_cast(pOutProducer.release()); +} + /* * Class: com_yuanrong_jni_LibRuntime * Method: GetInstanceRoute @@ -826,10 +961,10 @@ JNIEXPORT jstring JNICALL Java_com_yuanrong_jni_LibRuntime_GetInstanceRoute(JNIE jstring objectID) { std::string cobjectID = YR::jni::JNIString::FromJava(env, objectID); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - auto instanceRoute = - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->GetInstanceRoute(cobjectID); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto instanceRoute = libRuntime->GetInstanceRoute(cobjectID); jstring jinstanceRoute = YR::jni::JNIString::FromCc(env, instanceRoute); ASSERT_NOT_NULL(jinstanceRoute); @@ -847,18 +982,19 @@ JNIEXPORT void JNICALL Java_com_yuanrong_jni_LibRuntime_SaveInstanceRoute(JNIEnv { std::string cobjectID = YR::jni::JNIString::FromJava(env, objectID); std::string cinstanceRoute = YR::jni::JNIString::FromJava(env, instanceRoute); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->SaveInstanceRoute(cobjectID, cinstanceRoute); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN_VOID(env, libRuntime, "exception occurred because LibRuntime is null"); + libRuntime->SaveInstanceRoute(cobjectID, cinstanceRoute); } JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KillSync(JNIEnv *env, jclass c, jstring instanceID) { auto instID = YR::jni::JNIString::FromJava(env, instanceID); - auto result = get_runtime_context_callback(env, c); - auto rtCtx = YR::jni::JNIString::FromJava(env, (jstring)result); - auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx)->Kill( - instID, libruntime::Signal::killInstanceSync); + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto err = libRuntime->Kill(instID, libruntime::Signal::killInstanceSync); jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); if (jerr == nullptr) { YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Libruntime_Kill, get null"); @@ -867,6 +1003,35 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_KillSync(JNIEnv *env, return jerr; } +/* + * Class: com_yuanrong_jni_LibRuntime + * Method: Nodes + * Signature: ()Lcom/yuanrong/errorcode/Pair + */ +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_nodes(JNIEnv *env, jclass c) +{ + auto rtCtx = get_runtime_context_callback(env, c); + auto libRuntime = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime(rtCtx); + CHECK_NULL_THROW_NEW_AND_RETURN(env, libRuntime, nullptr, "exception occurred because LibRuntime is null"); + auto [err, resourceUnitVector] = libRuntime->GetResources(); + if (!err.OK()) { + YRLOG_WARN("failed to GetResources, err({}), msg({})", fmt::underlying(err.Code()), err.Msg()); + YR::jni::JNILibruntimeException::Throw( + env, err.Code(), err.MCode(), + "failed to GetResources, err " + std::to_string(err.Code()) + ", msg: " + err.Msg()); + return nullptr; + } + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + jobject jnodes = YR::jni::JNIArrayList::FromCc( + env, resourceUnitVector, [](JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit) { + return YR::jni::JNINode::FromCc(env, resourceUnit); + }); + jobject jpair = YR::jni::JNIPair::CreateJPair(env, jerr, jnodes); + env->DeleteLocalRef(jerr); + env->DeleteLocalRef(jnodes); + return jpair; +} + #ifdef __cplusplus } #endif diff --git a/api/java/function-common/src/main/cpp/com_yuanrong_jni_LibRuntime.h b/api/java/function-common/src/main/cpp/com_yuanrong_jni_LibRuntime.h index 0f48b73..ef6ab50 100644 --- a/api/java/function-common/src/main/cpp/com_yuanrong_jni_LibRuntime.h +++ b/api/java/function-common/src/main/cpp/com_yuanrong_jni_LibRuntime.h @@ -18,11 +18,14 @@ /* Header for class com_yuanrong_jni_LibRuntime */ #pragma once + +#include + #ifdef __cplusplus extern "C" { #endif -static jobject get_runtime_context_callback(JNIEnv *env, jclass c); +static std::string get_runtime_context_callback(JNIEnv *env, jclass c); /* * Class: com_yuanrong_jni_LibRuntime @@ -224,6 +227,21 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_GroupWait(JNIEnv *, j JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_processLog(JNIEnv *, jclass, jobject); +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_CreateStreamProducer(JNIEnv *, jclass, jstring, jlong, + jlong, jlong, jboolean, jboolean, + jlong, jlong); + +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_CreateStreamConsumer(JNIEnv *, jclass, jstring, jstring, + jobject, jboolean); + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_LibRuntime_DeleteStream(JNIEnv *, jclass, jstring); + +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_QueryGlobalProducersNum(JNIEnv *, jclass, jstring); + +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_QueryGlobalConsumersNum(JNIEnv *, jclass, jstring); + +JNIEXPORT jlong JNICALL Java_com_yuanrong_jni_LibRuntime_CreateStreamProducerWithConfig(JNIEnv *, jclass, + jstring, jobject); /* * Class: com_yuanrong_jni_LibRuntime * Method: GetInstanceRoute diff --git a/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.cpp b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.cpp new file mode 100644 index 0000000..f1141b0 --- /dev/null +++ b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "com_yuanrong_jni_Producer.h" +#include + +#include "jni_types.h" + +#include "src/dto/stream_conf.h" +#include "src/libruntime/err_type.h" +#include "src/libruntime/streamstore/stream_producer_consumer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBufferDefaultTimeout(JNIEnv *env, jclass, + jlong handle, + jbyteArray bytes, + jlong len) +{ + auto producer = reinterpret_cast *>(handle); + jbyte *bytekey = env->GetByteArrayElements(bytes, 0); + if (bytekey == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "An empty array is passed in the Java layer."); + return nullptr; + } else { + YR::Libruntime::Element element(reinterpret_cast(bytekey), len); + auto err = (*producer)->Send(element); + env->ReleaseByteArrayElements(bytes, bytekey, 0); + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Producer_send, get null"); + return nullptr; + } + return jerr; + } +} + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBufferDefaultTimeout(JNIEnv *env, jclass, + jlong handle, + jobject buf) +{ + auto producer = reinterpret_cast *>(handle); + auto body = env->GetDirectBufferAddress(buf); + if (!body) { + YR::jni::JNILibruntimeException::ThrowNew(env, "cannot get element address"); + return nullptr; + } + int limit = YR::jni::JNIByteBuffer::GetByteBufferLimit(env, buf); + YR::Libruntime::Element element(static_cast(body), limit); + auto err = (*producer)->Send(element); + body = NULL; + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Producer_send, get null"); + return nullptr; + } + return jerr; +} + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBuffer(JNIEnv *env, jclass, jlong handle, + jbyteArray bytes, jlong len, + jint timeoutMs) +{ + auto producer = reinterpret_cast *>(handle); + jbyte *bytekey = env->GetByteArrayElements(bytes, 0); + if (bytekey == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "An empty array is passed in the Java layer."); + return nullptr; + } else { + YR::Libruntime::Element element(reinterpret_cast(bytekey), len); + auto err = (*producer)->Send(element, timeoutMs); + env->ReleaseByteArrayElements(bytes, bytekey, 0); + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Producer_send, get null"); + return nullptr; + } + return jerr; + } +} + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBuffer(JNIEnv *env, jclass, jlong handle, + jobject buf, jint timeoutMs) +{ + auto producer = reinterpret_cast *>(handle); + auto body = env->GetDirectBufferAddress(buf); + if (!body) { + YR::jni::JNILibruntimeException::ThrowNew(env, "cannot get element address"); + return nullptr; + } + int limit = YR::jni::JNIByteBuffer::GetByteBufferLimit(env, buf); + YR::Libruntime::Element element(static_cast(body), limit); + auto err = (*producer)->Send(element, timeoutMs); + body = NULL; + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Producer_send, get null"); + return nullptr; + } + return jerr; +} + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_flush(JNIEnv *env, jclass, jlong handle) +{ + auto producer = reinterpret_cast *>(handle); + auto err = (*producer)->Flush(); + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Producer_flush, get null"); + return nullptr; + } + return jerr; +} + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_close(JNIEnv *env, jclass, jlong handle) +{ + auto producer = reinterpret_cast *>(handle); + auto err = (*producer)->Close(); + delete (producer); + jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); + if (jerr == nullptr) { + YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Producer_close, get null"); + return nullptr; + } + return jerr; +} + +JNIEXPORT void JNICALL Java_com_yuanrong_jni_JniProducer_freeJNIPtrNative(JNIEnv *, jclass, jlong handle) +{ + auto producer = reinterpret_cast *>(handle); + delete producer; +} + +#ifdef __cplusplus +} +#endif diff --git a/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.h b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.h new file mode 100644 index 0000000..8b1d099 --- /dev/null +++ b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#pragma once +#ifdef __cplusplus +extern "C" { +#endif + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBufferDefaultTimeout(JNIEnv *, jclass, jlong, + jbyteArray, jlong); + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBufferDefaultTimeout(JNIEnv *, jclass, + jlong, jobject); + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBuffer(JNIEnv *, jclass, jlong, jbyteArray, + jlong, jint); + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBuffer(JNIEnv *, jclass, jlong, jobject, + jint); + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_flush(JNIEnv *, jclass, jlong); + +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_close(JNIEnv *, jclass, jlong); + +JNIEXPORT void JNICALL Java_com_yuanrong_jni_JniProducer_freeJNIPtrNative(JNIEnv *, jclass, jlong); + +#ifdef __cplusplus +} +#endif diff --git a/api/java/function-common/src/main/cpp/jni_errorinfo.cpp b/api/java/function-common/src/main/cpp/jni_errorinfo.cpp index 03f7c2f..ab3b342 100644 --- a/api/java/function-common/src/main/cpp/jni_errorinfo.cpp +++ b/api/java/function-common/src/main/cpp/jni_errorinfo.cpp @@ -61,8 +61,13 @@ jobject JNIErrorInfo::FromCc(JNIEnv *env, const YR::Libruntime::ErrorInfo &error } jobject jstackTraceInfos = JNIStackTraceInfo::ListFromCc(env, errorInfo.GetStackTraceInfos()); + jobject jerrorInfo = (jobject)env->NewObject(clz_, init_, jerrorCode, jmoduleCode, jmsg, jstackTraceInfos); - return (jobject)env->NewObject(clz_, init_, jerrorCode, jmoduleCode, jmsg, jstackTraceInfos); + env->DeleteLocalRef(jmsg); + env->DeleteLocalRef(jerrorCode); + env->DeleteLocalRef(jmoduleCode); + env->DeleteLocalRef(jstackTraceInfos); + return jerrorInfo; } YR::Libruntime::ErrorInfo JNIErrorInfo::FromJava(JNIEnv *env, jobject o) @@ -71,8 +76,13 @@ YR::Libruntime::ErrorInfo JNIErrorInfo::FromJava(JNIEnv *env, jobject o) std::string cmsg = JNIString::FromJava(env, jstr); env->DeleteLocalRef(jstr); - YR::Libruntime::ErrorCode errorCode = JNIErrorCode::FromJava(env, env->CallObjectMethod(o, getCode_)); - YR::Libruntime::ModuleCode moduleCode = JNIModuleCode::FromJava(env, env->CallObjectMethod(o, getMCode_)); + jobject errorCodeObj = env->CallObjectMethod(o, getCode_); + YR::Libruntime::ErrorCode errorCode = JNIErrorCode::FromJava(env, errorCodeObj); + env->DeleteLocalRef(errorCodeObj); + + jobject moduleCodeObj = env->CallObjectMethod(o, getMCode_); + YR::Libruntime::ModuleCode moduleCode = JNIModuleCode::FromJava(env, moduleCodeObj); + env->DeleteLocalRef(moduleCodeObj); jobject objList = env->CallObjectMethod(o, getStackTraceInfos_); std::vector stackTraceInfos = JNIStackTraceInfo::ListFromJava(env, objList); diff --git a/api/java/function-common/src/main/cpp/jni_function_meta.cpp b/api/java/function-common/src/main/cpp/jni_function_meta.cpp index 845b2d6..685f952 100644 --- a/api/java/function-common/src/main/cpp/jni_function_meta.cpp +++ b/api/java/function-common/src/main/cpp/jni_function_meta.cpp @@ -82,6 +82,13 @@ jobject JNIFunctionMeta::FromCc(JNIEnv *env, const YR::Libruntime::FunctionMeta jobject japiType = JNIApiType::FromCc(env, meta.apiType); jobject obj = env->CallStaticObjectMethod(factoryClz_, init_, jappName, jmoduleName, jfuncName, jclassName, jlanguage, japiType, jsignature); + env->DeleteLocalRef(jappName); + env->DeleteLocalRef(jmoduleName); + env->DeleteLocalRef(jfuncName); + env->DeleteLocalRef(jclassName); + env->DeleteLocalRef(jlanguage); + env->DeleteLocalRef(jsignature); + env->DeleteLocalRef(japiType); return obj; } diff --git a/api/java/function-common/src/main/cpp/jni_init.cpp b/api/java/function-common/src/main/cpp/jni_init.cpp index 9d4665e..518297b 100644 --- a/api/java/function-common/src/main/cpp/jni_init.cpp +++ b/api/java/function-common/src/main/cpp/jni_init.cpp @@ -19,6 +19,7 @@ JavaVM *jvm; +/// Load and cache frequently-used Java classes and methods jint JNI_OnLoad(JavaVM *vm, void *reserved) { JNIEnv *env; @@ -44,10 +45,12 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) YR::jni::JNIList::Init(env); YR::jni::JNIArrayList::Init(env); YR::jni::JNIByteBuffer::Init(env); + YR::jni::JNIFloat::Init(env); YR::jni::JNIIterator::Init(env); YR::jni::JNIMapEntry::Init(env); YR::jni::JNISet::Init(env); YR::jni::JNIMap::Init(env); + YR::jni::JNIHashMap::Init(env); YR::jni::JNIApiType::Init(env); YR::jni::JNILanguageType::Init(env); YR::jni::JNIInvokeType::Init(env); @@ -79,6 +82,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) YR::jni::JNIStackTraceElement::Init(env); YR::jni::JNIYRAutoInitInfo::Init(env); YR::jni::JNIFunctionLog::Init(env); + YR::jni::JNIProducerConfig::Init(env); + YR::jni::JNINode::Init(env); if (isOnCloud) { YR::jni::JNIReturnType::Init(env); YR::jni::JNICodeLoader::Init(env); @@ -90,6 +95,7 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) return CURRENT_JNI_VERSION; } +/// Load and cache frequently-used Java classes and methods void JNI_OnUnload(JavaVM *vm, void *reserved) { JNIEnv *env; @@ -98,12 +104,14 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) YR::jni::JNIString::Recycle(env); YR::jni::JNIApacheCommonsExceptionUtils::Recycle(env); YR::jni::JNIList::Recycle(env); + YR::jni::JNIFloat::Recycle(env); YR::jni::JNIArrayList::Recycle(env); YR::jni::JNIByteBuffer::Recycle(env); YR::jni::JNIIterator::Recycle(env); YR::jni::JNIMapEntry::Recycle(env); YR::jni::JNISet::Recycle(env); YR::jni::JNIMap::Recycle(env); + YR::jni::JNIHashMap::Recycle(env); YR::jni::JNIInvokeType::Recycle(env); YR::jni::JNIApiType::Recycle(env); YR::jni::JNILanguageType::Recycle(env); @@ -137,4 +145,6 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) YR::jni::JNIStackTraceElement::Recycle(env); YR::jni::JNIYRAutoInitInfo::Recycle(env); YR::jni::JNIFunctionLog::Recycle(env); + YR::jni::JNIProducerConfig::Recycle(env); + YR::jni::JNINode::Recycle(env); } \ No newline at end of file diff --git a/api/java/function-common/src/main/cpp/jni_stacktrace_element.cpp b/api/java/function-common/src/main/cpp/jni_stacktrace_element.cpp index 570694f..9992383 100644 --- a/api/java/function-common/src/main/cpp/jni_stacktrace_element.cpp +++ b/api/java/function-common/src/main/cpp/jni_stacktrace_element.cpp @@ -75,7 +75,11 @@ jobject JNIStackTraceElement::FromCc(JNIEnv *env, const YR::Libruntime::StackTra return nullptr; } - return (jobject)env->NewObject(clz_, init_, jclassName, jmethodName, jfileName, lineNumberInt); + jobject jStackTraceElement = env->NewObject(clz_, init_, jclassName, jmethodName, jfileName, lineNumberInt); + env->DeleteLocalRef(jclassName); + env->DeleteLocalRef(jmethodName); + env->DeleteLocalRef(jfileName); + return jStackTraceElement; } YR::Libruntime::StackTraceElement JNIStackTraceElement::FromJava(JNIEnv *env, jobject obj) diff --git a/api/java/function-common/src/main/cpp/jni_stacktrace_info.cpp b/api/java/function-common/src/main/cpp/jni_stacktrace_info.cpp index 5bc9688..2d161ec 100644 --- a/api/java/function-common/src/main/cpp/jni_stacktrace_info.cpp +++ b/api/java/function-common/src/main/cpp/jni_stacktrace_info.cpp @@ -60,23 +60,33 @@ jobject JNIStackTraceInfo::FromCc(JNIEnv *env, const YR::Libruntime::StackTraceI return nullptr; } - return (jobject)env->NewObject(clz_, init_, jtype, jmessage, jstackTraceElements, jlanguage); + jobject jstackTraceInfo = (jobject)env->NewObject(clz_, init_, jtype, jmessage, jstackTraceElements, jlanguage); + + env->DeleteLocalRef(jtype); + env->DeleteLocalRef(jmessage); + env->DeleteLocalRef(jlanguage); + env->DeleteLocalRef(jstackTraceElements); + return jstackTraceInfo; } YR::Libruntime::StackTraceInfo JNIStackTraceInfo::FromJava(JNIEnv *env, jobject o) { jstring jtype = static_cast(env->CallObjectMethod(o, getType_)); std::string ctype = JNIString::FromJava(env, jtype); + env->DeleteLocalRef(jtype); jstring jmessage = static_cast(env->CallObjectMethod(o, getMessage_)); std::string cmessage = JNIString::FromJava(env, jmessage); + env->DeleteLocalRef(jmessage); jstring jlanguage = static_cast(env->CallObjectMethod(o, getLanguage_)); std::string clanguage = JNIString::FromJava(env, jlanguage); + env->DeleteLocalRef(jlanguage); jobject objList = env->CallObjectMethod(o, getStackTraceElements_); std::vector stackTraceElements = JNIStackTraceElement::ListFromJava(env, objList); + env->DeleteLocalRef(objList); YR::Libruntime::StackTraceInfo stackTraceInfo(ctype, cmessage, stackTraceElements, clanguage); return stackTraceInfo; } diff --git a/api/java/function-common/src/main/cpp/jni_types.cpp b/api/java/function-common/src/main/cpp/jni_types.cpp index d16b1bf..a8b0e7e 100644 --- a/api/java/function-common/src/main/cpp/jni_types.cpp +++ b/api/java/function-common/src/main/cpp/jni_types.cpp @@ -53,6 +53,8 @@ const int INSTANCE_PREFERRED_ANTI = 22; const int INSTANCE_REQUIRED = 23; const int INSTANCE_REQUIRED_ANTI = 24; const int MAX_PASSWD_LENGTH = 100; +const int POD = 1; +const int NODE = 2; inline jfieldID GetJStaticField(JNIEnv *env, const jclass &clz, const std::string &fieldName, const std::string &sig) { @@ -130,6 +132,9 @@ void JNILibruntimeException::Throw(JNIEnv *env, const YR::Libruntime::ErrorCode jthrowable exception = (jthrowable)env->NewObject(clz_, constructorId, jerrorCode, jmoduleCode, jmessage); env->Throw(exception); + env->DeleteLocalRef(jerrorCode); + env->DeleteLocalRef(jmoduleCode); + env->DeleteLocalRef(jmessage); } void JNIString::Init(JNIEnv *env) {} @@ -171,6 +176,25 @@ std::unordered_set JNIString::FromJArrayToUnorderedSet(JNIEnv *env, }); } +void JNIFloat::Init(JNIEnv *env) +{ + clz_ = LoadClass(env, "java/lang/Float"); + jmInit_ = env->GetMethodID(clz_, "", "(F)V"); +} + +void JNIFloat::Recycle(JNIEnv *env) +{ + if (clz_) { + env->DeleteGlobalRef(clz_); + } +} + +jobject JNIFloat::FromCc(JNIEnv *env, const float &arg) +{ + float val = arg; + return env->NewObject(clz_, jmInit_, val); +} + void JNIList::Init(JNIEnv *env) { clz_ = LoadClass(env, "java/util/List"); @@ -254,8 +278,10 @@ std::string JNIApacheCommonsExceptionUtils::GetStackTrace(JNIEnv *env, jthrowabl env->ReleaseStringUTFChars(jst, cstr); if (env->ExceptionOccurred()) { YRLOG_ERROR("Exception occurred when convert exception info to C string."); + env->DeleteLocalRef(jst); return "exception occurred when convert exception info to C string"; } + env->DeleteLocalRef(jst); return result; } @@ -677,6 +703,7 @@ void JNILibRuntimeConfig::Init(JNIEnv *env) jmGetRuntimePrivateKeyContextPath_ = GetJMethod(env, clz_, "getRuntimePrivateKeyContextPath", "()Ljava/lang/String;"); jmGetVerifyFilePath_ = GetJMethod(env, clz_, "getVerifyFilePath", "()Ljava/lang/String;"); + jmGetPrivateKeyPaaswd_ = GetJMethod(env, clz_, "getPrivateKeyPaaswd", "()Ljava/lang/String;"); jmGetServerName_ = GetJMethod(env, clz_, "getServerName", "()Ljava/lang/String;"); jmGetFunctionSystemIpAddr_ = GetJMethod(env, clz_, "getFunctionSystemIpAddr", "()Ljava/lang/String;"); jmGetFunctionSystemPort_ = GetJMethod(env, clz_, "getFunctionSystemPort", "()I"); @@ -701,6 +728,8 @@ void JNILibRuntimeConfig::Init(JNIEnv *env) jmGetMaxConcurrencyCreateNum_ = GetJMethod(env, clz_, "getMaxConcurrencyCreateNum", "()I"); jmGetThreadPoolSize_ = GetJMethod(env, clz_, "getThreadPoolSize", "()I"); jmGetLoadPaths_ = GetJMethod(env, clz_, "getLoadPaths", "()Ljava/util/List;"); + jGetHttpIocThreadsNum_ = GetJMethod(env, clz_, "getHttpIocThreadsNum", "()I"); + jGetHttpIdleTime_ = GetJMethod(env, clz_, "getHttpIdleTime", "()I"); jGetRpcTimeout_ = GetJMethod(env, clz_, "getRpcTimeout", "()I"); jGetTenantId_ = GetJMethod(env, clz_, "getTenantId", "()Ljava/lang/String;"); jGetNs_ = GetJMethod(env, clz_, "getNs", "()Ljava/lang/String;"); @@ -781,7 +810,19 @@ YR::Libruntime::LibruntimeConfig JNILibRuntimeConfig::FromJava(JNIEnv *env, cons [](JNIEnv *env, const jobject &vo) -> std::string { return JNIString::FromJava(env, static_cast(vo)); }); + + jstring jStrPasswd = static_cast(env->CallObjectMethod(meta, jmGetPrivateKeyPaaswd_)); + if (jStrPasswd != nullptr) { + const char *passwd = env->GetStringUTFChars(jStrPasswd, nullptr); + if (passwd != nullptr) { + size_t passwdLen = strlen(passwd) + 1; + memcpy_s(libConfig.privateKeyPaaswd, passwdLen, passwd, passwdLen); + env->ReleaseStringUTFChars(jStrPasswd, passwd); + } + } libConfig.inCluster = static_cast(env->CallBooleanMethod(meta, jmIsInCluster_)); + libConfig.httpIocThreadsNum = static_cast(env->CallIntMethod(meta, jGetHttpIocThreadsNum_)); + libConfig.httpIdleTime = env->CallIntMethod(meta, jGetHttpIdleTime_); libConfig.rpcTimeout = static_cast(env->CallIntMethod(meta, jGetRpcTimeout_)); auto codePath = JNIList::FromJava( @@ -828,6 +869,36 @@ std::unordered_map JNIMap::FromJava(JNIEnv *env, const jobject &jmap, return cmap; } +void JNIHashMap::Init(JNIEnv *env) +{ + clz_ = LoadClass(env, "java/util/HashMap"); + init_ = GetJMethod(env, clz_, "", "()V"); + jmPut_ = GetJMethod(env, clz_, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); +} + +void JNIHashMap::Recycle(JNIEnv *env) +{ + if (clz_) { + env->DeleteGlobalRef(clz_); + } +} + +template +jobject JNIHashMap::FromCc(JNIEnv *env, const std::unordered_map &map, + std::function converterK, std::function converterV) +{ + jobject hashMap = env->NewObject(clz_, init_); + for (const auto& kv : map) { + auto tmpKey = converterK(env, kv.first); + auto tmpValue = converterV(env, kv.second); + env->CallObjectMethod(hashMap, jmPut_, tmpKey, tmpValue); + + env->DeleteLocalRef(tmpKey); + env->DeleteLocalRef(tmpValue); + } + return hashMap; +} + void JNISet::Init(JNIEnv *env) { clz_ = LoadClass(env, "java/util/Set"); @@ -1479,7 +1550,7 @@ jobject JNIErrorCode::FromCc(JNIEnv *env, const YR::Libruntime::ErrorCode &error }; if (auto it = fieldMap.find(errorCode); it == fieldMap.end()) { - YRLOG_ERROR("Failed to match errorcode, code: ", errorCode); + YRLOG_ERROR("Failed to match errorcode, code: ", fmt::underlying(errorCode)); return nullptr; } return env->NewObject(clz_, jfInitWithInt_, static_cast(fieldMap[errorCode])); @@ -1605,6 +1676,7 @@ void JNIAffinity::Init(JNIEnv *env) clz_ = LoadClass(env, "com/yuanrong/affinity/Affinity"); getValue_ = GetJMethod(env, clz_, "getAffinityValue", "()I"); getOperators_ = GetJMethod(env, clz_, "getLabelOperators", "()Ljava/util/List;"); + getAffinityScopeValue_ = GetJMethod(env, clz_, "getAffinityScopeValue", "()I"); } std::shared_ptr JNIAffinity::FromJava(JNIEnv *env, jobject o, bool preferredPriority, @@ -1653,6 +1725,18 @@ std::shared_ptr JNIAffinity::FromJava(JNIEnv *env, job affinity->SetRequiredPriority(requiredPriority); affinity->SetPreferredAntiOtherLabels(preferredAntiOtherLabels); } + int affinityScope = static_cast(env->CallIntMethod(o, getAffinityScopeValue_)); + switch (affinityScope) { + case POD: + affinity->SetAffinityScope(YR::Libruntime::AFFINITYSCOPE_POD); + break; + case NODE: + affinity->SetAffinityScope(YR::Libruntime::AFFINITYSCOPE_NODE); + break; + default: + YRLOG_DEBUG("affinityScope is not set."); + break; + } return affinity; } @@ -2080,8 +2164,13 @@ jobject JNIInternalWaitResult::FromCc(JNIEnv *env, const std::shared_ptrfirst); jobject err = JNIErrorInfo::FromCc(env, it->second); env->CallObjectMethod(jmap, mPut_, key, err); + env->DeleteLocalRef(key); + env->DeleteLocalRef(err); } jobject internalWaitRes = env->NewObject(clz_, init_, jreadyList, junreadyList, jmap); + env->DeleteLocalRef(jreadyList); + env->DeleteLocalRef(junreadyList); + env->DeleteLocalRef(jmap); return internalWaitRes; } @@ -2140,11 +2229,16 @@ void JNIYRAutoInitInfo::Recycle(JNIEnv *env) jobject JNIYRAutoInitInfo::FromCc(JNIEnv *env, YR::Libruntime::ClusterAccessInfo info) { - jobject yrAutoInitInfo = env->NewObject(clz_, init_, env->NewStringUTF(info.serverAddr.c_str()), - env->NewStringUTF(info.dsAddr.c_str()), info.inCluster); + jstring jServerAddr = env->NewStringUTF(info.serverAddr.c_str()); + jstring jDsAddr = env->NewStringUTF(info.dsAddr.c_str()); + + jobject yrAutoInitInfo = env->NewObject(clz_, init_, jServerAddr, jDsAddr, info.inCluster); if (yrAutoInitInfo == nullptr) { YRLOG_WARN("Failed to create Java object of com/yuanrong/jni/YRAutoInitInfo"); } + + env->DeleteLocalRef(jServerAddr); + env->DeleteLocalRef(jDsAddr); return yrAutoInitInfo; } @@ -2298,5 +2392,152 @@ int JNIFunctionLog::GetErrorCode(JNIEnv *env, jobject obj) int result = static_cast(value); return result; } + +void JNIProducerConfig::Init(JNIEnv *env) +{ + clz_ = LoadClass(env, "com/yuanrong/stream/ProducerConfig"); + jmGetDelayFlushTimeMs_ = GetJMethod(env, clz_, "getDelayFlushTimeMs", "()J"); + jmGetPageSizeByte_ = GetJMethod(env, clz_, "getPageSizeByte", "()J"); + jmGetMaxStreamSize_ = GetJMethod(env, clz_, "getMaxStreamSize", "()J"); + jmGetAutoCleanup_ = GetJMethod(env, clz_, "isAutoCleanup", "()Z"); + jmGetEncryptStream_ = GetJMethod(env, clz_, "isEncryptStream", "()Z"); + jmGetRetainForNumConsumers_ = GetJMethod(env, clz_, "getRetainForNumConsumers", "()J"); + jmGetReserveSize_ = GetJMethod(env, clz_, "getReserveSize", "()J"); + jmGetExtendConfig_ = GetJMethod(env, clz_, "getExtendConfig", "()Ljava/util/Map;"); +} + +void JNIProducerConfig::Recycle(JNIEnv *env) +{ + if (clz_) { + env->DeleteGlobalRef(clz_); + } +} + +YR::Libruntime::ProducerConf JNIProducerConfig::FromJava(JNIEnv *env, jobject obj) +{ + YR::Libruntime::ProducerConf producerConf; + producerConf.delayFlushTime = GetDelayFlushTimeMs(env, obj); + producerConf.pageSize = GetPageSizeByte(env, obj); + producerConf.maxStreamSize = GetMaxStreamSize(env, obj); + producerConf.autoCleanup = GetAutoCleanup(env, obj); + producerConf.encryptStream = GetEncryptStream(env, obj); + producerConf.retainForNumConsumers = GetRetainForNumConsumers(env, obj); + producerConf.reserveSize = GetReserveSize(env, obj); + producerConf.extendConfig = GetExtendConfig(env, obj); + return producerConf; +} + +int64_t JNIProducerConfig::GetDelayFlushTimeMs(JNIEnv *env, jobject obj) +{ + jlong value = static_cast(env->CallLongMethod(obj, jmGetDelayFlushTimeMs_)); + int64_t result = static_cast(value); + return result; +} + +int64_t JNIProducerConfig::GetPageSizeByte(JNIEnv *env, jobject obj) +{ + jlong value = static_cast(env->CallLongMethod(obj, jmGetPageSizeByte_)); + int64_t result = static_cast(value); + return result; +} + +uint64_t JNIProducerConfig::GetMaxStreamSize(JNIEnv *env, jobject obj) +{ + jlong value = static_cast(env->CallLongMethod(obj, jmGetMaxStreamSize_)); + uint64_t result = static_cast(value); + return result; +} + +bool JNIProducerConfig::GetAutoCleanup(JNIEnv *env, jobject obj) +{ + jboolean value = static_cast(env->CallBooleanMethod(obj, jmGetAutoCleanup_)); + bool result = static_cast(value); + return result; +} + +bool JNIProducerConfig::GetEncryptStream(JNIEnv *env, jobject obj) +{ + jboolean value = static_cast(env->CallBooleanMethod(obj, jmGetEncryptStream_)); + bool result = static_cast(value); + return result; +} + +uint64_t JNIProducerConfig::GetRetainForNumConsumers(JNIEnv *env, jobject obj) +{ + jlong value = static_cast(env->CallLongMethod(obj, jmGetRetainForNumConsumers_)); + uint64_t result = static_cast(value); + return result; +} + +uint64_t JNIProducerConfig::GetReserveSize(JNIEnv *env, jobject obj) +{ + jlong value = static_cast(env->CallLongMethod(obj, jmGetReserveSize_)); + uint64_t result = static_cast(value); + return result; +} + +std::unordered_map JNIProducerConfig::GetExtendConfig(JNIEnv *env, jobject obj) +{ + std::unordered_map result = JNIMap::FromJava( + env, env->CallObjectMethod(obj, jmGetExtendConfig_), + [](JNIEnv *env, const jobject &ko) -> std::string { + return JNIString::FromJava(env, static_cast(ko)); + }, + [](JNIEnv *env, const jobject &vo) -> std::string { + return JNIString::FromJava(env, static_cast(vo)); + }); + return result; +} + +void JNINode::Init(JNIEnv *env) +{ + clz_ = LoadClass(env, "com/yuanrong/api/Node"); + init_ = GetJMethod(env, clz_, "", "()V"); + jmInit_ = GetJMethod(env, clz_, "", "(Ljava/lang/String;ZLjava/util/Map;Ljava/util/Map;)V"); +} + +void JNINode::Recycle(JNIEnv *env) +{ + if (clz_) { + env->DeleteGlobalRef(clz_); + } +} + +jobject JNINode::GetResourcesFromResourceUnit(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit) +{ + return JNIHashMap::FromCc(env, resourceUnit.capacity, + [](JNIEnv *env, const std::string &key) { return env->NewStringUTF(key.c_str()); }, + [](JNIEnv *env, const float &value) { return JNIFloat::FromCc(env, value); } + ); +} + + +jobject JNINode::GetLabelsFromResourceUnit(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit) +{ + return JNIHashMap::FromCc>(env, resourceUnit.nodeLabels, + [](JNIEnv *env, const std::string &key) { + return env->NewStringUTF(key.c_str()); + }, + [](JNIEnv *env, const std::vector &value) { + return JNIArrayList::FromCc( + env, value, [](JNIEnv *env, const std::string &s) { return env->NewStringUTF(s.c_str()); } + ); + } + ); +} + +jobject JNINode::FromCc(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit) +{ + jstring jId = JNIString::FromCc(env, resourceUnit.id); + jboolean jAlive = (resourceUnit.status == 0) ? JNI_TRUE : JNI_FALSE; + jobject jResources = JNINode::GetResourcesFromResourceUnit(env, resourceUnit); + jobject jLabels = JNINode::GetLabelsFromResourceUnit(env, resourceUnit); + jobject result = env->NewObject(clz_, jmInit_, jId, jAlive, jResources, jLabels); + + env->DeleteLocalRef(jId); + env->DeleteLocalRef(jResources); + env->DeleteLocalRef(jLabels); + return result; +} } // namespace jni } // namespace YR diff --git a/api/java/function-common/src/main/cpp/jni_types.h b/api/java/function-common/src/main/cpp/jni_types.h index 395fa23..d8a5aae 100644 --- a/api/java/function-common/src/main/cpp/jni_types.h +++ b/api/java/function-common/src/main/cpp/jni_types.h @@ -38,7 +38,9 @@ #include "src/dto/internal_wait_result.h" #include "src/dto/invoke_arg.h" #include "src/dto/invoke_options.h" +#include "src/dto/resource_unit.h" #include "src/dto/status.h" +#include "src/dto/stream_conf.h" #include "src/libruntime/auto_init.h" #include "src/libruntime/err_type.h" #include "src/libruntime/libruntime_config.h" @@ -119,6 +121,20 @@ using FunctionLog = ::libruntime::FunctionLog; } \ } while (false) +#define CHECK_NULL_THROW_NEW_AND_RETURN(env, ptr, returnValue, msg) \ + if ((ptr) == nullptr) { \ + YR::jni::JNILibruntimeException::Throw(env, YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, \ + YR::Libruntime::ModuleCode::RUNTIME, std::string(msg)); \ + return returnValue; \ + } + +#define CHECK_NULL_THROW_NEW_AND_RETURN_VOID(env, ptr, msg) \ + if ((ptr) == nullptr) { \ + YR::jni::JNILibruntimeException::Throw(env, YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, \ + YR::Libruntime::ModuleCode::RUNTIME, std::string(msg)); \ + return; \ + } + inline jclass LoadClass(JNIEnv *env, const std::string &className) { jclass tempLocalClassRef = env->FindClass(className.c_str()); @@ -188,6 +204,18 @@ private: inline static jclass clz_ = nullptr; }; +class JNIFloat { +public: + static void Init(JNIEnv *env); + static void Recycle(JNIEnv *env); + + static jobject FromCc(JNIEnv *env, const float &arg); + +private: + inline static jclass clz_ = nullptr; + inline static jmethodID jmInit_ = nullptr; +}; + class JNIList { public: static void Init(JNIEnv *env); @@ -252,6 +280,7 @@ public: for (const T &ele : vect) { auto eleTmp = converter(env, ele); JNIList::Add(env, jlst, eleTmp); + env->DeleteLocalRef(eleTmp); } return jlst; } @@ -374,6 +403,20 @@ private: inline static jmethodID jmEntrySet_ = nullptr; }; +class JNIHashMap { +public: + static void Init(JNIEnv *env); + static void Recycle(JNIEnv *env); + template + static jobject FromCc(JNIEnv *env, const std::unordered_map &map, + std::function converterK, std::function converterV); + +private: + inline static jclass clz_ = nullptr; + inline static jmethodID init_ = nullptr; + inline static jmethodID jmPut_ = nullptr; +}; + class JNIInvokeType { public: static void Init(JNIEnv *env); @@ -459,6 +502,7 @@ private: inline static jmethodID jmGetRuntimePublicKeyContextPath_ = nullptr; inline static jmethodID jmGetRuntimePrivateKeyContextPath_ = nullptr; inline static jmethodID jmGetVerifyFilePath_ = nullptr; + inline static jmethodID jmGetPrivateKeyPaaswd_ = nullptr; inline static jmethodID jmGetServerName_ = nullptr; inline static jmethodID jmGetFunctionSystemIpAddr_ = nullptr; inline static jmethodID jmGetFunctionSystemPort_ = nullptr; @@ -482,6 +526,8 @@ private: inline static jmethodID jmGetMaxConcurrencyCreateNum_ = nullptr; inline static jmethodID jmGetThreadPoolSize_ = nullptr; inline static jmethodID jmGetLoadPaths_ = nullptr; + inline static jmethodID jGetHttpIocThreadsNum_ = nullptr; + inline static jmethodID jGetHttpIdleTime_ = nullptr; inline static jmethodID jGetRpcTimeout_ = nullptr; inline static jmethodID jGetTenantId_ = nullptr; inline static jmethodID jGetNs_ = nullptr; @@ -677,6 +723,7 @@ private: inline static jclass clz_ = nullptr; inline static jmethodID getValue_ = nullptr; inline static jmethodID getOperators_ = nullptr; + inline static jmethodID getAffinityScopeValue_ = nullptr; }; class JNIReturnType { @@ -911,5 +958,45 @@ private: inline static jmethodID jmIsFinish_ = nullptr; }; +class JNIProducerConfig { +public: + static void Init(JNIEnv *env); + static void Recycle(JNIEnv *env); + static YR::Libruntime::ProducerConf FromJava(JNIEnv *env, jobject obj); + static int64_t GetDelayFlushTimeMs(JNIEnv *env, jobject o); + static int64_t GetPageSizeByte(JNIEnv *env, jobject obj); + static uint64_t GetMaxStreamSize(JNIEnv *env, jobject obj); + static bool GetAutoCleanup(JNIEnv *env, jobject obj); + static bool GetEncryptStream(JNIEnv *env, jobject obj); + static uint64_t GetRetainForNumConsumers(JNIEnv *env, jobject obj); + static uint64_t GetReserveSize(JNIEnv *env, jobject obj); + static std::unordered_map GetExtendConfig(JNIEnv *env, jobject obj); + +private: + inline static jclass clz_ = nullptr; + inline static jmethodID jmGetDelayFlushTimeMs_ = nullptr; + inline static jmethodID jmGetPageSizeByte_ = nullptr; + inline static jmethodID jmGetMaxStreamSize_ = nullptr; + inline static jmethodID jmGetAutoCleanup_ = nullptr; + inline static jmethodID jmGetEncryptStream_ = nullptr; + inline static jmethodID jmGetRetainForNumConsumers_ = nullptr; + inline static jmethodID jmGetReserveSize_ = nullptr; + inline static jmethodID jmGetExtendConfig_ = nullptr; +}; + +class JNINode { +public: + static void Init(JNIEnv *env); + static void Recycle(JNIEnv *env); + static jobject GetResourcesFromResourceUnit(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit); + static jobject GetLabelsFromResourceUnit(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit); + static jobject FromCc(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit); + +private: + inline static jclass clz_ = nullptr; + inline static jmethodID init_ = nullptr; + inline static jmethodID jmInit_ = nullptr; +}; + } // namespace jni } // namespace YR diff --git a/api/java/function-common/src/main/java/com/services/enums/FaasErrorCode.java b/api/java/function-common/src/main/java/com/services/enums/FaasErrorCode.java new file mode 100644 index 0000000..e7b81de --- /dev/null +++ b/api/java/function-common/src/main/java/com/services/enums/FaasErrorCode.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.services.enums; + +/** + * The faaS error code + * + * @since 2024-07-20 + */ +public enum FaasErrorCode { + NONE_ERROR(0, "no error"), + + FAAS_ERROR(500, "error in faaS executor"), + + FAAS_INIT_ERROR(6001, "illegal argument exception"), + + ENTRY_NOT_FOUND(4001, "user entry not found"), + + FUNCTION_RUN_ERROR(4002, "user function failed to run"), + + RESPONSE_EXCEED_LIMIT(4004, "response body size exceeds the limit of 6291456"), + + INITIALIZE_FUNCTION_ERROR(4009, "function initialization exception"), + + INVOKE_FUNCTION_TIMEOUT(4010, "invoke timed out"), + + INIT_FUNCTION_FAIL(4201, "function initialization failed"), + + INIT_FUNCTION_TIMEOUT(4211, "runtime initialization timed out"), + + REQUEST_BODY_EXCEED_LIMIT(4140, "request body exceeds limit"); + + private final int code; + + private final String errorMessage; + + FaasErrorCode(int code, String errorMessage) { + this.code = code; + this.errorMessage = errorMessage; + } + + /** + * getCode + * + * @return code + */ + public int getCode() { + return code; + } + + /** + * getErrorMessage + * + * @return errorMessage + */ + public String getErrorMessage() { + return errorMessage; + } +} diff --git a/api/java/function-common/src/main/java/com/services/exception/FaaSException.java b/api/java/function-common/src/main/java/com/services/exception/FaaSException.java new file mode 100644 index 0000000..6cd2794 --- /dev/null +++ b/api/java/function-common/src/main/java/com/services/exception/FaaSException.java @@ -0,0 +1,85 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.services.exception; + +import java.util.Objects; + +/** + * The faaS executor exception + * + * @since 2024-07-06 + */ +public class FaaSException extends Exception { + private final String errorMessage; + + /** + * constructor + * + * @param message message + * @param exception exception + */ + public FaaSException(String message, Throwable exception) { + super(exception); + this.errorMessage = message; + } + + /** + * constructor + * + * @param message message + */ + public FaaSException(String message) { + this.errorMessage = message; + } + + /** + * getMessage + * + * @return errorMessage + */ + public String getMessage() { + return this.errorMessage; + } + + /** + * equals + * + * @param obj the Object + * @return boolean + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FaaSException that = (FaaSException) obj; + return errorMessage.equals(that.errorMessage); + } + + /** + * hashCode + * + * @return hash value + */ + @Override + public int hashCode() { + return Objects.hash(errorMessage); + } +} diff --git a/api/java/function-common/src/main/java/com/services/logger/UserFunctionLogger.java b/api/java/function-common/src/main/java/com/services/logger/UserFunctionLogger.java index 8465de7..6df292c 100644 --- a/api/java/function-common/src/main/java/com/services/logger/UserFunctionLogger.java +++ b/api/java/function-common/src/main/java/com/services/logger/UserFunctionLogger.java @@ -18,12 +18,9 @@ package com.services.logger; import com.services.runtime.RuntimeLogger; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; /** * UserFunctionLogger @@ -34,15 +31,9 @@ public class UserFunctionLogger implements RuntimeLogger { /** * userLogger name at log4j2.xml */ - private static final Logger logger = LogManager.getLogger("userLogger"); + private static final Logger logger = LoggerFactory.getLogger(UserFunctionLogger.class); - private static final Set LOG_LEVELS = new HashSet<>(); - - static { - LOG_LEVELS.addAll(Arrays.asList("DEBUG", "INFO", "WARN", "ERROR")); - } - - private String logLevel = "INFO"; + private static String logLevel = "INFO"; /** * log with msg @@ -55,6 +46,9 @@ public class UserFunctionLogger implements RuntimeLogger { case "DEBUG": this.debug(msg); break; + case "INFO": + this.info(msg); + break; case "WARN": this.warn(msg); break; @@ -113,7 +107,7 @@ public class UserFunctionLogger implements RuntimeLogger { */ @Override public void setLevel(String level) { - if (LOG_LEVELS.contains(level)) { + if (level.equals("DEBUG") || level.equals("INFO") || level.equals("WARN") || level.equals("ERROR")) { this.logLevel = level; } } diff --git a/api/java/function-common/src/main/java/com/services/model/Response.java b/api/java/function-common/src/main/java/com/services/model/Response.java index 313550c..403710e 100644 --- a/api/java/function-common/src/main/java/com/services/model/Response.java +++ b/api/java/function-common/src/main/java/com/services/model/Response.java @@ -33,6 +33,8 @@ public abstract class Response { private String logResult; + private long userFuncTime; + /** * setBody * diff --git a/api/java/function-common/src/main/java/com/services/runtime/Context.java b/api/java/function-common/src/main/java/com/services/runtime/Context.java index 34b7076..ff13f66 100644 --- a/api/java/function-common/src/main/java/com/services/runtime/Context.java +++ b/api/java/function-common/src/main/java/com/services/runtime/Context.java @@ -41,6 +41,34 @@ public interface Context { */ int getRemainingTimeInMilliSeconds(); + /** + * Get the Access key information of the tenant + * + * @return access key + */ + String getAccessKey(); + + /** + * Get the Secret Key information of the tenant + * + * @return secret key + */ + String getSecretKey(); + + /** + * Get the Security Access key information of the tenant + * + * @return access key + */ + String getSecurityAccessKey(); + + /** + * Get the Security Secret Key information of the tenant + * + * @return secret key + */ + String getSecuritySecretKey(); + /** * Get the user data, which saved in a map * @@ -103,6 +131,21 @@ public interface Context { */ String getPackage(); + /** + * Get token of the tenant + * + * @return token + */ + String getToken(); + + /** + * Get security token. In order to invoke interface of other service, + * AK,SK and security token should been provided. + * + * @return security token + */ + String getSecurityToken(); + /** * Get function alias. * @@ -162,6 +205,13 @@ public interface Context { */ String getTraceID(); + /** + * Set the traceID + * + * @param traceID for the user traceId + */ + void setTraceID(String traceID); + /** * Gets the invoke id * @@ -211,6 +261,13 @@ public interface Context { */ String getFrontendResponseStreamName(); + /** + * Get iam token + * + * @return iam token + */ + String getIAMToken(); + /** * Get extra map * diff --git a/api/java/function-common/src/main/java/com/services/runtime/action/ContextImpl.java b/api/java/function-common/src/main/java/com/services/runtime/action/ContextImpl.java index 2b96449..40d66f7 100644 --- a/api/java/function-common/src/main/java/com/services/runtime/action/ContextImpl.java +++ b/api/java/function-common/src/main/java/com/services/runtime/action/ContextImpl.java @@ -107,6 +107,26 @@ public class ContextImpl implements Context { return (int) (this.funcMetaData.getTimeout() * 1000L - System.currentTimeMillis() + this.startTime); } + @Override + public String getAccessKey() { + return this.delegateDecrypt.getAccessKey(); + } + + @Override + public String getSecretKey() { + return this.delegateDecrypt.getSecretKey(); + } + + @Override + public String getSecurityAccessKey() { + return null; + } + + @Override + public String getSecuritySecretKey() { + return null; + } + @Override public String getUserData(String key) { return this.funcMetaData.getUserData().get(key); @@ -152,6 +172,16 @@ public class ContextImpl implements Context { return this.funcMetaData.getService(); } + @Override + public String getToken() { + return this.delegateDecrypt.getAuthToken(); + } + + @Override + public String getSecurityToken() { + return this.delegateDecrypt.getSecurityToken(); + } + @Override public String getAlias() { return this.funcMetaData.getAlias(); @@ -194,6 +224,11 @@ public class ContextImpl implements Context { return this.contextInvokeParams.getRequestID(); } + @Override + public void setTraceID(String requestID) { + this.contextInvokeParams.setRequestID(requestID); + } + @Override public String getInvokeID() { return this.contextInvokeParams.getInvokeID(); @@ -229,6 +264,11 @@ public class ContextImpl implements Context { return null; } + @Override + public String getIAMToken() { + return ""; + } + @Override public Map getExtraMap() { return null; diff --git a/api/java/function-common/src/main/java/com/services/runtime/action/ContextInvokeParams.java b/api/java/function-common/src/main/java/com/services/runtime/action/ContextInvokeParams.java index 90fcba1..7e5753f 100644 --- a/api/java/function-common/src/main/java/com/services/runtime/action/ContextInvokeParams.java +++ b/api/java/function-common/src/main/java/com/services/runtime/action/ContextInvokeParams.java @@ -27,8 +27,14 @@ import lombok.Setter; @Setter @Getter public class ContextInvokeParams { + private String accessKey = ""; + private String secretKey = ""; + private String securityAccessKey = ""; + private String securitySecretKey = ""; private String requestID = ""; private String invokeID = ""; + private String token = ""; + private String securityToken = ""; private String alias = ""; private String workflowID = ""; private String workflowRunID = ""; diff --git a/api/java/function-common/src/main/java/com/services/runtime/action/DelegateDecrypt.java b/api/java/function-common/src/main/java/com/services/runtime/action/DelegateDecrypt.java index e416112..630308c 100644 --- a/api/java/function-common/src/main/java/com/services/runtime/action/DelegateDecrypt.java +++ b/api/java/function-common/src/main/java/com/services/runtime/action/DelegateDecrypt.java @@ -27,6 +27,14 @@ import lombok.Data; */ @Data public class DelegateDecrypt { + private String accessKey; + + private String secretKey; + + private String authToken; + + private String securityToken; + private String environment; @SerializedName("encrypted_user_data") diff --git a/api/java/function-common/src/main/java/com/services/runtime/action/ExtendedMetaData.java b/api/java/function-common/src/main/java/com/services/runtime/action/ExtendedMetaData.java index 6987c47..96abde8 100644 --- a/api/java/function-common/src/main/java/com/services/runtime/action/ExtendedMetaData.java +++ b/api/java/function-common/src/main/java/com/services/runtime/action/ExtendedMetaData.java @@ -29,6 +29,9 @@ import lombok.Data; public class ExtendedMetaData { private Initializer initializer; + @SerializedName("pre_stop") + private PreStop preStop; + @SerializedName("log_tank_service") private LogTankService logTankService; } diff --git a/api/java/function-common/src/main/java/com/services/runtime/action/PreStop.java b/api/java/function-common/src/main/java/com/services/runtime/action/PreStop.java new file mode 100644 index 0000000..285c7b1 --- /dev/null +++ b/api/java/function-common/src/main/java/com/services/runtime/action/PreStop.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.services.runtime.action; + +import com.google.gson.annotations.SerializedName; + +import lombok.Data; + +/** + * The type pre_stop + * + * @since 2025.05.28 + */ +@Data +public class PreStop { + @SerializedName("pre_stop_handler") + private String preStopHandler; + + @SerializedName("pre_stop_timeout") + private int preStopTimeout; +} diff --git a/api/java/function-common/src/main/java/com/services/runtime/utils/Util.java b/api/java/function-common/src/main/java/com/services/runtime/utils/Util.java index a14a680..f877705 100644 --- a/api/java/function-common/src/main/java/com/services/runtime/utils/Util.java +++ b/api/java/function-common/src/main/java/com/services/runtime/utils/Util.java @@ -17,11 +17,13 @@ package com.services.runtime.utils; +import com.services.runtime.Context; import com.services.runtime.action.CustomLoggerStream; import com.yuanrong.jni.LibRuntime; import com.yuanrong.libruntime.generated.Socket.FunctionLog; import com.yuanrong.runtime.util.Constants; import com.yuanrong.runtime.util.ExtClasspathLoader; +import com.yuanrong.runtime.util.Utils; import lombok.extern.slf4j.Slf4j; @@ -31,6 +33,7 @@ import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Date; +import java.util.Locale; /** * The type Util. @@ -172,6 +175,51 @@ public class Util { inheritableThreadLocal.remove(); } + /** + * get ServiceNameFromEnv + * + * @param context the context + * @return service + */ + public static String getServiceNameFromEnv(Context context) { + String service = ""; + if (context != null) { + service = context.getPackage(); + } + if (service == null || service.isEmpty()) { + service = Utils.getFromJavaEnv("RUNTIME_PACKAGE"); + } + return service; + } + + /** + * get TenantIdFromEnv + * + * @param context the context + * @return tenantId + */ + public static String getTenantIdFromEnv(Context context) { + String tenantId = ""; + if (context != null) { + tenantId = context.getProjectID(); + } + if (tenantId == null || tenantId.isEmpty()) { + tenantId = Utils.getFromJavaEnv("RUNTIME_PROJECT_ID"); + } + return tenantId; + } + + /** + * getFunctionInfo + * + * @param context context + * @return function info + */ + public static String getFunctionInfo(Context context) { + return String.format(Locale.ROOT, "%s:function:0@%s@%s:%s", getTenantIdFromEnv(context), + getServiceNameFromEnv(context), context.getFunctionName(), context.getVersion()); + } + /** * Get UTC ISO8601 timestamp * diff --git a/api/java/function-common/src/main/java/com/yuanrong/InvokeOptions.java b/api/java/function-common/src/main/java/com/yuanrong/InvokeOptions.java index 27fd9d7..1db24dc 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/InvokeOptions.java +++ b/api/java/function-common/src/main/java/com/yuanrong/InvokeOptions.java @@ -18,6 +18,7 @@ package com.yuanrong; import com.yuanrong.affinity.Affinity; import com.yuanrong.affinity.AffinityKind; +import com.yuanrong.affinity.AffinityScope; import com.yuanrong.affinity.AffinityType; import com.yuanrong.affinity.LabelOperator; import com.yuanrong.runtime.util.Constants; @@ -300,7 +301,7 @@ public class InvokeOptions { * @param val the customExtensions value * @return InvokeOptions Builder class object. */ - public Builder addCustomExtensions(String key, String val) { + public Builder addCustomExtension(String key, String val) { if (Constants.POST_START_EXEC.equals(key)) { options.createOptions.put(key, val); return this; @@ -357,6 +358,20 @@ public class InvokeOptions { return this; } + /** + * set the scheduleAffinities with AffinityKind.INSTANCE and affinityScope + * + * @param type the affinity type + * @param operators the affinity operators + * @param affinityScope the affinity scope + * @return InvokeOptions Builder class object. + */ + public Builder addInstanceAffinity(AffinityType type, List operators, + AffinityScope affinityScope) { + options.scheduleAffinities.add(new Affinity(AffinityKind.INSTANCE, type, operators, affinityScope)); + return this; + } + /** * set the scheduleAffinities with AffinityKind.RESOURCE * diff --git a/api/java/function-common/src/main/java/com/yuanrong/affinity/Affinity.java b/api/java/function-common/src/main/java/com/yuanrong/affinity/Affinity.java index 2755997..b17137e 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/affinity/Affinity.java +++ b/api/java/function-common/src/main/java/com/yuanrong/affinity/Affinity.java @@ -33,6 +33,8 @@ public class Affinity { private List labelOperators; + private AffinityScope affinityScope; + private ResourceAffinity resourceAffinity; private InstanceAffinity instanceAffinity; @@ -50,6 +52,20 @@ public class Affinity { this.labelOperators = labelOperators; } + /** + * Init Affinity with affinityScope + * + * @param affinityKind the affinityKind + * @param affinityType the affinityType + * @param labelOperators the labelOperators + * @param affinityScope the affinityScope + */ + public Affinity(AffinityKind affinityKind, AffinityType affinityType, List labelOperators, + AffinityScope affinityScope) { + this(affinityKind, affinityType, labelOperators); + this.affinityScope = affinityScope; + } + /** * Init Affinity * @@ -90,4 +106,16 @@ public class Affinity { } return this.affinityKind.getKind() + this.affinityType.getType(); } + + /** + * get affinityKind value + * + * @return value + */ + public int getAffinityScopeValue() { + if (this.affinityScope == null) { + return 0; + } + return this.affinityScope.getScope(); + } } diff --git a/api/java/function-common/src/main/java/com/yuanrong/affinity/AffinityScope.java b/api/java/function-common/src/main/java/com/yuanrong/affinity/AffinityScope.java index dc978e4..f6c1608 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/affinity/AffinityScope.java +++ b/api/java/function-common/src/main/java/com/yuanrong/affinity/AffinityScope.java @@ -22,8 +22,8 @@ package com.yuanrong.affinity; * @since 2024-09-11 */ public enum AffinityScope { - NODE(1), - POD(2); + POD(1), + NODE(2); private int scope; diff --git a/api/java/function-common/src/main/java/com/yuanrong/affinity/LabelOperator.java b/api/java/function-common/src/main/java/com/yuanrong/affinity/LabelOperator.java index f99676c..4c26cfd 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/affinity/LabelOperator.java +++ b/api/java/function-common/src/main/java/com/yuanrong/affinity/LabelOperator.java @@ -83,7 +83,8 @@ public class LabelOperator { * @param values the values */ public LabelOperator(OperatorType operatorType, List values) { - this(operatorType, "", values); + this.operatorType = operatorType; + this.values = values; } /** diff --git a/api/java/function-common/src/main/java/com/yuanrong/api/Node.java b/api/java/function-common/src/main/java/com/yuanrong/api/Node.java new file mode 100644 index 0000000..66ec8ab --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/api/Node.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.api; + +import lombok.Data; + +import java.util.List; +import java.util.Map; + +/** + * The node information. + * + * @since 2025/07/28 + */ +@Data +public class Node { + private String id; + private boolean alive; + private Map resources; + private Map> labels; + + public Node() {} + + /** + * Init Node with id, alive, resources and labels. + * + * @param id node id. + * @param alive whether this node is alive. + * @param resources resources of this node. + * @param labels labels of this node. + */ + public Node(String id, boolean alive, Map resources, Map> labels) { + this.id = id; + this.alive = alive; + this.resources = resources; + this.labels = labels; + } +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/errorcode/ErrorCode.java b/api/java/function-common/src/main/java/com/yuanrong/errorcode/ErrorCode.java index 8bed65a..f899b35 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/errorcode/ErrorCode.java +++ b/api/java/function-common/src/main/java/com/yuanrong/errorcode/ErrorCode.java @@ -211,6 +211,23 @@ public class ErrorCode { * ERR_CREATE_RETURN_BUFFER */ public static final ErrorCode ERR_CREATE_RETURN_BUFFER = new ErrorCode(9001); + // ErrorCode provided for the jobExecutor interface: + + /** + * ERR_JOB_USER_CODE_EXCEPTION + */ + public static final ErrorCode ERR_JOB_USER_CODE_EXCEPTION = new ErrorCode(50001); + + /** + * ERR_JOB_RUNTIME_EXCEPTION + */ + public static final ErrorCode ERR_JOB_RUNTIME_EXCEPTION = new ErrorCode(50002); + + /** + * ERR_JOB_INNER_SYSTEM_EXCEPTION + */ + public static final ErrorCode ERR_JOB_INNER_SYSTEM_EXCEPTION = new ErrorCode(50003); + private int code; diff --git a/api/java/function-common/src/main/java/com/yuanrong/exception/handler/traceback/StackTraceUtils.java b/api/java/function-common/src/main/java/com/yuanrong/exception/handler/traceback/StackTraceUtils.java index 3c17967..a173496 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/exception/handler/traceback/StackTraceUtils.java +++ b/api/java/function-common/src/main/java/com/yuanrong/exception/handler/traceback/StackTraceUtils.java @@ -18,6 +18,7 @@ package com.yuanrong.exception.handler.traceback; import com.yuanrong.errorcode.ErrorCode; import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; import com.yuanrong.errorcode.Pair; import com.yuanrong.exception.YRException; @@ -59,10 +60,16 @@ public class StackTraceUtils { private static final String REGEX_BEFORE = "((\\w+\\.?)+\\w+)\\.(\\w+)"; private static final String REGEX_AFTER = "(\\w+\\.java):(\\d+)"; - private static final String AT_PATTERN = "^\\tat.*$"; + private static final String AT_PATTERN = "\tat"; + private static final String SPLIT_SYMBOL = "\\("; + private static final String BRACKET_CLOSE_SYMBOL = ")"; - private static final int METHOD_MATCHER_NUMBER = 3; - private static final int FILE_MATCHER_NUMBER = 2; + + private static final int CLASS_NAME_INDEX = 3; + private static final int FILE_NAME_INDEX = 0; + + private static final char CLASS_METHOD_SEPARATOR = '.'; + private static final char FILE_INFO_SEPARATOR = ':'; /** * Check error and throw. @@ -72,6 +79,10 @@ public class StackTraceUtils { * @throws YRException the YR exception */ public static void checkErrorAndThrowForInvokeException(ErrorInfo errorInfo, String msg) throws YRException { + if (errorInfo == null) { + throw new YRException(ErrorCode.ERR_INNER_SYSTEM_ERROR, ModuleCode.RUNTIME, + "unknown exception occurred, errorInfo is null, msg: " + msg); + } if (errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { return; } @@ -80,12 +91,12 @@ public class StackTraceUtils { List stackTraceInfos = errorInfo.getStackTraceInfos(); LOGGER.error("occurs exception: {}, ErrorCode:{}, stackTraceInfo number:{} ", msg, errorInfo.getErrorCode(), stackTraceInfos.size()); - if (stackTraceInfos.size() == 0) { + if (stackTraceInfos.isEmpty()) { throw new YRException(errorInfo); } Exception exception = fromStackTraceInfoListToException(stackTraceInfos); - throw new YRException(errorInfo.getErrorCode(), - errorInfo.getModuleCode(), errorInfo.getErrorMessage(), exception); + throw new YRException(errorInfo.getErrorCode(), errorInfo.getModuleCode(), + errorInfo.getErrorMessage(), exception); } else { // process exception of yuanrong throw new YRException(errorInfo); @@ -110,7 +121,7 @@ public class StackTraceUtils { return exp; } catch (ClassNotFoundException | NoSuchMethodException | SecurityException | InstantiationException | IllegalAccessException | InvocationTargetException e) { - throw new RuntimeException("Unable to find " + className + ", messsage:" + message, e); + throw new RuntimeException("Unable to find " + className + ", message:" + message, e); } } @@ -163,6 +174,13 @@ public class StackTraceUtils { String[] strArray = printedStackTrace.split(System.lineSeparator()); + /** + * sometimes there is no caused by information there + */ + /** + * java.io.IOException: Error creating file + * at com.example.CalleeC.throwUncheckedException(CalleeC.java:9) + */ int suppressedExceptionBeginIndex = getSuppressedExceptionBeginIndex(strArray, printedStackTrace); Pair returnValue = getRuntimeExceptionBeginIndex(strArray); int rtExceptionIndex = returnValue.getFirst(); @@ -176,15 +194,12 @@ public class StackTraceUtils { suppressedExceptionBeginIndex = 1; } - Pattern pattern = Pattern.compile(AT_PATTERN, Pattern.MULTILINE); - if (rtExceptionIndex == strArray.length - RUNTIME_EXCEPTION_EXECUTE_COUNT) { // RUNTIME_EXCEPTION_BEGIN_TAG LOGGER.debug("rtExceptionIndex == strArray.length - RUNTIME_EXCEPTION_EXECUTE_COUNT : {}, {}", rtExceptionIndex, strArray.length - RUNTIME_EXCEPTION_EXECUTE_COUNT); for (int j = suppressedExceptionBeginIndex; j < rtExceptionIndex; j++) { - Matcher matcher = pattern.matcher(strArray[j]); - if (!strArray[j].contains(RUNTIME_CLASS_PATH) && matcher.find()) { + if (!strArray[j].contains(RUNTIME_CLASS_PATH) && strArray[j].startsWith(AT_PATTERN)) { result.add(stringToStackTraceEle(strArray[j].trim())); } } @@ -198,23 +213,20 @@ public class StackTraceUtils { end = strArray.length; } for (int j = suppressedExceptionBeginIndex; j < end; j++) { - Matcher matcher = pattern.matcher(strArray[j]); - if (!strArray[j].contains(RUNTIME_CLASS_PATH) && matcher.find()) { + if (!strArray[j].contains(RUNTIME_CLASS_PATH) && strArray[j].startsWith(AT_PATTERN)) { result.add(stringToStackTraceEle(strArray[j].trim())); } } } else { LOGGER.debug("rtExceptionIndex and strArray.length : {}, {}", rtExceptionIndex, strArray.length); for (int j = suppressedExceptionBeginIndex; j < rtExceptionIndex; j++) { - Matcher matcher = pattern.matcher(strArray[j]); - if (!strArray[j].contains(RUNTIME_CLASS_PATH) && matcher.find()) { + if (!strArray[j].contains(RUNTIME_CLASS_PATH) && strArray[j].startsWith(AT_PATTERN)) { result.add(stringToStackTraceEle(strArray[j].trim())); } } for (int j = 1; j < rtExceptionIndex; j++) { - Matcher matcher = pattern.matcher(strArray[j]); - if (matcher.find()) { + if (strArray[j].startsWith(AT_PATTERN)) { result.add(stringToStackTraceEle(strArray[j].trim())); } } @@ -231,26 +243,29 @@ public class StackTraceUtils { * @return the stack trace element */ public static StackTraceElement stringToStackTraceEle(String elementStr) { - String[] splits = elementStr.split("\\("); + // Example: at java.lang.reflect.Method.invoke(Method.java:498) + String[] splits = elementStr.split(SPLIT_SYMBOL); if (splits.length < 2) { throw new RuntimeException("exception happened while parsing staceTraceElement string"); } - Matcher methodMatcher = getStackMatcher(REGEX_BEFORE, splits[0]); - Matcher fileMatcher = getStackMatcher(REGEX_AFTER, splits[1]); + try { + // Parse method information. + String methodInfo = splits[0].trim(); + int classNameEndIndex = methodInfo.lastIndexOf(CLASS_METHOD_SEPARATOR); + String classInfo = methodInfo.substring(CLASS_NAME_INDEX, classNameEndIndex); + String methodName = methodInfo.substring(classNameEndIndex + 1); - if (methodMatcher.groupCount() < METHOD_MATCHER_NUMBER || fileMatcher.groupCount() < FILE_MATCHER_NUMBER) { - throw new RuntimeException("exception happened while matching method and file info"); - } + // Parse file information. + String fileInfo = splits[1].replace(BRACKET_CLOSE_SYMBOL, "").trim(); + int fileNameEndIndex = fileInfo.lastIndexOf(FILE_INFO_SEPARATOR); + String fileName = fileInfo.substring(FILE_NAME_INDEX, fileNameEndIndex); + int lineNumber = Integer.parseInt(fileInfo.substring(fileNameEndIndex + 1)); - StackTraceElement stackTrace = null; - if (methodMatcher.find() && fileMatcher.find()) { - String className = methodMatcher.group(1); - String methodName = methodMatcher.group(3); - String fileName = fileMatcher.group(1); - int lineNumber = Integer.parseInt(fileMatcher.group(2)); - stackTrace = new StackTraceElement(className, methodName, fileName, lineNumber); + // Create StackTraceElement object. + return new StackTraceElement(classInfo, methodName, fileName, lineNumber); + } catch (StringIndexOutOfBoundsException | NumberFormatException e) { + return null; } - return stackTrace; } private static Matcher getStackMatcher(String regex, String stackStr) { diff --git a/api/java/function-common/src/main/java/com/yuanrong/jni/JniConsumer.java b/api/java/function-common/src/main/java/com/yuanrong/jni/JniConsumer.java new file mode 100644 index 0000000..3e54397 --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/jni/JniConsumer.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jni; + +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.stream.Element; + +import java.util.List; + +/** + * NativeConsumer represents methods of libRuntime + * + * @since 2024-09-04 + */ +public class JniConsumer { + /** + * receive + * + * @param consumerPtr the consumerPtr + * @param expectNum the expectNum + * @param timeoutMs the timeoutMs + * @param hasExpectedNum the hasExpectedNum + * @return pair + */ + public static native Pair> receive(long consumerPtr, long expectNum, int timeoutMs, + boolean hasExpectedNum); + + /** + * ack + * + * @param consumerPtr the consumerPtr + * @param elementId the elementId + * @return ErrorInfo + */ + public static native ErrorInfo ack(long consumerPtr, long elementId); + + /** + * close + * + * @param consumerPtr the consumerPtr + * @return ErrorInfo + */ + public static native ErrorInfo close(long consumerPtr); + + /** + * freeJNIPtrNative + * + * @param clientPtr the clientPtr + * @return ErrorInfo + */ + public static native ErrorInfo freeJNIPtrNative(long clientPtr); +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/jni/JniProducer.java b/api/java/function-common/src/main/java/com/yuanrong/jni/JniProducer.java new file mode 100644 index 0000000..be323b6 --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/jni/JniProducer.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jni; + +import com.yuanrong.errorcode.ErrorInfo; + +import java.nio.ByteBuffer; + +/** + * NativeProducer represents methods of libRuntime + * + * @since 2024-09-04 + */ +public class JniProducer { + /** + * sendHeapBufferDefaultTimeout + * + * @param producerPtr producerPtr + * @param bytes bytes + * @param len len + * @return ErrorInfo + */ + public static native ErrorInfo sendHeapBufferDefaultTimeout(long producerPtr, byte[] bytes, long len); + + /** + * sendDirectBufferDefaultTimeout + * + * @param producerPtr producerPtr + * @param buffers buffers + * @return ErrorInfo + */ + public static native ErrorInfo sendDirectBufferDefaultTimeout(long producerPtr, ByteBuffer buffers); + + /** + * sendHeapBuffer + * + * @param producerPtr producerPtr + * @param bytes bytes + * @param len len + * @param timeoutMs timeoutMs + * @return ErrorInfo + */ + public static native ErrorInfo sendHeapBuffer(long producerPtr, byte[] bytes, long len, int timeoutMs); + + /** + * sendDirectBuffer + * + * @param producerPtr producerPtr + * @param buffers buffers + * @param timeoutMs timeoutMs + * @return ErrorInfo + */ + public static native ErrorInfo sendDirectBuffer(long producerPtr, ByteBuffer buffers, int timeoutMs); + + /** + * flush + * + * @param producerPtr producerPtr + * @return ErrorInfo + */ + public static native ErrorInfo flush(long producerPtr); + + /** + * close + * + * @param producerPtr producerPtr + * @return ErrorInfo + */ + public static native ErrorInfo close(long producerPtr); + + /** + * freeJNIPtrNative + * + * @param producerPtr producerPtr + */ + public static native void freeJNIPtrNative(long producerPtr); +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/jni/LibRuntime.java b/api/java/function-common/src/main/java/com/yuanrong/jni/LibRuntime.java index ec6ce0a..43df850 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/jni/LibRuntime.java +++ b/api/java/function-common/src/main/java/com/yuanrong/jni/LibRuntime.java @@ -23,6 +23,7 @@ import com.yuanrong.InvokeOptions; import com.yuanrong.MSetParam; import com.yuanrong.SetParam; import com.yuanrong.api.InvokeArg; +import com.yuanrong.api.Node; import com.yuanrong.errorcode.ErrorInfo; import com.yuanrong.errorcode.Pair; import com.yuanrong.exception.LibRuntimeException; @@ -30,6 +31,8 @@ import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; import com.yuanrong.libruntime.generated.Socket.FunctionLog; import com.yuanrong.runtime.config.RuntimeContext; import com.yuanrong.storage.InternalWaitResult; +import com.yuanrong.stream.ProducerConfig; +import com.yuanrong.stream.SubscriptionType; import java.util.List; @@ -84,9 +87,10 @@ public class LibRuntime { * @param args invoke arguments * @param opt invoke options * @return Pair of ErrorInfo and String + * @throws LibRuntimeException the LibRuntimeException */ public static native Pair InvokeInstance(FunctionMeta funcMeta, String instanceId, - List args, InvokeOptions opt); + List args, InvokeOptions opt) throws LibRuntimeException; /** * Get invoke results with ObjectRefIds @@ -107,8 +111,9 @@ public class LibRuntime { * @param data the data * @param nestObjIds the nestObjIds * @return Pair of ErrorInfo and String + * @throws LibRuntimeException the LibRuntimeException */ - public static native Pair Put(byte[] data, List nestObjIds); + public static native Pair Put(byte[] data, List nestObjIds) throws LibRuntimeException; /** * Native method for put data @@ -117,9 +122,10 @@ public class LibRuntime { * @param nestObjIds the nestObjIds * @param createParam create param of datasystem * @return Pair of ErrorInfo and String + * @throws LibRuntimeException the LibRuntimeException */ public static native Pair PutWithParam(byte[] data, List nestObjIds, - CreateParam createParam); + CreateParam createParam) throws LibRuntimeException; /** * Native method for Wait @@ -128,8 +134,10 @@ public class LibRuntime { * @param waitNum the waitNum * @param timeoutSec the timeoutSec * @return InternalWaitResult + * @throws LibRuntimeException the LibRuntimeException */ - public static native InternalWaitResult Wait(List ids, int waitNum, int timeoutSec); + public static native InternalWaitResult Wait(List ids, int waitNum, int timeoutSec) + throws LibRuntimeException; /** * Native method for ReceiveRequestLoop @@ -162,7 +170,6 @@ public class LibRuntime { /** * Native method for setRuntimeContext - * * @param jobID jobID */ public static native void setRuntimeContext(String jobID); @@ -172,8 +179,9 @@ public class LibRuntime { * * @param instanceId the instanceId * @return ErrorInfo + * @throws LibRuntimeException the LibRuntimeException */ - public static native ErrorInfo Kill(String instanceId); + public static native ErrorInfo Kill(String instanceId) throws LibRuntimeException; /** * Native method for autoInitYR @@ -189,8 +197,9 @@ public class LibRuntime { * @param objectID The ID of the object whose real instance ID is to be * retrieved * @return The real instance ID of the object as a String + * @throws LibRuntimeException the LibRuntimeException */ - public static native String GetRealInstanceId(String objectID); + public static native String GetRealInstanceId(String objectID) throws LibRuntimeException; /** * Native method for saving the real instance ID of an object. @@ -198,8 +207,10 @@ public class LibRuntime { * @param objectID the object id * @param instanceID the instance id * @param opts the invoke options + * @throws LibRuntimeException the LibRuntimeException */ - public static native void SaveRealInstanceId(String objectID, String instanceID, InvokeOptions opts); + public static native void SaveRealInstanceId(String objectID, String instanceID, InvokeOptions opts) + throws LibRuntimeException; /** * Native method for writing key-value pairs to a key-value store. @@ -210,8 +221,9 @@ public class LibRuntime { * ttlSecond, existence and cacheType. * @return An ErrorInfo object containing information about any errors that * occurred during the write operation. + * @throws LibRuntimeException the LibRuntimeException */ - public static native ErrorInfo KVWrite(String key, byte[] value, SetParam setParam); + public static native ErrorInfo KVWrite(String key, byte[] value, SetParam setParam) throws LibRuntimeException; /** * Native method for writing key-value pairs to a key-value store. @@ -222,8 +234,10 @@ public class LibRuntime { * ttlSecond, existence and cacheType. * @return An ErrorInfo object containing information about any errors that * occurred during the write operation. + * @throws LibRuntimeException the LibRuntimeException */ - public static native ErrorInfo KVMSetTx(List keys, List values, MSetParam mSetParam); + public static native ErrorInfo KVMSetTx(List keys, List values, MSetParam mSetParam) + throws LibRuntimeException; /** * Native method for reading a single key-value pair. @@ -232,8 +246,10 @@ public class LibRuntime { * @param timeoutMS The maximum amount of time in milliseconds to wait for the read operation * @return A Pair object containing the byte array value associated with the * key, and an ErrorInfo object if there was an error. + * @throws LibRuntimeException the LibRuntimeException */ - public static native Pair KVRead(String key, int timeoutMS); + public static native Pair KVRead(String key, int timeoutMS) + throws LibRuntimeException; /** * Native method for reading multiple key-value pairs. @@ -245,8 +261,10 @@ public class LibRuntime { * results if some of the keys are not found. * @return A Pair object containing a list of byte array values associated with * the keys, and an ErrorInfo object if there was an error. + * @throws LibRuntimeException the LibRuntimeException */ - public static native Pair, ErrorInfo> KVRead(List keys, int timeoutMS, boolean allowPartial); + public static native Pair, ErrorInfo> KVRead(List keys, int timeoutMS, boolean allowPartial) + throws LibRuntimeException; /** * Native method for reading multiple key-value pairs with get params. @@ -257,9 +275,10 @@ public class LibRuntime { * respond. * @return A Pair object containing a list of byte array values associated with * the keys, and an ErrorInfo object if there was an error. + * @throws LibRuntimeException the LibRuntimeException */ public static native Pair, ErrorInfo> KVGetWithParam(List keys, - GetParams params, int timeoutMS); + GetParams params, int timeoutMS) throws LibRuntimeException; /** * Native method for deleting the value associated with the given key from the @@ -268,8 +287,9 @@ public class LibRuntime { * @param key The key of the value to be deleted * @return An ErrorInfo object indicating the success or failure of the * operation + * @throws LibRuntimeException the LibRuntimeException */ - public static native ErrorInfo KVDel(String key); + public static native ErrorInfo KVDel(String key) throws LibRuntimeException; /** * Native method for deleting the values associated with the given keys from the @@ -279,24 +299,27 @@ public class LibRuntime { * @return A Pair object containing a List of keys that were successfully * deleted and an instance of ErrorInfo indicating the success or * failure of the operation + * @throws LibRuntimeException the LibRuntimeException */ - public static native Pair, ErrorInfo> KVDel(List keys); + public static native Pair, ErrorInfo> KVDel(List keys) throws LibRuntimeException; /** * Native method for SaveState * * @param timeoutMs the timeoutMs * @return An ErrorInfo object indicating the success or failure of the operation + * @throws LibRuntimeException the LibRuntimeException */ - public static native ErrorInfo SaveState(int timeoutMs); + public static native ErrorInfo SaveState(int timeoutMs) throws LibRuntimeException; /** * Native method for LoadState * * @param timeoutMs the timeoutMs * @return An ErrorInfo object indicating the success or failure of the operation + * @throws LibRuntimeException the LibRuntimeException */ - public static native ErrorInfo LoadState(int timeoutMs); + public static native ErrorInfo LoadState(int timeoutMs) throws LibRuntimeException; /** * Native method for GroupCreate @@ -304,8 +327,9 @@ public class LibRuntime { * @param groupName the groupName * @param opts the opts * @return An ErrorInfo object indicating the success or failure of the operation + * @throws LibRuntimeException the LibRuntimeException */ - public static native ErrorInfo GroupCreate(String groupName, GroupOptions opts); + public static native ErrorInfo GroupCreate(String groupName, GroupOptions opts) throws LibRuntimeException; /** * Native method for GroupTerminate @@ -320,8 +344,9 @@ public class LibRuntime { * * @param groupName the groupName * @return An ErrorInfo object indicating the success or failure of the operation + * @throws LibRuntimeException the LibRuntimeException */ - public static native ErrorInfo GroupWait(String groupName); + public static native ErrorInfo GroupWait(String groupName) throws LibRuntimeException; /** * Native method for processLog @@ -331,6 +356,64 @@ public class LibRuntime { */ public static native ErrorInfo processLog(FunctionLog functionLog); + /** + * Use jni to create producer. + * + * @param streamName The name of the stream. + * @param delayFlushTimeMs The time used in automatic flush after send. + * @param pageSizeByte The size used in allocate page. + * @param maxStreamSize The max stream size in worker. + * @param shouldCleanup Should auto delete when last producer/consumer exit + * @param encryptStream The encrypt stream + * @param retainForNumConsumers The retain for num consumers + * @param reserveSize The reserve size + * @return The Producer pointer. + * @throws LibRuntimeException if there is an exception during creating stream producer + */ + public static native long CreateStreamProducer(String streamName, long delayFlushTimeMs, + long pageSizeByte, long maxStreamSize, boolean shouldCleanup, boolean encryptStream, + long retainForNumConsumers, long reserveSize) + throws LibRuntimeException; + + /** + * Use jni to subscribe a new consumer onto master request + * + * @param streamName The name of the stream. + * @param subName The name of subscription. + * @param subscriptionType The type of SubscriptionType. + * @param shouldAutoAck Should AutoAck be enabled for this subscriber or not. + * @return The Consumer pointer. + * @throws LibRuntimeException if there is an exception during creating stream consumer + */ + public static native long CreateStreamConsumer(String streamName, String subName, + SubscriptionType subscriptionType, boolean shouldAutoAck) throws LibRuntimeException; + + /** + * Use jni to delete the stream + * + * @param streamName The name of the target stream. + * @return ErrorInfo + */ + public static native ErrorInfo DeleteStream(String streamName); + + /** + * Use jni to query numbers of producer in global worker node + * + * @param streamName The name of the target stream. + * @return The query result. + * @throws LibRuntimeException if there is an exception during quering global producersNum + */ + public static native long QueryGlobalProducersNum(String streamName) throws LibRuntimeException; + + /** + * Use jni to query numbers of consumer in global worker node + * + * @param streamName The name of the target stream. + * @return The query result. + * @throws LibRuntimeException if there is an exception during quering global consumersNum + */ + public static native long QueryGlobalConsumersNum(String streamName) throws LibRuntimeException; + /** * Native method for DecreaseReference * @@ -338,21 +421,34 @@ public class LibRuntime { */ public static native void DecreaseReference(List ids); + /** + * Use jni to create producer. + * + * @param streamName The name of the stream. + * @param producerConfig The producer config + * @return The Producer pointer. + * @throws LibRuntimeException if there is an exception during creating stream producer + */ + public static native long CreateStreamProducerWithConfig(String streamName, ProducerConfig producerConfig) + throws LibRuntimeException; + /** * Native method for retrieving the instance route of an object. * * @param objectID The ID of the object whose instance route is to be retrieved * @return The instance route of the object as a String + * @throws LibRuntimeException the LibRuntimeException */ - public static native String GetInstanceRoute(String objectID); + public static native String GetInstanceRoute(String objectID) throws LibRuntimeException; /** * Native method for saving the instance route of an object. * * @param objectID the object id * @param instanceRoute the instance route + * @throws LibRuntimeException the LibRuntimeException */ - public static native void SaveInstanceRoute(String objectID, String instanceRoute); + public static native void SaveInstanceRoute(String objectID, String instanceRoute) throws LibRuntimeException; /** * Native method for Kill instance sync @@ -361,4 +457,12 @@ public class LibRuntime { * @return ErrorInfo */ public static native ErrorInfo KillSync(String instanceId); + + /** + * Get node information in the cluster. + * + * @return Pair of ErrorInfo and list of node information. + * @throws LibRuntimeException the LibRuntimeException. + */ + public static native Pair> nodes() throws LibRuntimeException; } diff --git a/api/java/function-common/src/main/java/com/yuanrong/jni/LibRuntimeConfig.java b/api/java/function-common/src/main/java/com/yuanrong/jni/LibRuntimeConfig.java index 937d526..2074d58 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/jni/LibRuntimeConfig.java +++ b/api/java/function-common/src/main/java/com/yuanrong/jni/LibRuntimeConfig.java @@ -59,7 +59,10 @@ public class LibRuntimeConfig { private String certificateFilePath; private String privateKeyPath; private String verifyFilePath; + private String privateKeyPaaswd; private String serverName; + private int httpIocThreadsNum; + private int httpIdleTime; private boolean inCluster = true; private int rpcTimeout = 60; private String tenantId = ""; diff --git a/api/java/function-common/src/main/java/com/yuanrong/jni/LoadUtil.java b/api/java/function-common/src/main/java/com/yuanrong/jni/LoadUtil.java index 7fc262c..70a26b8 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/jni/LoadUtil.java +++ b/api/java/function-common/src/main/java/com/yuanrong/jni/LoadUtil.java @@ -192,6 +192,9 @@ public class LoadUtil { } if (outChannel.size() == 0) { copyJarSoToLocal(soFilePath, outChannel); + if (!localSoFile.setReadOnly()) { + LOGGER.warn("set file: {} read permission failed.", localSoFile.getAbsolutePath()); + } System.load(localSoFile.getCanonicalPath()); return true; } @@ -205,6 +208,9 @@ public class LoadUtil { try (FileChannel outChannel = FileChannel.open(tempSoFile.toPath(), StandardOpenOption.WRITE, StandardOpenOption.APPEND)) { copyJarSoToLocal(soFilePath, outChannel); + if (!tempSoFile.setReadOnly()) { + LOGGER.warn("set file: {} read permission failed.", tempSoFile.getAbsolutePath()); + } System.load(tempSoFile.getCanonicalPath()); } } diff --git a/api/java/function-common/src/main/java/com/yuanrong/runtime/client/ObjectRef.java b/api/java/function-common/src/main/java/com/yuanrong/runtime/client/ObjectRef.java index 1ae09ab..119e799 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/runtime/client/ObjectRef.java +++ b/api/java/function-common/src/main/java/com/yuanrong/runtime/client/ObjectRef.java @@ -69,6 +69,11 @@ public class ObjectRef { */ private final String objectID; + /** + * Decrease reference flag of ObjectRef. + */ + private boolean isReleased = false; + /** * The constructor for ObjectRef. * @@ -162,8 +167,16 @@ public class ObjectRef { @Override protected void finalize() throws Throwable { - if (LibRuntime.IsInitialized()) { + release(); + } + + /** + * Release the ObjectRef, decrease reference. + */ + public void release() { + if (!isReleased && LibRuntime.IsInitialized()) { LibRuntime.DecreaseReference(Collections.singletonList(this.objectID)); + isReleased = true; } } } diff --git a/api/java/function-common/src/main/java/com/yuanrong/runtime/util/Constants.java b/api/java/function-common/src/main/java/com/yuanrong/runtime/util/Constants.java index 27aecb1..1e78af4 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/runtime/util/Constants.java +++ b/api/java/function-common/src/main/java/com/yuanrong/runtime/util/Constants.java @@ -16,6 +16,8 @@ package com.yuanrong.runtime.util; +import java.util.regex.Pattern; + /** * Description: Constants * @@ -87,6 +89,11 @@ public final class Constants { */ public static final int DEFAULT_HTTP_IO_THREAD_CNT = 100; + /** + * idle time for http client + */ + public static final int DEFAULT_HTTP_IDLE_TIME = 120; + /** * The key to python packages, which are going to be installed in remote runtime. */ @@ -152,6 +159,10 @@ public final class Constants { */ public static final String CONCURRENCY = "Concurrency"; + /** + * PATTERN_FAAS_ENTRY + */ + public static final Pattern PATTERN_FAAS_ENTRY = Pattern.compile("^[^/]*.[^/]*$"); /** * The constant KEY_USER_INIT_ENTRY. diff --git a/api/java/function-common/src/main/java/com/yuanrong/runtime/util/ExtClasspathLoader.java b/api/java/function-common/src/main/java/com/yuanrong/runtime/util/ExtClasspathLoader.java index 4cd0cb4..7e1d50f 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/runtime/util/ExtClasspathLoader.java +++ b/api/java/function-common/src/main/java/com/yuanrong/runtime/util/ExtClasspathLoader.java @@ -63,11 +63,12 @@ public final class ExtClasspathLoader { * load jar classpath。 * * @param codePaths the codePaths + * @throws IOException throw load user function Failed * @throws InvocationTargetException throw load user function Failed * @throws IllegalAccessException throw load user function Failed */ public static void loadClasspath(List codePaths) - throws InvocationTargetException, IllegalAccessException { + throws IOException, InvocationTargetException, IllegalAccessException { codePaths.addAll(getJarFiles()); LOG.debug("function lib path: {}", codePaths); for (String filepath : codePaths) { diff --git a/api/java/function-common/src/main/java/com/yuanrong/runtime/util/FuncClassLoader.java b/api/java/function-common/src/main/java/com/yuanrong/runtime/util/FuncClassLoader.java index c2035b6..047f8ae 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/runtime/util/FuncClassLoader.java +++ b/api/java/function-common/src/main/java/com/yuanrong/runtime/util/FuncClassLoader.java @@ -44,6 +44,7 @@ public class FuncClassLoader extends URLClassLoader { add("com.services"); add("com.function"); add("com.datasystem"); + add("com.faas"); add("com.google.protobuf"); add("com.google.gson"); } @@ -64,7 +65,7 @@ public class FuncClassLoader extends URLClassLoader { * Bootstrap ClassLoader * | * Ext Classloader - * not start with / \ com.yuanrong.runtime + * not start with / \ start whit com.faas || com.yuanrong.runtime * FuncClassLoader SystemClassLoader * | | * SystemClassLoader FuncClassLoader diff --git a/api/java/function-common/src/main/java/com/yuanrong/runtime/util/Utils.java b/api/java/function-common/src/main/java/com/yuanrong/runtime/util/Utils.java index 55c417e..a4d95ce 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/runtime/util/Utils.java +++ b/api/java/function-common/src/main/java/com/yuanrong/runtime/util/Utils.java @@ -18,8 +18,10 @@ package com.yuanrong.runtime.util; +import com.yuanrong.api.InvokeArg; import com.yuanrong.errorcode.ErrorCode; import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; import com.yuanrong.exception.YRException; import com.yuanrong.exception.HandlerNotAvailableException; import com.yuanrong.serialization.Serializer; @@ -33,9 +35,12 @@ import org.objectweb.asm.Type; import java.io.IOException; import java.lang.reflect.Constructor; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -208,15 +213,14 @@ public class Utils { } } if (Objects.isNull(parameterTypes)) { - String uClassName = uClass.getName(); String errorMsg = new StringBuilder("Failed to find user-definded method: [") - .append(uClassName) + .append(uClass.getName()) .append(".") .append(methodName) .append(" signature: ") .append(methodSignature) .append("] from class ") - .append(uClassName) + .append(uClass.getName()) .append(" loaded in runtime.") .toString(); throw new IllegalArgumentException(errorMsg); @@ -305,6 +309,43 @@ public class Utils { return expMsg; } + /** + * The put method retry 3 times. + * + * @param objectId the object ref ids + * @param putData the id of object ref + * @param nestedObjectIds the nestedObjectIds that current ObjectRef depends on. + * @param client dataSystem client + * @param interval interval time to continuously put when spill may occur in milliseconds + * @param retryTime retry time to continuously put when spill may occur + * @return the boolean + */ + public static boolean put( + String objectId, + ByteBuffer putData, + List nestedObjectIds, + Object client, + long interval, + long retryTime) { + return true; + } + + /** + * Get list retry 3 times. + * + * @param ids the object ref ids + * @param timeout the timeout + * @param client the client + * @return Result : the list which contains all objects got from DataSystem If + * some objectRef failed to get its object, elements in the corresponding positions in + * the Result is null + * @throws YRException exception + */ + public static List get(List ids, int timeout, Object client) throws YRException { + throw new YRException(ErrorCode.ERR_GET_OPERATION_FAILED, ModuleCode.DATASYSTEM, + "failed to retry get from dataSystem"); + } + /** * Sleep seconds. * @@ -352,4 +393,118 @@ public class Utils { throw new YRException(errorInfo); } } + + /** + * getFromJavaEnv + * + * @param key the key + * @return String + */ + public static String getFromJavaEnv(String key) { + String res = ""; + String filedName = "theUnmodifiableEnvironment"; + try { + Class cls = Class.forName("java.lang.ProcessEnvironment"); + // get field and access + Field oldFiled = cls.getDeclaredField(filedName); + oldFiled.setAccessible(true); + // get Filed map + Object map = oldFiled.get(null); + Class unmodifiableMap = Class.forName("java.util.Collections$UnmodifiableMap"); + Field field = unmodifiableMap.getDeclaredField("m"); + field.setAccessible(true); + Object obj = field.get(map); + res = ((Map) obj).get(key); + } catch (ReflectiveOperationException e) { + LOGGER.error("get field: {} has an error: {}", filedName, e); + } + return res; + } + + /** + * Pack the arguments to be invoked into a list of InvokeArg objects. + * + * @param args The arguments to be packed. + * @return A list of InvokeArg objects. + */ + public static List packFaasInvokeArgs(Object... args) { + List invokeArgs = new ArrayList(); + for (Object arg : args) { + InvokeArg invokeArg; + invokeArg = new InvokeArg(gson.toJson(arg).getBytes(StandardCharsets.UTF_8)); + invokeArg.setObjectRef(false); + invokeArg.setNestedObjects(new HashSet<>()); + invokeArgs.add(invokeArg); + } + return invokeArgs; + } + + /** + * get method of class + * + * @param clazz Class + * @param methodName String + * @return Method + * @throws NoSuchMethodException Exception + */ + public static Method getMethod(Class clazz, String methodName) throws NoSuchMethodException { + Method specifiedMethod = null; + Method[] methods = clazz.getDeclaredMethods(); + for (Method method : methods) { + if (methodName.equals(method.getName())) { + specifiedMethod = method; + break; + } + } + if (specifiedMethod == null) { + throw new NoSuchMethodException("cannot find such method: " + methodName); + } + return specifiedMethod; + } + + /** + * Get User Code Entry class and method + * + * @param userCodeEntry user code string + * @param isInitialize isInitialize + * @return [class, method] + * @throws NoSuchMethodException user code not found class or entry method + */ + public static String[] splitUserClassAndMethod(String userCodeEntry, boolean isInitialize) + throws NoSuchMethodException { + if (userCodeEntry == null || userCodeEntry.isEmpty()) { + throw new NoSuchMethodException(USER_CODE_CLASS_NOT_FOUND); + } + return splitEntryWithSeparators(userCodeEntry, + new String[] {DOUBLE_COLON_SEPARATOR, String.valueOf(DOT_SEPARATOR)}); + } + + /** + * Get User Code Entry class and method with separator + * + * @param userCodeEntry user code string + * @param separator separator in [., ::] + * @return [class, method] + * @throws NoSuchMethodException user code not found separator + */ + public static String[] splitEntryWithSeparators(String userCodeEntry, String[] separator) + throws NoSuchMethodException { + for (String sep : separator) { + int lastIndex = userCodeEntry.lastIndexOf(sep); + if (lastIndex == -1) { + continue; + } + if (String.valueOf(DOT_SEPARATOR).equals(sep)) { + return new String[]{userCodeEntry.substring(0, lastIndex), + userCodeEntry.substring(lastIndex + 1)}; + } else if (DOUBLE_COLON_SEPARATOR.equals(sep)) { + // classPath::Method lastIndex must plus 2 + return new String[]{userCodeEntry.substring(0, lastIndex), + userCodeEntry.substring(lastIndex + 2)}; + } else { + throw new NoSuchMethodException("user class and method separator invalid"); + } + } + throw new NoSuchMethodException("cannot separate user entry: " + userCodeEntry); + } } \ No newline at end of file diff --git a/api/java/function-common/src/main/java/com/yuanrong/serialization/strategy/Strategy.java b/api/java/function-common/src/main/java/com/yuanrong/serialization/strategy/Strategy.java index 4755c53..b0d5ca0 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/serialization/strategy/Strategy.java +++ b/api/java/function-common/src/main/java/com/yuanrong/serialization/strategy/Strategy.java @@ -94,7 +94,8 @@ public class Strategy { * @param getRes the byte array list to retrieve objects from * @param refs the object reference list to use for deserialization * @return a list of deserialized objects - * @throws YRException if there is an error retrieving or deserializing the objects. + * @throws YRException if there is an error retrieving or deserializing + * the objects */ public static List getObjects(List getRes, List refs) throws YRException { LOGGER.debug("getting objects {}", refs.size()); diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/Consumer.java b/api/java/function-common/src/main/java/com/yuanrong/stream/Consumer.java new file mode 100644 index 0000000..730d80b --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/Consumer.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import com.yuanrong.exception.YRException; + +import java.util.List; + +/** + * Consumer interface class. + * + * @class Consumer Consumer.java "function-common/src/main/java/com/yuanrong/stream/Consumer.java". + * + * @since 2024/09/04 + */ +public interface Consumer { + /** + * The consumer receives data with a subscription function. The consumer waits for expectNum elements. + * The call returns when the timeout time timeoutMs is reached or the expected number of data is received. + * + * @param expectNum The number of elements expected to be received. + * @param timeoutMs Timeout for receiving. + * @return List A list of Elements that store data. + * @throws YRException Unified exception types thrown. + */ + List receive(long expectNum, int timeoutMs) throws YRException; + + /** + * The consumer receives data with a subscription function. The call returns when the timeout time timeoutMs is + * reached. + * + * @param timeoutMs Timeout for receiving. + * @return List A list of Elements that store data. + * @throws YRException Unified exception types thrown. + */ + List receive(int timeoutMs) throws YRException; + + /** + * After a consumer finishes using an element identified by a certain elementId, it needs to confirm that it has + * finished consuming, so that each worker can obtain information on whether all consumers have finished consuming. + * If a certain page has been consumed, the internal memory recovery mechanism can be triggered. If not ack it will + * be automatically ack when the consumer exits. + * + * @param elementId The id of the consumed element to be confirmed. + * @throws YRException Unified exception types thrown. + */ + void ack(long elementId) throws YRException; + + /** + * Close the consumer. Once closed, the consumer cannot be used. + * + * @throws YRException Unified exception types thrown. + */ + void close() throws YRException; +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/ConsumerImpl.java b/api/java/function-common/src/main/java/com/yuanrong/stream/ConsumerImpl.java new file mode 100644 index 0000000..8191404 --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/ConsumerImpl.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.exception.YRException; +import com.yuanrong.exception.handler.traceback.StackTraceUtils; +import com.yuanrong.jni.JniConsumer; + +import java.util.List; +import java.util.Objects; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * The consumer implement, visible in this package. + * + * @since 2024-09-04 + */ +public class ConsumerImpl implements Consumer { + // for consumerPtr. + private final ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock(); + + private final Lock rLock = rwLock.readLock(); + + private final Lock wLock = rwLock.writeLock(); + + // Point to jni consumer object. + private long consumerPtr; + + /** + * set the consumerPtr + * + * @param consumerPtr the consumerPtr + */ + public ConsumerImpl(long consumerPtr) { + this.consumerPtr = consumerPtr; + } + + @Override + public List receive(long expectNum, int timeoutMs) throws YRException { + rLock.lock(); + try { + ensureOpen(); + Pair> res = JniConsumer.receive(consumerPtr, expectNum, timeoutMs, true); + StackTraceUtils.checkErrorAndThrowForInvokeException(res.getFirst(), res.getFirst().getErrorMessage()); + return res.getSecond(); + } finally { + rLock.unlock(); + } + } + + @Override + public List receive(int timeoutMs) throws YRException { + rLock.lock(); + try { + ensureOpen(); + Pair> res = JniConsumer.receive(consumerPtr, 0L, timeoutMs, false); + StackTraceUtils.checkErrorAndThrowForInvokeException(res.getFirst(), res.getFirst().getErrorMessage()); + return res.getSecond(); + } finally { + rLock.unlock(); + } + } + + @Override + public void ack(long elementId) throws YRException { + rLock.lock(); + try { + ensureOpen(); + ErrorInfo err = JniConsumer.ack(consumerPtr, elementId); + StackTraceUtils.checkErrorAndThrowForInvokeException(err, err.getErrorMessage()); + } finally { + rLock.unlock(); + } + } + + /** + * Checks to make sure that producer has not been closed. + * + * @throws YRException the YRException + */ + private void ensureOpen() throws YRException { + if (consumerPtr == 0) { + throw new YRException(ErrorCode.ERR_INNER_SYSTEM_ERROR, ModuleCode.RUNTIME, + this.getClass().getName() + ": Consumer closed"); + } + } + + @Override + public void close() throws YRException { + wLock.lock(); + try { + if (consumerPtr == 0) { + throw new YRException(ErrorCode.ERR_INNER_SYSTEM_ERROR, ModuleCode.RUNTIME, + "Consumer has been closed"); + } + ErrorInfo err = JniConsumer.close(consumerPtr); + consumerPtr = 0; + StackTraceUtils.checkErrorAndThrowForInvokeException(err, err.getErrorMessage()); + } finally { + wLock.unlock(); + } + } + + @Override + public String toString() { + return super.toString(); + } + + @Override + protected void finalize() { + if (consumerPtr != 0) { + JniConsumer.freeJNIPtrNative(consumerPtr); + } + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + ConsumerImpl otherConsumerImpl = (ConsumerImpl) other; + return consumerPtr == otherConsumerImpl.consumerPtr; + } + + @Override + public int hashCode() { + return Objects.hash(consumerPtr); + } +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/Element.java b/api/java/function-common/src/main/java/com/yuanrong/stream/Element.java new file mode 100644 index 0000000..fe67b08 --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/Element.java @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import lombok.Data; +import lombok.experimental.Accessors; + +import java.nio.ByteBuffer; +import java.util.Objects; + +/** + * Element class that contains element id and data cache. + * + * @since 2024/09/04 + */ +@Data +@Accessors(chain = true) +public class Element { + /** + * The id of the element. + */ + private long id; + + /** + * Data cache. + */ + private ByteBuffer buffer; + + /** + * The constructor of Element. + * + * @param id The id of the element. + * @param buffer Stored data + */ + public Element(long id, ByteBuffer buffer) { + this.id = id; + this.buffer = buffer; + } + + /** + * Default constructor for Element. + */ + public Element() {} + + @Override + public String toString() { + return super.toString(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + Element otherElement = (Element) other; + return id == otherElement.id; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/Producer.java b/api/java/function-common/src/main/java/com/yuanrong/stream/Producer.java new file mode 100644 index 0000000..6cd8d84 --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/Producer.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import com.yuanrong.exception.YRException; + +/** + * Producer interface class. + * + * @class Producer Producer.java "function-common/src/main/java/com/yuanrong/stream/Producer.java". + * + * @since 2024/09/04 + */ +public interface Producer { + /** + * The producer sends data, which is first placed in a buffer. The buffer is flushed according to the configured + * automatic flush policy (send at a certain interval or when the buffer is full) or by actively calling flush to + * allow consumers to access it. + * + * @param element The Element data to be sent. Element can refer to the Element object structure in the public + * structure. + * @throws YRException Unified exception types thrown. + */ + void send(Element element) throws YRException; + + /** + * The producer sends data, which is first placed in a buffer. The buffer is flushed according to the configured + * automatic flush policy (send at a certain interval or when the buffer is full) or by actively calling flush to + * allow consumers to access it. + * + * @param element The Element data to be sent. Element can refer to the Element object structure in the public + * structure. + * @param timeoutMs Timeout period. + * @throws YRException Unified exception types thrown. + */ + void send(Element element, int timeoutMs) throws YRException; + + /** + * Manually flush the buffer data to make it visible to consumers. + * + * @throws YRException Unified exception types thrown. + */ + void flush() throws YRException; + + /** + * Closing a producer triggers an automatic flush of the data buffer and indicates that the data buffer is no longer + * in use. Once closed, the producer can no longer be used. + * + * @throws YRException Unified exception types thrown. + */ + void close() throws YRException; +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerConfig.java b/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerConfig.java new file mode 100644 index 0000000..f0d4c06 --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerConfig.java @@ -0,0 +1,219 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import lombok.Data; +import lombok.experimental.Accessors; + +import java.util.HashMap; +import java.util.Map; + +/** + * Configuration class for creating the producer + * + * @since 2024/09/04 + */ +@Data +@Accessors(chain = true) +public class ProducerConfig { + /** + * After sending, the flush is triggered after the maximum delay time. <0: No automatic flush. 0: Flush immediately. + * Otherwise, it indicates the delay time in milliseconds. + */ + private long delayFlushTimeMs = 5L; + + /** + * Represents the buffer page size corresponding to the producer, in bytes (B); when the page is full, flush is + * triggered. The default is ``1MB``, and it must be greater than 0 and a multiple of 4K. + */ + private long pageSizeByte = 1024 * 1024L; + + /** + * Specifies the maximum amount of shared memory that a stream can use on a worker, in units of B (bytes). + * The default is ``100MB``, with a range of [64KB, the size of the worker's shared memory]. + */ + private long maxStreamSize = 100 * 1024 * 1024L; + + /** + * Specifies whether the stream has the automatic cleanup feature enabled. The default is ``false``, which means it + * is disabled. + */ + private boolean autoCleanup = false; + + /** + * Specifies whether the stream has the content encryption feature enabled. The default is ``false``, which means it + * is disabled. + */ + private boolean encryptStream = false; + + /** + * The data sent by the producer will be retained until the Nth consumer receives it. The default value is ``0``, + * which means that if there are no consumers when the producer sends the data, the data will not be retained, + * and the consumer might not receive the data after it is created. This parameter is only effective for the first + * consumer created, and the current valid range is [0, 1], and it does not support multiple consumers. + */ + private long retainForNumConsumers = 0L; + + /** + * Represents the reserved memory size, in units of B (bytes). When creating a producer, it will attempt to reserve + * reserveSize bytes of memory. If the reservation fails, an exception will be thrown during the creation of the + * producer. reserveSize must be an integer multiple of pageSize, and its value range is [0, maxStreamSize]. If + * reserveSize is 0, it will be set to pageSize. The default value is ``0``. + */ + private long reserveSize = 0L; + + /** + * Producer expansion configuration. Common configuration items are as follows:\n + * "STREAM_MODE": Stream mode, can be ``MPMC``, ``MPSC``, or ``SPSC``, default is ``MPMC``. If it is not one of the + * above modes, an exception will be thrown. ``MPMC`` stands for multiple producers and multiple consumers, ``MPSC`` + * stands for multiple producers and single consumer, and ``SPSC`` stands for single producer and single consumer. + * If it is ``MPSC`` or ``SPSC`` mode, the data system internally enables the multi-stream shared Page function. + */ + private Map extendConfig = new HashMap<>(); + + /** + * The ProducerConfig class Builder. + * + * @return Builder object of ProducerConfig class. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * The ProducerConfig Builder Class. + */ + public static class Builder { + private ProducerConfig producerConfig; + + /** + * The ProducerConfig Builder. + */ + public Builder() { + producerConfig = new ProducerConfig(); + } + + /** + * Sets the delayFlushTimeMs of ProducerConfig class. + * + * @param delayFlushTimeMs the delayFlushTimeMs. + * @return ProducerConfig Builder class object. + */ + public Builder delayFlushTimeMs(long delayFlushTimeMs) { + producerConfig.delayFlushTimeMs = delayFlushTimeMs; + return this; + } + + /** + * Sets the pageSizeByte of ProducerConfig class. + * + * @param pageSizeByte the pageSizeByte. + * @return ProducerConfig Builder class object. + */ + public Builder pageSizeByte(long pageSizeByte) { + producerConfig.pageSizeByte = pageSizeByte; + return this; + } + + /** + * Sets the maxStreamSize of ProducerConfig class. + * + * @param maxStreamSize the maxStreamSize. + * @return ProducerConfig Builder class object. + */ + public Builder maxStreamSize(long maxStreamSize) { + producerConfig.maxStreamSize = maxStreamSize; + return this; + } + + /** + * Sets the autoCleanup of ProducerConfig class. + * + * @param autoCleanup the autoCleanup. + * @return ProducerConfig Builder class object. + */ + public Builder autoCleanup(boolean autoCleanup) { + producerConfig.autoCleanup = autoCleanup; + return this; + } + + /** + * Sets the encryptStream of ProducerConfig class. + * + * @param encryptStream the encryptStream. + * @return ProducerConfig Builder class object. + */ + public Builder encryptStream(boolean encryptStream) { + producerConfig.encryptStream = encryptStream; + return this; + } + + /** + * Set the retainForNumConsumers. + * + * @param retainForNumConsumers retainForNumConsumers. + * @return ProducerConfig Builder class object. + */ + public Builder retainForNumConsumers(long retainForNumConsumers) { + producerConfig.retainForNumConsumers = retainForNumConsumers; + return this; + } + + /** + * Sets the reserveSize of ProducerConfig class. + * + * @param reserveSize the reserveSize. + * @return ProducerConfig Builder class object. + */ + public Builder reserveSize(long reserveSize) { + producerConfig.reserveSize = reserveSize; + return this; + } + + /** + * Add the extendConfig of ProducerConfig class. + * + * @param key the extendConfig key. + * @param val the extendConfig value. + * @return ProducerConfig Builder class object. + */ + public Builder addExtendConfig(String key, String val) { + producerConfig.extendConfig.put(key, val); + return this; + } + + /** + * Sets the extendConfig of ProducerConfig class. + * + * @param extendConfig the extendConfig map. + * @return ProducerConfig Builder class object. + */ + public Builder extendConfig(Map extendConfig) { + producerConfig.extendConfig.putAll(extendConfig); + return this; + } + + /** + * Builds the ProducerConfig object. + * + * @return ProducerConfig class object. + */ + public ProducerConfig build() { + return producerConfig; + } + } +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerImpl.java b/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerImpl.java new file mode 100644 index 0000000..158902f --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerImpl.java @@ -0,0 +1,160 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.exception.YRException; +import com.yuanrong.exception.handler.traceback.StackTraceUtils; +import com.yuanrong.jni.JniProducer; + +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * The producer implement, visible in this package. + * + * @since 2024-09-04 + */ +public class ProducerImpl implements Producer { + // for producerPtr. + private final ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock(); + private final Lock rLock = rwLock.readLock(); + private final Lock wLock = rwLock.writeLock(); + + // Point to jni producer object. + private long producerPtr; + + /** + * set the producerPtr + * + * @param producerPtr the producerPtr + */ + public ProducerImpl(long producerPtr) { + this.producerPtr = producerPtr; + } + + @Override + public void send(Element element) throws YRException { + ByteBuffer buffer = element.getBuffer(); + rLock.lock(); + try { + ensureOpen(); + ErrorInfo err = null; + if (buffer.isDirect()) { + err = JniProducer.sendDirectBufferDefaultTimeout(producerPtr, buffer); + } else { + err = JniProducer.sendHeapBufferDefaultTimeout(producerPtr, buffer.array(), buffer.limit()); + } + StackTraceUtils.checkErrorAndThrowForInvokeException(err, err.getErrorMessage()); + } finally { + rLock.unlock(); + } + } + + @Override + public void send(Element element, int timeoutMs) throws YRException { + ByteBuffer buffer = element.getBuffer(); + rLock.lock(); + try { + ensureOpen(); + ErrorInfo err; + if (buffer.isDirect()) { + err = JniProducer.sendDirectBuffer(producerPtr, buffer, timeoutMs); + } else { + err = JniProducer.sendHeapBuffer(producerPtr, buffer.array(), buffer.limit(), timeoutMs); + } + if (err != null) { + StackTraceUtils.checkErrorAndThrowForInvokeException(err, err.getErrorMessage()); + } + } finally { + rLock.unlock(); + } + } + + @Override + public void flush() throws YRException { + rLock.lock(); + try { + ensureOpen(); + ErrorInfo err = JniProducer.flush(this.producerPtr); + StackTraceUtils.checkErrorAndThrowForInvokeException(err, err.getErrorMessage()); + } finally { + rLock.unlock(); + } + } + + /** + * Checks to make sure that producer has not been closed. + * + * @throws YRException the YRException + */ + private void ensureOpen() throws YRException { + if (producerPtr == 0) { + throw new YRException(ErrorCode.ERR_INNER_SYSTEM_ERROR, ModuleCode.RUNTIME, + this.getClass().getName() + ": Producer closed"); + } + } + + @Override + public void close() throws YRException { + wLock.lock(); + try { + if (producerPtr == 0) { + throw new YRException(ErrorCode.ERR_INNER_SYSTEM_ERROR, ModuleCode.RUNTIME, + "Consumer has been closed"); + } + ErrorInfo err = JniProducer.close(this.producerPtr); + producerPtr = 0; + StackTraceUtils.checkErrorAndThrowForInvokeException(err, err.getErrorMessage()); + } finally { + wLock.unlock(); + } + } + + @Override + protected void finalize() { + if (producerPtr != 0) { + JniProducer.freeJNIPtrNative(producerPtr); + } + } + + @Override + public String toString() { + return super.toString(); + } + + @Override + public int hashCode() { + return Objects.hash(producerPtr); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + ProducerImpl otherProducerImpl = (ProducerImpl) other; + return producerPtr == otherProducerImpl.producerPtr; + } +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/SubscriptionConfig.java b/api/java/function-common/src/main/java/com/yuanrong/stream/SubscriptionConfig.java new file mode 100644 index 0000000..8b038d0 --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/SubscriptionConfig.java @@ -0,0 +1,125 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import lombok.Data; +import lombok.experimental.Accessors; + +import java.util.HashMap; +import java.util.Map; + +/** + * Consumer subscription configuration class + * + * @since 2024/09/04 + */ +@Data +@Accessors(chain = true) +public class SubscriptionConfig { + /** + * Subscription name + */ + private String subscriptionName = ""; + + /** + * Subscription types include ``STREAM``, ``ROUND_ROBIN``, and ``KEY_PARTITIONS``. Currently, only the ``STREAM`` + * type is supported. Other types are not supported for the time being. The default subscription type is ``STREAM``. + */ + private SubscriptionType subscriptionType = SubscriptionType.STREAM; + + /** + * Indicates extended configuration, reserved field. + */ + private Map extendConfig = new HashMap<>(); + + /** + * The SubscriptionConfig class Builder. + * + * @return Builder object of SubscriptionConfig class. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * The Builder class of SubscriptionConfig class. + */ + public static class Builder { + private SubscriptionConfig subscriptionConfig; + + /** + * The SubscriptionConfig Builder. + */ + public Builder() { + subscriptionConfig = new SubscriptionConfig(); + } + + /** + * Sets the subscriptionName of SubscriptionConfig class. + * + * @param subscriptionName the subscriptionName. + * @return SubscriptionConfig Builder class object. + */ + public Builder subscriptionName(String subscriptionName) { + subscriptionConfig.subscriptionName = subscriptionName; + return this; + } + + /** + * Sets the subscriptionName of SubscriptionConfig class. + * + * @param subscriptionType the subscriptionType. + * @return SubscriptionConfig Builder class object. + */ + public Builder subscriptionType(SubscriptionType subscriptionType) { + subscriptionConfig.subscriptionType = subscriptionType; + return this; + } + + /** + * Add the extendConfig of SubscriptionConfig class. + * + * @param key the extendConfig key. + * @param val the extendConfig value. + * @return SubscriptionConfig Builder class object. + */ + public Builder addExtendConfig(String key, String val) { + subscriptionConfig.extendConfig.put(key, val); + return this; + } + + /** + * Sets the extendConfig of SubscriptionConfig class. + * + * @param extendConfig the extendConfig map. + * @return SubscriptionConfig Builder class object. + */ + public Builder extendConfig(Map extendConfig) { + subscriptionConfig.extendConfig.putAll(extendConfig); + return this; + } + + /** + * Builds the SubscriptionConfig object. + * + * @return SubscriptionConfig class object. + */ + public SubscriptionConfig build() { + return subscriptionConfig; + } + } +} diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/SubscriptionType.java b/api/java/function-common/src/main/java/com/yuanrong/stream/SubscriptionType.java new file mode 100644 index 0000000..85dbfae --- /dev/null +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/SubscriptionType.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +/** + * Specifies the subscription mode + * + * @since 2024-09-04 + */ +public enum SubscriptionType { + STREAM, ROUND_ROBIN, KEY_PARTITIONS +} diff --git a/api/java/function-common/src/test/java/com/services/model/TestFaaSModel.java b/api/java/function-common/src/test/java/com/services/model/TestFaaSModel.java new file mode 100644 index 0000000..0babfa6 --- /dev/null +++ b/api/java/function-common/src/test/java/com/services/model/TestFaaSModel.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.services.model; + +import com.google.gson.JsonObject; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +public class TestFaaSModel { + @Test + public void testCallRequest() { + CallRequest callRequest = new CallRequest(); + callRequest.setHeader(new HashMap(){ + { + put("key1", "val1"); + } + }); + callRequest.setBody("body1"); + Assert.assertEquals(callRequest.getHeader().get("key1"), "val1"); + Assert.assertEquals(callRequest.getBody(), "body1"); + } + + @Test + public void testCallResponseJsonObject() { + CallResponseJsonObject callResponse = new CallResponseJsonObject(); + JsonObject jsonObject = new JsonObject(); + jsonObject.addProperty("name", "faas"); + callResponse.setBody(jsonObject); + callResponse.setInnerCode("200"); + callResponse.setBillingDuration("x-bill-duration"); + callResponse.setLogResult("x-log-result"); + callResponse.setInvokerSummary("x-invoke-summary"); + Assert.assertEquals(callResponse.getInnerCode(), "200"); + Assert.assertEquals(callResponse.getBillingDuration(), "x-bill-duration"); + Assert.assertEquals(callResponse.getLogResult(), "x-log-result"); + Assert.assertEquals(callResponse.getInvokerSummary(), "x-invoke-summary"); + Assert.assertEquals(callResponse.getBody().get("name").getAsString(), "faas"); + } +} diff --git a/api/java/function-common/src/test/java/com/services/runtime/action/TestContextImpl.java b/api/java/function-common/src/test/java/com/services/runtime/action/TestContextImpl.java index ac99e6c..5a04c6e 100644 --- a/api/java/function-common/src/test/java/com/services/runtime/action/TestContextImpl.java +++ b/api/java/function-common/src/test/java/com/services/runtime/action/TestContextImpl.java @@ -63,6 +63,10 @@ public class TestContextImpl { Assert.assertEquals(context.getRequestID(), "requestID"); Assert.assertEquals(context.getTraceID(), "requestID"); Assert.assertEquals(context.getInvokeID(), ""); + Assert.assertNull(context.getAccessKey()); + Assert.assertNull(context.getSecretKey()); + Assert.assertNull(context.getToken()); + Assert.assertNull(context.getSecurityToken()); Assert.assertNull(context.getAlias()); Assert.assertNull(context.getInvokeProperty()); Assert.assertNotNull(context.getLogger()); @@ -120,16 +124,28 @@ public class TestContextImpl { @Test public void testContextInvokeParams() { ContextInvokeParams param = new ContextInvokeParams(); + param.setAccessKey("ak"); + param.setSecretKey("sk"); + param.setSecurityAccessKey("sak"); + param.setSecuritySecretKey("ssk"); param.setRequestID("reqID"); param.setInvokeID("invokeID"); + param.setToken("token"); + param.setSecurityToken("sToken"); param.setAlias("alias"); param.setWorkflowID("workflowID"); param.setWorkflowRunID("runid"); param.setWorkflowStateID("statID"); param.setReqStreamName("reqStreamName"); param.setRespStreamName("respStreamName"); + Assert.assertEquals(param.getAccessKey(), "ak"); + Assert.assertEquals(param.getSecretKey(), "sk"); + Assert.assertEquals(param.getSecurityAccessKey(), "sak"); + Assert.assertEquals(param.getSecuritySecretKey(), "ssk"); Assert.assertEquals(param.getRequestID(), "reqID"); Assert.assertEquals(param.getInvokeID(), "invokeID"); + Assert.assertEquals(param.getToken(), "token"); + Assert.assertEquals(param.getSecurityToken(), "sToken"); Assert.assertEquals(param.getAlias(), "alias"); Assert.assertEquals(param.getWorkflowID(), "workflowID"); Assert.assertEquals(param.getWorkflowRunID(), "runid"); diff --git a/api/java/function-common/src/test/java/com/yuanrong/TestInvokeOptions.java b/api/java/function-common/src/test/java/com/yuanrong/TestInvokeOptions.java index 24686f8..91cdfd3 100644 --- a/api/java/function-common/src/test/java/com/yuanrong/TestInvokeOptions.java +++ b/api/java/function-common/src/test/java/com/yuanrong/TestInvokeOptions.java @@ -18,6 +18,7 @@ package com.yuanrong; import com.yuanrong.affinity.Affinity; import com.yuanrong.affinity.AffinityKind; +import com.yuanrong.affinity.AffinityScope; import com.yuanrong.affinity.AffinityType; import com.yuanrong.affinity.LabelOperator; import com.yuanrong.affinity.OperatorType; @@ -99,6 +100,24 @@ public class TestInvokeOptions { Assert.assertEquals(expectedValue, options.getInstancePriority()); } + @Test + public void testGroupName() { + String inputGroupName = "testGroupName"; + InvokeOptions options = new InvokeOptions.Builder() + .groupName(inputGroupName) + .build(); + Assert.assertEquals(inputGroupName, options.getGroupName()); + } + + @Test + public void testTraceId() { + String inputTraceId = "testTraceId"; + InvokeOptions options = new InvokeOptions.Builder() + .traceId(inputTraceId) + .build(); + Assert.assertEquals(inputTraceId, options.getTraceId()); + } + @Test public void testPreemptedAllowed() { boolean expectedValue = true; @@ -139,8 +158,8 @@ public class TestInvokeOptions { testCustomExtensions.put(expectedMapKey, expectedMapValue); testCustomExtensions.put(Constants.POST_START_EXEC, "false"); InvokeOptions options = new InvokeOptions.Builder().customExtensions(testCustomExtensions) - .addCustomExtensions(expectedMapKey, expectedMapValue) - .addCustomExtensions(Constants.POST_START_EXEC, "true") + .addCustomExtension(expectedMapKey, expectedMapValue) + .addCustomExtension(Constants.POST_START_EXEC, "true") .build(); options.setCustomExtensions(testCustomExtensions); options.addCustomExtensions(expectedMapKey, expectedMapValue); @@ -207,12 +226,13 @@ public class TestInvokeOptions { testAffinityList.add(affinity); InvokeOptions options = new InvokeOptions.Builder().scheduleAffinity(testAffinityList) .addInstanceAffinity(AffinityType.PREFERRED, testOperatorsList) + .addInstanceAffinity(AffinityType.PREFERRED, testOperatorsList, AffinityScope.NODE) .addResourceAffinity(AffinityType.PREFERRED, testOperatorsList) .addScheduleAffinity(affinity) .build(); options.parserAffinityMsgFromJsonStr(""); options.parserAffinityMsgFromJsonStr(options.affinityMsgToJsonStr()); - Assert.assertEquals(4, options.getScheduleAffinities().size()); + Assert.assertEquals(5, options.getScheduleAffinities().size()); InvokeOptions newOptions = new InvokeOptions(options); List newList = newOptions.getScheduleAffinities(); diff --git a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestAffinity.java b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestAffinity.java index 8cc6673..22a0a87 100644 --- a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestAffinity.java +++ b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestAffinity.java @@ -46,6 +46,9 @@ public class TestAffinity { affinity3.setInstanceAffinity(new InstanceAffinity()); Assert.assertFalse(affinity.equals(affinity1)); Assert.assertTrue(affinity.equals(affinity)); + + Affinity affinity5 = new Affinity(AffinityKind.INSTANCE, AffinityType.PREFERRED, new ArrayList<>(), AffinityScope.POD); + Assert.assertEquals(affinity5.getAffinityScope(), AffinityScope.POD); } @Test @@ -88,4 +91,14 @@ public class TestAffinity { Affinity affinity1 = new Affinity(null, null, new ArrayList<>()); Assert.assertEquals(0, affinity1.getAffinityValue()); } + + @Test + public void testGetAffinityScopeValue() { + Affinity affinity = new Affinity(AffinityKind.RESOURCE, AffinityType.PREFERRED, new ArrayList<>()); + Assert.assertEquals(0, affinity.getAffinityScopeValue()); + + Affinity affinity1 = new Affinity(AffinityKind.INSTANCE, AffinityType.PREFERRED_ANTI, new ArrayList<>(), + AffinityScope.NODE); + Assert.assertEquals(2, affinity1.getAffinityScopeValue()); + } } diff --git a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestInstanceAffinity.java b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestInstanceAffinity.java index 335561d..0f894fb 100644 --- a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestInstanceAffinity.java +++ b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestInstanceAffinity.java @@ -62,8 +62,122 @@ public class TestInstanceAffinity { instanceAffinity6.toString(); instanceAffinity.canEqual(instanceAffinity2); - Assert.assertEquals(1, instanceAffinity11.getScope().getScope()); + Assert.assertEquals(2, instanceAffinity11.getScope().getScope()); Assert.assertFalse(instanceAffinity3.equals(instanceAffinity8)); Assert.assertEquals(AffinityScope.POD, instanceAffinity.getScope()); } + + @Test + public void testInitInstanceAffinityPreferredAffinity() { + InstanceAffinity instanceAffinity = new InstanceAffinity(AffinityType.PREFERRED, new Selector(), + AffinityScope.NODE); + + Assert.assertNotNull(instanceAffinity.getPreferredAffinity()); + + Assert.assertNull(instanceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAntiAffinity()); + + Assert.assertEquals(AffinityScope.NODE, instanceAffinity.getScope()); + } + + @Test + public void testInitInstanceAffinityPreferredAntiAffinity() { + InstanceAffinity instanceAffinity = new InstanceAffinity(AffinityType.PREFERRED_ANTI, new Selector(), + AffinityScope.NODE); + + Assert.assertNotNull(instanceAffinity.getPreferredAntiAffinity()); + + Assert.assertNull(instanceAffinity.getPreferredAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAntiAffinity()); + + Assert.assertEquals(AffinityScope.NODE, instanceAffinity.getScope()); + } + + @Test + public void testInitInstanceAffinityRequiredAffinity() { + InstanceAffinity instanceAffinity = new InstanceAffinity(AffinityType.REQUIRED, new Selector(), + AffinityScope.NODE); + + Assert.assertNotNull(instanceAffinity.getRequiredAffinity()); + + Assert.assertNull(instanceAffinity.getPreferredAffinity()); + Assert.assertNull(instanceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAntiAffinity()); + + Assert.assertEquals(AffinityScope.NODE, instanceAffinity.getScope()); + } + + @Test + public void testInitInstanceAffinityRequiredAntiAffinity() { + InstanceAffinity instanceAffinity = new InstanceAffinity(AffinityType.REQUIRED_ANTI, new Selector(), + AffinityScope.NODE); + + Assert.assertNotNull(instanceAffinity.getRequiredAntiAffinity()); + + Assert.assertNull(instanceAffinity.getPreferredAffinity()); + Assert.assertNull(instanceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAffinity()); + + Assert.assertEquals(AffinityScope.NODE, instanceAffinity.getScope()); + } + + @Test + public void testInitInstanceAffinitySelector() { + Selector pa = new Selector(); + InstanceAffinity instanceAffinity = new InstanceAffinity(pa); + + Assert.assertEquals(pa, instanceAffinity.getPreferredAffinity()); + + Assert.assertNull(instanceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAntiAffinity()); + + Assert.assertEquals(AffinityScope.NODE, instanceAffinity.getScope()); + } + + @Test + public void testInitInstanceAffinitySelectorScope() { + Selector pa = new Selector(); + InstanceAffinity instanceAffinity = new InstanceAffinity(pa, AffinityScope.POD); + + Assert.assertEquals(pa, instanceAffinity.getPreferredAffinity()); + + Assert.assertNull(instanceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAffinity()); + Assert.assertNull(instanceAffinity.getRequiredAntiAffinity()); + + Assert.assertEquals(AffinityScope.POD, instanceAffinity.getScope()); + } + + @Test + public void testInitInstanceAffinitySelectors() { + Selector pa = new Selector(); + Selector paa = new Selector(); + Selector ra = new Selector(); + Selector raa = new Selector(); + InstanceAffinity instanceAffinity = new InstanceAffinity(pa, paa, ra, raa); + Assert.assertEquals(pa, instanceAffinity.getPreferredAffinity()); + Assert.assertEquals(paa, instanceAffinity.getPreferredAntiAffinity()); + Assert.assertEquals(ra, instanceAffinity.getRequiredAffinity()); + Assert.assertEquals(raa, instanceAffinity.getRequiredAntiAffinity()); + + Assert.assertEquals(AffinityScope.NODE, instanceAffinity.getScope()); + } + + @Test + public void testInitInstanceAffinitySelectorsScope() { + Selector pa = new Selector(); + Selector paa = new Selector(); + Selector ra = new Selector(); + Selector raa = new Selector(); + InstanceAffinity instanceAffinity = new InstanceAffinity(pa, paa, ra, raa, AffinityScope.POD); + Assert.assertEquals(pa, instanceAffinity.getPreferredAffinity()); + Assert.assertEquals(paa, instanceAffinity.getPreferredAntiAffinity()); + Assert.assertEquals(ra, instanceAffinity.getRequiredAffinity()); + Assert.assertEquals(raa, instanceAffinity.getRequiredAntiAffinity()); + + Assert.assertEquals(AffinityScope.POD, instanceAffinity.getScope()); + } } diff --git a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestResourceAffinity.java b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestResourceAffinity.java index 831efdd..4ab1f10 100644 --- a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestResourceAffinity.java +++ b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestResourceAffinity.java @@ -54,4 +54,74 @@ public class TestResourceAffinity { Assert.assertFalse(resourceAffinity10.equals(resourceAffinity11)); } + + @Test + public void testInitResourceAffinityPreferredAffinity() { + ResourceAffinity resourceAffinity = new ResourceAffinity(AffinityType.PREFERRED, new Selector()); + + Assert.assertNotNull(resourceAffinity.getPreferredAffinity()); + + Assert.assertNull(resourceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(resourceAffinity.getRequiredAffinity()); + Assert.assertNull(resourceAffinity.getRequiredAntiAffinity()); + } + + @Test + public void testInitResourceAffinityPreferredAntiAffinity() { + ResourceAffinity resourceAffinity = new ResourceAffinity(AffinityType.PREFERRED_ANTI, new Selector()); + + Assert.assertNotNull(resourceAffinity.getPreferredAntiAffinity()); + + Assert.assertNull(resourceAffinity.getPreferredAffinity()); + Assert.assertNull(resourceAffinity.getRequiredAffinity()); + Assert.assertNull(resourceAffinity.getRequiredAntiAffinity()); + } + + @Test + public void testInitResourceAffinityRequiredAffinity() { + ResourceAffinity resourceAffinity = new ResourceAffinity(AffinityType.REQUIRED, new Selector()); + + Assert.assertNotNull(resourceAffinity.getRequiredAffinity()); + + Assert.assertNull(resourceAffinity.getPreferredAffinity()); + Assert.assertNull(resourceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(resourceAffinity.getRequiredAntiAffinity()); + } + + @Test + public void testInitResourceAffinityRequiredAntiAffinity() { + ResourceAffinity resourceAffinity = new ResourceAffinity(AffinityType.REQUIRED_ANTI, new Selector()); + + Assert.assertNotNull(resourceAffinity.getRequiredAntiAffinity()); + + Assert.assertNull(resourceAffinity.getPreferredAffinity()); + Assert.assertNull(resourceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(resourceAffinity.getRequiredAffinity()); + } + + @Test + public void testInitResourceAffinitySelector() { + Selector pa = new Selector(); + ResourceAffinity resourceAffinity = new ResourceAffinity(pa); + + Assert.assertEquals(pa, resourceAffinity.getPreferredAffinity()); + + Assert.assertNull(resourceAffinity.getPreferredAntiAffinity()); + Assert.assertNull(resourceAffinity.getRequiredAffinity()); + Assert.assertNull(resourceAffinity.getRequiredAntiAffinity()); + } + + @Test + public void testInitResourceAffinitySelectors() { + Selector pa = new Selector(); + Selector paa = new Selector(); + Selector ra = new Selector(); + Selector raa = new Selector(); + ResourceAffinity resourceAffinity = new ResourceAffinity(pa, paa, ra, raa); + + Assert.assertEquals(pa, resourceAffinity.getPreferredAffinity()); + Assert.assertEquals(paa, resourceAffinity.getPreferredAntiAffinity()); + Assert.assertEquals(ra, resourceAffinity.getRequiredAffinity()); + Assert.assertEquals(raa, resourceAffinity.getRequiredAntiAffinity()); + } } diff --git a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestSelector.java b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestSelector.java new file mode 100644 index 0000000..c22878d --- /dev/null +++ b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestSelector.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.affinity; + +import org.junit.Assert; +import org.junit.Test; + +public class TestSelector { + @Test + public void testInitSelector() { + Selector selector = new Selector(); + + Condition cond = new Condition(); + Selector selector1 = new Selector(cond); + Assert.assertEquals(cond, selector1.getCondition()); + + selector.setCondition(new Condition()); + selector1.getCondition(); + selector1.hashCode(); + selector1.toString(); + selector1.canEqual(selector); + + Assert.assertNotEquals(selector, selector1); + Assert.assertFalse(selector.equals(selector1)); + } + +} diff --git a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestSubCondition.java b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestSubCondition.java index 6c488e6..eb2f360 100644 --- a/api/java/function-common/src/test/java/com/yuanrong/affinity/TestSubCondition.java +++ b/api/java/function-common/src/test/java/com/yuanrong/affinity/TestSubCondition.java @@ -46,8 +46,38 @@ public class TestSubCondition { } @Test - public void testInitSubCondition2() { + public void testInitSubConditionOperator() { + String labelKey = "key"; + LabelOperator operator = new LabelOperator(); + SubCondition subCondition = new SubCondition(labelKey, operator); + LabelExpression ex = subCondition.getExpressions().get(0); + Assert.assertEquals(labelKey, ex.getKey()); + Assert.assertEquals(operator, ex.getOperator()); + Assert.assertEquals(0, subCondition.getWeight()); + } + + @Test + public void testInitSubConditionExpression() { + LabelExpression ex = new LabelExpression(); + SubCondition subCondition = new SubCondition(ex); + Assert.assertEquals(ex, subCondition.getExpressions().get(0)); + Assert.assertEquals(0, subCondition.getWeight()); + } + @Test + public void testInitSubConditionExpressions() { + LabelExpression ex1 = new LabelExpression(); + LabelExpression ex2 = new LabelExpression(); + List expressions = new ArrayList<>(); + expressions.add(ex1); + expressions.add(ex2); + SubCondition subCondition = new SubCondition(expressions); + Assert.assertEquals(expressions, subCondition.getExpressions()); + Assert.assertEquals(0, subCondition.getWeight()); + } + + @Test + public void testInitSubConditionExpressionWeight() { String labelKey = "key"; List values = new ArrayList<>(); values.add("value1"); @@ -63,4 +93,20 @@ public class TestSubCondition { Assert.assertEquals(weight, subCondition.getWeight()); } + @Test + public void testInitSubConditionExpressionsWeight() { + LabelExpression ex1 = new LabelExpression(); + LabelExpression ex2 = new LabelExpression(); + List expressions = new ArrayList<>(); + expressions.add(ex1); + expressions.add(ex2); + + int weight = 10; + + SubCondition subCondition = new SubCondition(expressions, weight); + + Assert.assertEquals(expressions, subCondition.getExpressions()); + Assert.assertEquals(weight, subCondition.getWeight()); + } + } diff --git a/api/java/function-common/src/test/java/com/yuanrong/runtime/TestUtils.java b/api/java/function-common/src/test/java/com/yuanrong/runtime/TestUtils.java index 404a4f9..147b60b 100644 --- a/api/java/function-common/src/test/java/com/yuanrong/runtime/TestUtils.java +++ b/api/java/function-common/src/test/java/com/yuanrong/runtime/TestUtils.java @@ -30,6 +30,18 @@ import java.util.ArrayList; import java.util.HashMap; public class TestUtils { + @Test + public void testSplitEntryWithSeparators() { + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.timeoutInitializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler-handler\"}"; + String[] entry = null; + try { + entry = Utils.splitEntryWithSeparators(udfEntry, new String[] {"-"}); + } catch (NoSuchMethodException e) { + Assert.assertEquals(null, entry); + } + } + @Test public void testInitUtils() throws NoSuchMethodException { Utils utils = new Utils(); @@ -90,6 +102,24 @@ public class TestUtils { } String stringWithLen = stringBuilder.toString(); Utils.getProcessedExceptionMsg(new Throwable(stringWithLen)); + Utils.put("test", null, new ArrayList<>(), null, 1L, 1L); + + isException = false; + try { + Utils.get(new ArrayList<>(), 1, 1); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); Utils.sleepSeconds(10); + + + isException = false; + try { + Utils.splitUserClassAndMethod(null,true); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); } } diff --git a/api/java/function-common/src/test/java/com/yuanrong/stream/TestConsumerImpl.java b/api/java/function-common/src/test/java/com/yuanrong/stream/TestConsumerImpl.java new file mode 100644 index 0000000..881611f --- /dev/null +++ b/api/java/function-common/src/test/java/com/yuanrong/stream/TestConsumerImpl.java @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.when; + +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.jni.JniConsumer; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.ArrayList; +import java.util.List; + +@RunWith(PowerMockRunner.class) +@PrepareForTest( {JniConsumer.class}) +@SuppressStaticInitializationFor( {"com.yuanrong.jni.JniConsumer"}) +@PowerMockIgnore("javax.management.*") +public class TestConsumerImpl { + @Test + public void testInitConsumerImpl() { + ConsumerImpl consumer = new ConsumerImpl(10L); + ConsumerImpl consumer1 = new ConsumerImpl(20L); + consumer.toString(); + consumer1.hashCode(); + Assert.assertTrue(consumer.equals(consumer)); + Assert.assertFalse(consumer1.equals(null)); + Assert.assertFalse(consumer.equals(consumer1)); + } + + @Test + public void testReceive() { + ConsumerImpl consumer = new ConsumerImpl(10L); + PowerMockito.mockStatic(JniConsumer.class); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.CORE, ""); + Pair> errorInfoListPair = new Pair<>(errorInfo, new ArrayList<>()); + when(JniConsumer.receive(anyLong(), anyLong(), anyInt(), anyBoolean())).thenReturn(errorInfoListPair); + + boolean isException = false; + try { + consumer.receive(10); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + isException = false; + try { + consumer.receive(10L, 10); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + ConsumerImpl consumer1 = new ConsumerImpl(0L); + isException = false; + try { + consumer1.receive(10L, 10); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testAck() { + ConsumerImpl consumer = new ConsumerImpl(10L); + PowerMockito.mockStatic(JniConsumer.class); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.CORE, ""); + when(JniConsumer.ack(anyLong(), anyLong())).thenReturn(errorInfo); + boolean isException = false; + try { + consumer.ack(10); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testClose() { + ConsumerImpl consumer = new ConsumerImpl(10L); + PowerMockito.mockStatic(JniConsumer.class); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.CORE, ""); + when(JniConsumer.close(anyLong())).thenReturn(errorInfo); + boolean isException = false; + try { + consumer.close(); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + ConsumerImpl consumer1 = new ConsumerImpl(0L); + isException = false; + try { + consumer1.close(); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testFinalize() { + ConsumerImpl consumer = new ConsumerImpl(10L); + PowerMockito.mockStatic(JniConsumer.class); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.CORE, ""); + when(JniConsumer.freeJNIPtrNative(anyLong())).thenReturn(errorInfo); + boolean isException = false; + try { + consumer.finalize(); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + } +} diff --git a/api/java/function-common/src/test/java/com/yuanrong/stream/TestElement.java b/api/java/function-common/src/test/java/com/yuanrong/stream/TestElement.java new file mode 100644 index 0000000..e23cd92 --- /dev/null +++ b/api/java/function-common/src/test/java/com/yuanrong/stream/TestElement.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class TestElement { + @Test + public void testInitElement() { + Element element = new Element(); + Element element1 = new Element(1L, ByteBuffer.allocate(10)); + element.setId(2L); + element1.setBuffer(ByteBuffer.allocate(20)); + element.getBuffer(); + element.getId(); + element.toString(); + element1.hashCode(); + + Assert.assertTrue(element1.equals(element1)); + Assert.assertFalse(element1.equals(null)); + Assert.assertFalse(element1.equals(element)); + } +} diff --git a/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerConfig.java b/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerConfig.java new file mode 100644 index 0000000..5bd37e4 --- /dev/null +++ b/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerConfig.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import org.junit.Assert; +import org.junit.Test; + +public class TestProducerConfig { + @Test + public void testInitProducerConfig() { + ProducerConfig.Builder builder = ProducerConfig.builder(); + ProducerConfig.Builder builder1 = new ProducerConfig.Builder(); + builder1.delayFlushTimeMs(1L); + builder1.pageSizeByte(2L); + builder1.maxStreamSize(3L); + builder1.autoCleanup(true); + builder1.encryptStream(false); + builder1.reserveSize(4L); + ProducerConfig build = builder1.build(); + Assert.assertEquals(1L, build.getDelayFlushTimeMs()); + } +} diff --git a/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerImpl.java b/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerImpl.java new file mode 100644 index 0000000..a222390 --- /dev/null +++ b/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerImpl.java @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.when; + +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.jni.JniProducer; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.nio.ByteBuffer; + +@RunWith(PowerMockRunner.class) +@PrepareForTest( {JniProducer.class}) +@SuppressStaticInitializationFor( {"com.yuanrong.jni.JniProducer"}) +@PowerMockIgnore("javax.management.*") +public class TestProducerImpl { + @Test + public void testInitProducerImpl() { + ProducerImpl producer = new ProducerImpl(10L); + producer.toString(); + producer.hashCode(); + ProducerImpl producer1 = new ProducerImpl(20L); + Assert.assertTrue(producer1.equals(producer1)); + Assert.assertFalse(producer1.equals(null)); + Assert.assertFalse(producer.equals(producer1)); + } + + @Test + public void testSend() { + ProducerImpl producer = new ProducerImpl(10L); + ByteBuffer byteBuffer = ByteBuffer.allocateDirect(100); + byteBuffer.put(new byte[] {1, 2, 3, 4, 5}); + byteBuffer.flip(); + + PowerMockito.mockStatic(JniProducer.class); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.CORE, ""); + when(JniProducer.sendDirectBufferDefaultTimeout(anyLong(), any())).thenReturn(errorInfo); + when(JniProducer.sendDirectBuffer(anyLong(), any(), anyInt())).thenReturn(errorInfo); + + boolean isException = false; + try { + producer.send(new Element(10L, byteBuffer)); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + isException = false; + try { + producer.send(new Element(10L, byteBuffer), 10); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + ByteBuffer byteBuffer1 = ByteBuffer.allocate(100); + when(JniProducer.sendHeapBufferDefaultTimeout(anyLong(), any(), anyLong())).thenReturn(errorInfo); + when(JniProducer.sendHeapBuffer(anyLong(), any(), anyLong(), anyInt())).thenReturn(errorInfo); + + isException = false; + try { + producer.send(new Element(10L, byteBuffer1)); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + isException = false; + try { + producer.send(new Element(10L, byteBuffer1), 10); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testFlush() { + ProducerImpl producer = new ProducerImpl(10L); + ProducerImpl producer1 = new ProducerImpl(0L); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.CORE, ""); + PowerMockito.mockStatic(JniProducer.class); + when(JniProducer.flush(anyLong())).thenReturn(errorInfo); + boolean isException = false; + try { + producer1.flush(); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + producer.flush(); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testClose() { + ProducerImpl producer = new ProducerImpl(10L); + ProducerImpl producer1 = new ProducerImpl(0L); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.CORE, ""); + PowerMockito.mockStatic(JniProducer.class); + when(JniProducer.close(anyLong())).thenReturn(errorInfo); + boolean isException = false; + + try { + producer1.close(); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + producer.close(); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testFinalize() throws Exception { + ProducerImpl producer = new ProducerImpl(10L); + PowerMockito.mockStatic(JniProducer.class); + PowerMockito.doNothing().when(JniProducer.class, "freeJNIPtrNative",anyLong()); + producer.finalize(); + } +} diff --git a/api/java/function-common/src/test/java/com/yuanrong/stream/TestSubscriptionConfig.java b/api/java/function-common/src/test/java/com/yuanrong/stream/TestSubscriptionConfig.java new file mode 100644 index 0000000..5f1f619 --- /dev/null +++ b/api/java/function-common/src/test/java/com/yuanrong/stream/TestSubscriptionConfig.java @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.stream; + +import org.junit.Assert; +import org.junit.Test; + +public class TestSubscriptionConfig { + @Test + public void testInitSubscriptionConfig() { + SubscriptionConfig.Builder builder = SubscriptionConfig.builder(); + SubscriptionConfig.Builder builder1 = new SubscriptionConfig.Builder(); + builder1.subscriptionName("test"); + builder1.subscriptionType(SubscriptionType.STREAM); + SubscriptionConfig subscriptionConfig = builder1.build(); + Assert.assertEquals("test", subscriptionConfig.getSubscriptionName()); + } +} diff --git a/api/java/yr-api-sdk/resource/sdkpom.xml b/api/java/yr-api-sdk/resource/sdkpom.xml index 8db81d3..eca7ea5 100644 --- a/api/java/yr-api-sdk/resource/sdkpom.xml +++ b/api/java/yr-api-sdk/resource/sdkpom.xml @@ -33,7 +33,7 @@ com.fasterxml.jackson.core jackson-core - 2.16.2 + 2.18.2 org.apache.commons diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/Config.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/Config.java index 684615c..6e22c0f 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/Config.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/Config.java @@ -19,6 +19,7 @@ package com.yuanrong; import com.yuanrong.errorcode.ErrorCode; import com.yuanrong.errorcode.ModuleCode; import com.yuanrong.exception.YRException; +import com.yuanrong.runtime.util.Constants; import com.yuanrong.utils.SdkUtils; import com.yuanrong.jni.YRAutoInitInfo; import com.yuanrong.runtime.util.Utils; @@ -33,8 +34,8 @@ import java.util.Map; /** * The Config class is the initialization data structure of Yuanrong, used to store basic information such as * IP, port, and URN needed when initializing the Yuanrong system. The Config instance is the input parameter of the - * init interface. Except for functionURN, serverAddress, dataSystemAddress, and cppFunctionURN(the top - * four in the table), which are mandatory configurations and supported through the constructor, the rest of the + * init interface. Except for functionURN, serverAddress, dataSystemAddress, cppFunctionURN, and isInCluster (the top + * five in the table), which are mandatory configurations and supported through the constructor, the rest of the * parameters are default and set through setters. For specific interfaces, please refer to the table at the end. * * @since 2023/10/16 @@ -50,6 +51,8 @@ public class Config { = "sn:cn:yrk:12345678901234561234567890123456:function:0-defaultservice-java:$latest"; private static final String DEFAULT_CPP_URN = "sn:cn:yrk:12345678901234561234567890123456:function:0-defaultservice-cpp:$latest"; + private static final String DEFAULT_GO_URN + = "sn:cn:yrk:12345678901234561234567890123456:function:0-defaultservice-go:$latest"; /** * Specify the so path. If not specified, it is specified by services.yaml. @@ -68,6 +71,11 @@ public class Config { */ private String cppFunctionURN = DEFAULT_CPP_URN; + /** + * The functionID returned by the deployment go function. + */ + private String goFunctionURN = DEFAULT_GO_URN; + /** * Cluster IP (Yuanrong cluster master node). */ @@ -104,6 +112,11 @@ public class Config { */ private int threadPoolSize = 0; + /** + * Inside/Outside the cluster, determine whether the runtime is connected to the bus, default is ``false``; + * clients outside the cluster have no retry mechanism. + */ + private boolean isInCluster = true; /** * Is driver or not, determines whether the API is used on the cloud or locally, default is ``true``. @@ -118,7 +131,7 @@ public class Config { /** * Whether to enable metric collection and reporting. The default value is ``false``. */ - private boolean enableMetrics = false; + private boolean enableMetrics = true; /** * Whether to enable two-way authentication for external cloud clients, default is ``off``. @@ -140,6 +153,11 @@ public class Config { */ private String verifyFilePath; + /** + * Client private key encryption password. + */ + private String privateKeyPaaswd; + /** * Server name. */ @@ -196,6 +214,13 @@ public class Config { */ private boolean enableDisConvCallStack = true; + /** + * When the user configures iamAuthToken, the token is included in the header of each request sent to the cluster. + * The remote runtime will obtain the token ciphertext from RuntimeManager and then carry the token ciphertext + * in the runtime communication request. + */ + private String iamAuthToken = ""; + /** * Whether to enable multi-cluster mode. The default is ``false``. If isThreadLocal is true, call YR.init in * different threads and set different cluster addresses. The runtime Java SDK can connect to different clusters. @@ -208,6 +233,21 @@ public class Config { */ private boolean enableSetContext = false; + /** + * Tenant ID, default is ``empty``. + */ + private String tenantId = ""; + + /** + * Number of HTTP connection working threads. Default value: ``200``. + */ + private int httpIocThreadsNum = Constants.DEFAULT_HTTP_IO_THREAD_CNT; + + /** + * HTTP connection idle time, default value ``120``. + */ + private int httpIdleTime = Constants.DEFAULT_HTTP_IDLE_TIME; + /** * Used to set custom environment variables for the runtime. Currently, only LD_LIBRARY_PATH is supported. */ @@ -229,10 +269,12 @@ public class Config { * @param serverAddress Cluster IP (Yuanrong cluster master node). * @param dataSystemAddress Data system IP (Yuanrong cluster master node). * @param cppFunctionURN The functionID returned by the deployment cpp function. + * @param isInCluster Inside/Outside the cluster. */ - public Config(String functionURN, String serverAddress, String dataSystemAddress, String cppFunctionURN) { + public Config(String functionURN, String serverAddress, String dataSystemAddress, String cppFunctionURN, + boolean isInCluster) { this(functionURN, serverAddress, DEFAULT_SERVER_PORT, dataSystemAddress, DEFAULT_DS_PORT, cppFunctionURN, - true); + isInCluster, true); } /** @@ -247,6 +289,24 @@ public class Config { * @param serverAddress Cluster IP (Yuanrong cluster master node). * @param dataSystemAddress Data system IP (Yuanrong cluster master node). * @param cppFunctionURN The functionID returned by the deployment cpp function. + * @param goFunctionURN The functionID returned by the deployment go function. + * @param isInCluster Inside/Outside the cluster. + */ + public Config(String functionURN, String serverAddress, String dataSystemAddress, String cppFunctionURN, + String goFunctionURN, boolean isInCluster) { + this(functionURN, serverAddress, DEFAULT_SERVER_PORT, dataSystemAddress, DEFAULT_DS_PORT, cppFunctionURN, + isInCluster, true); + this.setGoFunctionURN(goFunctionURN); + } + + /** + * The constructor of Config. + * + * @param functionURN The functionURN returned by the deployment function. + * @param serverAddress Cluster IP (Yuanrong cluster master node). + * @param dataSystemAddress Data system IP (Yuanrong cluster master node). + * @param cppFunctionURN The functionID returned by the deployment cpp function. + * @param isInCluster Inside/Outside the cluster. * @param isDriver On cloud or off cloud. */ public Config( @@ -254,9 +314,62 @@ public class Config { String serverAddress, String dataSystemAddress, String cppFunctionURN, - boolean isDriver) { + boolean isInCluster, boolean isDriver) { this(functionURN, serverAddress, DEFAULT_SERVER_PORT, dataSystemAddress, DEFAULT_DS_PORT, cppFunctionURN, - isDriver); + isInCluster, isDriver); + } + + /** + * The constructor of Config. + * + * @param functionURN The functionURN returned by the deployment function. + * @param serverAddress Cluster IP (Yuanrong cluster master node). + * @param dataSystemAddress Data system IP (Yuanrong cluster master node). + * @param cppFunctionURN The functionID returned by the deployment cpp function. + * @param goFunctionURN The functionID returned by the deployment go function. + * @param isInCluster Inside/Outside the cluster. + * @param isDriver On cloud or off cloud. + */ + public Config( + String functionURN, + String serverAddress, + String dataSystemAddress, + String cppFunctionURN, + String goFunctionURN, + boolean isInCluster, boolean isDriver) { + this(functionURN, serverAddress, DEFAULT_SERVER_PORT, dataSystemAddress, DEFAULT_DS_PORT, cppFunctionURN, + isInCluster, isDriver); + this.setGoFunctionURN(goFunctionURN); + } + + /** + * The constructor of Config. + * + * @param functionUrn The functionURN returned by the deployment function. + * @param serverAddr Cluster IP (Yuanrong cluster master node). + * @param serverAddressPort Cluster port number. + * @param dataSystemAddress Data system IP (Yuanrong cluster master node). + * @param dataSystemAddressPort DataSystem port number. + * @param cppFunctionUrn The functionID returned by the deployment cpp function. + * @param isInCluster Inside/Outside the cluster. + */ + public Config( + String functionUrn, + String serverAddr, + int serverAddressPort, + String dataSystemAddr, + int dataSystemAddressPort, + String cppFunctionUrn, + boolean isInCluster) { + this( + functionUrn, + serverAddr, + serverAddressPort, + dataSystemAddr, + dataSystemAddressPort, + cppFunctionUrn, + isInCluster, + true); } /** @@ -265,9 +378,11 @@ public class Config { * @param functionUrn The functionURN returned by the deployment function. * @param serverAddr Cluster IP (Yuanrong cluster master node). * @param serverAddressPort Cluster port number. - * @param dataSystemAddr Data system IP (Yuanrong cluster master node). + * @param dataSystemAddress Data system IP (Yuanrong cluster master node). * @param dataSystemAddressPort DataSystem port number. * @param cppFunctionUrn The functionID returned by the deployment cpp function. + * @param goFunctionUrn The functionID returned by the deployment go function. + * @param isInCluster Inside/Outside the cluster. */ public Config( String functionUrn, @@ -275,7 +390,9 @@ public class Config { int serverAddressPort, String dataSystemAddr, int dataSystemAddressPort, - String cppFunctionUrn) { + String cppFunctionUrn, + String goFunctionUrn, + boolean isInCluster) { this( functionUrn, serverAddr, @@ -283,7 +400,9 @@ public class Config { dataSystemAddr, dataSystemAddressPort, cppFunctionUrn, + isInCluster, true); + this.setGoFunctionURN(goFunctionUrn); } /** @@ -292,9 +411,10 @@ public class Config { * @param functionUrn The functionURN returned by the deployment function. * @param serverAddr Cluster IP (Yuanrong cluster master node). * @param serverAddressPort Cluster port number. - * @param dataSystemAddr Data system IP (Yuanrong cluster master node). + * @param dataSystemAddress Data system IP (Yuanrong cluster master node). * @param dataSystemAddressPort DataSystem port number. * @param cppFunctionUrn The functionID returned by the deployment cpp function. + * @param isInCluster Inside/Outside the cluster. * @param isDriver On cloud or off cloud. */ public Config( @@ -304,11 +424,13 @@ public class Config { String dataSystemAddr, int dataSystemAddressPort, String cppFunctionUrn, + boolean isInCluster, boolean isDriver) { this.functionURN = functionUrn; this.serverAddress = serverAddr; this.dataSystemAddress = dataSystemAddr; this.cppFunctionURN = cppFunctionUrn; + this.isInCluster = isInCluster; this.isDriver = isDriver; this.serverAddressPort = serverAddressPort; this.dataSystemAddressPort = dataSystemAddressPort; @@ -318,8 +440,12 @@ public class Config { this.functionURN = builder.functionURN; this.serverAddress = builder.serverAddress; this.dataSystemAddress = builder.dataSystemAddress; + this.iamAuthToken = builder.iamAuthToken; this.cppFunctionURN = builder.cppFunctionURN; + this.goFunctionURN = builder.goFunctionURN; this.ns = builder.ns; + this.tenantId = builder.tenantId; + this.isInCluster = builder.isInCluster; this.isDriver = builder.isDriver; this.serverAddressPort = builder.serverAddressPort; this.dataSystemAddressPort = builder.dataSystemAddressPort; @@ -381,6 +507,10 @@ public class Config { throw new YRException(ErrorCode.ERR_INCORRECT_INIT_USAGE, ModuleCode.RUNTIME, "cppFunctionURN is invalid"); } + if (!goFunctionURN.isEmpty() && !SdkUtils.checkURN(goFunctionURN)) { + throw new YRException(ErrorCode.ERR_INCORRECT_INIT_USAGE, ModuleCode.RUNTIME, + "goFunctionURN is invalid"); + } if (!SdkUtils.checkIP(serverAddress)) { throw new YRException(ErrorCode.ERR_INCORRECT_INIT_USAGE, ModuleCode.RUNTIME, "serverAddress is invalid"); @@ -415,10 +545,14 @@ public class Config { private String functionURN = DEFAULT_FUNC_URN; private String serverAddress = ""; private String dataSystemAddress = ""; + private String iamAuthToken = ""; private String cppFunctionURN = ""; + private String goFunctionURN = ""; private String ns = ""; private String logDir = DEFAULT_LOG_DIR; private String logLevel = ""; + private String tenantId = ""; + private boolean isInCluster = true; private boolean isDriver = true; private int serverAddressPort = DEFAULT_SERVER_PORT; private int dataSystemAddressPort = DEFAULT_DS_PORT; @@ -461,6 +595,17 @@ public class Config { return this; } + /** + * Sets the iamAuthToken of Config class. + * + * @param iamAuthToken the IAM token String. + * @return Config Builder class object. + */ + public Builder iamAuthToken(String iamAuthToken) { + this.iamAuthToken = iamAuthToken; + return this; + } + /** * Sets the cppFunctionURN string of Config class. * @@ -472,6 +617,39 @@ public class Config { return this; } + /** + * Sets the goFunctionURN string of Config class. + * + * @param goFunctionURN the goFunctionURN string. + * @return Config Builder class object. + */ + public Builder goFunctionURN(String goFunctionURN) { + this.goFunctionURN = goFunctionURN; + return this; + } + + /** + * Sets the tenantId string of Config class. + * + * @param tenantId the tenantId string. + * @return Config Builder class object. + */ + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + + /** + * Sets the isInCluster boolean of Config class. + * + * @param isInCluster the isInCluster boolean. Applying driver mode if it is true. + * @return Config Builder class object. + */ + public Builder isInCluster(boolean isInCluster) { + this.isInCluster = isInCluster; + return this; + } + /** * Sets the isDriver boolean of Config class. * @@ -519,7 +697,8 @@ public class Config { } /** - * If enableSetContext is true, tenant context switching is allowed. + * If enableSetContext is true, tenant context switching is allowed, + * including tenantid, cluster and so on. * * @param enableSetContext boolean indicates that whether support tenant context. * @return Config Builder class object. @@ -613,6 +792,7 @@ public class Config { */ public YRAutoInitInfo buildClusterAccessInfo() { YRAutoInitInfo info = new YRAutoInitInfo(); + info.setInCluster(this.isInCluster); info.setFunctionSystemServerIpAddr(this.serverAddress); if (this.serverAddress != null && !this.serverAddress.isEmpty()) { info.setFunctionSystemServerPort(this.serverAddressPort); diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/ConfigManager.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/ConfigManager.java index a9f50f4..f471b6a 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/ConfigManager.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/ConfigManager.java @@ -17,6 +17,7 @@ package com.yuanrong; import com.yuanrong.runtime.config.RuntimeContext; +import com.yuanrong.runtime.util.Constants; import com.yuanrong.utils.SdkUtils; import lombok.Data; @@ -43,14 +44,16 @@ public class ConfigManager { private String jobId = ""; private String runtimeId = ""; private String cppFunctionURN; + private String goFunctionURN; private int recycleTime; - private boolean isInCluster = true; + private boolean isInCluster; private int maxTaskInstanceNum = -1; - private boolean enableMetrics; + private boolean enableMetrics = true; private boolean enableMTLS = false; private String certificateFilePath; private String privateKeyPath; private String verifyFilePath; + private String privateKeyPaaswd; private String serverName; private String driverServerIP; private boolean isDriver = true; @@ -64,6 +67,8 @@ public class ConfigManager { private boolean isLogMerge = false; private int threadPoolSize; private ArrayList loadPaths; + private int httpIocThreadsNum = Constants.DEFAULT_HTTP_IO_THREAD_CNT; + private int httpIdleTime = Constants.DEFAULT_HTTP_IDLE_TIME; private boolean enableDisConvCallStack; private int rpcTimeout; private String tenantId = ""; @@ -101,7 +106,9 @@ public class ConfigManager { this.serverAddress = config.getServerAddress(); this.dataSystemAddress = config.getDataSystemAddress(); this.cppFunctionURN = config.getCppFunctionURN(); + this.goFunctionURN = config.getGoFunctionURN(); this.recycleTime = config.getRecycleTime(); + this.isInCluster = config.isInCluster(); this.maxTaskInstanceNum = config.getMaxTaskInstanceNum(); this.enableMetrics = config.isEnableMetrics(); this.maxConcurrencyCreateNum = config.getMaxConcurrencyCreateNum(); @@ -114,14 +121,17 @@ public class ConfigManager { this.loadPaths = config.getLoadPaths(); this.serverAddressPort = config.getServerAddressPort(); this.dataSystemAddressPort = config.getDataSystemAddressPort(); + this.httpIocThreadsNum = config.getHttpIocThreadsNum(); this.ns = config.getNs(); this.enableDisConvCallStack = config.isEnableDisConvCallStack(); this.rpcTimeout = config.getRpcTimeout(); + this.tenantId = config.getTenantId(); this.customEnvs = config.getCustomEnvs(); this.enableMTLS = config.isEnableMTLS(); this.certificateFilePath = config.getCertificateFilePath(); this.privateKeyPath = config.getPrivateKeyPath(); this.verifyFilePath = config.getVerifyFilePath(); + this.privateKeyPaaswd = config.getPrivateKeyPaaswd(); this.serverName = config.getServerName(); this.isInitialized = true; this.codePath = config.getCodePath(); diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/YRCall.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/YRCall.java index 4cb122b..ce3029f 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/YRCall.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/YRCall.java @@ -19,12 +19,16 @@ package com.yuanrong; import com.yuanrong.call.CppFunctionHandler; import com.yuanrong.call.CppInstanceCreator; import com.yuanrong.call.FunctionHandler; +import com.yuanrong.call.GoFunctionHandler; +import com.yuanrong.call.GoInstanceCreator; import com.yuanrong.call.InstanceCreator; import com.yuanrong.call.JavaFunctionHandler; import com.yuanrong.call.JavaInstanceCreator; import com.yuanrong.call.VoidFunctionHandler; import com.yuanrong.function.CppFunction; import com.yuanrong.function.CppInstanceClass; +import com.yuanrong.function.GoFunction; +import com.yuanrong.function.GoInstanceClass; import com.yuanrong.function.JavaFunction; import com.yuanrong.function.JavaInstanceClass; import com.yuanrong.function.YRFunc0; @@ -592,4 +596,52 @@ public class YRCall extends YRGetInstance { public static CppInstanceCreator instance(CppInstanceClass cppInstanceClass, String name, String nameSpace) { return new CppInstanceCreator(cppInstanceClass, name, nameSpace); } + + /** + * Function go function handler. + * + * @param Return value type. + * @param goFunction go Function name. + * @return GoFunctionHandler Instance. + */ + public static GoFunctionHandler function(GoFunction goFunction) { + return new GoFunctionHandler<>(goFunction); + } + + /** + * Instance go instance creator. + * + * @param goInstanceClass the go function instance class. + * @return GoInstanceCreator Instance. + */ + public static GoInstanceCreator instance(GoInstanceClass goInstanceClass) { + return new GoInstanceCreator(goInstanceClass, "", ConfigManager.getInstance().getNs()); + } + + /** + * Instance go instance creator. + * + * @param goInstanceClass the go instance class. + * @param name The instance name of the named instance, the second parameter. + * When only name exists, the instance name will be set to name. + * @return GoInstanceCreator Instance. + */ + public static GoInstanceCreator instance(GoInstanceClass goInstanceClass, String name) { + return new GoInstanceCreator(goInstanceClass, name, ConfigManager.getInstance().getNs()); + } + + /** + * Instance go instance creator. + * + * @param goInstanceClass the go instance class. + * @param name The instance name of the named instance, the second parameter. + * When only name exists, the instance name will be set to name. + * @param nameSpace Namespace of the named instance. When both name and nameSpace exist, the instance name is + * concatenated into nameSpace-name. + * This field is currently used only for concatenation. + * @return GoInstanceCreator Instance. + */ + public static GoInstanceCreator instance(GoInstanceClass goInstanceClass, String name, String nameSpace) { + return new GoInstanceCreator(goInstanceClass, name, nameSpace); + } } diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/api/JobExecutorCaller.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/api/JobExecutorCaller.java new file mode 100644 index 0000000..9f2609c --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/api/JobExecutorCaller.java @@ -0,0 +1,288 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.api; + +import com.yuanrong.InvokeOptions; +import com.yuanrong.call.InstanceCreator; +import com.yuanrong.call.InstanceHandler; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.exception.YRException; +import com.yuanrong.exception.LibRuntimeException; +import com.yuanrong.exception.handler.traceback.StackTraceUtils; +import com.yuanrong.function.YRFunc4; +import com.yuanrong.jni.LibRuntime; +import com.yuanrong.jobexecutor.JobExecutor; +import com.yuanrong.jobexecutor.RuntimeEnv; +import com.yuanrong.jobexecutor.YRJobInfo; +import com.yuanrong.jobexecutor.YRJobParam; +import com.yuanrong.jobexecutor.YRJobStatus; +import com.yuanrong.libruntime.generated.Libruntime.ApiType; +import com.yuanrong.runtime.client.ObjectRef; +import com.yuanrong.runtime.config.RuntimeContext; +import com.yuanrong.runtime.util.Constants; +import com.yuanrong.storage.InternalWaitResult; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * The local methods for the invocation for remote runtime JobExecutor actor + * member functions. + * + * @since 2023 /06/06 + */ +public class JobExecutorCaller { + private static final Logger LOGGER = LoggerFactory.getLogger(JobExecutorCaller.class); + + private static final int DEFAULT_JOB_EXECUTOR_INVOKE_TIMEOUT_MS = 30000; + + private static ConcurrentHashMap> jobInfoCaches = + new ConcurrentHashMap<>(); + + private static final List JOB_ERROR_CODES = Arrays.asList(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, + ErrorCode.ERR_JOB_RUNTIME_EXCEPTION, + ErrorCode.ERR_JOB_INNER_SYSTEM_EXCEPTION); + + private static final List USER_CODE_ERROR_CODES = Arrays.asList(ErrorCode.ERR_INCORRECT_INIT_USAGE, + ErrorCode.ERR_INCORRECT_INVOKE_USAGE, + ErrorCode.ERR_PARAM_INVALID); + + private static final List RUNTIME_ERROR_CODES = Arrays.asList(ErrorCode.ERR_INIT_CONNECTION_FAILED, + ErrorCode.ERR_USER_CODE_LOAD, + ErrorCode.ERR_PARSE_INVOKE_RESPONSE_ERROR, ErrorCode.ERR_INSTANCE_ID_EMPTY); + + /** + * Invokes the JobExecutor actor instance. + * + * @param yrJobParam a YRJobParam object. All fields in this object are + * required except runtimeEnv. + * @return String the instanceID of the JobExecutor actor. + * @throws YRException the actor task exception. + */ + public static String submitJob(YRJobParam yrJobParam) throws YRException { + InstanceCreator jobExecutor = new InstanceCreator( + (YRFunc4, String, JobExecutor>) JobExecutor::new); + ArrayList entryPoint = yrJobParam.getLocalEntryPoint(); + InvokeOptions invokeOptions = yrJobParam.extractInvokeOptions(); + String objectID = ""; + try { + InstanceHandler handler = jobExecutor + .options(invokeOptions) + .invoke(yrJobParam.getJobName(), + yrJobParam.getRuntimeEnv(), + entryPoint, + invokeOptions.affinityMsgToJsonStr()); + objectID = handler.getInstanceId(); + } catch (YRException e) { + throw adaptException(e); + } + + String userJobID; + try { + userJobID = LibRuntime.GetRealInstanceId(objectID); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } + InternalWaitResult waitResult; + try { + waitResult = LibRuntime.Wait(Collections.singletonList(objectID), 1, Constants.NO_TIMEOUT); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } + if (waitResult == null) { + throw new YRException(ErrorCode.ERR_INNER_SYSTEM_ERROR, ModuleCode.RUNTIME, + "failed to create instance"); + } + if (!waitResult.getExceptionIds().isEmpty()) { + LOGGER.warn("wait objIds, exception ids size is {}", waitResult.getExceptionIds().size()); + Iterator> it = waitResult.getExceptionIds().entrySet().iterator(); + Map.Entry entry = it.next(); + StackTraceUtils.checkErrorAndThrowForInvokeException(entry.getValue(), + "wait objIds(" + entry.getKey() + "...)"); + } + // A record needs to be initialized here. Otherwise, + // no record is found when the 'listjobs()' is invoked. + getJobInfoCache().put(userJobID, new YRJobInfo()); + return userJobID; + } + + /** + * Stops the job and release resources of the related attached-runtime process. + * The JobExecutor actor instance remains alive. + * + * @param userJobID the instanceID of remote JobExecutor actor. + * @throws YRException the actor task exception. + */ + public static void stopJob(String userJobID) throws YRException { + try { + InstanceHandler handler = getInstanceHandler(userJobID); + handler.function(JobExecutor::stop).invoke(); + } catch (YRException e) { + throw adaptException(e); + } + } + + /** + * Gets the YRjobInfo object according to a given userJobID. + * + * @param userJobID the instanceID of remote JobExecutor actor. + * @return YRJobInfo object. + * @throws YRException the actor task exception. + */ + public static YRJobInfo getYrJobInfo(String userJobID) throws YRException { + return updateYRJobInfo(userJobID); + } + + /** + * Gets the current status of a specified job. The Status can be + * one of RUNNING/SUCCEEDED/STOPPED or FAILED. + * + * @param userJobID the instanceID of remote JobExecutor actor. + * @return YRJobStatus object indicates the current status of the job. + * @throws YRException the actor task exception. + */ + public static YRJobStatus getJobStatus(String userJobID) throws YRException { + YRJobInfo jobInfo = updateYRJobInfo(userJobID); + return jobInfo.getStatus(); + } + + /** + * Obtains Specified jobs information in the current SDK domain. + * + * @param userJobIDList String[] or Strings of userJobIDs. + * @return the Map map contains YRjobInfos related to given + * userJobIDs. + * @throws YRException the actor task exception. + */ + public static Map listJobs(String... userJobIDList) throws YRException { + Map jobsMap = new HashMap<>(); + for (String userJobID : userJobIDList) { + jobsMap.put(userJobID, updateYRJobInfo(userJobID)); + } + return jobsMap; + } + + /** + * Obtains all jobs information in the SDK domain. + * + * @return the Map map contains YRjobInfo objects. + * @throws YRException Failed to update jobs information. + */ + public static Map listJobs() throws YRException { + List keys = new ArrayList(getJobInfoCache().keySet()); + return listJobs(keys.toArray(new String[0])); + } + + /** + * Deletes all cached information related to user's job and + * terminates the corresponding instance. + * + * @param userJobID the instanceID of remote JobExecutor actor. + * @throws YRException the actor task exception. + */ + public static void deleteJob(String userJobID) throws YRException { + if (userJobID == null || userJobID.isEmpty()) { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME_INVOKE, + "Instance ID is empty"); + } + YR.getRuntime().terminateInstance(userJobID); + getJobInfoCache().remove(userJobID); + } + + /** + * Note that this method may return a null value, which needs to be verified + * before using its. + * + * @param userJobID the job instanceID returned by submitJob. + * @return null or YRJobInfo object which contains jobInfo. + * @throws YRException the actor task exception. + */ + private static YRJobInfo updateYRJobInfo(String userJobID) throws YRException { + YRJobInfo jobInfo = getJobInfoCache().get(userJobID); + if (jobInfo != null && jobInfo.ifFinalized()) { + return new YRJobInfo(jobInfo); + } + + boolean isWithStatic = false; + if (jobInfo == null || jobInfo.getJobName() == null) { + isWithStatic = true; + jobInfo = new YRJobInfo(); + } else { + jobInfo = new YRJobInfo(jobInfo); + } + + Object yrObj; + try { + InstanceHandler handler = getInstanceHandler(userJobID); + ObjectRef objectRef = handler.function(JobExecutor::getJobInfo).invoke(isWithStatic); + yrObj = YR.getRuntime().get(objectRef, DEFAULT_JOB_EXECUTOR_INVOKE_TIMEOUT_MS); + } catch (YRException e) { + LOGGER.error("(JobExecutor) Failed to invoke remote job info to update the job (userJobID: {}). ", + userJobID); + throw adaptException(e); + } + + if (yrObj instanceof YRJobInfo) { + jobInfo.update((YRJobInfo) yrObj); + } else { + throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, + "yrObj is not an instance of YRJobInfo"); + } + getJobInfoCache().put(userJobID, jobInfo); + return new YRJobInfo(jobInfo); + } + + private static InstanceHandler getInstanceHandler(String userJobID) { + return new InstanceHandler(userJobID, ApiType.Function); + } + + private static YRException adaptException(YRException exception) { + if (JOB_ERROR_CODES.contains(exception.getErrorCode())) { + return exception; + } + + if (USER_CODE_ERROR_CODES.contains(exception.getErrorCode())) { + return new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, exception.getModuleCode(), + exception.getErrorMessage()); + } + + if (RUNTIME_ERROR_CODES.contains(exception.getErrorCode())) { + return new YRException(ErrorCode.ERR_JOB_RUNTIME_EXCEPTION, exception.getModuleCode(), + exception.getErrorMessage()); + } + + return new YRException(ErrorCode.ERR_JOB_INNER_SYSTEM_EXCEPTION, exception.getModuleCode(), + exception.getErrorMessage()); + } + + private static Map getJobInfoCache() { + String runtimeCtx = RuntimeContext.RUNTIME_CONTEXT.get(); + jobInfoCaches.putIfAbsent(runtimeCtx, new ConcurrentHashMap()); + return jobInfoCaches.get(runtimeCtx); + } +} \ No newline at end of file diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/api/YR.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/api/YR.java index 0fe3bdf..11a6b0e 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/api/YR.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/api/YR.java @@ -26,6 +26,9 @@ import com.yuanrong.exception.YRException; import com.yuanrong.jni.LibRuntime; import com.yuanrong.jni.LibRuntimeConfig; import com.yuanrong.jni.YRAutoInitInfo; +import com.yuanrong.jobexecutor.YRJobInfo; +import com.yuanrong.jobexecutor.YRJobParam; +import com.yuanrong.jobexecutor.YRJobStatus; import com.yuanrong.runtime.ClusterModeRuntime; import com.yuanrong.runtime.Runtime; import com.yuanrong.runtime.client.KVManager; @@ -34,6 +37,10 @@ import com.yuanrong.runtime.config.RuntimeContext; import com.yuanrong.runtime.util.Constants; import com.yuanrong.storage.InternalWaitResult; import com.yuanrong.storage.WaitResult; +import com.yuanrong.stream.Consumer; +import com.yuanrong.stream.Producer; +import com.yuanrong.stream.ProducerConfig; +import com.yuanrong.stream.SubscriptionConfig; import com.yuanrong.utils.SdkUtils; import org.slf4j.Logger; @@ -97,6 +104,9 @@ public class YR extends YRCall { * The Yuanrong initialization interface is used to configure parameters. * For parameter descriptions, see the data structure Config. * + * @note When the Yuanrong cluster is enabled for multiple tenants, users must configure the tenant ID. + * For details about configuring the tenant ID, see the description of the tenant ID in [Config]. + * * @param conf Initialization parameter configuration of Yuanrong. * @return Information returned to the user after init ends. ClientInfo class description: * jobID, String type, the jobID of init, used for subsequent tracking and management of the current job @@ -381,6 +391,89 @@ public class YR extends YRCall { return YR.getRuntime().getKVManager(); } + /** + * Submits a user job. The job would be executed in remote runtime. + * + * @param yrJobParam a YRJobParam object. All fields in this object are + * required, except runtimeEnv which is optional. + * @return a String representing the unique jobID of a submitted job. + * @throws YRException the actor task exception. + */ + public static String submitJob(YRJobParam yrJobParam) throws YRException { + return JobExecutorCaller.submitJob(yrJobParam); + } + + /** + * Stops a specified job and release resources of the related attached-runtime + * process. + * + * @param userJobID a String representing the job instanceID returned by + * submitJob. + * @throws YRException the actor task exception. + */ + public static void stopJob(String userJobID) throws YRException { + JobExecutorCaller.stopJob(userJobID); + } + + /** + * Gets the current status of a specified job. The Status can be + * RUNNING/SUCCEEDED/STOPPED or FAILED. + * + * @param userJobID a String representing the job instanceID returned by + * submitJob. + * @return YRJobStatus object, which contains an enum type. + * @throws YRException the actor task exception. + */ + public static YRJobStatus getJobStatus(String userJobID) throws YRException { + return JobExecutorCaller.getJobStatus(userJobID); + } + + /** + * Gets the current YRJobInfo of a specified job. + * + * @param userJobID a String representing the job instanceID returned by + * submitJob. + * @return YRJobInfo object, which contains YRJobInfo. + * @throws YRException the actor task exception. + */ + public static YRJobInfo getJobInfo(String userJobID) throws YRException { + return JobExecutorCaller.getYrJobInfo(userJobID); + } + + /** + * Obtains all jobs information in the current SDK domain. + * Jobs information is updated and synchronized with the reomte runtime. + * + * @return Map map object. + * @throws YRException the actor task exception. + */ + public static Map listJobs() throws YRException { + return JobExecutorCaller.listJobs(); + } + + /** + * Obtains Specified jobs information given userJobIDs in the current SDK + * domain. + * Jobs information is updated and synchronized with the reomte runtime. + * + * @param userJobIDlist String[] or Strings of userJobIDs. + * @return Map map + * @throws YRException the actor task exception. + */ + public static Map listJobs(String... userJobIDlist) throws YRException { + return JobExecutorCaller.listJobs(userJobIDlist); + } + + /** + * Delete a specific user job given a userJobID. + * + * @param userJobID the job instanceID returned by submitJob. + * @throws YRException the actor task exception. + */ + public static void deleteJob(String userJobID) throws YRException { + JobExecutorCaller.deleteJob(userJobID); + } + /** * Save the state of the runtime with a timeout. * @@ -419,6 +512,75 @@ public class YR extends YRCall { YR.getRuntime().loadState(DEFAULT_SAVE_LOAD_STATE_TIMEOUT); } + /** + * Create a producer. + * + * @param streamName The name of the stream. The length must be less than 256 characters and contain only the + * following characters ``(a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^\\&\\*\\(\\)\\+\\=\\:;)``. + * @param producerConf Producer configuration information. + * @return Producer: Producer Interface. + * @throws YRException Unified exception types thrown. + */ + public static Producer createProducer(String streamName, + ProducerConfig producerConf) throws YRException { + return YR.getRuntime().createStreamProducer(streamName, producerConf); + } + + /** + * Create a producer. + * + * @param streamName The name of the stream. The length must be less than 256 characters and contain only the + * following characters ``(a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^\\&\\*\\(\\)\\+\\=\\:;)``. + * @return Producer: Producer Interface. + * @throws YRException Unified exception types thrown. + */ + public static Producer createProducer(String streamName) throws YRException { + return YR.getRuntime().createStreamProducer(streamName, new ProducerConfig()); + } + + /** + * Create a consumer. + * + * @param streamName The name of the stream. The length must be less than 256 characters and contain only the + * following characters ``(a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^\\&\\*\\(\\)\\+\\=\\:;)``. + * @param config Consumer configuration information. + * @return Consumer: Consumer interface. + * @throws YRException Unified exception types thrown. + */ + public static Consumer subscribe(String streamName, + SubscriptionConfig config) throws YRException { + return YR.getRuntime().createStreamConsumer(streamName, config, false); + } + + /** + * Create a consumer. + * + * @param streamName The name of the stream. The length must be less than 256 characters and contain only the + * following characters ``(a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^\\&\\*\\(\\)\\+\\=\\:;)``. + * @param config Consumer configuration information. + * @param autoAck When autoAck = true, the consumer automatically sends an Ack to the data system for the previous + * message when it receives a message. + * @return Consumer: Consumer interface. + * @throws YRException Unified exception types thrown. + */ + public static Consumer subscribe(String streamName, SubscriptionConfig config, + boolean autoAck) throws YRException { + return YR.getRuntime().createStreamConsumer(streamName, config, autoAck); + } + + /** + * Delete data stream. When the number of global producers and consumers is 0, this data stream is no longer used, + * and the metadata related to this data stream on each worker and master is cleaned up. This function can be called + * on any Host node. + * + * @param streamName The name of the stream. The length must be less than 256 characters and contain only the + * following characters ``(a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^\\&\\*\\(\\)\\+\\=\\:;)``. + * @throws YRException Unified exception types thrown. + */ + public static void deleteStream(String streamName) throws YRException { + YR.getRuntime().deleteStream(streamName); + } + private static ClientInfo initInternal(String runtimeCtx, Config conf) throws YRException { RuntimeContext.RUNTIME_CONTEXT.set(runtimeCtx); if (runtimeCache.get(runtimeCtx) != null) { @@ -433,6 +595,7 @@ public class YR extends YRCall { conf.setServerAddressPort(autoinfo.getFunctionSystemServerPort()); conf.setDataSystemAddress(autoinfo.getDataSystemIpAddr()); conf.setDataSystemAddressPort(autoinfo.getDataSystemPort()); + conf.setInCluster(autoinfo.isInCluster()); } } @@ -468,4 +631,14 @@ public class YR extends YRCall { LOGGER.debug("Succeeded to init YR, jobID is {}, tenant context is {}", jobID, ctx); return info; } + + /** + * Get node information in the cluster. + * + * @return List: node information. + * @throws YRException the actor task exception. + */ + public static List nodes() throws YRException { + return YR.getRuntime().nodes(); + } } diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppFunctionHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppFunctionHandler.java index de6ab88..774cf52 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppFunctionHandler.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppFunctionHandler.java @@ -86,7 +86,9 @@ public class CppFunctionHandler { /** * When Java calls a stateless function in C++, set the functionUrn for the function. * - * @param urn functionUrn, can be obtained after the function is deployed. + * @param urn functionUrn, can be obtained after the function is deployed. The tenant ID in the function urn must be + * consistent with the tenant ID configured in the config. For information about tenant ID configuration, + * see "About tenant ID" in Config. * @return CppFunctionHandler, with built-in invoke method, can create and call the cpp function instance. */ public CppFunctionHandler setUrn(String urn) { diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppInstanceCreator.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppInstanceCreator.java index 4fdb295..ef96827 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppInstanceCreator.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppInstanceCreator.java @@ -106,14 +106,15 @@ public class CppInstanceCreator { CppInstanceHandler cppInstanceHandler = new CppInstanceHandler(instanceId, functionID, this.cppInstanceClass.className); cppInstanceHandler.setNeedOrder(options.isNeedOrder()); - runtime.collectInstanceHandlerInfo(cppInstanceHandler); return cppInstanceHandler; } /** * When Java calls a stateful function in C++, set the functionUrn for the function. * - * @param urn functionUrn, can be obtained after the function is deployed. + * @param urn functionUrn, can be obtained after the function is deployed. The tenant ID in the function urn must be + * consistent with the tenant ID configured in the config. For information about tenant ID configuration, + * see "About tenant ID" in Config. * @return CppInstanceCreator, with built-in invoke method, can create instances of this cpp function class. * * @snippet{trimleft} SetUrnExample.java set urn of java invoke cpp stateful function diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppInstanceHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppInstanceHandler.java index 405b23d..ed4d766 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppInstanceHandler.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/CppInstanceHandler.java @@ -40,6 +40,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -245,6 +246,15 @@ public class CppInstanceHandler { className = ""; } + /** + * Release CppInstanceHandler, decrease reference. + * + * @throws YRException Unified exception types thrown. + */ + public void release() throws YRException { + YR.getRuntime().decreaseReference(Collections.singletonList(instanceId)); + } + /** * Obtain instance handle information. CppInstanceHandler class member method. * Users can obtain handle information through this method, which can be serialized and stored in a database or diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoFunctionHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoFunctionHandler.java new file mode 100644 index 0000000..3238326 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoFunctionHandler.java @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.call; + +import com.yuanrong.InvokeOptions; +import com.yuanrong.api.InvokeArg; +import com.yuanrong.api.YR; +import com.yuanrong.exception.YRException; +import com.yuanrong.function.GoFunction; +import com.yuanrong.libruntime.generated.Libruntime.ApiType; +import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; +import com.yuanrong.libruntime.generated.Libruntime.LanguageType; +import com.yuanrong.runtime.client.ObjectRef; +import com.yuanrong.utils.SdkUtils; + +import java.util.List; + +/** + * Class for invoking stateless functions in Go. + * + * @note The GoFunctionHandler class is the handle for a Go function created on the cloud by a Java function; + * it is the return type of the interface `YR.function(GoFunction func)`.\n + * Users can use GoFunctionHandler to create and invoke Go function instances. + * + * @since 2024/03/12 + */ +public class GoFunctionHandler { + private final GoFunction func; + + private InvokeOptions options = new InvokeOptions(); + + /** + * The constructor of GoFunctionHandler. + * + * @param func GoFunction class instance. + */ + public GoFunctionHandler(GoFunction func) { + this.func = func; + } + + /** + * Member method of the GoFunctionHandler class, used to call a Go function. + * + * @param args The input parameters required to call the specified method. + * @return ObjectRef: The "id" of the method's return value in the data system. Use YR.get() to get the actual + * return value of the method. + * @throws YRException Unified exception types thrown. + */ + public ObjectRef invoke(Object... args) throws YRException { + FunctionMeta functionMeta = FunctionMeta.newBuilder() + .setClassName("") + .setFunctionName(func.functionName) + .setSignature("") + .setLanguage(LanguageType.Golang) + .setApiType(ApiType.Function) + .build(); + List packedArgs = SdkUtils.packInvokeArgs(args); + String objId = YR.getRuntime().invokeByName(functionMeta, packedArgs, options); + return new ObjectRef(objId, func.returnType); + } + + /** + * The member method of the GoFunctionHandler class is used to dynamically modify the parameters of the called + * go function. + * + * @param opt Function call options, used to specify functions such as calling resources. + * @return GoFunctionHandler Class handle. + * + * @snippet{trimleft} GoInstanceExample.java GoFunctionHandle options 样例代码 + */ + public GoFunctionHandler options(InvokeOptions opt) { + this.options = new InvokeOptions(opt); + return this; + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceCreator.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceCreator.java new file mode 100644 index 0000000..58b59bf --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceCreator.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.call; + +import com.yuanrong.InvokeOptions; +import com.yuanrong.api.YR; +import com.yuanrong.exception.YRException; +import com.yuanrong.function.GoInstanceClass; +import com.yuanrong.libruntime.generated.Libruntime.ApiType; +import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; +import com.yuanrong.libruntime.generated.Libruntime.LanguageType; +import com.yuanrong.runtime.Runtime; +import com.yuanrong.utils.SdkUtils; + +import java.io.IOException; + +/** + * Create an operation class for creating a Go stateful function instance. + * + * @note The GoInstanceCreator class is a Java function that creates Go class instances; + * it is the return type of the interface `YR.instance(GoInstanceClass goInstanceClass)`.\n + * Users can use the invoke method of GoInstanceCreator to create a Go class instance and return the + * goInstanceHandler class handle. + * + * @since 2023/03/12 + */ +public class GoInstanceCreator { + private final GoInstanceClass goInstanceClass; + + private FunctionMeta functionMeta; + + private InvokeOptions options = new InvokeOptions(); + + private String name; + + private String nameSpace; + + /** + * The constructor of GoInstanceCreator. + * + * @param goInstanceClass GoInstanceClass class instance. + */ + public GoInstanceCreator(GoInstanceClass goInstanceClass) { + this(goInstanceClass, "", ""); + } + + /** + * The constructor of GoInstanceCreator. + * + * @param goInstanceClass GoInstanceClass class instance. + * @param name The instance name of a named instance. When only name exists, the instance name will be set to name. + * @param nameSpace Namespace of the named instance. When both name and nameSpace exist, the instance name will be + * concatenated into nameSpace-name. This field is currently only used for concatenation, and + * namespace isolation and other related functions will be completed later. + */ + public GoInstanceCreator(GoInstanceClass goInstanceClass, String name, String nameSpace) { + this.goInstanceClass = goInstanceClass; + this.name = name; + this.nameSpace = nameSpace; + } + + /** + * The member method of the GoInstanceCreator class is used to create a Go class instance. + * + * @param args The input parameters required to create an instance of the class. + * @return GoInstanceHandler class handle. + * @throws YRException Unified exception types thrown. + */ + public GoInstanceHandler invoke(Object... args) throws YRException { + String funcName = goInstanceClass.className; + this.functionMeta = FunctionMeta.newBuilder() + .setClassName("") + .setFunctionName(funcName) + .setSignature("") + .setLanguage(LanguageType.Golang) + .setApiType(ApiType.Function) + .setName(name) + .setNs(nameSpace) + .build(); + Runtime runtime = YR.getRuntime(); + String instanceId = runtime.createInstance(this.functionMeta, SdkUtils.packInvokeArgs(args), options); + return new GoInstanceHandler(instanceId, this.goInstanceClass.className); + } + + /** + * The member method of the GoInstanceCreator class is used to dynamically modify the parameters for creating a Go + * function instance. + * + * @param opt Function call options, used to specify functions such as calling resources. + * @return GoInstanceCreator Class handle. + * + * @snippet{trimleft} GoInstanceExample.java GoInstanceCreator options 样例代码 + */ + public GoInstanceCreator options(InvokeOptions opt) { + this.options = new InvokeOptions(opt); + return this; + } + + /** + * The member method of the GoInstanceCreator class is used to obtain function metadata. + * + * @return FunctionMeta class instance: function metadata. + */ + public FunctionMeta getFunctionMeta() { + return this.functionMeta; + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceFunctionHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceFunctionHandler.java new file mode 100644 index 0000000..7af9df5 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceFunctionHandler.java @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.call; + +import com.yuanrong.InvokeOptions; +import com.yuanrong.api.InvokeArg; +import com.yuanrong.api.YR; +import com.yuanrong.exception.YRException; +import com.yuanrong.function.GoInstanceMethod; +import com.yuanrong.libruntime.generated.Libruntime.ApiType; +import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; +import com.yuanrong.libruntime.generated.Libruntime.LanguageType; +import com.yuanrong.runtime.client.ObjectRef; +import com.yuanrong.utils.SdkUtils; + +import lombok.Getter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * The operation class that calls the go stateful function instance member function. + * + * @note The GoInstanceFunctionHandler class is the handle of the member function of the Go class instance after the + * Java function creates the Go class instance; it is the return value type of the interface + * `GoInstanceHandler.function`.\n + * Users can use the invoke method of GoInstanceFunctionHandler to call member functions of Go class instances. + * + * @since 2024/03/12 + */ +@Getter +public class GoInstanceFunctionHandler { + private GoInstanceMethod goInstanceMethod; + + private InvokeOptions options = new InvokeOptions(); + + private String instanceId; + + private String className; + + /** + * The constructor of GoInstanceFunctionHandler. + * + * @param instanceId Go function instance id. + * @param className Go function class name. + * @param goInstanceMethod GoInstanceMethod class instance. + */ + GoInstanceFunctionHandler(String instanceId, String className, GoInstanceMethod goInstanceMethod) { + this.className = className; + this.goInstanceMethod = goInstanceMethod; + this.instanceId = instanceId; + } + + /** + * The member method of the GoInstanceFunctionHandler class is used to call the member function of a Go class + * instance. + * + * @param args The input parameters required to call the specified method. + * @return ObjectRef: The "id" of the method's return value in the data system. Use YR.get() to get the actual + * return value of the method. + * @throws YRException Unified exception types thrown. + */ + public List invoke(Object... args) throws YRException { + FunctionMeta functionMeta = FunctionMeta.newBuilder() + .setClassName("") + .setFunctionName(this.goInstanceMethod.methodName) + .setSignature("") + .setLanguage(LanguageType.Golang) + .setApiType(ApiType.Function) + .build(); + List packedArgs = SdkUtils.packInvokeArgs(args); + String objId = YR.getRuntime().invokeInstance(functionMeta, this.instanceId, packedArgs, options); + return new ArrayList<>(Arrays.asList(new ObjectRef(objId, goInstanceMethod.returnType))); + } + + /** + * The member method of the GoInstanceFunctionHandler class is used to dynamically modify the parameters of the + * called GO function. + * + * @param opt Function call options, used to specify functions such as calling resources. + * @return GoInstanceFunctionHandler Class handle. + * + * @snippet{trimleft} GoInstanceExample.java GoInstanceFunctionHandler options 样例代码 + */ + public GoInstanceFunctionHandler options(InvokeOptions opt) { + this.options = new InvokeOptions(opt); + return this; + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceHandler.java new file mode 100644 index 0000000..7c54ea1 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/GoInstanceHandler.java @@ -0,0 +1,207 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.call; + +import com.yuanrong.api.YR; +import com.yuanrong.exception.YRException; +import com.yuanrong.function.GoInstanceMethod; +import com.yuanrong.runtime.Runtime; +import com.yuanrong.runtime.util.Constants; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +import lombok.Getter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; + +/** + * A helper class to serializer GoInstanceHandler. + * + * @since 2024.03.12 + */ +class GoInstanceHandlerHelper { + private static final Logger LOGGER = LoggerFactory.getLogger(GoInstanceHandlerHelper.class); + + /** + * The serializer of class GoInstanceHandler. + * + * @since 2024.03.12 + */ + public static class GoInstanceHandlerSerializer extends JsonSerializer { + @Override + public void serialize(GoInstanceHandler goInstanceHandler, JsonGenerator jsonGenerator, + SerializerProvider serializerProvider) { + try { + HashMap handlerMap = new HashMap<>(); + handlerMap.put(Constants.INSTANCE_KEY, ""); + handlerMap.put(Constants.INSTANCE_ID, goInstanceHandler.getInstanceId()); + handlerMap.put(Constants.FUNCTION_KEY, ""); + handlerMap.put(Constants.CLASS_NAME, goInstanceHandler.getClassName()); + handlerMap.put(Constants.MODULE_NAME, ""); + handlerMap.put(Constants.LANGUAGE_KEY, ""); + jsonGenerator.writeObject(handlerMap); + } catch (IOException e) { + LOGGER.error("Error while serialize object", e); + } + } + } + + /** + * The deserializer of class GoInstanceHandler . + * + * @since 2024.03.12 + */ + public static class GoInstanceHandlerDeserializer extends JsonDeserializer { + @Override + public GoInstanceHandler deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) { + try { + JsonToken currentToken = jsonParser.nextToken(); + String instanceID = ""; + String className = ""; + for (; currentToken == JsonToken.FIELD_NAME; currentToken = jsonParser.nextToken()) { + String fieldName = jsonParser.getCurrentName(); + currentToken = jsonParser.nextToken(); + if (Constants.INSTANCE_ID.equals(fieldName)) { + instanceID = jsonParser.getValueAsString(); + LOGGER.debug("set instanceID = {}", instanceID); + } else if (Constants.CLASS_NAME.equals(fieldName)) { + className = jsonParser.getValueAsString(); + LOGGER.debug("set className = {}", className); + } else { + LOGGER.debug("get fieldName = {}", fieldName); + } + } + return new GoInstanceHandler(instanceID, className); + } catch (IOException e) { + LOGGER.error("Error while deserializing object", e); + } + return new GoInstanceHandler(); + } + } +} + +/** + * Create an operation class for creating a Go stateful function instance. + * + * @note The GoInstanceHandler class is the handle returned after a Java function creates a Go class instance; + * it is the return type of the Go class instance created by the `GoInstanceCreator.invoke` interface.\n + * Users can use the function method of GoInstanceHandler to create a Go class instance member method handle and + * return the handle class GoInstanceFunctionHandler. + * + * @since 2024/03/12 + */ +@JsonSerialize(using = GoInstanceHandlerHelper.GoInstanceHandlerSerializer.class) +@JsonDeserialize(using = GoInstanceHandlerHelper.GoInstanceHandlerDeserializer.class) +@Getter +public class GoInstanceHandler { + private String instanceId; + + private String className = ""; + + /** + * The constructor of GoInstanceHandler. + * + * @param instanceId Go function instance id. + * @param className Go function class name. + */ + public GoInstanceHandler(String instanceId, String className) { + this.instanceId = instanceId; + this.className = className; + } + + /** + * Default constructor of GoInstanceHandler. + */ + public GoInstanceHandler() {} + + /** + * The member method of the GoInstanceHandler class is used to return the member function handle of the cloud Go + * class instance. + * + * @param the type of the object. + * @param goInstanceMethod GoInstanceMethod class instance. + * @return GoInstanceFunctionHandler Instance. + * + * @snippet{trimleft} GoInstanceExample.java GoInstanceHandler function example + */ + public GoInstanceFunctionHandler function(GoInstanceMethod goInstanceMethod) { + return new GoInstanceFunctionHandler<>(this.instanceId, className, goInstanceMethod); + } + + /** + * The member method of the GoInstanceHandler class is used to recycle cloud Go function instances. + * + * @note The default timeout for the current kill request is 30 seconds. In scenarios such as high disk load and + * etcd failure, the kill request processing time may exceed 30 seconds, causing the interface to throw a + * timeout exception. Since the kill request has a retry mechanism, users can choose not to handle or retry + * after capturing the timeout exception. + * + * @throws YRException Unified exception types thrown. + * + * @snippet{trimleft} GoInstanceExample.java GoInstanceHandler terminate example + */ + public void terminate() throws YRException { + YR.getRuntime().terminateInstance(instanceId); + } + + /** + * The member method of the GoInstanceHandler class is used to recycle cloud Go function instances. It supports + * synchronous or asynchronous termination. + * + * @note When synchronous termination is not enabled, the default timeout for the current kill request is + * 30 seconds. In scenarios such as high disk load or etcd failure, the kill request processing time may + * exceed 30 seconds, causing the interface to throw a timeout exception. Since the kill request has a retry + * mechanism, users can choose not to handle or retry after capturing the timeout exception. When synchronous + * termination is enabled, this interface will block until the instance is completely exited. + * + * @param isSync Whether to enable synchronization. If true, it indicates sending a kill request with the signal + * quantity of killInstanceSync to the function-proxy, and the kernel synchronously kills the + * instance; if false, it indicates sending a kill request with the signal quantity of killInstance to + * the function-proxy, and the kernel asynchronously kills the instance. + * @throws YRException Unified exception types thrown. + * + * @snippet{trimleft} GoInstanceExample.java GoInstanceHandler terminate sync example + */ + public void terminate(boolean isSync) throws YRException { + Runtime runtime = YR.getRuntime(); + if (isSync) { + runtime.terminateInstanceSync(instanceId); + return; + } + runtime.terminateInstance(instanceId); + } + + /** + * The member method of the GoInstanceHandler class is used to clear the information in the handle. + */ + public void clearHandlerInfo() { + instanceId = ""; + className = ""; + } +} \ No newline at end of file diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceCreator.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceCreator.java index 959a79d..bf2f249 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceCreator.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceCreator.java @@ -70,8 +70,8 @@ public class InstanceCreator extends Handler { * The constructor of InstanceCreator. * * @param func YRFuncR class instance, Java function name. - * @param apiType The enumeration class has two values: Function and Posix. - * It is used internally by Yuanrong to distinguish function types. The default is Function. + * @param apiType The enumeration class has three values: Function, Faas, and Posix. + * It is used internally by Yuanrong to distinguish function types. The default is Actor. */ public InstanceCreator(YRFuncR func, ApiType apiType) { this(func, "", "", apiType); @@ -98,8 +98,8 @@ public class InstanceCreator extends Handler { * @param nameSpace Namespace of the named instance. When both name and nameSpace exist, the instance name will be * concatenated into nameSpace-name. This field is currently only used for concatenation, and * namespace isolation and other related functions will be completed later. - * @param apiType The enumeration class has two values: Function and Posix. - * It is used internally by Yuanrong to distinguish function types. The default is Function. + * @param apiType The enumeration class has three values: Function, Faas, and Posix. + * It is used internally by Yuanrong to distinguish function types. The default is Actor. */ public InstanceCreator(YRFuncR func, String name, String nameSpace, ApiType apiType) { this.func = func; @@ -125,7 +125,6 @@ public class InstanceCreator extends Handler { String instanceId = runtime.createInstance(functionMeta, SdkUtils.packInvokeArgs(args), options); InstanceHandler handler = new InstanceHandler(instanceId, apiType); handler.setNeedOrder(options.isNeedOrder()); - runtime.collectInstanceHandlerInfo(handler); return handler; } diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceFunctionHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceFunctionHandler.java index b990776..9d0f0e4 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceFunctionHandler.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceFunctionHandler.java @@ -62,8 +62,8 @@ public class InstanceFunctionHandler extends Handler { * * @param func YRFuncR Class instance. * @param instanceId Java function instance ID. - * @param apiType The enumeration class has two values: Function and Posix. - * It is used internally by Yuanrong to distinguish function types. The default is Function. + * @param apiType The enumeration class has three values: Function, Faas, and Posix. + * It is used internally by Yuanrong to distinguish function types. The default is Actor. */ public InstanceFunctionHandler(YRFuncR func, String instanceId, ApiType apiType) { this.func = func; diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceHandler.java index 52beb54..ed24b8f 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceHandler.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/InstanceHandler.java @@ -52,6 +52,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -152,8 +153,8 @@ public class InstanceHandler { * The constructor of the InstanceHandler. * * @param instanceId Java function instance ID. - * @param apiType The enumeration class has two values: Function and Posix. - * It is used internally by Yuanrong to distinguish function types. The default is Function. + * @param apiType The enumeration class has three values: Function, Faas, and Posix. + * It is used internally by Yuanrong to distinguish function types. The default is Actor. */ public InstanceHandler(String instanceId, ApiType apiType) { this.instanceId = instanceId; @@ -338,7 +339,7 @@ public class InstanceHandler { /** * set need order. * - * @param needOrder indicates wheather need order. + * @param needOrder indicates whether need order. * */ void setNeedOrder(boolean needOrder) { @@ -397,6 +398,15 @@ public class InstanceHandler { this.apiType = null; } + /** + * Release InstanceHandler, decrease reference. + * + * @throws YRException Unified exception types thrown. + */ + public void release() throws YRException { + YR.getRuntime().decreaseReference(Collections.singletonList(instanceId)); + } + /** * The member method of the InstanceHandler class allows users to obtain handle information, which can be serialized * and stored in a database or other persistence tools. When the tenant context is enabled, handle information can diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaFunctionHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaFunctionHandler.java index e7832eb..d5bfeba 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaFunctionHandler.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaFunctionHandler.java @@ -84,7 +84,9 @@ public class JavaFunctionHandler { /** * When Java calls a stateless function in Java, set the functionUrn for the function. * - * @param urn functionUrn, can be obtained after the function is deployed. + * @param urn functionUrn, can be obtained after the function is deployed. The tenant ID in the function urn must be + * consistent with the tenant ID configured in the config. For information about tenant ID configuration, + * see "About tenant ID" in Config. * @return JavaFunctionHandler, with built-in invoke method, can create and invoke the java function instance. * * @snippet{trimleft} SetUrnExample.java set urn of java invoke java stateless function diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaInstanceCreator.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaInstanceCreator.java index 3b7e366..cbfbbbb 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaInstanceCreator.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaInstanceCreator.java @@ -106,14 +106,15 @@ public class JavaInstanceCreator { JavaInstanceHandler javaInstanceHandler = new JavaInstanceHandler(instanceId, functionID, this.javaInstanceClass.className); javaInstanceHandler.setNeedOrder(options.isNeedOrder()); - runtime.collectInstanceHandlerInfo(javaInstanceHandler); return javaInstanceHandler; } /** * When Java calls a Java stateful function, set the functionUrn for the function. * - * @param urn functionUrn, can be obtained after the function is deployed. + * @param urn functionUrn, can be obtained after the function is deployed. The tenant ID in the function urn must be + * consistent with the tenant ID configured in the config. For information about tenant ID configuration, + * see "About tenant ID" in Config. * @return JavaInstanceCreator, with built-in invoke method, can create instances of this Java function class. * * @snippet{trimleft} SetUrnExample.java set urn of java invoke java stateful function diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaInstanceHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaInstanceHandler.java index 0973b21..568c809 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaInstanceHandler.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/JavaInstanceHandler.java @@ -40,6 +40,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -296,4 +297,13 @@ public class JavaInstanceHandler { void setNeedOrder(boolean needOrder) { this.needOrder = needOrder; } + + /** + * Release JavaInstanceHandler, decrease reference. + * + * @throws YRException Unified exception types thrown. + */ + public void release() throws YRException { + YR.getRuntime().decreaseReference(Collections.singletonList(instanceId)); + } } diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/VoidInstanceFunctionHandler.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/VoidInstanceFunctionHandler.java index c7d0432..b7194fb 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/VoidInstanceFunctionHandler.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/call/VoidInstanceFunctionHandler.java @@ -60,8 +60,8 @@ public class VoidInstanceFunctionHandler extends Handler { * * @param func Java function name, supports 0 ~ 5 parameters, no return value user function. * @param instanceId Java function instance ID. - * @param apiType The enumeration class has two values: Function and Posix. - * It is used internally by Yuanrong to distinguish function types. The default is Function. + * @param apiType The enumeration class has three values: Function, Faas, and Posix. + * It is used internally by Yuanrong to distinguish function types. The default is Actor. */ public VoidInstanceFunctionHandler(YRFuncVoid func, String instanceId, ApiType apiType) { this.func = func; diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoFunction.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoFunction.java new file mode 100644 index 0000000..26bd28f --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoFunction.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.function; + +/** + * Helper class to save function name and return type. + * + * @since 2024.03.12 + */ +public class GoFunction { + /** + * The name of this function + */ + public final String functionName; + + /** + * Type of the return value of this function + */ + public final Class returnType; + + /** + * Number of the return value of this function + */ + public final int returnNum; + + private GoFunction(String functionName, Class returnType, int returnNum) { + this.functionName = functionName; + this.returnType = returnType; + this.returnNum = returnNum; + } + + /** + * Create a go function. + * + * @param functionName The name of this function + * @param returnNum Number of the return values of this function + * @return a go function. + */ + public static GoFunction of(String functionName, int returnNum) { + return of(functionName, Object.class, returnNum); + } + + /** + * Create a go function. + * + * @param functionName The name of this function + * @param returnType Class of the return value of this function + * @param returnNum Number of the return values of this function + * @return a go function. + */ + public static GoFunction of(String functionName, Class returnType, int returnNum) { + return new GoFunction<>(functionName, returnType, returnNum); + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoInstanceClass.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoInstanceClass.java new file mode 100644 index 0000000..3715fd2 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoInstanceClass.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.function; + +/** + * Helper class to new go instance class. + * + * @since 2024.03.12 + */ +public class GoInstanceClass { + /** + * the name of class. + */ + public final String className; + + /** + * Constructor of GoInstanceClass. + * + * @param className the name of class. + */ + GoInstanceClass(String className) { + this.className = className; + } + + /** + * Create a go instance class. + * + * @param className The name of the instance class + * @return a go instance class + */ + public static GoInstanceClass of(String className) { + return new GoInstanceClass(className); + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoInstanceMethod.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoInstanceMethod.java new file mode 100644 index 0000000..d3484ed --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/function/GoInstanceMethod.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.function; + +/** + * Helper class to invoke member function. + * + * @since 2024.03.12 + */ +public class GoInstanceMethod { + /** + * the name of this instance method. + */ + public final String methodName; + + /** + * Type of the return value of the instance method. + */ + public final Class returnType; + + /** + * Number of the return value of the instance method + */ + private int returnNum = 1; + + /** + * Constructor of GoInstanceMethod. + * + * @param methodName the name of the instance method. + * @param returnType the return type of method function. + * @param returnNum Number of the return values of this function. + */ + GoInstanceMethod(String methodName, Class returnType, int returnNum) { + this.methodName = methodName; + this.returnType = returnType; + this.returnNum = returnNum; + } + + /** + * Create a go instance method. + * + * @param methodName The name of the instance method. + * @param returnNum Number of the return values of this function. + * @return a go instance method. + */ + public static GoInstanceMethod of(String methodName, int returnNum) { + return new GoInstanceMethod<>(methodName, Object.class, returnNum); + } + + /** + * Create a go instance method. + * + * @param methodName The name of this instance method. + * @param returnType Class of the return value of the instance method. + * @param returnNum Number of the return values of this function. + * @return a go instance method. + */ + public static GoInstanceMethod of(String methodName, Class returnType, int returnNum) { + return new GoInstanceMethod<>(methodName, returnType, returnNum); + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/JobExecutor.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/JobExecutor.java new file mode 100644 index 0000000..2b8e200 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/JobExecutor.java @@ -0,0 +1,362 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import com.yuanrong.exception.YRException; +import com.yuanrong.runtime.util.Constants; +import com.yuanrong.runtime.util.Utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.lang.reflect.Field; +import java.nio.charset.StandardCharsets; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * The JobExecutor actor runs in remote runtime to manage driver code. + * + * @since 2023 /06/06 + */ +public class JobExecutor { + private static final Logger LOGGER = LoggerFactory.getLogger(JobExecutor.class); + + /** + * The list of programs that are allowed to run on the processBuilder. + */ + private static final List ALLOWED_PROGRAM = Collections.singletonList("python3.9"); + + /** + * The key to yuanrong functionProxy port in remote JobExecutor runtime + * environment parameter map. The value of the port will be pass to the + * attached-runtime process as it' environment parameter. + */ + private static final String DRIVER_SERVER_PORT = "DRIVER_SERVER_PORT"; + + /** + * The key to downloaded driver code path in JobExcutor actor runtime envronment. + */ + private static final String ENV_DELEGATE_DOWNLOAD = "ENV_DELEGATE_DOWNLOAD"; + + /** + * The key to JobExcutor actor instance ID, which is also called 'userJobID' + * in the context of JobExecutor. + */ + private static final String INSTANCE_ID = "INSTANCE_ID"; + private static final String LOG_PREFIX = "(JobExecutor) "; + private static final String DATASYSTEM_ADDR = "DATASYSTEM_ADDR"; + private static final String HOST_IP = "HOST_IP"; + private static final String YR_SPARK_HOME = "YR_SPARK_HOME"; + private static final String YR_SPARK_FUNC_URN = "YR_SPARK_FUNC_URN"; + private static final String POSIX_LISTEN_ADDR = "POSIX_LISTEN_ADDR"; + private static final String JAVA_HOME = "JAVA_HOME"; + private static final String YR_JOB_AFFINITY = "YR_JOB_AFFINITY"; + private static final String DEFAULT_DS_ADDRESS_PORT = "31501"; + private static final String DEFAULT_DRIVER_SERVER_PORT = "22771"; + private static final String DEFAULT_JAVA_HOME = "/opt/function/runtime/java8/rtsp/jre"; + private static final String DEFAULT_SPARK_HOME = "/dcache/layer/func/bucket-jobexecutor-test/spark_test.zip"; + private static final String DEFAULT_YR_SPARK_FUNC_URN = + "sn:cn:yrk:12345678901234561234567890123456:function:0-sparkonyr-core:$latest"; + private static final int CODE_PATH_INDEX = 1; + + /** + * The thread pool for asynchronous log printing and waiting + * for the end of the attached-runtime. + */ + private ThreadPoolExecutor attachedRuntimeThreadPool = new ThreadPoolExecutor(3, 3, 60L, + TimeUnit.SECONDS, new SynchronousQueue<>()); + private int attachedRuntimePid = -1; + private Process attachedRuntime; + private String jobName = ""; + private String userJobID; + private LocalDateTime jobStartTime; + private volatile LocalDateTime jobEndTime; + private volatile String errorMessage; + private RuntimeEnv runtimeEnv; + private volatile YRJobStatus attachedRuntimeStatus = YRJobStatus.RUNNING; + private int exitCode; + private BufferedReader inputReader; + private BufferedReader errReader; + + /** + * The Constructor of JobExecutor. + * + * @param jobName the user-defined job name. + * @param runtimeEnv the python environment to be installed in remote + * runtime. + * @param localEntryPoint the entry point to be executed in remote runtime. it + * will contains the localCodePath if provided by user. + * @param affinityMsgJsonStr the affinityMsgJsonStr + * @throws YRException the actor task exception. + */ + public JobExecutor(String jobName, RuntimeEnv runtimeEnv, ArrayList localEntryPoint, + String affinityMsgJsonStr) throws YRException { + this.runtimeEnv = runtimeEnv; + this.jobName = jobName; + this.userJobID = Utils.getEnvWithDefualtValue(INSTANCE_ID, "", LOG_PREFIX); + LOGGER.debug("(JobExecutor) Job submitted, jobExecutor actor instanceID: {}", userJobID); + + String downloadCodePath = System.getenv(ENV_DELEGATE_DOWNLOAD); + if (downloadCodePath != null && !downloadCodePath.isEmpty()) { + localEntryPoint.set(CODE_PATH_INDEX, + String.join(Constants.BACKSLASH, downloadCodePath, localEntryPoint.get(CODE_PATH_INDEX))); + } + + LOGGER.info("(JobExecutor) Starts attached-runtime process, entrypoint: {}", localEntryPoint); + startAttachedRuntime(generateAttachedRuntimeEnv(affinityMsgJsonStr), localEntryPoint); + asyncReadProcessStream(); + asyncWaitForAttachedRuntime(); + } + + /** + * Gets the jobInfo of the the attached-runtime. + * + * @param isWithStatic whether invoke static infomation from remote JobExecutor + * runtime. + * @return YRJobInfo + */ + public YRJobInfo getJobInfo(boolean isWithStatic) { + YRJobInfo jobInfo = new YRJobInfo(); + if (isWithStatic) { + jobInfo.setUserJobID(this.userJobID); + jobInfo.setJobName(this.jobName); + jobInfo.setRuntimeEnv(this.runtimeEnv); + if (this.jobStartTime != null) { + jobInfo.setJobStartTime(this.jobStartTime.toString()); + } + } + jobInfo.setStatus(this.attachedRuntimeStatus); + jobInfo.setErrorMessage(this.errorMessage); + if (this.jobEndTime != null) { + jobInfo.setJobEndTime(this.jobEndTime.toString()); + } + return jobInfo; + } + + /** + * Terminates the attached-runtime process. + * This method does NOT immediately free the resources occupied by the + * attached-runtime process. + */ + public void stop() { + // If 'stop()' is invoked more than once, + // 'jobEndTime' should not be changed. + // Therefore, an 'if' judgment is required here. + if (this.jobEndTime == null) { + this.jobEndTime = LocalDateTime.now(); + } + writeStatus(YRJobStatus.STOPPED); + if (this.attachedRuntime != null) { + this.attachedRuntime.destroy(); + } + this.clearThreadPool(); + LOGGER.info("(JobExecutor) Stop the attached-runtime process (pid: {}), job name: {}." + + " Current job status: {}, stop time: {}.", this.attachedRuntimePid, this.jobName, + this.attachedRuntimeStatus, this.jobEndTime); + } + + private Map generateAttachedRuntimeEnv(String affinityMsgJsonStr) throws YRException { + String defaultDsAddr = System.getenv(HOST_IP) + ":" + DEFAULT_DS_ADDRESS_PORT; + + Map arEnv = new HashMap(); + arEnv.put(YR_SPARK_HOME, Utils.getEnvWithDefualtValue(YR_SPARK_HOME, DEFAULT_SPARK_HOME, LOG_PREFIX)); + arEnv.put(YR_SPARK_FUNC_URN, + Utils.getEnvWithDefualtValue(YR_SPARK_FUNC_URN, DEFAULT_YR_SPARK_FUNC_URN, LOG_PREFIX)); + arEnv.put(DATASYSTEM_ADDR, Utils.getEnvWithDefualtValue(DATASYSTEM_ADDR, defaultDsAddr, LOG_PREFIX)); + arEnv.put(POSIX_LISTEN_ADDR, Utils.getEnvWithDefualtValue(POSIX_LISTEN_ADDR, "", LOG_PREFIX)); + arEnv.put(DRIVER_SERVER_PORT, + Utils.getEnvWithDefualtValue(DRIVER_SERVER_PORT, DEFAULT_DRIVER_SERVER_PORT, LOG_PREFIX)); + arEnv.put(JAVA_HOME, Utils.getEnvWithDefualtValue(JAVA_HOME, DEFAULT_JAVA_HOME, LOG_PREFIX)); + arEnv.put(Constants.AUTHORIZATION, Utils.getEnvWithDefualtValue(Constants.AUTHORIZATION, "", LOG_PREFIX)); + + String runtimeEnvCommand = this.runtimeEnv.toCommand(); + arEnv.put(Constants.POST_START_EXEC, runtimeEnvCommand); + LOGGER.debug("(JobExecutor) Sets environment key-value pair({}: {}) for job (jobName: {}).", + Constants.POST_START_EXEC, runtimeEnvCommand, this.jobName); + + arEnv.put(YR_JOB_AFFINITY, affinityMsgJsonStr); + LOGGER.debug("(JobExecutor) Sets environment key-value pair({}: {}) for job (jobName: {}).", + YR_JOB_AFFINITY, affinityMsgJsonStr, this.jobName); + + return arEnv; + } + + private void startAttachedRuntime(Map attachedRuntimeEnv, List entryPoint) { + String program = entryPoint.get(0); + if (!ALLOWED_PROGRAM.contains(program)) { + LOGGER.error("(JobExecutor) Program {} is not allowed to be executed for the security reason. " + + "Allowed programs are: {}", program, ALLOWED_PROGRAM); + writeStatus(YRJobStatus.FAILED); + return; + } + ProcessBuilder processBuilder = new ProcessBuilder(entryPoint); + Map env = processBuilder.environment(); + env.putAll(attachedRuntimeEnv); + + try { + this.attachedRuntime = processBuilder.start(); + this.jobStartTime = LocalDateTime.now(); + } catch (IOException e) { + LOGGER.error("(JobExecutor) Failed to start the attached-runtime process for the job: {}. Exception: ", + this.jobName, e); + this.errorMessage = "Failed to start the attached-runtime process, userJobID: " + this.userJobID; + writeStatus(YRJobStatus.FAILED); + return; + } + + this.attachedRuntimePid = getProcessPid(this.attachedRuntime); + LOGGER.info( + "(JobExecutor) Succeeded to start a process as an attached-runtime(pid: {}) to run driver code: {}", + this.attachedRuntimePid, this.jobName); + } + + private void asyncWaitForAttachedRuntime() { + if (this.attachedRuntime == null) { + LOGGER.warn("(JobExecutor) Attached-runtime has not been started."); + return; + } + this.attachedRuntimeThreadPool.execute(() -> { + try { + this.exitCode = this.attachedRuntime.waitFor(); + if (this.jobEndTime == null) { + this.jobEndTime = LocalDateTime.now(); + } + } catch (InterruptedException e) { + if (this.jobEndTime != null) { + // If 'this.jobEndTime' has not null value, + // attached-runtime is stopped by calling 'stop()', + // which is not regarded as an unexpected error. + return; + } + LOGGER.warn("(JobExecutor) Attached-runtime process (pid: {}) exits abnormally, Exception: ", + this.attachedRuntimePid, e); + writeStatus(YRJobStatus.FAILED); + this.clearThreadPool(); + return; + } + LOGGER.info("(JobExecutor) Attached-runtime process(pid: {}) exit. Exit code: {}", + this.attachedRuntimePid, this.exitCode); + if (this.exitCode != 0) { + writeStatus(YRJobStatus.FAILED); + this.clearThreadPool(); + return; + } + writeStatus(YRJobStatus.SUCCEEDED); + /* + * There is a one-to-one correspondence between JobExecutor runtime and + * attached-runtime. Therefore, after attached-runtime is finished, the thread pool belonging + * to JobExecutor runtime can be closed. + */ + this.clearThreadPool(); + }); + } + + private int getProcessPid(Process process) { + int pid = -1; + Field field; + try { + field = process.getClass().getDeclaredField("pid"); + field.setAccessible(true); + pid = (int) field.get(process); + } catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) { + LOGGER.warn("(JobExecutor) Failed to get pid of driver code (jobName: {}) process. Exception: ", + this.jobName, e); + } + return pid; + } + + private void writeStatus(YRJobStatus status) { + switch (status) { + case RUNNING: + this.attachedRuntimeStatus = status; + break; + default: + if (this.attachedRuntimeStatus == null || this.attachedRuntimeStatus == YRJobStatus.RUNNING) { + this.attachedRuntimeStatus = status; + } + } + } + + private void asyncReadProcessStream() { + if (this.attachedRuntime == null) { + LOGGER.warn("(JobExecutor) Attached-runtime has not been started."); + return; + } + + this.errReader = new BufferedReader( + new InputStreamReader(this.attachedRuntime.getErrorStream(), StandardCharsets.UTF_8)); + this.attachedRuntimeThreadPool.execute(() -> { + String line; + this.errorMessage = ""; + try { + while ((line = errReader.readLine()) != null) { + this.errorMessage = this.errorMessage + System.lineSeparator() + line; + } + } catch (IOException e) { + LOGGER.error("(JobExecutor) Failed to read attached-runtime error message, Exception: ", e); + this.errorMessage = "Failed to read attached-runtime error message. Exception: " + e.getMessage(); + } + if (!this.errorMessage.isEmpty()) { + LOGGER.error("(Attached-runtime) {}", this.errorMessage); + } + }); + + this.inputReader = new BufferedReader( + new InputStreamReader(this.attachedRuntime.getInputStream(), StandardCharsets.UTF_8)); + this.attachedRuntimeThreadPool.execute(() -> { + String line; + try { + while ((line = inputReader.readLine()) != null) { + LOGGER.info("(Attached-runtime) {}", line); + } + } catch (IOException e) { + LOGGER.error("(JobExecutor) Failed to read attached-runtime inputStream message. Exception: ", e); + } + }); + } + + private void clearThreadPool() { + try { + if (this.inputReader != null) { + this.inputReader.close(); + } + } catch (IOException e) { + LOGGER.error("(JobExecutor) Failed to close attached-runtime input stream, Exception: ", e); + } + + try { + if (errReader != null) { + errReader.close(); + } + } catch (IOException e) { + LOGGER.error("(JobExecutor) Failed to close attached-runtime error stream, Exception: ", e); + } + + this.attachedRuntimeThreadPool.shutdownNow(); + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/OBSoptions.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/OBSoptions.java new file mode 100644 index 0000000..a346078 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/OBSoptions.java @@ -0,0 +1,152 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.exception.YRException; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; + +import lombok.Getter; + +import java.util.HashMap; +import java.util.Map; + +/** + * The OBS(Object Storage Service) configuration for remote JobExecutor runtime + * to download driver code. + * + * @since 2023 /11/11 + */ +@Getter +public class OBSoptions { + private static final String DELEGATE_DOWNLOAD = "DELEGATE_DOWNLOAD"; + + private static final String BUCKET_ID = "bucketId"; + + private static final String OBJECT_ID = "objectId"; + + private static final String HOSTNAME = "hostName"; + + private static final String SECURITY_TOKEN = "securityToken"; + + private static final String TEMPORARY_ACCESS_KEY = "temporaryAccessKey"; + + private static final String TEMPORARY_SECRET_KEY = "temporarySecretKey"; + + private String endPoint = ""; + + private String bucketID = ""; + + private String objectID = ""; + + private String securityToken = ""; + + private String ak = ""; + + private String sk = ""; + + /** + * The access address of OBS. + * + * @param endPoint the access address String. + */ + public void setEndPoint(String endPoint) { + this.endPoint = endPoint; + } + + /** + * The container ID for storing objects in OBS. + * + * @param bucketID the container ID String. + */ + public void setBucketID(String bucketID) { + this.bucketID = bucketID; + } + + /** + * The ID of OBS object to be downloaded. + * + * @param objectID the object ID String. + */ + public void setObjectID(String objectID) { + this.objectID = objectID; + } + + /** + * The access token issued by the system to an IAM user. It carries information + * such as user identities and permissions. + * + * @param securityToken the security token String. + */ + public void setSecurityToken(String securityToken) { + this.securityToken = securityToken; + } + + /** + * The temporary access key(AK) to OBS. It is a unique identifier associated + * with the secret access key(SK). AK and SK are used together to encrypt and + * sign requests. + * + * @param ak the temporary access key String. + */ + public void setAk(String ak) { + this.ak = ak; + } + + /** + * The temporary secret access key(SK) used together with the temporary access + * key(AK) to encrypt and sign requests to OBS. + * + * @param sk the temporary secret access key String. + */ + public void setSk(String sk) { + this.sk = sk; + } + + /** + * Converts OBS options to a map contains an single item { "DELEGATE_DOWNLOAD": + * jsonString }, which can be put into createOptions. + * + * @return Map map contains String key and jsonString value. + * @throws YRException Failed to convert mapper to json string. + */ + public Map toMap() throws YRException { + ObjectMapper mapper = new ObjectMapper(); + ObjectNode rootNode = mapper.createObjectNode(); + rootNode.put(HOSTNAME, endPoint); + rootNode.put(BUCKET_ID, bucketID); + rootNode.put(OBJECT_ID, objectID); + rootNode.put(SECURITY_TOKEN, securityToken); + rootNode.put(TEMPORARY_ACCESS_KEY, ak); + rootNode.put(TEMPORARY_SECRET_KEY, sk); + + String jsonStr; + try { + jsonStr = mapper.writeValueAsString(rootNode); + } catch (JsonProcessingException e) { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME, + "Failed to convert OBS options to a json string."); + } + Map map = new HashMap<>(); + map.put(DELEGATE_DOWNLOAD, jsonStr); + return map; + } +} \ No newline at end of file diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/RuntimeEnv.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/RuntimeEnv.java new file mode 100644 index 0000000..2622702 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/RuntimeEnv.java @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.exception.YRException; +import com.yuanrong.runtime.util.Constants; + +import lombok.Getter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; + +/** + * The environment to be installed in remote runtime, used by the + * attached-runtime and runtimes started by it. + * + * @since 2023 /06/06 + */ +@Getter +public class RuntimeEnv { + private static final Logger LOGGER = LoggerFactory.getLogger(RuntimeEnv.class); + + /** + * The string indicates "pip3.9", typicallly used to construct a pip installation + * command. + */ + private static final String PIP = "pip3.9"; + + /** + * The string indicates "pip3.9 check", typicallly used to construct a pip check + * command after installation. + */ + private static final String PIP_CHECK = "&& " + RuntimeEnv.PIP + " check"; + + private String packageManager = PIP; + + /** + * The packages to be installed in runtime. It should be initialized as an empty + * String array rather than null when calling YR.invoke. + */ + private String[] packages = new String[0]; + + private boolean shouldPipCheck = false; + + private String trustedSource = ""; + + /** + * The default constructor of type RuntimeEnv. + */ + public RuntimeEnv() {} + + /** + * The copy constructor of type RuntimeEnv. + * + * @param runtimeEnv the RuntimeEnv object. + */ + public RuntimeEnv(RuntimeEnv runtimeEnv) { + this.packageManager = runtimeEnv.packageManager; + this.packages = Arrays.copyOf(runtimeEnv.packages, runtimeEnv.packages.length); + this.shouldPipCheck = runtimeEnv.shouldPipCheck; + } + + /** + * The package manager for environment installation, such as "pip3.9". + * + * @param packageManager the package manager String. + */ + public void setPackageManager(String packageManager) { + this.packageManager = packageManager; + } + + /** + * Sets packages to be installed. Example of the value: + *

{"numpy", "pandas"}

+ * + * @param packages the packages Strings. + */ + public void setPackages(String... packages) { + this.packages = packages; + } + + /** + * Gets packages to be installed. Example of the value: + *

{"numpy", "pandas"}

+ * + * @return the packages to be installed in remote runtime. + */ + public String[] getPackages() { + return Arrays.copyOf(this.packages, this.packages.length); + } + + /** + * Sets pipCheck. Pip check would be enable after the pip installation is + * complete if true is set. The default value is false. + * + * @param shouldPipCheck the boolean indicates whether to perform pip check. + */ + public void setShouldPipCheck(boolean shouldPipCheck) { + this.shouldPipCheck = shouldPipCheck; + } + + /** + * The trusted source to be used. NOTE that an empty string is also allowed. + * + * @param trustedSource the trusted source of the format "--trusted-host HOST_NAME -i ADDRESS". + */ + public void setTrustedSource(String trustedSource) { + this.trustedSource = trustedSource; + } + + /** + * Converts environment infomation to a completed packages installation command. + * + * @return the completed packages installation command. + * @throws YRException the actor task exception. + */ + public String toCommand() throws YRException { + if (this.packageManager == null || this.packageManager.isEmpty()) { + LOGGER.warn("(JobExecutor) PackageManager is empty, no environment would be installed in remote runtime."); + return ""; + } + + if (this.packages.length == 0) { + LOGGER.warn("(JobExecutor) Packages array is empty, no environment would be installed in remote runtime."); + return ""; + } + + String command = ""; + if (PIP.equals(this.packageManager)) { + String packs = String.join(Constants.SPACE, this.packages); + command = String.join(Constants.SPACE, this.packageManager, "install", packs, this.trustedSource); + } else { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME, + "Only pip3.9 is supported currently."); + } + + if (this.shouldPipCheck) { + command = String.join(Constants.SPACE, command, RuntimeEnv.PIP_CHECK); + } + + return command; + } +}; \ No newline at end of file diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobInfo.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobInfo.java new file mode 100644 index 0000000..bf6db70 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobInfo.java @@ -0,0 +1,163 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import lombok.AccessLevel; +import lombok.Setter; + +/** + * The type YRJobInfo + * + * @since 2023 /06/06 + */ +@Setter(value = AccessLevel.PACKAGE) +public class YRJobInfo { + private String jobName; + + private String userJobID; + + private String jobStartTime; + + private String jobEndTime; + + private RuntimeEnv runtimeEnv; + + private YRJobStatus status; + + private String errorMessage; + + /** + * The copy constructor of type YRJobInfo. + * + * @param yrJobInfo the YRJobInfo object. + */ + public YRJobInfo(YRJobInfo yrJobInfo) { + this.userJobID = yrJobInfo.userJobID; + this.jobName = yrJobInfo.jobName; + this.jobStartTime = yrJobInfo.jobStartTime; + this.jobEndTime = yrJobInfo.jobEndTime; + if (yrJobInfo.runtimeEnv != null) { + this.runtimeEnv = new RuntimeEnv(yrJobInfo.runtimeEnv); + } + this.status = yrJobInfo.status; + this.errorMessage = yrJobInfo.errorMessage; + } + + /** + * The constructor of type YRJobInfo. + */ + public YRJobInfo() {} + + /** + * Gets the user-defined job name set in YRJobParam. + * + * @return the user-defined job name String. + */ + public String getJobName() { + return jobName; + } + + /** + * Gets the job instanceID return by YRJobExecutor.sumitJob(). + * + * @return unique userJobID String. + */ + public String getUserJobID() { + return userJobID; + } + + /** + * Gets the start time of remote attaced-runtime process of formatt + * "yyyy-MM-dd'T'HH:mm:ss.SSSSS" + * + * @return the start time String. + */ + public String getJobStartTime() { + return jobStartTime; + } + + /** + * Gets the end time of remote attaced-runtime process of formatt + * "yyyy-MM-dd'T'HH:mm:ss.SSSSS". + * The attaced-runtime process may be stopped by the user or finished the job + * successfully/unsuccessfully. + * + * @return the end time String. + */ + public String getJobEndTime() { + return jobEndTime; + } + + /** + * Gets the environment to be installed for the execution of user jobs. + * + * @return RuntimeEnv object. + */ + public RuntimeEnv getRuntimeEnv() { + return runtimeEnv; + } + + /** + * Gets the current job's status. + * + * @return YRJobStatus enum type. + */ + public YRJobStatus getStatus() { + return status; + } + + /** + * Get the error message raised by the attached-runtime. + * The error message defaults to an empty string "". + * + * @return the error message String. + */ + public String getErrorMessage() { + return errorMessage; + } + + /** + * Whether the job is in final states including STOPPED/SUCCEEDED/FAILED. + * + * @return a boolean. It is true when the job is in final state. + */ + public boolean ifFinalized() { + return this.status != YRJobStatus.RUNNING && this.status != null; + } + + /** + * Updates members' values from another YRJobInfo object's NOT null values. + * + * @param source another YRJobInfo object. + */ + public void update(YRJobInfo source) { + this.userJobID = updateNotNull(this.userJobID, source.userJobID); + this.jobName = updateNotNull(this.jobName, source.jobName); + this.jobStartTime = updateNotNull(this.jobStartTime, source.jobStartTime); + this.jobEndTime = updateNotNull(this.jobEndTime, source.jobEndTime); + this.runtimeEnv = updateNotNull(this.runtimeEnv, source.runtimeEnv); + this.status = updateNotNull(this.status, source.status); + this.errorMessage = updateNotNull(this.errorMessage, source.errorMessage); + } + + private T updateNotNull(T target, T source) { + if (source != null) { + return source; + } + return target; + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobParam.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobParam.java new file mode 100644 index 0000000..42a15b2 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobParam.java @@ -0,0 +1,471 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import com.yuanrong.InvokeOptions; +import com.yuanrong.affinity.Affinity; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.exception.YRException; +import com.yuanrong.runtime.util.Constants; + +import com.google.gson.Gson; + +import lombok.Getter; +import lombok.Setter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * The type YRJobParam used as the parameter for YR.submitJob + * + * @since 2023 /06/06 + */ +@Getter +@Setter +public class YRJobParam { + private static final Logger LOGGER = LoggerFactory.getLogger(YRJobParam.class); + + /** + * The index of code path in user defined entryPoint. + */ + private static final int CODE_PATH_INDEX = 1; + + /** + * The Minimum length of user defined entryPoint command after seperated by + * space. + */ + private static final int ENTRYPOINT_MIN_LENGTH = 2; + + /** + * The default CPU value for the JobExecutor runtime. + */ + private static final int DEFUALT_CPU_NUM = 500; + + /** + * The default memory value for the JobExecutor runtime. + */ + private static final int DEFUALT_MEM_NUM = 500; + + /** + * The maximum value of memory for the JobExecutor runtime. + */ + private static final int MAX_MEMORY = 65536; + + /** + * The minimum value of memory for the JobExecutor runtime. + */ + private static final int MIN_MEMORY = 128; + + /** + * The maximum value of CPU for the JobExecutor runtime. + */ + private static final int MAX_CPU = 16000; + + /** + * The minimum value of CPU for the JobExecutor runtime. + */ + private static final int MIN_CPU = 300; + + /** + * The user-defined job name. + */ + private String jobName; + private ArrayList entryPoint; + private RuntimeEnv runtimeEnv = new RuntimeEnv(); + private OBSoptions obsOptions; + private String localCodePath; + private int cpu = DEFUALT_CPU_NUM; + private int memory = DEFUALT_MEM_NUM; + private List scheduleAffinities = new ArrayList(); + private boolean preferredPriority = true; + private boolean requiredPriority = false; + + /** + * The Custom resources."nvidia.com/gpu" + */ + private Map customResources = new HashMap<>(); + + /** + * The Custom extensions. concurrency, a int in the range of [1, 1000] + */ + private Map customExtensions = new HashMap<>(); + + /** + * entryPoint is an ArrayList represents the command for starting a subprocess + * (attached-runtime) running the driver code. For example: + *

+ * {"python", "my_script.py"} + *

+ * + * @param entryPoint the command ArrayList. + * @throws YRException the actor task exception. + */ + public void setEntryPoint(ArrayList entryPoint) throws YRException { + if (entryPoint.size() < ENTRYPOINT_MIN_LENGTH) { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME, + "Length of the entryPoint(=" + entryPoint.size() + ") is not valid, it should be >= " + + ENTRYPOINT_MIN_LENGTH); + } + + this.entryPoint = entryPoint; + } + + /** + * Sets the CPU value for the JobExecutor runtime. It is in 1/1024 cpu core, 300 + * to 16000 supported + * + * @param cpu the CPU value. + * @throws YRException the actor task exception. + */ + public void setCpu(int cpu) throws YRException { + if (cpu < MIN_CPU || cpu > MAX_CPU) { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME, + "The CPU value(" + cpu + ") is not in [" + MIN_CPU + ", " + MAX_CPU + "]"); + } + this.cpu = cpu; + } + + /** + * Sets the memory value for the JobExecutor runtime. It is in 1MB, 128 to 65536 + * supported. + * + * @param memory the memory value. + * @throws YRException the actor task exception. + */ + public void setMemory(int memory) throws YRException { + if (memory < MIN_MEMORY || memory > MAX_MEMORY) { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME, + "The memory value(" + memory + ") is not in [" + MIN_MEMORY + ", " + MAX_MEMORY + "]"); + } + this.memory = memory; + } + + /** + * Sets the value of runtimeEnv, which is the environment to be installed for + * the execution of user jobs. + * + * @param packageManager the package manager. For example: "pip" + * @param packages packages to be installed. For example: {"numpy", + * "pandas"} + */ + public void setRuntimeEnv(String packageManager, String[] packages) { + this.runtimeEnv.setPackageManager(packageManager); + this.runtimeEnv.setPackages(packages); + } + + /** + * Sets the value of runtimeEnv, which is the environment to be installed for + * the execution of user jobs. + * + * @param runtimeEnv the RuntimeEnv object contains environment infomation. + */ + public void setRuntimeEnv(RuntimeEnv runtimeEnv) { + this.runtimeEnv = runtimeEnv; + } + + /** + * The options for downloading user driver code from OBS. + * The setting of obsOptions is optional, but either obsOptions or + * localCodePath must be set and only one can be set. + * + * @param obsOptions the OBSoptions object. + */ + public void setObsOptions(OBSoptions obsOptions) { + if (this.localCodePath != null && !this.localCodePath.isEmpty()) { + LOGGER.warn("(JobExecutor) localCodePath has been set, OBS setting will not work."); + return; + } + this.obsOptions = obsOptions; + } + + /** + * Adds the given Affinity object to the list of schedule affinities. + * + * @param affinity the Affinity object to be added. + */ + public void addScheduleAffinity(Affinity affinity) { + this.scheduleAffinities.add(affinity); + } + + /** + * Sets the preferred priority. + * + * @param isPreferred the boolean value indicating the preferred priority. + */ + public void preferredPriority(boolean isPreferred) { + this.preferredPriority = isPreferred; + } + + /** + * Sets the required priority. + * + * @param isRequired the boolean value indicating the required priority. + */ + public void requiredPriority(boolean isRequired) { + this.requiredPriority = isRequired; + } + + /** + * The path to driver code in remote runtime. + * The setting of localCodePath is optional, but either obsOptions or + * localCodePath must be set and only one can be set. + * + * @param path the path to driver code in remote runtime. + * @throws YRException the actor task exception. + */ + public void setLocalCodePath(String path) throws YRException { + if (this.obsOptions != null) { + LOGGER.warn("(JobExecutor) obsOptions has been set, localCodePath setting will not work."); + return; + } + if (path.isEmpty()) { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME, + "localCodePath cannot be an empty String."); + } + + this.localCodePath = path; + } + + /** + * The localEntryPoint is converted from the entryPoint, which is to be + * performed by remote JobExecutor runtime. It provides the path in the + * JobExecutor runtime of the driver code if localCodePath is set. + * + * @param ArrayList the actual entryPoint executed in remote JobExecutor + * actor runtime. + * @return the localEntryPoint performed in remote JobExecutor runtime if + * localCodepath is set. + * @throws YRException the actor task exception. + */ + public ArrayList getLocalEntryPoint() throws YRException { + if (this.entryPoint == null) { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME, + "Failed to get entryPoint. EntryPoint should be set."); + } + if (this.localCodePath != null && !this.localCodePath.isEmpty()) { + ArrayList localEntryPoint = new ArrayList<>(this.entryPoint); + localEntryPoint.set(CODE_PATH_INDEX, + String.join(Constants.BACKSLASH, this.localCodePath, this.entryPoint.get(CODE_PATH_INDEX))); + return localEntryPoint; + } + return this.entryPoint; + } + + /** + * Generates an InvokeOptions for YR.instance invocation. + * + * @return InvokeOptions + * @throws YRException the actor task exception. + */ + public InvokeOptions extractInvokeOptions() throws YRException { + // sets runtime env to invokeOptions + String runtimeEnvStr = this.runtimeEnv.toCommand(); + if (!runtimeEnvStr.isEmpty()) { + this.customExtensions.put(Constants.POST_START_EXEC, runtimeEnvStr); + } + + // sets OBS options to invokeOptions + if (this.obsOptions == null && this.localCodePath == null) { + throw new YRException(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION, ModuleCode.RUNTIME, + "Either YRJobParam.obsOptions or YRJobParam.localCodePath should be set. Job name: " + + this.jobName); + } + if (this.obsOptions != null) { + this.customExtensions.putAll(obsOptions.toMap()); + } + + InvokeOptions option = InvokeOptions.builder() + .cpu(this.cpu) + .memory(this.memory) + .customResources(this.customResources) + .customExtensions(this.customExtensions) + .preferredPriority(this.preferredPriority) + .requiredPriority(this.requiredPriority) + .scheduleAffinity(this.scheduleAffinities) + .build(); + + LOGGER.debug("Succeeded to extract InvokeOption from YRJobparam: {}", new Gson().toJson(option)); + return option; + } + + /** + * The YRJobParam Builder + * + * @return Builder for YRJobParam + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for YRJobParam + * + * @since 2023 /06/06 + */ + public static class Builder { + private YRJobParam param; + + /** + * Builder + */ + public Builder() { + this.param = new YRJobParam(); + } + + /** + * set job name + * + * @param jobName the job name + * @return Builder + */ + + public Builder jobName(String jobName) { + this.param.setJobName(jobName); + return this; + } + + /** + * set entryPoint value + * + * @param entryPoint the command ArrayList. + * @return Builder + * @throws YRException the actor task exception. + */ + public Builder entryPoint(ArrayList entryPoint) throws YRException { + this.param.setEntryPoint(entryPoint); + return this; + } + + /** + * set runtimeEnv value + * + * @param runtimeEnv the runtime env + * @return Builder + */ + public Builder runtimeEnv(RuntimeEnv runtimeEnv) { + this.param.setRuntimeEnv(runtimeEnv); + return this; + } + + /** + * set obsOptions value + * + * @param obsOptions the OBSOptions object. + * @return Builder + */ + public Builder obsOptions(OBSoptions obsOptions) { + this.param.setObsOptions(obsOptions); + return this; + } + + /** + * set localCodePath value + * + * @param localCodePath the path to driver code in remote runtime. + * @return Builder + * @throws YRException the actor task exception. + */ + public Builder localCodePath(String localCodePath) throws YRException { + this.param.setLocalCodePath(localCodePath); + return this; + } + + /** + * set cpu value + * + * @param cpu the cpu value + * @return Builder + * @throws YRException the actor task exception. + */ + public Builder cpu(int cpu) throws YRException { + this.param.setCpu(cpu); + return this; + } + + /** + * set memory value + * + * @param memory the memory value. + * @return Builder + * @throws YRException the actor task exception. + */ + public Builder memory(int memory) throws YRException { + this.param.setMemory(memory); + return this; + } + + /** + * set addScheduleAffinity value + * + * @param affinity the Affinity object to be added. + * @return Builder + */ + public Builder addScheduleAffinity(Affinity affinity) { + this.param.scheduleAffinities.add(affinity); + return this; + } + + /** + * set addScheduleAffinity value + * + * @param affinities the Affinity object to be added. + * @return Builder + */ + public Builder scheduleAffinity(List affinities) { + this.param.scheduleAffinities = affinities; + return this; + } + + /** + * set preferredPriority value + * + * @param isPreferred the boolean value indicating the preferred priority. + * @return Builder + */ + public Builder preferredPriority(boolean isPreferred) { + this.param.preferredPriority = isPreferred; + return this; + } + + /** + * set requiredPriority value + * + * @param isRequired the boolean value indicating the required priority. + * @return Builder + */ + public Builder requiredPriority(boolean isRequired) { + this.param.requiredPriority = isRequired; + return this; + } + + /** + * YRJobParam build + * + * @return YRJobParam + */ + public YRJobParam build() { + return this.param; + } + } +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobStatus.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobStatus.java new file mode 100644 index 0000000..512e412 --- /dev/null +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/jobexecutor/YRJobStatus.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +/** + * The enum type YRJobStatus. + * + * @since 2023 /06/06 + */ +public enum YRJobStatus { + /** + * The status indicates that attached-runtime is successfully created, and the + * driver code is being executed. + */ + RUNNING, + + /** + * The status indicates that the attached-runtime running the driver code exits + * with exit code 0. + */ + SUCCEEDED, + + /** + * The status indicates that the attached-runtime has been stopped by user while + * the driver code is NOT completely executed. + */ + STOPPED, + + /** + * The status indicates that the attached-runtime running the driver code or the + * function instance of JobExecutor exits abnormally. + */ + FAILED +} diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/ClusterModeRuntime.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/ClusterModeRuntime.java index f90846d..77f9fb7 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/ClusterModeRuntime.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/ClusterModeRuntime.java @@ -25,10 +25,8 @@ import com.yuanrong.InvokeOptions; import com.yuanrong.MSetParam; import com.yuanrong.SetParam; import com.yuanrong.api.InvokeArg; +import com.yuanrong.api.Node; import com.yuanrong.api.YR; -import com.yuanrong.call.CppInstanceHandler; -import com.yuanrong.call.InstanceHandler; -import com.yuanrong.call.JavaInstanceHandler; import com.yuanrong.errorcode.ErrorCode; import com.yuanrong.errorcode.ErrorInfo; import com.yuanrong.errorcode.ModuleCode; @@ -47,6 +45,12 @@ import com.yuanrong.runtime.util.Utils; import com.yuanrong.serialization.Serializer; import com.yuanrong.serialization.strategy.Strategy; import com.yuanrong.storage.InternalWaitResult; +import com.yuanrong.stream.Consumer; +import com.yuanrong.stream.ConsumerImpl; +import com.yuanrong.stream.Producer; +import com.yuanrong.stream.ProducerConfig; +import com.yuanrong.stream.ProducerImpl; +import com.yuanrong.stream.SubscriptionConfig; import com.yuanrong.utils.SdkUtils; import lombok.Getter; @@ -96,27 +100,6 @@ public class ClusterModeRuntime implements Runtime { org.apache.commons.lang3.tuple.Pair>> functions = new ConcurrentHashMap<>(); - /** - * The Inheritance info. - */ - @Getter - Map> inheritanceInfo = new ConcurrentHashMap<>(); - - /** - * The Java instance handler map. - */ - Map instanceHandlerMap = new ConcurrentHashMap<>(); - - /** - * The Java instance handler map. - */ - Map javaInstanceHandlerMap = new ConcurrentHashMap<>(); - - /** - * The Cpp instance handler map. - */ - Map cppInstanceHandlerMap = new ConcurrentHashMap<>(); - private final ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock(); private final Lock rLock = rwLock.readLock(); @@ -193,7 +176,12 @@ public class ClusterModeRuntime implements Runtime { throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, e); } List nestObjIds = new ArrayList<>(Serializer.CONTAINED_OBJECT_IDS.get()); - Pair res = LibRuntime.Put(byteBuffer.array(), nestObjIds); + Pair res; + try { + res = LibRuntime.Put(byteBuffer.array(), nestObjIds); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } Utils.checkErrorAndThrow(res.getFirst(), "put object"); ObjectRef ret = new ObjectRef(res.getSecond()); ret.setByteBuffer((obj instanceof ByteBuffer)); @@ -211,7 +199,12 @@ public class ClusterModeRuntime implements Runtime { throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, e); } List nestObjIds = new ArrayList<>(Serializer.CONTAINED_OBJECT_IDS.get()); - Pair res = LibRuntime.PutWithParam(byteBuffer.array(), nestObjIds, createParam); + Pair res; + try { + res = LibRuntime.PutWithParam(byteBuffer.array(), nestObjIds, createParam); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } Utils.checkErrorAndThrow(res.getFirst(), "put object"); ObjectRef ret = new ObjectRef(res.getSecond()); ret.setByteBuffer((obj instanceof ByteBuffer)); @@ -232,7 +225,12 @@ public class ClusterModeRuntime implements Runtime { if (waitNum == 0) { throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, "waitNum cannot be 0"); } - InternalWaitResult waitResult = LibRuntime.Wait(objIds, waitNum, timeoutSec); + InternalWaitResult waitResult; + try { + waitResult = LibRuntime.Wait(objIds, waitNum, timeoutSec); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } if (waitResult == null) { throw new YRException(ErrorCode.ERR_INNER_SYSTEM_ERROR, ModuleCode.RUNTIME, "failed to get wait result"); @@ -278,7 +276,12 @@ public class ClusterModeRuntime implements Runtime { FunctionMeta functionMeta, List args, InvokeOptions opt) throws YRException { String language = functionMeta.getLanguage().name(); LOGGER.debug("start invoke function by name({}), language({})", functionMeta.getFunctionName(), language); - Pair res = LibRuntime.InvokeInstance(functionMeta, "", args, opt); + Pair res; + try { + res = LibRuntime.InvokeInstance(functionMeta, "", args, opt); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } Utils.checkErrorAndThrow(res.getFirst(), "invoke function"); String objId = res.getSecond(); LOGGER.debug("succeed to invoke function by name({}), language({}), objId({})", functionMeta.getFunctionName(), @@ -312,7 +315,12 @@ public class ClusterModeRuntime implements Runtime { public String invokeInstance(FunctionMeta functionMeta, String instanceId, List args, InvokeOptions opt) throws YRException { LOGGER.debug("start to invoke instance({}) functionMeta is {}", instanceId, functionMeta); - Pair res = LibRuntime.InvokeInstance(functionMeta, instanceId, args, opt); + Pair res; + try { + res = LibRuntime.InvokeInstance(functionMeta, instanceId, args, opt); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } Utils.checkErrorAndThrow(res.getFirst(), "invoke instance"); String objId = res.getSecond(); LOGGER.debug("succeed to invoke instance({}) objId({})", instanceId, objId); @@ -361,19 +369,34 @@ public class ClusterModeRuntime implements Runtime { */ @Override public void terminateInstance(String instanceId) throws YRException { - ErrorInfo errorInfo = LibRuntime.Kill(instanceId); + ErrorInfo errorInfo; + try { + errorInfo = LibRuntime.Kill(instanceId); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } Utils.checkErrorAndThrow(errorInfo, "kill instance(" + instanceId + ")"); LOGGER.info("succeed to terminate instance({})", instanceId); } + /** + * Decrease reference. + * + * @param ids the ids. + */ + @Override + public void decreaseReference(List ids) { + LibRuntime.DecreaseReference(ids); + } + /** * Is on cloud boolean. * * @return the boolean */ @Override - public boolean isOnCloud() { - return false; + public boolean isDriver() { + return true; } /** @@ -381,10 +404,15 @@ public class ClusterModeRuntime implements Runtime { * * @param objectId the object id * @return the string representing the real instance id + * @throws YRException the actor task exception. */ @Override - public String getRealInstanceId(String objectId) { - return LibRuntime.GetRealInstanceId(objectId); + public String getRealInstanceId(String objectId) throws YRException { + try { + return LibRuntime.GetRealInstanceId(objectId); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } } /** @@ -393,42 +421,15 @@ public class ClusterModeRuntime implements Runtime { * @param objectId the object id * @param instanceId the instance id * @param opts the invoke options + * @throws YRException the actor task exception. */ @Override - public void saveRealInstanceId(String objectId, String instanceId, InvokeOptions opts) { - LibRuntime.SaveRealInstanceId(objectId, instanceId, opts); - } - - /** - * Collect instance handler info. - * - * @param instanceHandler the instance handler - */ - @Override - public void collectInstanceHandlerInfo(InstanceHandler instanceHandler) { - instanceHandlerMap.put(instanceHandler.getInstanceId(), instanceHandler); - } - - @Override - public void collectInstanceHandlerInfo(JavaInstanceHandler javaInstanceHandler) { - javaInstanceHandlerMap.put(javaInstanceHandler.getInstanceId(), javaInstanceHandler); - } - - @Override - public void collectInstanceHandlerInfo(CppInstanceHandler cppInstanceHandler) { - cppInstanceHandlerMap.put(cppInstanceHandler.getInstanceId(), cppInstanceHandler); - } - - /** - * Returns an InstanceHandler object that contains the Java instance handler - * which is NOT terminated and associated with the specified instanceID. - * - * @param instanceID The instanceID that identifies the Java instance handler - * @return An InstanceHandler object associated with the specified instanceID - */ - @Override - public InstanceHandler getInstanceHandlerInfo(String instanceID) { - return this.instanceHandlerMap.get(instanceID); + public void saveRealInstanceId(String objectId, String instanceId, InvokeOptions opts) throws YRException { + try { + LibRuntime.SaveRealInstanceId(objectId, instanceId, opts); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } } /** @@ -440,7 +441,6 @@ public class ClusterModeRuntime implements Runtime { @Override public void Finalize() { LibRuntime.Finalize(); - clearInstanceHandlerInfo(); } /** @@ -455,9 +455,6 @@ public class ClusterModeRuntime implements Runtime { @Override public void Finalize(String runtimeCtx, int leftRuntimeNum) { LibRuntime.FinalizeWithCtx(runtimeCtx); - if (leftRuntimeNum == 0) { - clearInstanceHandlerInfo(); - } } /** @@ -494,7 +491,12 @@ public class ClusterModeRuntime implements Runtime { throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, "Cannot set a null value to key: " + key); } - ErrorInfo errorInfo = LibRuntime.KVWrite(key, value, setParam); + ErrorInfo errorInfo; + try { + errorInfo = LibRuntime.KVWrite(key, value, setParam); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("KVWrite err: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), errorInfo.getModuleCode(), errorInfo.getErrorMessage()); @@ -520,7 +522,12 @@ public class ClusterModeRuntime implements Runtime { "Cannot set a null value to key: " + keys.get(i)); } } - ErrorInfo errorInfo = LibRuntime.KVMSetTx(keys, values, mSetParam); + ErrorInfo errorInfo; + try { + errorInfo = LibRuntime.KVMSetTx(keys, values, mSetParam); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("KVMSetTx err: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), errorInfo.getModuleCode(), errorInfo.getErrorMessage()); @@ -545,7 +552,6 @@ public class ClusterModeRuntime implements Runtime { synchronized (this) { allClassFunctions = functions.get(className); if (allClassFunctions == null) { - inheritanceInfo.putIfAbsent(className, new HashSet<>()); allClassFunctions = loadFunctionsForClass(className); functions.putIfAbsent(className, allClassFunctions); } @@ -618,7 +624,6 @@ public class ClusterModeRuntime implements Runtime { } final String signature = type.getDescriptor(); String declaringClassName = executable.getDeclaringClass().getName(); - inheritanceInfo.get(className).add(declaringClassName); final String methodName = executable instanceof Method ? executable.getName() : CONSTRUCTOR_NAME; FunctionMeta meta = FunctionMeta.newBuilder().setClassName(declaringClassName) .setFunctionName(methodName) @@ -642,7 +647,12 @@ public class ClusterModeRuntime implements Runtime { */ @Override public byte[] KVRead(String key, int timeoutMS) throws YRException { - Pair result = LibRuntime.KVRead(key, timeoutMS); + Pair result; + try { + result = LibRuntime.KVRead(key, timeoutMS); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } ErrorInfo errorInfo = result.getSecond(); if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("KVRead err: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), errorInfo.getModuleCode(), @@ -665,7 +675,12 @@ public class ClusterModeRuntime implements Runtime { */ @Override public List KVRead(List keys, int timeoutMS, boolean allowPartial) throws YRException { - Pair, ErrorInfo> result = LibRuntime.KVRead(keys, timeoutMS, allowPartial); + Pair, ErrorInfo> result; + try { + result = LibRuntime.KVRead(keys, timeoutMS, allowPartial); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } ErrorInfo errorInfo = result.getSecond(); if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("KVRead err: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), errorInfo.getModuleCode(), @@ -688,7 +703,12 @@ public class ClusterModeRuntime implements Runtime { */ @Override public List KVGetWithParam(List keys, GetParams params, int timeoutMS) throws YRException { - Pair, ErrorInfo> result = LibRuntime.KVGetWithParam(keys, params, timeoutMS); + Pair, ErrorInfo> result; + try { + result = LibRuntime.KVGetWithParam(keys, params, timeoutMS); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } ErrorInfo errorInfo = result.getSecond(); if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("KVGetWithParam err: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), @@ -706,7 +726,12 @@ public class ClusterModeRuntime implements Runtime { */ @Override public void KVDel(String key) throws YRException { - ErrorInfo errorInfo = LibRuntime.KVDel(key); + ErrorInfo errorInfo; + try { + errorInfo = LibRuntime.KVDel(key); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("KVDel err: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), errorInfo.getModuleCode(), errorInfo.getErrorMessage()); @@ -719,10 +744,16 @@ public class ClusterModeRuntime implements Runtime { * * @param keys A list of keys of the pairs to be deleted * @return A list of keys that were failed to be deleted + * @throws YRException If an error occurs while performing the delete operation */ @Override - public List KVDel(List keys) { - Pair, ErrorInfo> result = LibRuntime.KVDel(keys); + public List KVDel(List keys) throws YRException { + Pair, ErrorInfo> result = null; + try { + result = LibRuntime.KVDel(keys); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } ErrorInfo errorInfo = result.getSecond(); if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("KVDel err: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), errorInfo.getModuleCode(), @@ -744,7 +775,12 @@ public class ClusterModeRuntime implements Runtime { if (timeoutSec != Constants.NO_TIMEOUT) { timeoutMs = timeoutSec * Constants.SEC_TO_MS; } - ErrorInfo errorInfo = LibRuntime.LoadState(timeoutMs); + ErrorInfo errorInfo; + try { + errorInfo = LibRuntime.LoadState(timeoutMs); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } Utils.checkErrorAndThrow(errorInfo, "Load state error"); } @@ -761,23 +797,13 @@ public class ClusterModeRuntime implements Runtime { if (timeoutSec != Constants.NO_TIMEOUT) { timeoutMs = timeoutSec * Constants.SEC_TO_MS; } - ErrorInfo errorInfo = LibRuntime.SaveState(timeoutMs); - Utils.checkErrorAndThrow(errorInfo, "Save state error"); - } - - private void clearInstanceHandlerInfo() { - for (Map.Entry entry : instanceHandlerMap.entrySet()) { - InstanceHandler instanceHandler = entry.getValue(); - instanceHandler.clearHandlerInfo(); - } - for (Map.Entry entry : javaInstanceHandlerMap.entrySet()) { - JavaInstanceHandler javaInstanceHandler = entry.getValue(); - javaInstanceHandler.clearHandlerInfo(); - } - for (Map.Entry entry : cppInstanceHandlerMap.entrySet()) { - CppInstanceHandler cppInstanceHandler = entry.getValue(); - cppInstanceHandler.clearHandlerInfo(); + ErrorInfo errorInfo; + try { + errorInfo = LibRuntime.SaveState(timeoutMs); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); } + Utils.checkErrorAndThrow(errorInfo, "Save state error"); } @Override @@ -787,7 +813,12 @@ public class ClusterModeRuntime implements Runtime { "The value of timeout should be -1 or greater than 0"); throw new YRException(errorInfo); } - ErrorInfo errorInfo = LibRuntime.GroupCreate(groupName, opts); + ErrorInfo errorInfo; + try { + errorInfo = LibRuntime.GroupCreate(groupName, opts); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("group create error: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), errorInfo.getModuleCode(), errorInfo.getErrorMessage()); @@ -796,7 +827,12 @@ public class ClusterModeRuntime implements Runtime { } @Override public void groupWait(String groupName) throws YRException { - ErrorInfo errorInfo = LibRuntime.GroupWait(groupName); + ErrorInfo errorInfo; + try { + errorInfo = LibRuntime.GroupWait(groupName); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } if (!errorInfo.getErrorCode().equals(ErrorCode.ERR_OK)) { LOGGER.error("group wait error: Code:{}, MCode:{}, Msg:{}", errorInfo.getErrorCode(), errorInfo.getModuleCode(), errorInfo.getErrorMessage()); @@ -809,15 +845,132 @@ public class ClusterModeRuntime implements Runtime { LibRuntime.GroupTerminate(groupName); } + /** + * createStreamProducer + * + * @param streamName the stream name + * @param producerConf the producer conf + * @return Producer the stream producer + * @throws YRException if there is an exception during creating stream producer + */ + @Override + public Producer createStreamProducer(String streamName, ProducerConfig producerConf) throws YRException { + if (producerConf.getMaxStreamSize() < 0) { + throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, + "maxStreamSize (" + producerConf.getMaxStreamSize() + ") is invalid, expect >= 0"); + } + if (producerConf.getRetainForNumConsumers() < 0) { + throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, + "retainForNumConsumers (" + producerConf.getRetainForNumConsumers() + ") is invalid, expect >= 0"); + } + if (producerConf.getReserveSize() < 0) { + throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, + "reserveSize (" + producerConf.getReserveSize() + ") is invalid, expect >= 0"); + } + rLock.lock(); + try { + long producerPtr = LibRuntime.CreateStreamProducerWithConfig(streamName, producerConf); + return new ProducerImpl(producerPtr); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } finally { + rLock.unlock(); + } + } + + /** + * createStreamConsumer + * + * @param streamName the stream name + * @param config the subscription conf + * @param autoAck if consumer auto ack + * @return Consumer the stream consumer + * @throws YRException if there is an exception during creating stream consumer + */ + @Override + public Consumer createStreamConsumer(String streamName, SubscriptionConfig config, boolean autoAck) + throws YRException { + rLock.lock(); + try { + long consumerPtr = LibRuntime.CreateStreamConsumer(streamName, config.getSubscriptionName(), + config.getSubscriptionType(), autoAck); + return new ConsumerImpl(consumerPtr); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } finally { + rLock.unlock(); + } + } + + /** + * deleteStream + * + * @param streamName the stream name + * @throws YRException if there is an exception when delete stream + */ + @Override + public void deleteStream(String streamName) throws YRException { + rLock.lock(); + try { + ErrorInfo err = LibRuntime.DeleteStream(streamName); + StackTraceUtils.checkErrorAndThrowForInvokeException(err, err.getErrorMessage()); + } finally { + rLock.unlock(); + } + } + + /** + * queryGlobalProducersNum + * + * @param streamName the stream name + * @return long the producers num + * @throws YRException if there is an exception during quering global producersNum + */ + @Override + public long queryGlobalProducersNum(String streamName) throws YRException { + rLock.lock(); + try { + return LibRuntime.QueryGlobalProducersNum(streamName); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } finally { + rLock.unlock(); + } + } + + /** + * queryGlobalConsumersNum + * + * @param streamName the stream name + * @return long the consumers num + * @throws YRException if there is an exception during quering global consumersNum + */ + @Override + public long queryGlobalConsumersNum(String streamName) throws YRException { + rLock.lock(); + try { + return LibRuntime.QueryGlobalConsumersNum(streamName); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } finally { + rLock.unlock(); + } + } + /** * Get instance route. * * @param objectId the object id * @return the string representing the instance route + * @throws YRException the YR exception. */ @Override - public String getInstanceRoute(String objectId) { - return LibRuntime.GetInstanceRoute(objectId); + public String getInstanceRoute(String objectId) throws YRException { + try { + return LibRuntime.GetInstanceRoute(objectId); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } } /** @@ -825,10 +978,15 @@ public class ClusterModeRuntime implements Runtime { * * @param objectId the object id * @param instanceRoute the instance route + * @throws YRException the YR exception. */ @Override - public void saveInstanceRoute(String objectId, String instanceRoute) { - LibRuntime.SaveInstanceRoute(objectId, instanceRoute); + public void saveInstanceRoute(String objectId, String instanceRoute) throws YRException { + try { + LibRuntime.SaveInstanceRoute(objectId, instanceRoute); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } } /** @@ -843,4 +1001,22 @@ public class ClusterModeRuntime implements Runtime { Utils.checkErrorAndThrow(errorInfo, "kill instance sync(" + instanceId + ")"); LOGGER.info("succeed to terminate instance sync({})", instanceId); } + + /** + * Get node information in the cluster. + * + * @return List: node information + * @throws YRException the actor task exception. + */ + @Override + public List nodes() throws YRException { + Pair> res; + try { + res = LibRuntime.nodes(); + } catch (LibRuntimeException e) { + throw new YRException(e.getErrorCode(), e.getModuleCode(), e.getMessage()); + } + Utils.checkErrorAndThrow(res.getFirst(), "get node information"); + return res.getSecond(); + } } diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/Runtime.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/Runtime.java index 88c82f0..a67457d 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/Runtime.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/Runtime.java @@ -24,15 +24,17 @@ import com.yuanrong.InvokeOptions; import com.yuanrong.MSetParam; import com.yuanrong.SetParam; import com.yuanrong.api.InvokeArg; -import com.yuanrong.call.CppInstanceHandler; -import com.yuanrong.call.InstanceHandler; -import com.yuanrong.call.JavaInstanceHandler; +import com.yuanrong.api.Node; import com.yuanrong.exception.YRException; import com.yuanrong.jni.LibRuntimeConfig; import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; import com.yuanrong.runtime.client.KVManager; import com.yuanrong.runtime.client.ObjectRef; import com.yuanrong.storage.InternalWaitResult; +import com.yuanrong.stream.Consumer; +import com.yuanrong.stream.Producer; +import com.yuanrong.stream.ProducerConfig; +import com.yuanrong.stream.SubscriptionConfig; import java.util.List; @@ -172,20 +174,28 @@ public interface Runtime { */ void terminateInstance(String instanceId) throws YRException; + /** + * Decrease reference. + * + * @param ids the ids. + */ + void decreaseReference(List ids); + /** * Is on cloud boolean. * * @return the boolean */ - boolean isOnCloud(); + boolean isDriver(); /** * Get real instance id. * * @param objectId the object id * @return the string representing the real instance id + * @throws YRException the actor task exception. */ - String getRealInstanceId(String objectId); + String getRealInstanceId(String objectId) throws YRException; /** * Save real instance id. @@ -193,38 +203,9 @@ public interface Runtime { * @param objectId the object id * @param instanceId the instance id * @param opts the invoke options + * @throws YRException the YR exception. */ - void saveRealInstanceId(String objectId, String instanceId, InvokeOptions opts); - - /** - * Collect instance handler info. - * - * @param instanceHandler the instance handler - */ - void collectInstanceHandlerInfo(InstanceHandler instanceHandler); - - /** - * Collect instance handler info. - * - * @param javaInstanceHandler the java instance handler - */ - void collectInstanceHandlerInfo(JavaInstanceHandler javaInstanceHandler); - - /** - * Collect instance handler info. - * - * @param cppInstanceHandler the cpp instance handler - */ - void collectInstanceHandlerInfo(CppInstanceHandler cppInstanceHandler); - - /** - * Returns an InstanceHandler object that contains the Java instance handler - * which is NOT terminated and associated with the specified instanceID. - * - * @param instanceID The instanceID that identifies the Java instance handler - * @return An InstanceHandler object associated with the specified instanceID - */ - InstanceHandler getInstanceHandlerInfo(String instanceID); + void saveRealInstanceId(String objectId, String instanceId, InvokeOptions opts) throws YRException; /** * Finalizes all actors and tasks and release any resources associated with @@ -336,8 +317,9 @@ public interface Runtime { * * @param keys A list of keys of the pairs to be deleted * @return A list of keys that were failed to be deleted + * @throws YRException the YR exception. */ - List KVDel(List keys); + List KVDel(List keys) throws YRException; /** * loadState @@ -379,21 +361,71 @@ public interface Runtime { */ void groupWait(String groupName) throws YRException; + /** + * createStreamProducer + * + * @param streamName the stream name + * @param producerConf the producer conf + * @return Producer the stream producer + * @throws YRException if there is an exception during creating stream producer + */ + Producer createStreamProducer(String streamName, ProducerConfig producerConf) throws YRException; + + /** + * createStreamConsumer + * + * @param streamName the stream name + * @param config the subscription conf + * @param autoAck if consumer auto ack + * @return Consumer the stream consumer + * @throws YRException if there is an exception during creating stream consumer + */ + Consumer createStreamConsumer(String streamName, SubscriptionConfig config, + boolean autoAck) throws YRException; + + /** + * deleteStream + * + * @param streamName the stream name + * @throws YRException if there is an exception when delete stream + */ + void deleteStream(String streamName) throws YRException; + + /** + * queryGlobalProducersNum + * + * @param streamName the stream name + * @return long the producers num + * @throws YRException if there is an exception during quering global producersNum + */ + long queryGlobalProducersNum(String streamName) throws YRException; + + /** + * queryGlobalConsumersNum + * + * @param streamName the stream name + * @return long the consumers num + * @throws YRException if there is an exception during quering global consumersNum + */ + long queryGlobalConsumersNum(String streamName) throws YRException; + /** * Get instance route. * * @param objectId the object id * @return the string representing the instance route + * @throws YRException the YR exception. */ - String getInstanceRoute(String objectId); + String getInstanceRoute(String objectId) throws YRException; /** * Save instance route. * * @param objectId the object id * @param instanceRoute the instance route + * @throws YRException the YR exception. */ - void saveInstanceRoute(String objectId, String instanceRoute); + void saveInstanceRoute(String objectId, String instanceRoute) throws YRException; /** * Sync terminate instance. @@ -402,4 +434,12 @@ public interface Runtime { * @throws YRException the YR exception. */ void terminateInstanceSync(String instanceId) throws YRException; + + /** + * Get node information in the cluster. + * + * @return List: node information + * @throws YRException the YR exception. + */ + List nodes() throws YRException; } diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/client/KVManager.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/client/KVManager.java index 0d195f3..41075c5 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/client/KVManager.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/runtime/client/KVManager.java @@ -104,8 +104,19 @@ public class KVManager { * writing to the key-value store */ public void set(String key, byte[] value, Integer length, ExistenceOpt existence) throws YRException { + byte[] newValue; + try { + newValue = Arrays.copyOfRange(value, 0, length); + } catch (ArrayIndexOutOfBoundsException e) { + LOGGER.error("Length of value({}) is smaller than then parameter 'length'({}), key: {}", value.length, + length, key); + throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, e); + } catch (NullPointerException e) { + throw new YRException(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, + "Cannot set a null value to key: " + key); + } SetParam setParam = new SetParam.Builder().existence(existence).build(); - set(key, value, length, setParam); + YR.getRuntime().KVWrite(key, newValue, setParam); } /** @@ -527,11 +538,10 @@ public class KVManager { List objects = new ArrayList(keys.size()); for (int i = 0; i < keys.size(); i++) { try { - byte[] value = serializedValue.get(i); - objects.add(value == null ? null : Serializer.deserialize(value, types.get(i))); + objects.add(Serializer.deserialize(serializedValue.get(i), types.get(i))); } catch (IOException e) { throw new YRException(ErrorCode.ERR_DESERIALIZATION_FAILED, ModuleCode.RUNTIME, - "Failed to deserialize the value associated with the key: " + keys.get(i)); + "Failed to deserialize the value associated with the key: " + keys.get(i)); } } return objects; diff --git a/api/java/yr-api-sdk/src/main/java/com/yuanrong/utils/SdkUtils.java b/api/java/yr-api-sdk/src/main/java/com/yuanrong/utils/SdkUtils.java index ef9a334..d4a1911 100644 --- a/api/java/yr-api-sdk/src/main/java/com/yuanrong/utils/SdkUtils.java +++ b/api/java/yr-api-sdk/src/main/java/com/yuanrong/utils/SdkUtils.java @@ -179,6 +179,8 @@ public class SdkUtils { libConfig.setThreadPoolSize(configManager.getThreadPoolSize()); libConfig.setLoadPaths(configManager.getLoadPaths()); libConfig.setInCluster(configManager.isInCluster()); + libConfig.setHttpIocThreadsNum(configManager.getHttpIocThreadsNum()); + libConfig.setHttpIdleTime(configManager.getHttpIdleTime()); libConfig.setRpcTimeout(configManager.getRpcTimeout()); libConfig.setTenantId(configManager.getTenantId()); libConfig.setCustomEnvs(configManager.getCustomEnvs()); @@ -187,6 +189,7 @@ public class SdkUtils { libConfig.setVerifyFilePath(configManager.getVerifyFilePath()); libConfig.setPrivateKeyPath(configManager.getPrivateKeyPath()); libConfig.setServerName(configManager.getServerName()); + libConfig.setPrivateKeyPaaswd(configManager.getPrivateKeyPaaswd()); libConfig.setCodePath(configManager.getCodePath()); Map functionIds = new HashMap() {}; @@ -196,6 +199,9 @@ public class SdkUtils { if (configManager.getCppFunctionURN() != null && !configManager.getCppFunctionURN().isEmpty()) { functionIds.put(Libruntime.LanguageType.Cpp, reformatFunctionUrn(configManager.getCppFunctionURN())); } + if (configManager.getGoFunctionURN() != null && !configManager.getGoFunctionURN().isEmpty()) { + functionIds.put(Libruntime.LanguageType.Golang, reformatFunctionUrn(configManager.getGoFunctionURN())); + } libConfig.setFunctionIds(functionIds); LOGGER.debug("java functionIds: {}", functionIds.get(Libruntime.LanguageType.Java)); LOGGER.debug("cpp functionIds: {}", functionIds.get(Libruntime.LanguageType.Cpp)); diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestConfig.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestConfig.java index c79f159..53c3fd7 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestConfig.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestConfig.java @@ -17,10 +17,14 @@ package com.yuanrong; import com.yuanrong.exception.YRException; +import com.yuanrong.runtime.util.Constants; import org.junit.Assert; import org.junit.Test; +import java.util.ArrayList; +import java.util.HashMap; + public class TestConfig { @Test public void testInitConfig() { @@ -29,6 +33,7 @@ public class TestConfig { "127.0.0.0", "127.0.0.0", "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true, false); Config testConf2 = new Config( "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", @@ -36,28 +41,77 @@ public class TestConfig { 1, "127.0.0.0", 1, - "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest"); + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); Config testConf3 = new Config( "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", - true); - Config testConf4 = new Config.Builder() + "test-go-urn", + true, + false); + Config testConf4 = new Config( + "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", + "127.0.0.0", + 1, + "127.0.0.0", + 1, + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + false); + Config testConf5 = new Config.Builder().iamAuthToken("test-token") .cppFunctionURN("sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest") + .goFunctionURN("sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest") .isDriver(true) .ns("test-ns") .logDir("/tmp") .logLevel("test-level") .enableDisConvCallStack(false) .rpcTimeout(1) + .codePath(new ArrayList<>()) .build(); + Config testConf6 = new Config(); testConf1.setMaxLogSizeMb(10); testConf2.setMaxLogFileNum(10); - testConf4.toString(); - Assert.assertNotEquals(testConf3, testConf2); - Assert.assertNotEquals(testConf1, testConf2); + testConf5.toString(); + testConf4.hashCode(); + Assert.assertFalse(testConf3.equals(testConf4)); + Assert.assertFalse(testConf1.equals(testConf2)); + + Assert.assertTrue(testConf6.isInCluster()); + Assert.assertTrue(testConf6.isDriver()); + Assert.assertTrue(testConf6.isEnableMetrics()); + Assert.assertFalse(testConf6.isEnableMTLS()); + Assert.assertFalse(testConf6.isEnableDsAuth()); + Assert.assertTrue(testConf6.isEnableDisConvCallStack()); + Assert.assertFalse(testConf6.isThreadLocal()); + Assert.assertFalse(testConf6.isEnableSetContext()); + Assert.assertEquals("", testConf6.getServerAddress()); + Assert.assertEquals("", testConf6.getDataSystemAddress()); + Assert.assertEquals("", testConf6.getNs()); + Assert.assertEquals("", testConf6.getLogLevel()); + Assert.assertEquals("", testConf6.getIamAuthToken()); + Assert.assertEquals("", testConf6.getTenantId()); + Assert.assertEquals("sn:cn:yrk:12345678901234561234567890123456:function:0-defaultservice-cpp:$latest", + testConf6.getCppFunctionURN()); + Assert.assertEquals("sn:cn:yrk:12345678901234561234567890123456:function:0-defaultservice-go:$latest", + testConf6.getGoFunctionURN()); + Assert.assertEquals(30 * 60, testConf6.getRpcTimeout()); + Assert.assertEquals(31222, testConf6.getServerAddressPort()); + Assert.assertEquals(31222, testConf6.getDataSystemAddressPort()); + Assert.assertEquals(0, testConf6.getThreadPoolSize()); + Assert.assertEquals(10, testConf6.getRecycleTime()); + Assert.assertEquals(System.getProperty("user.dir"), testConf6.getLogDir()); + Assert.assertEquals(0, testConf6.getMaxLogSizeMb()); + Assert.assertEquals(0, testConf6.getMaxLogFileNum()); + Assert.assertEquals(-1, testConf6.getMaxTaskInstanceNum()); + Assert.assertEquals(100, testConf6.getMaxConcurrencyCreateNum()); + Assert.assertEquals(Constants.DEFAULT_HTTP_IO_THREAD_CNT, testConf6.getHttpIocThreadsNum()); + Assert.assertEquals(Constants.DEFAULT_HTTP_IDLE_TIME, testConf6.getHttpIdleTime()); + Assert.assertEquals(new ArrayList<>(), testConf6.getLoadPaths()); + Assert.assertEquals(new HashMap<>(), testConf6.getCustomEnvs()); } @Test @@ -67,6 +121,8 @@ public class TestConfig { "127.0.0.0", "127.0.0.0", "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + "test-go-urn", + true, true); boolean isException = false; @@ -81,6 +137,18 @@ public class TestConfig { } Assert.assertTrue(isException); + isException = false; + testConf.setGoFunctionURN("test-goFunction"); + try { + testConf.checkParameter(); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("goFunctionURN is invalid")); + testConf.setGoFunctionURN( + "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest"); + isException = true; + } + Assert.assertTrue(isException); + isException = false; testConf.setServerAddress("test-server-address"); try { diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestConfigManager.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestConfigManager.java index 0482de4..8d4aa21 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestConfigManager.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestConfigManager.java @@ -32,7 +32,9 @@ public class TestConfigManager { 1, "127.0.0.0", 1, - "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest"); + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); ConfigManager configManager = new ConfigManager(); configManager.init(config); @@ -54,6 +56,7 @@ public class TestConfigManager { testConfigManager.setJobId("test-id"); testConfigManager.setRuntimeId("test-runtime"); testConfigManager.setCppFunctionURN("sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest"); + testConfigManager.setGoFunctionURN("sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest"); testConfigManager.setRecycleTime(1); testConfigManager.setInitialized(expectedValue); testConfigManager.setMaxTaskInstanceNum(10); @@ -66,6 +69,8 @@ public class TestConfigManager { testConfigManager.setLogFileSizeMax(1024); testConfigManager.setThreadPoolSize(1); testConfigManager.setLoadPaths(new ArrayList<>()); + testConfigManager.setHttpIocThreadsNum(1); + testConfigManager.setHttpIdleTime(60); testConfigManager.setRpcTimeout(10); testConfigManager.setTenantId("test-tenantID"); testConfigManager.setCustomEnvs(new HashMap<>()); diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestGroup.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestGroup.java index 6911cb4..4377217 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestGroup.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/TestGroup.java @@ -21,6 +21,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.when; import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.exception.LibRuntimeException; import com.yuanrong.jni.LibRuntime; import org.junit.Assert; @@ -46,7 +47,7 @@ public class TestGroup { } @Test - public void testInitGroup() { + public void testInitGroup() throws LibRuntimeException { when(LibRuntime.GroupCreate(anyString(), any())).thenReturn(new ErrorInfo()); when(LibRuntime.GroupTerminate(anyString())).thenReturn(new ErrorInfo()); when(LibRuntime.GroupWait(anyString())).thenReturn(new ErrorInfo()); diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/api/TestJobExecutorCaller.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/api/TestJobExecutorCaller.java new file mode 100644 index 0000000..ae871dd --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/api/TestJobExecutorCaller.java @@ -0,0 +1,214 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.api; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.when; + +import com.yuanrong.FunctionWrapper; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.exception.YRException; +import com.yuanrong.exception.handler.traceback.StackTraceUtils; +import com.yuanrong.jni.LibRuntime; +import com.yuanrong.jobexecutor.OBSoptions; +import com.yuanrong.jobexecutor.RuntimeEnv; +import com.yuanrong.jobexecutor.YRJobInfo; +import com.yuanrong.jobexecutor.YRJobParam; +import com.yuanrong.runtime.ClusterModeRuntime; +import com.yuanrong.storage.InternalWaitResult; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Optional; + +@RunWith(PowerMockRunner.class) +@PrepareForTest( {LibRuntime.class, YR.class, StackTraceUtils.class}) +@SuppressStaticInitializationFor( {"com.yuanrong.jni.LibRuntime"}) +@PowerMockIgnore("javax.management.*") +public class TestJobExecutorCaller { + @Test + public void testSubmitJob() throws Exception { + String mockUserJobID = "test-userJobID"; + ArrayList readyIds = new ArrayList<>(); + readyIds.add("test-readId"); + ArrayList unreadyIds = new ArrayList<>(); + unreadyIds.add("test-unreadyId"); + HashMap exceptionIds = new HashMap<>(); + exceptionIds.put("test-error", + new ErrorInfo(ErrorCode.ERR_ETCD_OPERATION_ERROR, ModuleCode.RUNTIME_INVOKE, "msg")); + InternalWaitResult result = new InternalWaitResult(readyIds, unreadyIds, exceptionIds); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, ""); + Pair errorInfoStringPair = new Pair(errorInfo, "ok"); + + ClusterModeRuntime runtime = PowerMockito.mock(ClusterModeRuntime.class); + PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); + + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.CreateInstance(any(), anyList(), any())).thenReturn(errorInfoStringPair); + when(LibRuntime.GetRealInstanceId(anyString())).thenReturn(mockUserJobID); + when(LibRuntime.Wait(anyList(), anyInt(), anyInt())).thenReturn(result); + + PowerMockito.mockStatic(StackTraceUtils.class); + PowerMockito.doNothing() + .when(StackTraceUtils.class, "checkErrorAndThrowForInvokeException", any(), anyString()); + + ArrayList testEntryPoint = new ArrayList<>(); + testEntryPoint.add("test-entryPoint1"); + testEntryPoint.add("test-entryPoint2"); + testEntryPoint.add("test-entryPoint3"); + testEntryPoint.add("test-entryPoint4"); + + boolean isException = false; + YRJobParam yrJobParam = null; + try { + yrJobParam = new YRJobParam.Builder().cpu(500) + .memory(500) + .entryPoint(testEntryPoint) + .localCodePath("/tmp") + .jobName("test-job") + .addScheduleAffinity(null) + .obsOptions(new OBSoptions()) + .runtimeEnv(new RuntimeEnv()) + .preferredPriority(true) + .scheduleAffinity(new ArrayList<>()) + .build(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + + isException = false; + try { + JobExecutorCaller jobExecutorCaller = new JobExecutorCaller(); + JobExecutorCaller.submitJob(yrJobParam); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testStopJob() throws Exception { + ClusterModeRuntime runtime = PowerMockito.mock(ClusterModeRuntime.class); + PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); + + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.IsInitialized()).thenReturn(true); + + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, ""); + Pair errorInfoStringPair = new Pair(errorInfo, "ok"); + when(LibRuntime.InvokeInstance(any(), anyString(), anyList(), any())).thenReturn(errorInfoStringPair); + + boolean isException = false; + + try { + JobExecutorCaller.stopJob("test-ID"); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testGetYrJobInfo() throws Exception { + FunctionWrapper function = PowerMockito.mock(FunctionWrapper.class); + when(function.getReturnType()).thenReturn(Optional.empty()); + + ClusterModeRuntime runtime = PowerMockito.mock(ClusterModeRuntime.class); + PowerMockito.when(runtime.get(any(), anyInt())).thenReturn(new YRJobInfo()); + PowerMockito.when(runtime.getJavaFunction(any())).thenReturn(function); + PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); + + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.IsInitialized()).thenReturn(true); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, ""); + Pair errorInfoStringPair = new Pair(errorInfo, "ok"); + when(LibRuntime.InvokeInstance(any(), anyString(), anyList(), any())).thenReturn(errorInfoStringPair); + + boolean isException = false; + try { + JobExecutorCaller.getYrJobInfo("test-ID"); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + JobExecutorCaller.getJobStatus("test-jobStatus"); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + JobExecutorCaller.listJobs("test-listJob"); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + JobExecutorCaller.listJobs(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testDeleteJob() throws Exception { + boolean isException = false; + + ClusterModeRuntime runtime = PowerMockito.mock(ClusterModeRuntime.class); + PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.IsInitialized()).thenReturn(true); + + try { + JobExecutorCaller.deleteJob(""); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + JobExecutorCaller.deleteJob("test-JobExecutorCaller"); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppFunctionHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppFunctionHandler.java index a4aaaa7..7c63015 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppFunctionHandler.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppFunctionHandler.java @@ -47,7 +47,9 @@ public class TestCppFunctionHandler { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceCreator.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceCreator.java index faa99dd..a2a8130 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceCreator.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceCreator.java @@ -47,7 +47,9 @@ public class TestCppInstanceCreator { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceFunctionHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceFunctionHandler.java index 6855983..26447bf 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceFunctionHandler.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceFunctionHandler.java @@ -47,7 +47,9 @@ public class TestCppInstanceFunctionHandler { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceHandler.java index f549876..cbea5bf 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceHandler.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestCppInstanceHandler.java @@ -53,7 +53,9 @@ public class TestCppInstanceHandler { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } @@ -64,12 +66,13 @@ public class TestCppInstanceHandler { } @Test - public void testCppInstanceHandler() { + public void testCppInstanceHandler() throws YRException { String instanceId = "instanceId"; String functionId = "functionId"; String className = "Counter"; CppInstanceHandler instance = new CppInstanceHandler(instanceId, functionId, className); CppInstanceHandler cppInstanceHandler = new CppInstanceHandler(); + cppInstanceHandler.release(); cppInstanceHandler.clearHandlerInfo(); CppInstanceHandlerHelper cppInstanceHandlerHelper = new CppInstanceHandlerHelper(); CppInstanceFunctionHandler functionHandler = instance.function( @@ -153,6 +156,7 @@ public class TestCppInstanceHandler { Assert.assertEquals(newHandler.getInstanceId(), instanceId); Assert.assertEquals(newHandler.getRealInstanceId(), "realInsID"); Assert.assertEquals(newHandler.isNeedOrder(), true); + newHandler.release(); } catch (YRException exp) { exp.printStackTrace(); isException = true; diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestFunctionHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestFunctionHandler.java index b88dac8..813b761 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestFunctionHandler.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestFunctionHandler.java @@ -60,7 +60,9 @@ public class TestFunctionHandler { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.mockStatic(LibRuntime.class); when(LibRuntime.IsInitialized()).thenReturn(true); when(LibRuntime.Init(any())).thenReturn(new ErrorInfo()); @@ -73,7 +75,7 @@ public class TestFunctionHandler { } @Test - public void testFunctionHandler() throws IOException, YRException { + public void testFunctionHandler() throws Exception { Pair mockResult = new Pair(new ErrorInfo(), "objID"); when(LibRuntime.InvokeInstance(any(), anyString(), anyList(), any())).thenReturn(mockResult); InvokeOptions options= new InvokeOptions(); diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoFunctionHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoFunctionHandler.java new file mode 100644 index 0000000..8ec442b --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoFunctionHandler.java @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.call; + +import com.yuanrong.Config; +import com.yuanrong.InvokeOptions; +import com.yuanrong.api.YR; +import com.yuanrong.exception.YRException; +import com.yuanrong.function.GoFunction; +import com.yuanrong.runtime.ClusterModeRuntime; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.io.IOException; + +/** + * Go function handler test + * + * @since 2024/03/20 + */ +@RunWith(PowerMockRunner.class) +@PowerMockIgnore({"javax.net.ssl.*", "javax.management.*"}) +@PrepareForTest(YR.class) +public class TestGoFunctionHandler { + ClusterModeRuntime runtime = PowerMockito.mock(ClusterModeRuntime.class); + @Before + public void initYR() throws Exception { + Config conf = new Config( + "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", + "127.0.0.0", + "127.0.0.0", + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); + PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); + YR.init(conf); + } + + @Test + public void testGoInstanceHandlerInvoke() throws IOException, YRException { + GoFunction goFunction = GoFunction.of("PlusOne", int.class, 1); + GoFunction plusOne = GoFunction.of("PlusOne", 1); + Assert.assertEquals("PlusOne", goFunction.functionName); + Assert.assertEquals(int.class, goFunction.returnType); + Assert.assertEquals(1, goFunction.returnNum); + GoFunctionHandler goFunctionHandler = YR.function(goFunction); + InvokeOptions invokeOptions = new InvokeOptions(); + invokeOptions.setCpu(1500); + invokeOptions.setMemory(1500); + goFunctionHandler.options(invokeOptions); + goFunctionHandler.invoke(goFunction); + } + + @After + public void finalizeYR() throws YRException { + YR.Finalize(); + } +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceCreator.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceCreator.java new file mode 100644 index 0000000..719b5d7 --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceCreator.java @@ -0,0 +1,86 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.call; + +import com.yuanrong.Config; +import com.yuanrong.InvokeOptions; +import com.yuanrong.api.YR; +import com.yuanrong.exception.YRException; +import com.yuanrong.function.GoInstanceClass; +import com.yuanrong.runtime.ClusterModeRuntime; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.io.IOException; + +/** + * Go instance creator test + * + * @since 2024/03/20 + */ +@RunWith(PowerMockRunner.class) +@PowerMockIgnore({"javax.net.ssl.*", "javax.management.*"}) +@PrepareForTest(YR.class) +public class TestGoInstanceCreator { + ClusterModeRuntime runtime = PowerMockito.mock(ClusterModeRuntime.class); + + @Before + public void initYR() throws Exception { + Config conf = new Config( + "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", + "127.0.0.0", + "127.0.0.0", + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); + PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); + YR.init(conf); + } + + @Test + public void testGoInstanceHandlerInvoke() { + GoInstanceClass goInstanceClass = GoInstanceClass.of("Counter"); + Assert.assertEquals("Counter", goInstanceClass.className); + InvokeOptions invokeOptions = new InvokeOptions(); + invokeOptions.setCpu(1500); + invokeOptions.setMemory(1500); + GoInstanceCreator goInstanceCreator = YR.instance(goInstanceClass).options(invokeOptions); + Assert.assertNotNull(goInstanceCreator); + } + + @Test + public void testInitGoInstanceCreator() throws IOException, YRException { + GoInstanceClass goInstanceClass = GoInstanceClass.of("Counter"); + GoInstanceCreator goInstanceCreator1 = new GoInstanceCreator(goInstanceClass); + goInstanceCreator1.invoke(goInstanceClass); + goInstanceCreator1.getFunctionMeta(); + Assert.assertEquals("Counter", goInstanceClass.className); + } + + @After + public void finalizeYR() throws YRException { + YR.Finalize(); + } +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceFunctionHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceFunctionHandler.java new file mode 100644 index 0000000..82da508 --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceFunctionHandler.java @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.call; + +import com.yuanrong.Config; +import com.yuanrong.InvokeOptions; +import com.yuanrong.api.YR; +import com.yuanrong.exception.YRException; +import com.yuanrong.function.GoInstanceMethod; +import com.yuanrong.runtime.ClusterModeRuntime; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.io.IOException; + +/** + * Go instance function handler test + * + * @since 2024/03/20 + */ +@RunWith(PowerMockRunner.class) +@PowerMockIgnore({"javax.net.ssl.*", "javax.management.*"}) +@PrepareForTest(YR.class) +public class TestGoInstanceFunctionHandler { + ClusterModeRuntime runtime = PowerMockito.mock(ClusterModeRuntime.class); + + @Before + public void initYR() throws Exception { + Config conf = new Config( + "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", + "127.0.0.0", + "127.0.0.0", + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); + PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); + YR.init(conf); + } + + @Test + public void testGoInstanceHandlerInvoke() { + String instanceId = "instanceId"; + String className = "Counter"; + GoInstanceHandler instance = new GoInstanceHandler(instanceId, className); + Assert.assertEquals(className, instance.getClassName()); + Assert.assertEquals(instanceId, instance.getInstanceId()); + GoInstanceFunctionHandler functionHandler = instance.function(GoInstanceMethod.of("Add", int.class, 1)); + InvokeOptions invokeOptions = new InvokeOptions(); + invokeOptions.setCpu(1500); + invokeOptions.setMemory(2500); + functionHandler.options(invokeOptions); + Assert.assertEquals(className, functionHandler.getClassName()); + Assert.assertEquals(instanceId, functionHandler.getInstanceId()); + Assert.assertEquals(1500, functionHandler.getOptions().getCpu()); + Assert.assertEquals(2500, functionHandler.getOptions().getMemory()); + } + + @Test + public void testInvoke() throws IOException, YRException { + String instanceId = "instanceId"; + String className = "Counter"; + GoInstanceHandler instance = new GoInstanceHandler(instanceId, className); + GoInstanceFunctionHandler functionHandler = instance.function(GoInstanceMethod.of("Add", int.class, 1)); + functionHandler.invoke(instance); + functionHandler.getGoInstanceMethod(); + Assert.assertEquals(instanceId, functionHandler.getInstanceId()); + } + + @After + public void finalizeYR() throws YRException { + YR.Finalize(); + } +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceHandler.java new file mode 100644 index 0000000..fc98594 --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestGoInstanceHandler.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.call; + +import com.yuanrong.Config; +import com.yuanrong.api.YR; +import com.yuanrong.exception.YRException; +import com.yuanrong.function.GoInstanceMethod; +import com.yuanrong.serialization.Serializer; +import com.yuanrong.runtime.ClusterModeRuntime; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.io.IOException; + +/** + * Description: test case of GoInstanceHandler. + * + * @since 2024/03/20 + */ +@RunWith(PowerMockRunner.class) +@PowerMockIgnore({"javax.net.ssl.*", "javax.management.*"}) +@PrepareForTest(YR.class) +public class TestGoInstanceHandler { + ClusterModeRuntime runtime = PowerMockito.mock(ClusterModeRuntime.class); + + @Before + public void init() throws Exception { + Config conf = new Config( + "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", + "127.0.0.0", + "127.0.0.0", + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); + PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); + YR.init(conf); + } + + @After + public void finalizeYR() throws YRException { + YR.Finalize(); + } + + @Test + public void testGoInstanceHandler() { + String instanceId = "instanceId"; + String className = "Counter"; + GoInstanceHandler instance = new GoInstanceHandler(instanceId, className); + GoInstanceHandlerHelper goInstanceHandlerHelper = new GoInstanceHandlerHelper(); + GoInstanceHandler goInstanceHandler = new GoInstanceHandler(); + goInstanceHandler.clearHandlerInfo(); + GoInstanceFunctionHandler functionHandler = instance.function( + GoInstanceMethod.of("Add", int.class, 1)); + GoInstanceMethod add = GoInstanceMethod.of("Add", 1); + + Assert.assertNotNull(goInstanceHandlerHelper); + Assert.assertEquals(functionHandler.getClassName(), className); + Assert.assertEquals(functionHandler.getInstanceId(), instanceId); + Assert.assertNotNull(functionHandler.getOptions()); + boolean isException = false; + try { + instance.terminate(); + instance.terminate(true); + instance.terminate(false); + } catch (YRException exp) { + exp.printStackTrace(); + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testGoInstanceHandlerDefault() { + GoInstanceHandler instance = new GoInstanceHandler(); + GoInstanceFunctionHandler functionHandler = instance.function( + GoInstanceMethod.of("Add", int.class, 1)); + + Assert.assertEquals(functionHandler.getClassName(), ""); + Assert.assertNull(functionHandler.getInstanceId()); + Assert.assertNotNull(functionHandler.getOptions()); + } + + @Test + public void testGoInstanceHandlerPacker() throws IOException { + String instanceId = "instanceId"; + String className = "Counter"; + GoInstanceHandler instance = new GoInstanceHandler(instanceId, className); + byte[] bytes = Serializer.serialize(instance); + if (Serializer.deserialize(bytes, GoInstanceHandler.class) instanceof GoInstanceHandler) { + GoInstanceHandler unpackHandler = (GoInstanceHandler) Serializer.deserialize(bytes, + GoInstanceHandler.class); + Assert.assertEquals(unpackHandler.getClassName(), className); + Assert.assertEquals(unpackHandler.getInstanceId(), instanceId); + } + } +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestInstanceHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestInstanceHandler.java index 2d0cfea..6563966 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestInstanceHandler.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestInstanceHandler.java @@ -65,7 +65,9 @@ public class TestInstanceHandler { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } @@ -183,6 +185,7 @@ public class TestInstanceHandler { Assert.assertEquals(newHandler.getInstanceId(), instanceId); Assert.assertEquals(newHandler.getRealInstanceId(), "realInsID"); Assert.assertEquals(newHandler.isNeedOrder(), true); + newHandler.release(); } catch (YRException exp) { exp.printStackTrace(); isException = true; diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaFunctionHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaFunctionHandler.java index c5a240d..6d6ed69 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaFunctionHandler.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaFunctionHandler.java @@ -47,7 +47,9 @@ public class TestJavaFunctionHandler { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceCreator.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceCreator.java index 39126c3..b65b0dd 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceCreator.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceCreator.java @@ -52,7 +52,9 @@ public class TestJavaInstanceCreator { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceFunctionHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceFunctionHandler.java index f7b5d72..74d582e 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceFunctionHandler.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceFunctionHandler.java @@ -52,7 +52,9 @@ public class TestJavaInstanceFunctionHandler { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } @@ -80,7 +82,7 @@ public class TestJavaInstanceFunctionHandler { } @Test - public void testInvokeJavaInstanceFunctionHandler() throws YRException { + public void testInvokeJavaInstanceFunctionHandler() throws IOException, YRException { JavaInstanceHandler javaInstanceHandler = new JavaInstanceHandler(); JavaInstanceFunctionHandler functionHandler = javaInstanceHandler.function(JavaInstanceMethod.of("Add", int.class)); functionHandler.invoke(javaInstanceHandler); diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceHandler.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceHandler.java index 06de012..097bab6 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceHandler.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/call/TestJavaInstanceHandler.java @@ -53,7 +53,9 @@ public class TestJavaInstanceHandler { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - ""); + "", + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); PowerMockito.whenNew(ClusterModeRuntime.class).withAnyArguments().thenReturn(runtime); YR.init(conf); } @@ -142,6 +144,7 @@ public class TestJavaInstanceHandler { Assert.assertEquals(newHandler.getInstanceId(), instanceId); Assert.assertEquals(newHandler.getRealInstanceId(), "realInsID"); Assert.assertEquals(newHandler.isNeedOrder(), true); + newHandler.release(); } catch (YRException exp) { exp.printStackTrace(); isException = true; diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestJobExecutor.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestJobExecutor.java new file mode 100644 index 0000000..c25af30 --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestJobExecutor.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yuanrong.jobexecutor; + +import com.yuanrong.exception.YRException; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; + +public class TestJobExecutor { + @Test + public void testGetJobInfo() { + RuntimeEnv runtimeEnv = new RuntimeEnv(); + runtimeEnv.setPackages("test-packages"); + runtimeEnv.setPackageManager("pip3.9"); + ArrayList testEntryPoint = new ArrayList<>(); + testEntryPoint.add("python3.9"); + + ArrayList testEntryPointWithWrongValue = new ArrayList<>(); + testEntryPointWithWrongValue.add("point01"); + testEntryPointWithWrongValue.add("point02"); + + boolean isException = false; + JobExecutor jobExecutor; + + try { + jobExecutor = new JobExecutor("test-job", runtimeEnv, testEntryPointWithWrongValue, "test-affinity"); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + + try { + jobExecutor = new JobExecutor("test-job", runtimeEnv, testEntryPoint, "test-affinity"); + jobExecutor.getJobInfo(true); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + + try { + jobExecutor = new JobExecutor("test-job", runtimeEnv, testEntryPoint, "test-affinity"); + jobExecutor.stop(); + jobExecutor.getJobInfo(false); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestOBSoptions.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestOBSoptions.java new file mode 100644 index 0000000..beab73c --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestOBSoptions.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import com.yuanrong.exception.YRException; + +import org.junit.Assert; +import org.junit.Test; + +public class TestOBSoptions { + @Test + public void testInitOptions() { + OBSoptions obSoptions = new OBSoptions(); + String expectedValue = "test-ID"; + obSoptions.setAk("test-Ak"); + obSoptions.setSk("test-SK"); + obSoptions.setEndPoint("test-EndPoint"); + obSoptions.setBucketID(expectedValue); + obSoptions.setObjectID(expectedValue); + obSoptions.setSecurityToken("test-token"); + + obSoptions.getAk(); + obSoptions.getSk(); + obSoptions.getEndPoint(); + obSoptions.getSecurityToken(); + + Assert.assertEquals(expectedValue, obSoptions.getBucketID()); + Assert.assertEquals(expectedValue, obSoptions.getObjectID()); + } + + @Test + public void testToMap() { + + OBSoptions obSoptions = new OBSoptions(); + obSoptions.setEndPoint("https://example.com"); + obSoptions.setBucketID("test-bucket"); + obSoptions.setObjectID("test-object"); + obSoptions.setSecurityToken("test-token"); + obSoptions.setAk("test-ak"); + obSoptions.setSk("test-sk"); + + boolean isException = false; + + try { + obSoptions.toMap(); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("Failed to convert OBS options to a json string")); + isException = true; + } + Assert.assertFalse(isException); + } +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestRuntimeEnv.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestRuntimeEnv.java new file mode 100644 index 0000000..97dfa74 --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestRuntimeEnv.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import com.yuanrong.exception.YRException; + +import org.junit.Assert; +import org.junit.Test; + +public class TestRuntimeEnv { + @Test + public void testInitRuntimeEnvWithExpectedValue() { + boolean expectedValue = true; + + RuntimeEnv runtimeEnv = new RuntimeEnv(); + RuntimeEnv testRuntimeEnv = new RuntimeEnv(runtimeEnv); + testRuntimeEnv.setPackages("test-package"); + testRuntimeEnv.setPackageManager("test-PIP"); + testRuntimeEnv.setShouldPipCheck(expectedValue); + testRuntimeEnv.setTrustedSource("test-source"); + testRuntimeEnv.getPackages(); + testRuntimeEnv.getPackageManager(); + testRuntimeEnv.getTrustedSource(); + + Assert.assertTrue(testRuntimeEnv.isShouldPipCheck()); + } + + @Test + public void testRuntimeEnvToCommand() { + RuntimeEnv runtimeEnv = new RuntimeEnv(); + boolean isException = false; + + runtimeEnv.setPackageManager(""); + try { + runtimeEnv.toCommand(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + + runtimeEnv.setPackageManager(null); + try { + runtimeEnv.toCommand(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + runtimeEnv.setPackageManager("test-PIP"); + + runtimeEnv.setPackages(); + try { + runtimeEnv.toCommand(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + runtimeEnv.setPackages("test-packages"); + + runtimeEnv.setPackageManager("pip3.9"); + try { + runtimeEnv.toCommand(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + + runtimeEnv.setPackageManager("pip3.8"); + try { + runtimeEnv.toCommand(); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("Only pip3.9 is supported currently.")); + runtimeEnv.setPackageManager("pip3.9"); + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + runtimeEnv.setShouldPipCheck(true); + try { + runtimeEnv.toCommand(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } + +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestYRJobInfo.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestYRJobInfo.java new file mode 100644 index 0000000..7dc75bb --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestYRJobInfo.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import org.junit.Assert; +import org.junit.Test; + +public class TestYRJobInfo { + @Test + public void testInitYRJobInfo() { + YRJobInfo yrJobInfo = new YRJobInfo(); + yrJobInfo.setRuntimeEnv(new RuntimeEnv()); + YRJobInfo testJobInfo = new YRJobInfo(yrJobInfo); + testJobInfo.setJobName("test-name"); + testJobInfo.setUserJobID("test-id"); + testJobInfo.setJobStartTime("2024-01-01"); + testJobInfo.setJobEndTime("2025-01-01"); + testJobInfo.setRuntimeEnv(new RuntimeEnv()); + testJobInfo.setStatus(YRJobStatus.SUCCEEDED); + testJobInfo.setErrorMessage("failed"); + + testJobInfo.getJobName(); + testJobInfo.getJobEndTime(); + testJobInfo.getUserJobID(); + testJobInfo.getStatus(); + testJobInfo.getErrorMessage(); + testJobInfo.getJobStartTime(); + testJobInfo.getRuntimeEnv(); + Assert.assertTrue(testJobInfo.ifFinalized()); + } + + @Test + public void testUpdate() { + YRJobInfo yrJobInfo = new YRJobInfo(); + YRJobInfo testJobInfo = new YRJobInfo(); + testJobInfo.setJobName("test-name"); + testJobInfo.setUserJobID("test-id"); + testJobInfo.setJobStartTime("2024-01-01"); + testJobInfo.setJobEndTime("2025-01-01"); + testJobInfo.setRuntimeEnv(new RuntimeEnv()); + testJobInfo.setStatus(YRJobStatus.SUCCEEDED); + testJobInfo.setErrorMessage(null); + yrJobInfo.update(testJobInfo); + Assert.assertTrue(yrJobInfo.ifFinalized()); + } +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestYRJobParam.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestYRJobParam.java new file mode 100644 index 0000000..46c6728 --- /dev/null +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/jobexecutor/TestYRJobParam.java @@ -0,0 +1,254 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.jobexecutor; + +import com.yuanrong.exception.YRException; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashMap; + +public class TestYRJobParam { + @Test + public void testInitYRJobParam() { + boolean isException = false; + YRJobParam yrJobParam = new YRJobParam(); + yrJobParam.setJobName("test-job"); + + try { + yrJobParam.setCpu(100); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("The CPU value")); + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + yrJobParam.setCpu(500); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("The CPU value")); + isException = true; + } + Assert.assertFalse(isException); + + try { + yrJobParam.setMemory(100); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("The memory value")); + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + yrJobParam.setMemory(500); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("The memory value")); + isException = true; + } + Assert.assertFalse(isException); + + try { + yrJobParam.setEntryPoint(new ArrayList<>()); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("Length of the entryPoint")); + isException = true; + } + Assert.assertTrue(isException); + + ArrayList testEntryPoint = new ArrayList<>(); + testEntryPoint.add("test-entryPoint1"); + testEntryPoint.add("test-entryPoint2"); + testEntryPoint.add("test-entryPoint3"); + testEntryPoint.add("test-entryPoint4"); + isException = false; + try { + yrJobParam.setEntryPoint(testEntryPoint); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("Length of the entryPoint")); + isException = true; + } + Assert.assertFalse(isException); + + yrJobParam.setRuntimeEnv(new RuntimeEnv()); + yrJobParam.setRuntimeEnv("test-packageManager", new String[] {"test-package"}); + yrJobParam.setObsOptions(null); + yrJobParam.addScheduleAffinity(null); + yrJobParam.setCustomExtensions(new HashMap<>()); + yrJobParam.setCustomResources(new HashMap<>()); + yrJobParam.setScheduleAffinities(new ArrayList<>()); + yrJobParam.preferredPriority(true); + + try { + yrJobParam.setLocalCodePath(""); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("localCodePath cannot be an empty String.")); + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + try { + yrJobParam.setLocalCodePath("/tmp"); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("localCodePath cannot be an empty String.")); + isException = true; + } + Assert.assertFalse(isException); + + isException = false; + try { + OBSoptions obSoptions = new OBSoptions(); + obSoptions.setAk("test-Ak"); + obSoptions.setSk("test-SK"); + yrJobParam.setObsOptions(obSoptions); + yrJobParam.setLocalCodePath("test"); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("localCodePath cannot be an empty String.")); + isException = true; + } + Assert.assertFalse(isException); + + yrJobParam.setObsOptions(new OBSoptions()); + yrJobParam.setPreferredPriority(true); + yrJobParam.requiredPriority(true); + yrJobParam.getJobName(); + yrJobParam.getCpu(); + yrJobParam.getMemory(); + yrJobParam.getEntryPoint(); + yrJobParam.getCustomResources(); + yrJobParam.getCustomExtensions(); + yrJobParam.getLocalCodePath(); + yrJobParam.getObsOptions(); + yrJobParam.getScheduleAffinities(); + yrJobParam.getObsOptions(); + yrJobParam.getRuntimeEnv(); + Assert.assertTrue(yrJobParam.isPreferredPriority()); + } + + @Test + public void testBuilder() { + YRJobParam testYRJobparam = YRJobParam.builder().build(); + Assert.assertTrue(testYRJobparam.isPreferredPriority()); + + ArrayList testEntryPoint = new ArrayList<>(); + testEntryPoint.add("test-entryPoint1"); + testEntryPoint.add("test-entryPoint2"); + testEntryPoint.add("test-entryPoint3"); + testEntryPoint.add("test-entryPoint4"); + + boolean isException = false; + try { + YRJobParam yrJobParam = new YRJobParam.Builder().cpu(500) + .memory(500) + .entryPoint(testEntryPoint) + .localCodePath("/tmp") + .jobName("test-job") + .addScheduleAffinity(null) + .obsOptions(new OBSoptions()) + .runtimeEnv(new RuntimeEnv()) + .preferredPriority(true) + .scheduleAffinity(new ArrayList<>()) + .requiredPriority(true) + .build(); + Assert.assertTrue(yrJobParam.isPreferredPriority()); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testGetLocalEntryPoint() { + YRJobParam yrJobParam = new YRJobParam(); + boolean isException = false; + + try { + yrJobParam.getLocalEntryPoint(); + } catch (YRException e) { + Assert.assertTrue(e.getMessage().contains("Failed to get entryPoint. EntryPoint should be set.")); + isException = true; + } + Assert.assertTrue(isException); + + ArrayList testEntryPoint = new ArrayList<>(); + testEntryPoint.add("test-entryPoint1"); + testEntryPoint.add("test-entryPoint2"); + testEntryPoint.add("test-entryPoint3"); + testEntryPoint.add("test-entryPoint4"); + isException = false; + + try { + yrJobParam.setEntryPoint(testEntryPoint); + yrJobParam.getLocalEntryPoint(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + + try { + yrJobParam.setLocalCodePath("/tmp"); + yrJobParam.getLocalEntryPoint(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testExtractInvokeOptions() { + YRJobParam yrJobParam = new YRJobParam(); + + RuntimeEnv runtimeEnv = new RuntimeEnv(); + runtimeEnv.setPackages("test-packages"); + runtimeEnv.setPackageManager("pip3.9"); + + OBSoptions obSoptions = new OBSoptions(); + obSoptions.setAk("test-Ak"); + obSoptions.setSk("test-SK"); + obSoptions.setEndPoint("test-EndPoint"); + obSoptions.setBucketID("test-ID"); + obSoptions.setObjectID("test-ID"); + obSoptions.setSecurityToken("test-token"); + + boolean isException = false; + + try { + yrJobParam.setRuntimeEnv(runtimeEnv); + yrJobParam.extractInvokeOptions(); + } catch (YRException e) { + Assert.assertTrue( + e.getMessage().contains("Either YRJobParam.obsOptions or YRJobParam.localCodePath should be set")); + isException = true; + } + Assert.assertTrue(isException); + + isException = false; + + try { + yrJobParam.setObsOptions(obSoptions); + yrJobParam.extractInvokeOptions(); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } + +} diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/TestClusterModeRuntime.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/TestClusterModeRuntime.java index 5ef9d69..a80be2b 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/TestClusterModeRuntime.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/TestClusterModeRuntime.java @@ -24,10 +24,12 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.when; +import com.yuanrong.CreateParam; import com.yuanrong.GetParam; import com.yuanrong.GetParams; import com.yuanrong.GroupOptions; import com.yuanrong.InvokeOptions; +import com.yuanrong.MSetParam; import com.yuanrong.SetParam; import com.yuanrong.api.InvokeArg; import com.yuanrong.errorcode.ErrorCode; @@ -37,6 +39,7 @@ import com.yuanrong.errorcode.Pair; import com.yuanrong.exception.YRException; import com.yuanrong.exception.LibRuntimeException; import com.yuanrong.call.CppInstanceHandler; +import com.yuanrong.call.GoInstanceHandler; import com.yuanrong.call.InstanceHandler; import com.yuanrong.call.JavaInstanceHandler; import com.yuanrong.jni.LibRuntime; @@ -45,6 +48,8 @@ import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; import com.yuanrong.libruntime.generated.Libruntime.LanguageType; import com.yuanrong.runtime.client.ObjectRef; import com.yuanrong.storage.InternalWaitResult; +import com.yuanrong.stream.ProducerConfig; +import com.yuanrong.stream.SubscriptionConfig; import com.yuanrong.utils.SdkUtils; import org.junit.Assert; @@ -60,6 +65,7 @@ import org.powermock.modules.junit4.PowerMockRunner; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -81,9 +87,9 @@ public class TestClusterModeRuntime { public static void SetUp() {} @Test - public void testIsOnCloud() { + public void testIsDriver() { Runtime runtime = new ClusterModeRuntime(); - Assert.assertFalse(runtime.isOnCloud()); + Assert.assertTrue(runtime.isDriver()); } @Test @@ -93,7 +99,7 @@ public class TestClusterModeRuntime { } @Test - public void testSaveAndGetRealInstanceId() throws YRException { + public void testSaveAndGetRealInstanceId() throws Exception { PowerMockito.mockStatic(LibRuntime.class); Runtime runtime = new ClusterModeRuntime(); when(LibRuntime.GetRealInstanceId(anyString())).thenReturn("instanceID"); @@ -108,13 +114,15 @@ public class TestClusterModeRuntime { JavaInstanceHandler jHander = new JavaInstanceHandler("jInstanceId", "jFunctionId", "jClassName"); CppInstanceHandler cHandler = new CppInstanceHandler("cInstanceId", "cFunctionId", "cClassName"); InstanceHandler handler = new InstanceHandler("instance", ApiType.Function); - runtime.collectInstanceHandlerInfo(jHander); - runtime.collectInstanceHandlerInfo(cHandler); - runtime.collectInstanceHandlerInfo(handler); - Assert.assertNotNull(runtime.getInstanceHandlerInfo("instance")); - runtime.Finalize("ctx", 0); - runtime.Finalize(); - runtime.exit(); + boolean isException = false; + try { + runtime.Finalize("ctx", 0); + runtime.Finalize(); + runtime.exit(); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); } @Test @@ -128,7 +136,6 @@ public class TestClusterModeRuntime { for (int i = 0; i < numberOfThreads; i++) { executor.submit(() -> { InstanceHandler handler = new InstanceHandler("instance", ApiType.Function); - runtime.collectInstanceHandlerInfo(handler); }); } } catch (Exception e) { @@ -136,12 +143,14 @@ public class TestClusterModeRuntime { } Assert.assertFalse(isException); executor.shutdown(); + runtime.decreaseReference(Collections.singletonList("instanceID")); + runtime.Finalize("ctx", 0); runtime.Finalize(); runtime.exit(); } @Test - public void testKVWrite() throws YRException { + public void testKVWrite() throws Exception { String key = "key1"; byte[] value = "value1".getBytes(); PowerMockito.mockStatic(LibRuntime.class); @@ -174,7 +183,7 @@ public class TestClusterModeRuntime { } @Test - public void testKVRead() throws YRException { + public void testKVRead() throws Exception { String key = "key1"; PowerMockito.mockStatic(LibRuntime.class); byte[] failedKey = ("key1").getBytes(); @@ -202,7 +211,7 @@ public class TestClusterModeRuntime { } @Test - public void testKVReadList() throws YRException { + public void testKVReadList() throws Exception { List keys = new ArrayList(); keys.add("key1"); PowerMockito.mockStatic(LibRuntime.class); @@ -233,7 +242,7 @@ public class TestClusterModeRuntime { } @Test - public void testKVGetWithParam() throws YRException { + public void testKVGetWithParam() throws Exception { List keys = new ArrayList(){{add("key1");}}; PowerMockito.mockStatic(LibRuntime.class); List vals = new ArrayList(); @@ -266,7 +275,7 @@ public class TestClusterModeRuntime { } @Test - public void testKVdel() throws YRException { + public void testKVdel() throws Exception { String key = "key1"; PowerMockito.mockStatic(LibRuntime.class); ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_DATASYSTEM_FAILED, ModuleCode.DATASYSTEM, ""); @@ -291,7 +300,7 @@ public class TestClusterModeRuntime { } @Test - public void testKVdelList() throws YRException { + public void testKVdelList() throws Exception { List keys = new ArrayList(); keys.add("key1"); @@ -310,7 +319,7 @@ public class TestClusterModeRuntime { } @Test - public void testLoadState() throws YRException { + public void testLoadState() throws Exception { PowerMockito.mockStatic(LibRuntime.class); ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, ""); when(LibRuntime.LoadState(anyInt())).thenReturn(errorInfo); @@ -334,7 +343,7 @@ public class TestClusterModeRuntime { } @Test - public void testSaveState() throws YRException { + public void testSaveState() throws Exception { PowerMockito.mockStatic(LibRuntime.class); ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, ""); when(LibRuntime.SaveState(anyInt())).thenReturn(errorInfo); @@ -447,6 +456,31 @@ public class TestClusterModeRuntime { isException = true; } Assert.assertFalse(isException); + + when(LibRuntime.PutWithParam(any(), anyList(), any())).thenReturn(mockRes); + try { + runtime.put(obj, 10L, 20L, new CreateParam()); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + + when(LibRuntime.Put(any(), anyList())).thenThrow(new LibRuntimeException("error occurred")); + try { + runtime.put(obj, 10L, 20L); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + + when(LibRuntime.PutWithParam(any(), anyList(), any())).thenThrow(new LibRuntimeException("error occurred")); + isException = false; + try { + runtime.put(obj, 10L, 20L, new CreateParam()); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); } @Test @@ -498,7 +532,7 @@ public class TestClusterModeRuntime { List readyIds = Arrays.asList("2", "1"); List unreadyIds = new ArrayList(); - Map exceptionIds = new HashMap(){ + Map exceptionIds = new HashMap() { { put("err", new ErrorInfo(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, "")); } @@ -518,6 +552,22 @@ public class TestClusterModeRuntime { when(LibRuntime.Wait(anyList(), anyInt(), anyInt())).thenReturn(waitResult); InternalWaitResult res = runtime.wait(normalObjIds, 2, 10); Assert.assertNotNull(res); + + exceptionIds = new HashMap() { + { + put("err", null); + } + }; + waitResult = new InternalWaitResult(readyIds, unreadyIds, exceptionIds); + when(LibRuntime.Wait(anyList(), anyInt(), anyInt())).thenReturn(waitResult); + isException = false; + try { + runtime.wait(normalObjIds, 2, 10); + } catch (YRException e) { + isException = true; + Assert.assertTrue(e.getMessage().contains("errorInfo is null")); + } + Assert.assertTrue(isException); } @Test @@ -532,20 +582,41 @@ public class TestClusterModeRuntime { Pair mockRes = new Pair(new ErrorInfo(), "objID"); when(LibRuntime.CreateInstance(any(), anyList(), any())).thenReturn(mockRes); String instanceID = runtime.createInstance(meta, args, opts); - Assert.assertTrue(instanceID.equals("objID")); + Assert.assertEquals("objID", instanceID); + + when(LibRuntime.CreateInstance(any(), anyList(), any())).thenThrow( + new LibRuntimeException("error occurred")); + boolean isException = false; + try { + runtime.createInstance(meta, args, opts); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); when(LibRuntime.InvokeInstance(any(), anyString(), anyList(), any())).thenReturn(mockRes); String objID = runtime.invokeInstance(meta, instanceID, args, opts); - Assert.assertTrue(objID.equals("objID")); + Assert.assertEquals("objID", objID); + + when(LibRuntime.InvokeInstance(any(), anyString(), anyList(), any())).thenThrow( + new LibRuntimeException("error occurred")); + isException = false; + try { + runtime.invokeInstance(meta, instanceID, args, opts); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); when(LibRuntime.Kill(anyString())).thenReturn(new ErrorInfo()); - boolean isException = false; + isException = false; try { runtime.terminateInstance(instanceID); } catch (YRException e) { isException = true; } Assert.assertFalse(isException); + when(LibRuntime.KillSync(anyString())).thenReturn(new ErrorInfo()); isException = false; try { @@ -557,7 +628,7 @@ public class TestClusterModeRuntime { } @Test - public void testGroupCreate() throws YRException { + public void testGroupCreate() throws Exception { Runtime runtime = new ClusterModeRuntime(); GroupOptions opt = new GroupOptions(); opt.setTimeout(-2); @@ -594,7 +665,7 @@ public class TestClusterModeRuntime { } @Test - public void testGroupWait() throws YRException { + public void testGroupWait() throws Exception { Runtime runtime = new ClusterModeRuntime(); PowerMockito.mockStatic(LibRuntime.class); ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, ""); @@ -615,4 +686,209 @@ public class TestClusterModeRuntime { } Assert.assertFalse(isException); } + + @Test + public void testGroupTerminate() throws YRException { + Runtime runtime = new ClusterModeRuntime(); + PowerMockito.mockStatic(LibRuntime.class); + runtime.groupTerminate("group1"); + } + + @Test + public void testCreateStreamProducer() throws Exception { + ProducerConfig cfg = new ProducerConfig(); + Runtime runtime = new ClusterModeRuntime(); + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.CreateStreamProducerWithConfig(anyString(), any())).thenReturn(10086L); + boolean isException = false; + try { + runtime.createStreamProducer("stream", cfg); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + when(LibRuntime.CreateStreamProducerWithConfig(anyString(), any())).thenThrow(new LibRuntimeException("err")); + isException = false; + try { + runtime.createStreamProducer("stream", cfg); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testCreateStreamProducerFailed() throws Exception { + Runtime runtime = new ClusterModeRuntime(); + boolean isException = false; + try { + ProducerConfig cfg = new ProducerConfig(); + cfg.setMaxStreamSize(-1L); + runtime.createStreamProducer("aaaaaaaa", cfg); + } catch (YRException e) { + Assert.assertTrue(e.getErrorCode().toString().equals("1001")); + Assert.assertTrue(e.getErrorMessage().contains("is invalid, expect >= 0")); + isException = true; + } + Assert.assertTrue(isException); + isException = false; + try { + ProducerConfig cfg = new ProducerConfig(); + cfg.setRetainForNumConsumers(-1L); + runtime.createStreamProducer("aaaaaaaa", cfg); + } catch (YRException e) { + Assert.assertTrue(e.getErrorCode().toString().equals("1001")); + Assert.assertTrue(e.getErrorMessage().contains("is invalid, expect >= 0")); + isException = true; + } + Assert.assertTrue(isException); + isException = false; + try { + ProducerConfig cfg = new ProducerConfig(); + cfg.setReserveSize(-1L); + runtime.createStreamProducer("aaaaaaaa", cfg); + } catch (YRException e) { + Assert.assertTrue(e.getErrorCode().toString().equals("1001")); + Assert.assertTrue(e.getErrorMessage().contains("is invalid, expect >= 0")); + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testCreateStreamConsumer() throws Exception { + SubscriptionConfig cfg = new SubscriptionConfig(); + Runtime runtime = new ClusterModeRuntime(); + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.CreateStreamConsumer(anyString(), anyString(), any(), anyBoolean())).thenReturn(10086L); + boolean isException = false; + try { + runtime.createStreamConsumer("stream", cfg, false); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + when(LibRuntime.CreateStreamConsumer(anyString(), anyString(), any(), anyBoolean())) + .thenThrow(new LibRuntimeException("err")); + isException = false; + try { + runtime.createStreamConsumer("stream", cfg, false); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testDeleteStream() throws Exception { + Runtime runtime = new ClusterModeRuntime(); + PowerMockito.mockStatic(LibRuntime.class); + ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, ""); + when(LibRuntime.DeleteStream(anyString())).thenReturn(errorInfo); + boolean isException = false; + try { + runtime.deleteStream("stream"); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + when(LibRuntime.DeleteStream(anyString())).thenReturn(new ErrorInfo()); + isException = false; + try { + runtime.deleteStream("stream"); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testQueryGlobalConsumersNum() throws Exception { + Runtime runtime = new ClusterModeRuntime(); + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.QueryGlobalConsumersNum(anyString())).thenReturn(10086L); + boolean isException = false; + try { + runtime.queryGlobalConsumersNum("stream"); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + when(LibRuntime.QueryGlobalConsumersNum(anyString())).thenThrow(new LibRuntimeException("err")); + isException = false; + try { + runtime.queryGlobalConsumersNum("stream"); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testQueryGlobalProducersNum() throws Exception { + Runtime runtime = new ClusterModeRuntime(); + PowerMockito.mockStatic(LibRuntime.class); + when(LibRuntime.QueryGlobalProducersNum(anyString())).thenReturn(10086L); + boolean isException = false; + try { + runtime.queryGlobalProducersNum("stream"); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + when(LibRuntime.QueryGlobalProducersNum(anyString())).thenThrow(new LibRuntimeException("err")); + isException = false; + try { + runtime.queryGlobalProducersNum("stream"); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testKVMSetTx() throws Exception { + PowerMockito.mockStatic(LibRuntime.class); + ClusterModeRuntime runtime = new ClusterModeRuntime(); + ErrorInfo mockRes = new ErrorInfo(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, ""); + when(LibRuntime.KVMSetTx(anyList(), anyList(), any())).thenReturn(mockRes); + boolean isException = false; + try { + runtime.KVMSetTx(new ArrayList<>(), new ArrayList<>(), new MSetParam()); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + + when(LibRuntime.KVMSetTx(anyList(), anyList(), any())).thenThrow(new LibRuntimeException("error occurred")); + isException = false; + try { + runtime.KVMSetTx(new ArrayList<>(), new ArrayList<>(), new MSetParam()); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + } + + @Test + public void testGetInstanceRoute() throws Exception { + PowerMockito.mockStatic(LibRuntime.class); + ClusterModeRuntime runtime = new ClusterModeRuntime(); + when(LibRuntime.GetInstanceRoute(anyString())).thenReturn("getInstanceRoute"); + boolean isException = false; + try { + runtime.getInstanceRoute("objId"); + } catch (YRException e) { + isException = true; + } + Assert.assertFalse(isException); + + when(LibRuntime.GetInstanceRoute(anyString())).thenThrow(new LibRuntimeException("error occurred")); + try { + runtime.getInstanceRoute("objId"); + } catch (YRException e) { + isException = true; + } + Assert.assertTrue(isException); + } } diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/TestYR.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/TestYR.java index 96da7b9..1227a16 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/TestYR.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/TestYR.java @@ -26,6 +26,7 @@ import static org.mockito.Mockito.when; import com.yuanrong.Config; import com.yuanrong.ConfigManager; import com.yuanrong.api.ClientInfo; +import com.yuanrong.api.Node; import com.yuanrong.api.YR; import com.yuanrong.errorcode.ErrorCode; import com.yuanrong.errorcode.ErrorInfo; @@ -33,9 +34,14 @@ import com.yuanrong.errorcode.ModuleCode; import com.yuanrong.errorcode.Pair; import com.yuanrong.exception.YRException; import com.yuanrong.jni.LibRuntime; +import com.yuanrong.jobexecutor.YRJobParam; import com.yuanrong.runtime.client.ObjectRef; import com.yuanrong.storage.InternalWaitResult; import com.yuanrong.storage.WaitResult; +import com.yuanrong.stream.Consumer; +import com.yuanrong.stream.Producer; +import com.yuanrong.stream.ProducerConfig; +import com.yuanrong.stream.SubscriptionConfig; import com.yuanrong.utils.SdkUtils; import org.junit.Assert; @@ -51,6 +57,7 @@ import org.powermock.modules.junit4.PowerMockRunner; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -69,7 +76,8 @@ public class TestYR { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest"); + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); @Before public void init() throws Exception { @@ -107,7 +115,8 @@ public class TestYR { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest"); + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); invalidConf.setDataSystemAddress("invalidaddress"); boolean isException = false; try { @@ -127,7 +136,7 @@ public class TestYR { } Assert.assertFalse(isException); Config conf = new Config("sn:cn:yrk:12345678901234561234567890123456:function:0-st-stjava:$latest", - "127.0.0.1", 304822, "127.0.0.1", 290811, ""); + "127.0.0.1", 304822, "127.0.0.1", 290811, "", true); try { YR.init(invalidConf); } catch (Exception e) { @@ -158,6 +167,17 @@ public class TestYR { } } + /** + * Description: + * Test 'Exit' throws exception out of the runtime. + * Steps: + * 1. Calls 'Exit' when 'inCluster' is true. + * 2. Sets 'inCluster' false. + * 3. Calls 'Exit' again. + * Expectation: + * When 'inCluster' is true, call 'Exit' would not cause exception, + * and throws YRException otherwise. + */ @Test public void testExit() throws Exception { when(LibRuntime.IsInitialized()).thenReturn(true); @@ -170,6 +190,16 @@ public class TestYR { } Assert.assertFalse(isException); + conf.setInCluster(false); + ConfigManager.getInstance().init(conf); + try { + YR.exit(); + } catch (YRException e) { + isException = true; + Assert.assertTrue(e.getErrorMessage().contains("Not support exit out of cluster")); + } + Assert.assertTrue(isException); + // make other cases available YR.Finalize(); } @@ -177,10 +207,12 @@ public class TestYR { public void testTenantContext() { Config ctxConf = Config.builder() .functionURN("sn:cn:yrk:12345678901234561234567890123456:function:0-j-a:$latest") - .serverAddress("127.0.0.1") + .serverAddress("10.243.25.129") .serverAddressPort(31222) - .dataSystemAddress("127.0.0.1") + .dataSystemAddress("10.243.25.129") .dataSystemAddressPort(31501) + .isInCluster(false) + .tenantId("tenantId1") .enableSetContext(true) .build(); boolean isException = false; @@ -228,10 +260,12 @@ public class TestYR { public void testTenantContextFailed() { Config ctxConf = Config.builder() .functionURN("sn:cn:yrk:12345678901234561234567890123456:function:0-j-a:$latest") - .serverAddress("127.0.0.1") + .serverAddress("10.243.25.129") .serverAddressPort(31222) - .dataSystemAddress("127.0.0.1") + .dataSystemAddress("10.243.25.129") .dataSystemAddressPort(31501) + .isInCluster(false) + .tenantId("tenantId1") .isThreadLocal(true) .enableSetContext(true) .build(); @@ -268,14 +302,40 @@ public class TestYR { } @Test - public void testState() throws YRException { + public void testStream() throws YRException { + YR.init(conf); + ProducerConfig pCfg = new ProducerConfig(); + Producer producer = YR.createProducer("streamName", pCfg); + Assert.assertNotNull(producer); + producer = YR.createProducer("streamName"); + Assert.assertNotNull(producer); + + SubscriptionConfig sCfg = new SubscriptionConfig(); + Consumer consumer = YR.subscribe("streamName", sCfg); + Assert.assertNotNull(consumer); + consumer = YR.subscribe("streamName", sCfg, false); + Assert.assertNotNull(consumer); + + when(LibRuntime.DeleteStream(anyString())).thenReturn(new ErrorInfo()); + YR.deleteStream("streamName"); + YR.Finalize(); + } + + @Test + public void testState() throws Exception { YR.init(conf); when(LibRuntime.LoadState(anyInt())).thenReturn(new ErrorInfo()); when(LibRuntime.SaveState(anyInt())).thenReturn(new ErrorInfo()); - YR.saveState(20); - YR.saveState(); - YR.loadState(20); - YR.loadState(); + boolean isException = false; + try { + YR.saveState(20); + YR.saveState(); + YR.loadState(20); + YR.loadState(); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); YR.Finalize(); } @@ -295,8 +355,6 @@ public class TestYR { Assert.assertNotNull(ref); Object obj = YR.get(ref, 10); Assert.assertNotNull(obj); - obj = YR.get(ref); - Assert.assertNotNull(obj); obj = YR.get(Arrays.asList(ref), 10); Assert.assertNotNull(obj); obj = YR.get(Arrays.asList(ref), 10, false); @@ -336,4 +394,72 @@ public class TestYR { Assert.assertTrue(res.getReady().size() == 2); YR.Finalize(); } + + @Test + public void testJobs() throws Exception { + YR.init(conf); + Pair mockRes = new Pair(new ErrorInfo(), "objID"); + when(LibRuntime.CreateInstance(any(), anyList(), any())).thenReturn(mockRes); + when(LibRuntime.InvokeInstance(any(), anyString(), anyList(), any())).thenReturn(mockRes); + List readyIds = Arrays.asList("2", "1"); + List unreadyIds = new ArrayList(); + Map exceptionIds = new HashMap(); + InternalWaitResult waitResult = new InternalWaitResult(readyIds, unreadyIds, exceptionIds); + when(LibRuntime.Wait(anyList(), anyInt(), anyInt())).thenReturn(waitResult); + when(LibRuntime.GetRealInstanceId(anyString())).thenReturn("instanceID"); + List ok = new ArrayList(); + ok.add("result1".getBytes(StandardCharsets.UTF_8)); + ok.add("result2".getBytes(StandardCharsets.UTF_8)); + Pair> getRes = new Pair>(new ErrorInfo(), ok); + when(LibRuntime.Get(anyList(), anyInt(), anyBoolean())).thenReturn(getRes); + + YRJobParam param = new YRJobParam(); + ArrayList entryPoints = new ArrayList(){ + { + add("java1.8"); + add("/home/snuser/java"); + } + }; + param.setEntryPoint(entryPoints); + param.setJobName("jobName"); + param.setLocalCodePath("DEFAULT_LOCAL_PATH"); + + String res = YR.submitJob(param); + Assert.assertNotNull(res); + + YR.Finalize(); + } + + @Test + public void testNodes() throws Exception { + YR.init(conf); + Node node = new Node(); + String nodeId = "function-agent-x.x.x.x"; + node.setId(nodeId); + node.setAlive(true); + Map resources = new HashMap<>(); + resources.put("CPU", 1000F); + resources.put("Memory", 1000F); + node.setResources(resources); + Map> labels = new HashMap<>(); + List label = new ArrayList<>(); + label.add("label1"); + label.add("label2"); + labels.put(nodeId, label); + node.setLabels(labels); + List nodes = Collections.singletonList(node); + when(LibRuntime.nodes()).thenReturn(new Pair<>(new ErrorInfo(), nodes)); + + List getNodes = YR.nodes(); + Node node1 = getNodes.get(0); + Assert.assertEquals(node1.getId(), nodeId); + Assert.assertTrue(node1.isAlive()); + Assert.assertEquals(node1.getResources(), resources); + Assert.assertEquals(node1.getLabels(), labels); + + Node node2 = new Node(nodeId, true, resources, labels); + Assert.assertEquals(node, node2); + + YR.Finalize(); + } } diff --git a/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/client/TestKVManager.java b/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/client/TestKVManager.java index eb76820..af84b01 100644 --- a/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/client/TestKVManager.java +++ b/api/java/yr-api-sdk/src/test/java/com/yuanrong/runtime/client/TestKVManager.java @@ -57,7 +57,8 @@ public class TestKVManager { "sn:cn:yrk:12345678901234561234567890123456:function:0-crossyrlib-helloworld:$latest", "127.0.0.0", "127.0.0.0", - "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest"); + "sn:cn:yrk:12345678901234561234567890123456:function:0-test-hello:$latest", + true); @Before public void init() throws Exception { @@ -184,6 +185,9 @@ public class TestKVManager { add("test".getBytes(StandardCharsets.UTF_8)); add("test".getBytes(StandardCharsets.UTF_8)); }}; + List vals2 = new ArrayList(){{ + add(null); + }}; List lengths = Arrays.asList(5, 5); boolean isException = false; MSetParam abnormalParam = new MSetParam.Builder().ttlSecond(-1).build(); @@ -244,6 +248,14 @@ public class TestKVManager { isException = true; } Assert.assertTrue(isException); + + isException = false; + try { + kvManager.mSetTx(keys2, vals2, okParam); + } catch (Exception e) { + isException = true; + } + Assert.assertTrue(isException); } @Test diff --git a/api/java/yr-runtime/src/main/java/com/yuanrong/Entrypoint.java b/api/java/yr-runtime/src/main/java/com/yuanrong/Entrypoint.java index 2e86bfc..f103173 100644 --- a/api/java/yr-runtime/src/main/java/com/yuanrong/Entrypoint.java +++ b/api/java/yr-runtime/src/main/java/com/yuanrong/Entrypoint.java @@ -50,6 +50,7 @@ public class Entrypoint { * INSTANCE_ID, DRIVER_SERVER_PORT, HOST_IP, YR_FUNCTION_LIB_PATH, LAYER_LIB_PATH, LD_LIBRARY_PATH is unused. * see functionsystem/src/runtime_manager/config/build.cpp for details */ + private static final String ENV_POSIX_LISTEN_ADDR = "POSIX_LISTEN_ADDR"; private static final String ENV_DATASYSTEM_ADDR = "DATASYSTEM_ADDR"; @@ -58,6 +59,10 @@ public class Entrypoint { private static final String ENV_LOG_LEVEL = "logLevel"; + private static final String ENV_LOG_ID = "logId"; + + private static final String ENV_RUNTIME_ID = "runtimeId"; + private static final String ENV_GLOG_DIR = "GLOG_log_dir"; private static final String ENV_JOB_ID = "jobId"; @@ -70,6 +75,8 @@ public class Entrypoint { put(ENV_LOG_LEVEL, "INFO"); put(ENV_GLOG_DIR, "/home/snuser/log"); put(ENV_JOB_ID, Utils.generateCloudJobId()); + put(ENV_LOG_ID, ""); + put(ENV_RUNTIME_ID, ""); } }; diff --git a/api/java/yr-runtime/src/main/java/com/yuanrong/codemanager/CodeExecutor.java b/api/java/yr-runtime/src/main/java/com/yuanrong/codemanager/CodeExecutor.java index 931db36..79a131f 100644 --- a/api/java/yr-runtime/src/main/java/com/yuanrong/codemanager/CodeExecutor.java +++ b/api/java/yr-runtime/src/main/java/com/yuanrong/codemanager/CodeExecutor.java @@ -19,6 +19,7 @@ package com.yuanrong.codemanager; import com.yuanrong.errorcode.ErrorInfo; import com.yuanrong.errorcode.Pair; import com.yuanrong.executor.FunctionHandler; +import com.yuanrong.executor.FaaSHandler; import com.yuanrong.executor.HandlerIntf; import com.yuanrong.executor.PosixHandler; import com.yuanrong.executor.ReturnType; @@ -43,7 +44,9 @@ public class CodeExecutor { private static final String INIT_HANDLER = "INIT_HANDLER"; private static final String FUNCTION_HANDLER - = "com.yuanrong.handler.InitHandler"; + = "com.yuanrong.handler.InitHandler"; + + private static final String FAAS_HANDLER = "com.services.handler.FaaSExecutor.faasInitHandler"; private static final Logger LOG = LoggerFactory.getLogger(CodeExecutor.class); @@ -56,6 +59,8 @@ public class CodeExecutor { String handler = System.getenv(INIT_HANDLER); if (FUNCTION_HANDLER.equals(handler)) { handlerIntf = new FunctionHandler(); + } else if (FAAS_HANDLER.equals(handler)) { + handlerIntf = new FaaSHandler(); } else { handlerIntf = new PosixHandler(); } diff --git a/api/java/yr-runtime/src/main/java/com/yuanrong/codemanager/CodeLoader.java b/api/java/yr-runtime/src/main/java/com/yuanrong/codemanager/CodeLoader.java index 32845c7..32caa0e 100644 --- a/api/java/yr-runtime/src/main/java/com/yuanrong/codemanager/CodeLoader.java +++ b/api/java/yr-runtime/src/main/java/com/yuanrong/codemanager/CodeLoader.java @@ -24,6 +24,7 @@ import com.yuanrong.runtime.util.ExtClasspathLoader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.List; @@ -39,7 +40,7 @@ public class CodeLoader { LOG.info("CodeLoader is running"); try { ExtClasspathLoader.loadClasspath(codePaths); - } catch (InvocationTargetException | IllegalAccessException e) { + } catch (InvocationTargetException | IllegalAccessException | IOException e) { String errorMsg = "failed to load code in specified path (" + codePaths + ") due to exception (" + e + ")"; LOG.error(errorMsg); return new ErrorInfo(ErrorCode.ERR_PARAM_INVALID, ModuleCode.RUNTIME, errorMsg); diff --git a/api/java/yr-runtime/src/main/java/com/yuanrong/executor/FaaSHandler.java b/api/java/yr-runtime/src/main/java/com/yuanrong/executor/FaaSHandler.java new file mode 100644 index 0000000..9e264dc --- /dev/null +++ b/api/java/yr-runtime/src/main/java/com/yuanrong/executor/FaaSHandler.java @@ -0,0 +1,745 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.executor; + +import static com.services.enums.FaasErrorCode.ENTRY_NOT_FOUND; +import static com.services.enums.FaasErrorCode.FAAS_INIT_ERROR; +import static com.services.enums.FaasErrorCode.FUNCTION_RUN_ERROR; +import static com.services.enums.FaasErrorCode.INITIALIZE_FUNCTION_ERROR; +import static com.services.enums.FaasErrorCode.INIT_FUNCTION_FAIL; +import static com.services.enums.FaasErrorCode.NONE_ERROR; +import static com.services.enums.FaasErrorCode.RESPONSE_EXCEED_LIMIT; +import static com.yuanrong.runtime.util.Utils.getMethod; +import static com.yuanrong.runtime.util.Utils.splitUserClassAndMethod; + +import com.services.UDFManager; +import com.services.enums.FaasErrorCode; +import com.services.model.CallResponse; +import com.services.model.CallResponseJsonObject; +import com.services.model.Response; +import com.services.runtime.Context; +import com.services.runtime.action.ContextImpl; +import com.services.runtime.action.ContextInvokeParams; +import com.services.runtime.action.DelegateDecrypt; +import com.services.runtime.action.LogTankService; +import com.services.runtime.utils.DataTypeAdapter; +import com.services.runtime.utils.Util; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.exception.HandlerNotAvailableException; +import com.yuanrong.libruntime.generated.Libruntime; +import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; +import com.yuanrong.runtime.util.Constants; +import com.yuanrong.runtime.util.ExtClasspathLoader; +import com.yuanrong.utils.RuntimeUtils; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonSyntaxException; +import com.google.gson.reflect.TypeToken; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * The type Faas handler. + * + * @since 2024-07-01 + */ +public class FaaSHandler implements HandlerIntf { + /** + * The constant LOG. + */ + private static final Logger LOG = LoggerFactory.getLogger(FaaSHandler.class); + + /** + * The constant KEY_USER_INIT_ENTRY. + */ + private static final String KEY_USER_INIT_ENTRY = "userInitEntry"; + + /** + * The constant KEY_USER_CALL_ENTRY. + */ + private static final String KEY_USER_CALL_ENTRY = "userCallEntry"; + + /** + * The constant KEY_INSTANCE_LABEL + */ + private static final String KEY_INSTANCE_LABEL = "instanceLabel"; + + /** + * The Index call context. + */ + private static final int INDEX_CALL_CONTEXT = 0; + + /** + * The Index user handler + */ + private static final int INDEX_USER_HANDLER = 1; + + private static final int INDEX_USER_ENTRY_CLASS = 0; + + private static final int INDEX_USER_ENTRY_METHOD = 1; + + /** + * The Index call user event. + */ + private static final int INDEX_CALL_USER_EVENT = 1; + + /** + * The first parameter type index. + */ + private static final int FIRST_PARAMETER_TYPE = 0; + + private static final int ARGS_MINIMUM_LENGTH = 2; + + private static final int USER_EVENT_MAX_SIZE = 6 * 1024 * 1024; + + private static final int RESPONSE_MAX_SIZE = 6 * 1024 * 1024; + + private static final String INVALID_ARGS_EXCEPTION = "faas get args invalid"; + + private static final String RUNTIME_ROOT = "/home/snuser/runtime"; + + private static final String RUNTIME_CODE_ROOT = "/opt/function/code"; + + private static final String RUNTIME_LOG_DIR = "/home/snuser/log"; + + private static final String LD_LIBRARY_PATH = "LD_LIBRARY_PATH"; + + private static final String HEADER_STR = "header"; + + private static final String BODY_STR = "body"; + + private static final String X_TRACE_ID = "X-Trace-Id"; + + private static final String ENV_DELEGATE_DECRYPT = "ENV_DELEGATE_DECRYPT"; + + /** + * The Gson. + */ + private static final Gson GSON = new GsonBuilder().serializeNulls().registerTypeAdapter( + new TypeToken>() { + }.getType(), new DataTypeAdapter() + ).disableHtmlEscaping().setPrettyPrinting().create(); + + private Context context = new ContextImpl(); + private Method callMethod; + private Class userEventClazz; + private int initializerTimeout; + private String preStopHandler; + private int preStopTimeout; + + /** + * Execute function + * + * @param meta the meta + * @param type the type + * @param args the args + * @return the return value, in ByteBuffer type, may need to release bytebuffer + * `buffer.clear()` + * @throws Exception the Exception + */ + @Override + public ReturnType execute(FunctionMeta meta, Libruntime.InvokeType type, List args) + throws Exception { + LOG.info("executing udf methods, current type: {}", type); + List argList = RuntimeUtils.convertArgListToStringList(args); + for (ByteBuffer buffer : args) { + releaseDirectBuffer(buffer); + } + args.clear(); + if (Objects.isNull(argList) || argList.isEmpty()) { + return new ReturnType(ErrorCode.ERR_PARAM_INVALID, "call handler arg list is empty."); + } + switch (type) { + case CreateInstance: + case CreateInstanceStateless: + LOG.debug("Invoking udf method matched, case create"); + ErrorInfo initErrorInfo = faasInitHandler(argList); + if (ErrorCode.ERR_OK.equals(initErrorInfo.getErrorCode())) { + return new ReturnType(ErrorCode.ERR_OK, initErrorInfo.getErrorMessage()); + } + return new ReturnType(initErrorInfo.getErrorCode(), initErrorInfo.getErrorMessage()); + case InvokeFunction: + case InvokeFunctionStateless: + LOG.debug("Invoking udf method matched, case invoke"); + ContextInvokeParams params = new ContextInvokeParams(); + ((ContextImpl) context).setContextInvokeParams(params); + return new ReturnType(ErrorCode.ERR_OK, faasCallHandler(argList)); + default: + LOG.debug("Invoking udf method matched, case dft"); + return new ReturnType(ErrorCode.ERR_INCORRECT_INVOKE_USAGE, "invalid invoke type"); + } + } + + private void releaseDirectBuffer(ByteBuffer buffer) { + if (buffer == null || !buffer.isDirect()) { + return; + } + // For java9+ + try { + Method cleanerMethod = buffer.getClass().getDeclaredMethod("cleaner"); + cleanerMethod.setAccessible(true); + Object cleaner = cleanerMethod.invoke(buffer); + if (cleaner != null) { + Method cleanMethod = cleaner.getClass().getMethod("clean"); + cleanMethod.invoke(cleaner); + } + return; + } catch (Exception e) { + LOG.warn("Failed to release direct buffer, it may be java8", e); + } + // For java8 + try { + Class directBufferClass = Class.forName("sun.nio.ch.DirectBuffer"); + if (directBufferClass.isInstance(buffer)) { + Method cleanerMethod = directBufferClass.getDeclaredMethod("cleaner"); + cleanerMethod.setAccessible(true); + Object cleaner = cleanerMethod.invoke(buffer); + if (cleaner != null) { + Method cleanMethod = cleaner.getClass().getMethod("clean"); + cleanMethod.invoke(cleaner); + } + } + } catch (Exception e) { + LOG.error("Failed to release direct buffer", e); + } + } + + /** + * Shutdown the instance gracefully. + * + * @param gracePeriodSeconds the time to wait for the instance to shutdown gracefully. + * @return ErrorInfo, the ErrorInfo of the execution of shutdown function. + */ + @Override + public ErrorInfo shutdown(int gracePeriodSeconds) { + if (preStopTimeout == 0 || preStopHandler == null || preStopHandler.isEmpty()) { + return new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, "no need to run shut down handler."); + } + ExecutorService shutdownExecutorService = Executors.newSingleThreadExecutor(); + Future shutdownFuture = shutdownExecutorService + .submit(() -> faasShutDownHandler(gracePeriodSeconds)); + try { + return shutdownFuture.get(preStopTimeout, TimeUnit.SECONDS); + } catch (TimeoutException e) { + String errorMassage = String.format(Locale.ROOT, "prestop timed out after %d s", preStopTimeout); + LOG.error(errorMassage); + return new ErrorInfo(new ErrorCode(FaasErrorCode.INVOKE_FUNCTION_TIMEOUT.getCode()), ModuleCode.RUNTIME, + errorMassage); + } catch (InterruptedException | ExecutionException e) { + String errorMassage = String.format(Locale.ROOT, + "faas failed to run user preStop code. err: %s , cause: %s", e.getMessage(), getCausedByString(e)); + LOG.error(errorMassage); + return new ErrorInfo(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ModuleCode.RUNTIME, errorMassage); + } finally { + shutdownExecutorService.shutdownNow(); + } + } + + /** + * Loads an instance of a class from serialized byte arrays. + * + * @param instanceBytes the serialized byte array representing the instance + * @param clzNameBytes the serialized byte array representing the class name + * @throws IOException if there is an error reading the byte arrays + * @throws ClassNotFoundException if the class specified by the class name is not found + */ + @Override + public void loadInstance(byte[] instanceBytes, byte[] clzNameBytes) {} + + /** + * Serializes the instance of the CodeExecutor class and returns a Pair + * containing the serialized byte arrays of the instance and the class name. + * + * @param instanceID the ID of the instance to be dumped + * @return a Pair containing the serialized byte arrays of the instance and the + * class name + * @throws JsonProcessingException if there is an error during serialization + */ + @Override + public Pair dumpInstance(String instanceID) throws JsonProcessingException { + return new Pair<>(new byte[0], new byte[0]); + } + + /** + * Recover the instance. + * + * @return ErrorInfo, the ErrorInfo of the execution of recover function. + */ + @Override + public ErrorInfo recover() { + return new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, ""); + } + + private Object getInputParameters(String userEvent, Class userEventClazz) { + if (userEventClazz.equals(String.class)) { + return userEvent; + } + return GSON.fromJson(userEvent, userEventClazz); + } + + /** + * Faas init handler. + * + * @param args the args + * @return InitErrorResponse + */ + public ErrorInfo faasInitHandler(List args) { + LOG.info("faas init handler called."); + if (args == null || args.size() < ARGS_MINIMUM_LENGTH) { + return new ErrorInfo(new ErrorCode(FAAS_INIT_ERROR.getCode()), ModuleCode.RUNTIME, INVALID_ARGS_EXCEPTION); + } + Map createParams = null; + try { + createParams = GSON.fromJson(args.get(INDEX_USER_HANDLER), Map.class); + context = initContext(args); + ((ContextImpl) context).setInstanceLabel(createParams.get(KEY_INSTANCE_LABEL)); + } catch (JsonSyntaxException e) { + String errorMessage = String.format(Locale.ROOT, "faas failed to convert json to object. err: %s", + e.getMessage()); + LOG.error(errorMessage); + return new ErrorInfo(new ErrorCode(INITIALIZE_FUNCTION_ERROR.getCode()), ModuleCode.RUNTIME, errorMessage); + } + LOG.debug("faas succeeds to init context "); + + // loadCallMethod must run before runUserInitHandler + String userCallEntry = createParams.get(KEY_USER_CALL_ENTRY); + ErrorInfo loadCallErrorInfo = loadCallMethod(userCallEntry); + if (!ErrorCode.ERR_OK.equals(loadCallErrorInfo.getErrorCode())) { + return loadCallErrorInfo; + } + String userInitEntry = createParams.get(KEY_USER_INIT_ENTRY); + ErrorInfo runUserInitErrorInfo = runUserInitHandler(userInitEntry); + if (!ErrorCode.ERR_OK.equals(runUserInitErrorInfo.getErrorCode())) { + return runUserInitErrorInfo; + } + LOG.info("faas init handler complete."); + return new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, "processed init handler successfully."); + } + + /** + * FaaS call handler. + * + * @param args the args + * @return String + * @throws HandlerNotAvailableException the handler not available exception + */ + public String faasCallHandler(List args) throws HandlerNotAvailableException { + LOG.info("faas call handler called."); + Object result; + int innerCode = NONE_ERROR.getCode(); + if (args == null || args.size() < ARGS_MINIMUM_LENGTH) { + result = INVALID_ARGS_EXCEPTION; + innerCode = FaasErrorCode.FAAS_INIT_ERROR.getCode(); + return handleResponse(result, innerCode); + } + if (args.get(INDEX_CALL_USER_EVENT).getBytes(StandardCharsets.UTF_8).length > USER_EVENT_MAX_SIZE) { + result = FaasErrorCode.REQUEST_BODY_EXCEED_LIMIT.getErrorMessage(); + innerCode = FaasErrorCode.REQUEST_BODY_EXCEED_LIMIT.getCode(); + return handleResponse(result, innerCode); + } + JsonObject jsonObject = GSON.fromJson(args.get(INDEX_CALL_USER_EVENT), JsonObject.class); + String userEvent = ""; + + if (jsonObject.has(BODY_STR) && !jsonObject.get(BODY_STR).isJsonNull()) { + JsonElement jsonElement = jsonObject.get(BODY_STR); + userEvent = (jsonElement.isJsonObject()) ? jsonElement.toString() : jsonElement.getAsString(); + } + + if (jsonObject.has(HEADER_STR) && !jsonObject.get(HEADER_STR).isJsonNull()) { + JsonObject headerObj = jsonObject.getAsJsonObject(HEADER_STR); + if (headerObj.has(X_TRACE_ID) && !headerObj.get(X_TRACE_ID).isJsonNull()) { + String traceId = headerObj.get(X_TRACE_ID).getAsString(); + context.setTraceID(traceId); + } + } + + userEvent = "null".equals(userEvent) ? "" : userEvent; + String logType = ""; + if (jsonObject.has(HEADER_STR)) { + JsonObject headerObject = jsonObject.get(HEADER_STR).getAsJsonObject(); + if (headerObject != null && headerObject.has(Constants.CFF_LOG_TYPE)) { + logType = headerObject.get(Constants.CFF_LOG_TYPE).toString(); + } + } + Util.setLogOpts(logType); + String logGroupId = ""; + String logStreamId = ""; + LogTankService logTankService = ((ContextImpl) context).getExtendedMetaData().getLogTankService(); + if (logTankService != null) { + if (logTankService.getLogGroupId() != null) { + logGroupId = logTankService.getLogGroupId(); + } + if (logTankService.getLogStreamId() != null) { + logStreamId = logTankService.getLogStreamId(); + } + } + + String[] callInfos = new String[] { + context.getInvokeID(), context.getRequestID(), context.getInstanceID(), Util.getFunctionInfo(context), + logGroupId, + logStreamId + }; + Util.setInheritableThreadLocal(callInfos); + UDFManager udfManager = UDFManager.getUDFManager(); + long startTime = System.currentTimeMillis(); + try { + result = callMethod.invoke(udfManager.loadInstance(KEY_USER_CALL_ENTRY), + getInputParameters(userEvent, userEventClazz), context); + } catch (IllegalAccessException | IllegalArgumentException e) { + String errorMsg = getErrorMsg(e); + String cause = getCausedByString(e); + LOG.error("faas run invoke method failed, errorMsg: {}, cause : {}", errorMsg, cause); + result = errorMsg; + innerCode = FaasErrorCode.FUNCTION_RUN_ERROR.getCode(); + } catch (InvocationTargetException e) { + innerCode = FaasErrorCode.FUNCTION_RUN_ERROR.getCode(); + String errorMsg = getErrorMsg(e); + String cause = getCausedByString(e.getCause()); + LOG.error("faas run invoke user method failed, errorMsg: {}, cause : {}", errorMsg, cause); + result = errorMsg; + } + Util.clearLogOpts(); + Util.clearInheritableThreadLocal(); + long userFuncTime = System.currentTimeMillis() - startTime; + return handleResponse(result, innerCode, userFuncTime); + } + + private static String getCausedByString(Throwable throwable) { + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + throwable.printStackTrace(printWriter); + return stringWriter.toString(); + } + + private static String getErrorMsg(Throwable exception) { + String msg = exception.getMessage(); + if (exception instanceof InvocationTargetException) { + Throwable cause = exception.getCause(); + if (cause != null) { + msg = cause.toString(); + } + } + return msg; + } + + /** + * Faas check point handler. + * + * @param args the args + */ + public void faasCheckPointHandler(List args) { + return; + } + + /** + * faasRecoverHandler + * + * @param args the args + */ + public void faasRecoverHandler(List args) { + return; + } + + /** + * faasShutDownHandler + * + * @param gracePeriodSeconds grace period seconds + * @return ErrorInfo, the ErrorInfo of the execution of preStopHandler. + */ + public ErrorInfo faasShutDownHandler(int gracePeriodSeconds) { + LOG.info("faas shut down handler called."); + UDFManager udfManager = UDFManager.getUDFManager(); + Class userClass = udfManager.loadClass(); + if (preStopHandler != null && !preStopHandler.isEmpty()) { + try { + String[] preStopClassMethod = splitUserClassAndMethod(preStopHandler, true); + Method method = userClass.getMethod(preStopClassMethod[INDEX_USER_ENTRY_METHOD], Context.class); + method.setAccessible(true); + method.invoke(udfManager.loadInstance(KEY_USER_CALL_ENTRY), context); + } catch (NoSuchMethodException e) { + LOG.error("faas failed to load user preStop code. err: {}", e.getMessage()); + return new ErrorInfo(new ErrorCode(ENTRY_NOT_FOUND.getCode()), ModuleCode.RUNTIME, + ENTRY_NOT_FOUND.getErrorMessage()); + } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) { + String errorMessage = String.format(Locale.ROOT, + "faas failed to run user preStop code. err: %s , cause: %s", getErrorMsg(e), getCausedByString(e)); + LOG.error(errorMessage); + return new ErrorInfo(new ErrorCode(FUNCTION_RUN_ERROR.getCode()), ModuleCode.RUNTIME, errorMessage); + } + } + return new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, "run shut down handler successfully."); + } + + /** + * faasSignalHandler + * + * @param args the args + */ + public void faasSignalHandler(List args) { + return; + } + + /** + * registerClass , registerInstance , set callMethod , set userEventClazz + * + * @param userCallEntry userCallEntry + * @return ErrorInfo + */ + private ErrorInfo loadCallMethod(String userCallEntry) { + ClassLoader classLoader = ExtClasspathLoader.getFunctionClassLoader(); + UDFManager udfManager = UDFManager.getUDFManager(); + try { + String[] callEntryClassMethod = splitUserClassAndMethod(userCallEntry, false); + Class userClass = classLoader.loadClass(callEntryClassMethod[INDEX_USER_ENTRY_CLASS]); + udfManager.registerClass(userClass); + Object entryInstance = userClass.newInstance(); + udfManager.registerInstance(KEY_USER_CALL_ENTRY, entryInstance); + callMethod = getMethod(userClass, callEntryClassMethod[INDEX_USER_ENTRY_METHOD]); + userEventClazz = callMethod.getParameterTypes()[FIRST_PARAMETER_TYPE]; + callMethod.setAccessible(true); + } catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException e) { + LOG.error("faas failed to init user code. err: {}", e.getMessage()); + return new ErrorInfo(new ErrorCode(ENTRY_NOT_FOUND.getCode()), ModuleCode.RUNTIME, + ENTRY_NOT_FOUND.getErrorMessage()); + } catch (Exception e) { + String errorMessage = String.format(Locale.ROOT, "faas unexpected exception: %s", e.getMessage()); + LOG.error(errorMessage); + return new ErrorInfo(new ErrorCode(INITIALIZE_FUNCTION_ERROR.getCode()), ModuleCode.RUNTIME, errorMessage); + } + return new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, "load call method successfully."); + } + + /** + * load userClass , loadUserInstance , invoke userInitMethod + * + * @param userInitEntry userInitEntry + * @return InitErrorResponse + */ + private ErrorInfo runUserInitHandler(String userInitEntry) { + UDFManager udfManager = UDFManager.getUDFManager(); + Class userClass = udfManager.loadClass(); + // execute user init code if present + if (userInitEntry != null && !userInitEntry.isEmpty()) { + try { + String[] initEntryClassMethod = splitUserClassAndMethod(userInitEntry, true); + Method method = userClass.getMethod(initEntryClassMethod[INDEX_USER_ENTRY_METHOD], Context.class); + method.setAccessible(true); + method.invoke(udfManager.loadInstance(KEY_USER_CALL_ENTRY), context); + } catch (NoSuchMethodException e) { + LOG.error("faas failed to load user init code. err: {}", e.getMessage()); + return new ErrorInfo(new ErrorCode(ENTRY_NOT_FOUND.getCode()), ModuleCode.RUNTIME, + ENTRY_NOT_FOUND.getErrorMessage()); + } catch (IllegalAccessException | IllegalArgumentException e) { + String errorMessage = String.format(Locale.ROOT, "faas failed to run user init code. err: %s", + getErrorMsg(e)); + LOG.error(errorMessage); + return new ErrorInfo(new ErrorCode(INIT_FUNCTION_FAIL.getCode()), ModuleCode.RUNTIME, errorMessage); + } catch (InvocationTargetException e) { + String errorMessage = String.format(Locale.ROOT, "faas failed to run user init code. err: %s", + getErrorMsg(e)); + LOG.error("{}, cause: {}", errorMessage, getCausedByString(e.getCause())); + return new ErrorInfo(new ErrorCode(INIT_FUNCTION_FAIL.getCode()), ModuleCode.RUNTIME, errorMessage); + } + } + return new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME, "run user init handler successfully."); + } + + private Context initContext(List args) { + ContextImpl contextImpl = GSON.fromJson(args.get(INDEX_CALL_CONTEXT), ContextImpl.class); + String delegateEncryptEnv = System.getenv(ENV_DELEGATE_DECRYPT); + DelegateDecrypt delegateDecrypt; + if (delegateEncryptEnv != null && !delegateEncryptEnv.isEmpty()) { + delegateDecrypt = GSON.fromJson(delegateEncryptEnv, DelegateDecrypt.class); + } else { + delegateDecrypt = GSON.fromJson(args.get(args.size() - 1), DelegateDecrypt.class); + } + if (delegateDecrypt == null) { + delegateDecrypt = new DelegateDecrypt(); + } + contextImpl.setDelegateDecrypt(delegateDecrypt); + Map runtimeUserDataMap = new HashMap<>(); + Map envSetMap = new HashMap<>(); + if (delegateDecrypt.getEnvironment() != null && !delegateDecrypt.getEnvironment().isEmpty()) { + Map environmentMap = GSON.fromJson(delegateDecrypt.getEnvironment(), HashMap.class); + for (Map.Entry entry : environmentMap.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + if (!LD_LIBRARY_PATH.equals(key)) { + envSetMap.put(key, value); + } + runtimeUserDataMap.put(key, value); + } + } + + if (delegateDecrypt.getEncryptedUserData() != null && !delegateDecrypt.getEncryptedUserData().isEmpty()) { + Map userDataMap = GSON.fromJson(delegateDecrypt.getEncryptedUserData(), HashMap.class); + runtimeUserDataMap.putAll(userDataMap); + } + String pathValue = runtimeUserDataMap.getOrDefault(LD_LIBRARY_PATH, ""); + envSetMap.put(LD_LIBRARY_PATH, + System.getenv(LD_LIBRARY_PATH) + String.format(":%s", pathValue)); + contextImpl.getFuncMetaData().setUserData(runtimeUserDataMap); + String userDataString = GSON.toJson(runtimeUserDataMap); + envSetMap.put("RUNTIME_USERDATA", userDataString); + setEnvContext(contextImpl, envSetMap); + this.initializerTimeout = Integer.parseInt(envSetMap.get("RUNTIME_INITIALIZER_TIMEOUT")); + this.preStopHandler = envSetMap.get("PRE_STOP_HANDLER"); + this.preStopTimeout = Integer.parseInt(envSetMap.getOrDefault("PRE_STOP_TIMEOUT", "0")); + return contextImpl; + } + + private String handleResponse(Object result, int innerCode) { + return handleResponse(result, innerCode, 0); + } + + private String handleResponse(Object result, int innerCode, long userFuncTime) { + Response response; + if (result instanceof JsonObject) { + response = new CallResponseJsonObject(); + } else { + response = new CallResponse(); + } + response.setBody(result); + response.setBillingDuration("this is billing duration TODO"); + response.setInnerCode(String.valueOf(innerCode)); + response.setInvokerSummary("this is summary TODO"); + response.setLogResult(Base64.getEncoder() + .encodeToString("this is user log TODO".getBytes(StandardCharsets.UTF_8))); + if (userFuncTime != 0) { + response.setUserFuncTime(userFuncTime); + } + String resultJson = GSON.toJson(response); + + int respLength = resultJson.getBytes(StandardCharsets.UTF_8).length; + if (respLength > RESPONSE_MAX_SIZE) { + response.setBody(String.format(Locale.ROOT, "response body size %d exceeds the limit of 6291456", + respLength)); + response.setInnerCode(String.valueOf(RESPONSE_EXCEED_LIMIT.getCode())); + resultJson = GSON.toJson(response); + } + return resultJson; + } + + private void setEnvContext(ContextImpl contextImpl, Map envSetMap) { + envSetMap.put("RUNTIME_PROJECT_ID", contextImpl.getFuncMetaData().getTenantId()); + envSetMap.put("RUNTIME_PACKAGE", contextImpl.getFuncMetaData().getService()); + envSetMap.put("RUNTIME_FUNC_NAME", contextImpl.getFuncMetaData().getFuncName()); + envSetMap.put("RUNTIME_FUNC_VERSION", contextImpl.getFuncMetaData().getVersion()); + envSetMap.put("RUNTIME_HANDLER", contextImpl.getFuncMetaData().getHandler()); + envSetMap.put("RUNTIME_TIMEOUT", Integer.toString(contextImpl.getFuncMetaData().getTimeout())); + envSetMap.put("RUNTIME_CPU", Integer.toString(contextImpl.getResourceMetaData().getCpu())); + envSetMap.put("RUNTIME_MEMORY", Integer.toString(contextImpl.getResourceMetaData().getMemory())); + envSetMap.put("RUNTIME_MAX_RESP_BODY_SIZE", Integer.toString(USER_EVENT_MAX_SIZE)); + if (contextImpl.getExtendedMetaData() != null && contextImpl.getExtendedMetaData().getInitializer() != null) { + if (contextImpl.getExtendedMetaData().getInitializer().getInitializerHandler() != null) { + envSetMap.put("RUNTIME_INITIALIZER_HANDLER", + contextImpl.getExtendedMetaData().getInitializer().getInitializerHandler()); + } + envSetMap.put("RUNTIME_INITIALIZER_TIMEOUT", + Integer.toString(contextImpl.getExtendedMetaData().getInitializer().getInitializerTimeout())); + } + if (contextImpl.getExtendedMetaData() != null && contextImpl.getExtendedMetaData().getPreStop() != null) { + if (contextImpl.getExtendedMetaData().getPreStop().getPreStopHandler() != null) { + envSetMap.put("PRE_STOP_HANDLER", + contextImpl.getExtendedMetaData().getPreStop().getPreStopHandler()); + } + envSetMap.put("PRE_STOP_TIMEOUT", + Integer.toString(contextImpl.getExtendedMetaData().getPreStop().getPreStopTimeout())); + } + envSetMap.put("RUNTIME_ROOT", RUNTIME_ROOT); + envSetMap.put("RUNTIME_CODE_ROOT", RUNTIME_CODE_ROOT); + envSetMap.put("RUNTIME_LOG_DIR", RUNTIME_LOG_DIR); + setJavaProcessEnvMap(envSetMap); + } + + /** + * Sets java process env map. + * + * @param envMap the env map + */ + private static void setJavaProcessEnvMap(Map envMap) { + // keep process previous env map configs + try { + Class processEnvironmentClass = Class.forName("java.lang.ProcessEnvironment"); + updateJavaEnvMap(processEnvironmentClass, "theCaseInsensitiveEnvironment", envMap); + updateJavaEnvMap(processEnvironmentClass, "theUnmodifiableEnvironment", envMap); + } catch (ClassNotFoundException e) { + LOG.error("get field: theEnvironment has an error: ", e); + } + } + + /** + * Update java env map. + * + * @param cls the cls + * @param filedName the filed name + * @param envMap the env map + */ + private static void updateJavaEnvMap(Class cls, String filedName, Map envMap) { + try { + // get field and access + Field oldFiled = cls.getDeclaredField(filedName); + oldFiled.setAccessible(true); + // get Filed map + Object unmodifiableMap = oldFiled.get(null); + for (Map.Entry entry : envMap.entrySet()) { + LOG.debug("updateJavaEnvMap key: {}, value:{}", entry.getKey(), entry.getValue()); + injectIntoUnmodifiableMap(entry.getKey(), entry.getValue(), unmodifiableMap); + } + } catch (ReflectiveOperationException e) { + LOG.error("get field: {} has an error: {}", filedName, e); + } + } + + private static void injectIntoUnmodifiableMap(String key, String value, Object map) + throws ReflectiveOperationException { + Class unmodifiableMap = Class.forName("java.util.Collections$UnmodifiableMap"); + Field field = unmodifiableMap.getDeclaredField("m"); + field.setAccessible(true); + Object obj = field.get(map); + ((Map) obj).put(key, value); + } +} diff --git a/api/java/yr-runtime/src/main/java/com/yuanrong/executor/FunctionHandler.java b/api/java/yr-runtime/src/main/java/com/yuanrong/executor/FunctionHandler.java index c6ec0f9..b66922a 100644 --- a/api/java/yr-runtime/src/main/java/com/yuanrong/executor/FunctionHandler.java +++ b/api/java/yr-runtime/src/main/java/com/yuanrong/executor/FunctionHandler.java @@ -24,6 +24,7 @@ import com.yuanrong.exception.YRException; import com.yuanrong.exception.handler.filter.FilterFactory; import com.yuanrong.exception.handler.traceback.StackTraceInfo; import com.yuanrong.libruntime.generated.Libruntime; +import com.yuanrong.libruntime.generated.Libruntime.ApiType; import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; import com.yuanrong.runtime.config.RuntimeContext; import com.yuanrong.runtime.util.Utils; @@ -263,6 +264,9 @@ public class FunctionHandler implements HandlerIntf { LOG.info("Creating for udf methods, clz({}), func({}), sig({})", clzName, funcName, signature); try { constructObject(clzName, funcName, signature, args); + if (ApiType.Faas.equals(meta.getApiType()) && !CONSTRUCTOR_FUNC_NAME.equals(funcName)) { + return invoke(meta, args); + } } catch (ClassNotFoundException | InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | IOException e) { LOG.error("failed to execute udf methods, clz({}), func({}), sig({}), reason: {}", clzName, funcName, diff --git a/api/java/yr-runtime/src/main/java/com/yuanrong/runtime/server/RuntimeLogger.java b/api/java/yr-runtime/src/main/java/com/yuanrong/runtime/server/RuntimeLogger.java index b4285c6..253d8db 100644 --- a/api/java/yr-runtime/src/main/java/com/yuanrong/runtime/server/RuntimeLogger.java +++ b/api/java/yr-runtime/src/main/java/com/yuanrong/runtime/server/RuntimeLogger.java @@ -29,11 +29,14 @@ import java.io.File; public class RuntimeLogger { private static Logger LOG = null; private static final String ENV_GLOG_DIR = "GLOG_log_dir"; + private static final String ENV_LOGGER_ID = "YR_LOG_PREFIX"; private static final String DEFAULT_LOG_PATH = "/home/snuser/log/"; private static final String DEFAULT_LOG_LEVEL = "INFO"; private static final String EXCEPTION_LOG_PATH = "exception"; private static final String ENV_LOG_LEVEL = "logLevel"; private static final String ENV_JAVA_LOG_PATH = "logPath"; + private static final String ENV_LOG_ID = "logId"; + private static final String ENV_RUNTIME_ID = "runtimeId"; private static final String ENV_EXCEPTION_LOG_DIR = "java.io.tmpdir"; /** @@ -44,15 +47,22 @@ public class RuntimeLogger { */ public static void initLogger(String runtimeID) { String logDir = System.getenv(ENV_GLOG_DIR); + String logId = System.getenv(ENV_LOGGER_ID); if (logDir == null || logDir.trim().isEmpty()) { logDir = DEFAULT_LOG_PATH; } - String logPathName = logDir + File.separator + runtimeID; + String logPathName = logDir; String logPathException = logDir + File.separator + EXCEPTION_LOG_PATH; String logLevel = System.getProperty(ENV_LOG_LEVEL); if (logLevel == null || logLevel.trim().isEmpty()) { logLevel = DEFAULT_LOG_LEVEL; } + if (logId != null && !logId.isEmpty()) { + System.setProperty(ENV_LOG_ID, logId); + } else { + System.setProperty(ENV_LOG_ID, runtimeID); + } + System.setProperty(ENV_RUNTIME_ID, runtimeID); System.setProperty(ENV_LOG_LEVEL, logLevel); System.setProperty(ENV_JAVA_LOG_PATH, logPathName); System.setProperty(ENV_EXCEPTION_LOG_DIR, logPathException); diff --git a/api/java/yr-runtime/src/main/resources/log4j2.xml b/api/java/yr-runtime/src/main/resources/log4j2.xml index 17469e8..01dbdd4 100644 --- a/api/java/yr-runtime/src/main/resources/log4j2.xml +++ b/api/java/yr-runtime/src/main/resources/log4j2.xml @@ -10,44 +10,16 @@ DEBUG > TRACE > ALL --> - + + charset="UTF-8" pattern="[%d{HH:mm:ss:SSS}{GMT+8}] | [%p] | %l | %t | runtime-java | ${sys:runtimeId} | %enc{%m}{CRLF}%n" /> - - - - - - - - - - - - - - - - - - - - - - - - - + + @@ -59,8 +31,5 @@ DEBUG > TRACE > ALL --> - - - \ No newline at end of file diff --git a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/MockFailedClass.java b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/MockFailedClass.java new file mode 100644 index 0000000..9603268 --- /dev/null +++ b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/MockFailedClass.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.executor; + +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.exception.YRException; + +public class MockFailedClass{ + private int cnt = 0; + + public int getCausingDeserializationFailure() { + return cnt; + } + + public void mockMethod() { + return; + } + + public void mockMethodWithException() throws YRException { + throw new YRException(ErrorCode.ERR_BUS_DISCONNECTION, ModuleCode.CORE, ""); + } + + public static void yrRecover() { + throw new RuntimeException("yrRecover failed"); + } + + public static void yrShutdown(int gracePeriodSeconds) { + throw new RuntimeException("yrShutdown failed"); + } +} diff --git a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/MockNoneClass.java b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/MockNoneClass.java new file mode 100644 index 0000000..3d43473 --- /dev/null +++ b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/MockNoneClass.java @@ -0,0 +1,19 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.executor; + +public class MockNoneClass {} diff --git a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestFaaSHandler.java b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestFaaSHandler.java new file mode 100644 index 0000000..e6f277a --- /dev/null +++ b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestFaaSHandler.java @@ -0,0 +1,652 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yuanrong.executor; + +import com.services.UDFManager; +import com.services.enums.FaasErrorCode; +import com.services.exception.FaaSException; +import com.services.model.CallRequest; +import com.services.model.CallResponse; +import com.services.runtime.action.ContextImpl; +import com.services.runtime.action.ContextInvokeParams; +import com.services.runtime.action.DelegateDecrypt; +import com.services.runtime.action.ExtendedMetaData; +import com.services.runtime.action.FunctionMetaData; +import com.services.runtime.action.Initializer; +import com.services.runtime.action.LogTankService; +import com.services.runtime.action.PreStop; +import com.services.runtime.action.ResourceMetaData; +import com.yuanrong.errorcode.ErrorCode; +import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.Pair; +import com.yuanrong.jni.LibRuntime; +import com.yuanrong.libruntime.generated.Libruntime; +import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; +import com.yuanrong.runtime.util.Constants; + +import com.google.gson.Gson; + +import org.junit.Assert; +import org.junit.Test; + +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class TestFaaSHandler { + private static void setJavaProcessEnvMap(Map envMap) { + // keep process previous env map configs + try { + Class processEnvironmentClass = Class.forName("java.lang.ProcessEnvironment"); + updateJavaEnvMap(processEnvironmentClass, "theUnmodifiableEnvironment", envMap); + } catch (ClassNotFoundException e) { + e.printStackTrace(); + } + } + + private static void updateJavaEnvMap(Class cls, String filedName, Map envMap) { + try { + // get field and access + Field oldFiled = cls.getDeclaredField(filedName); + oldFiled.setAccessible(true); + // get Filed map + Object unmodifiableMap = oldFiled.get(null); + for (Map.Entry entry : envMap.entrySet()) { + injectIntoUnmodifiableMap(entry.getKey(), entry.getValue(), unmodifiableMap); + } + } catch (ReflectiveOperationException e) { + e.printStackTrace(); + } + } + + private static void injectIntoUnmodifiableMap(String key, String value, Object map) + throws ReflectiveOperationException { + if (key == null || value == null) { + return; + } + Class unmodifiableMap = Class.forName("java.util.Collections$UnmodifiableMap"); + Field field = unmodifiableMap.getDeclaredField("m"); + field.setAccessible(true); + Object obj = field.get(map); + ((Map) obj).put(key, value); + } + + public List generateInitArgs() { + List args = new ArrayList<>(); + ContextImpl contextImpl = getContext(); + Gson gson = new Gson(); + String contextConfig = gson.toJson(contextImpl); + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::handler\"}"; + args.add(contextConfig); + args.add(udfEntry); + args.add(getSchedulerData()); + return args; + } + + public List generateInitArgsWiThPreStopHandler(String testCase) { + List args = new ArrayList<>(); + ContextImpl contextImpl = getContext(); + PreStop preStop = new PreStop(); + if ("NoSuchMethodException".equals(testCase)) { + preStop.setPreStopHandler("com.yuanrong.executor.UserTestHandler.wrongPreStop"); + preStop.setPreStopTimeout(60); + } else if ("InvocationTargetException".equals(testCase)) { + preStop.setPreStopHandler("com.yuanrong.executor.UserTestHandler.failedPreStop"); + preStop.setPreStopTimeout(60); + } else if ("InterruptedException".equals(testCase)) { + preStop.setPreStopHandler("com.yuanrong.executor.UserTestHandler.timeoutPreStop"); + preStop.setPreStopTimeout(1); + } else { + preStop.setPreStopHandler("com.yuanrong.executor.UserTestHandler.preStop"); + preStop.setPreStopTimeout(60); + } + contextImpl.getExtendedMetaData().setPreStop(preStop); + Gson gson = new Gson(); + String contextConfig = gson.toJson(contextImpl); + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::handler\"}"; + args.add(contextConfig); + args.add(udfEntry); + args.add(getSchedulerData()); + return args; + } + + public List generateCallTimeoutArgs() { + List args = new ArrayList<>(); + ContextImpl contextImpl = getContext(); + Gson gson = new Gson(); + String contextConfig = gson.toJson(contextImpl); + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::timeoutHandler\"}"; + args.add(contextConfig); + args.add(udfEntry); + return args; + } + + public List generateCallResponseLargeSizeArgs() { + List args = new ArrayList<>(); + ContextImpl contextImpl = getContext(); + Gson gson = new Gson(); + String contextConfig = gson.toJson(contextImpl); + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::largeResponse\"}"; + args.add(contextConfig); + args.add(udfEntry); + return args; + } + + public List generateErrorInitArgs(String errorType) { + Gson gson = new Gson(); + List args = new ArrayList<>(); + ContextImpl contextImpl = getContext(); + String contextConfig = gson.toJson(contextImpl); + args.add(contextConfig); + switch (errorType) { + case "function initialization exception": { + args.add("Strings cannot unmarshal with map"); + break; + } + case "runtime initialization timed out": { + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.timeoutInitializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::handler\"}"; + args.add(udfEntry); + break; + } + case "call user entry not found": { + String wrongUdfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::wrongHandler\"}"; + args.add(wrongUdfEntry); + break; + } + case "init user entry not found": { + String wrongUdfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.wrongInitializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::handler\"}"; + args.add(wrongUdfEntry); + break; + } + case "IllegalArgumentException" : { + String wrongUdfEntry = + "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.illegalArgumentInitializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::handler\"}"; + args.add(wrongUdfEntry); + break; + } + case "InvocationTargetException" : { + String wrongUdfEntry = + "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.failedInitializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::handler\"}"; + args.add(wrongUdfEntry); + break; + } + } + return args; + } + + private ContextImpl getContext() { + ContextImpl contextImpl = new ContextImpl(); + FunctionMetaData funcMetaData = new FunctionMetaData(); + ResourceMetaData resourceMetaData = new ResourceMetaData(); + ExtendedMetaData extendedMetaData = new ExtendedMetaData(); + Initializer initializer = new Initializer(); + DelegateDecrypt delegateDecrypt = new DelegateDecrypt(); + LogTankService logTankService = new LogTankService(); + funcMetaData.setTenantId("12345678910123456789"); + funcMetaData.setFuncName("testpythonbase001"); + funcMetaData.setService("base"); + funcMetaData.setVersion("latest"); + funcMetaData.setHandler("handler"); + funcMetaData.setTimeout(3); + resourceMetaData.setCpu(500); + resourceMetaData.setMemory(500); + initializer.setInitializerHandler("programmingmodel.TestJavaBase002::init"); + initializer.setInitializerTimeout(3); + extendedMetaData.setInitializer(initializer); + logTankService.setLogGroupId("groupID"); + logTankService.setLogStreamId("streamID"); + extendedMetaData.setLogTankService(logTankService); + delegateDecrypt.setAccessKey("accessKey123"); + delegateDecrypt.setSecretKey("secretKey123"); + delegateDecrypt.setAuthToken("authToken123"); + delegateDecrypt.setSecurityToken("securityToken123"); + contextImpl.setFuncMetaData(funcMetaData); + contextImpl.setResourceMetaData(resourceMetaData); + contextImpl.setExtendedMetaData(extendedMetaData); + contextImpl.setDelegateDecrypt(delegateDecrypt); + ContextInvokeParams params = new ContextInvokeParams(); + params.setRequestID("request-123456789"); + contextImpl.setContextInvokeParams(params); + return contextImpl; + } + + private DelegateDecrypt getDelegateDecrypt() { + Gson gson = new Gson(); + DelegateDecrypt delegateDecrypt = new DelegateDecrypt(); + Map map = new HashMap<>(2); + map.put("key1", "val1"); + map.put("key2", "val2"); + delegateDecrypt.setEnvironment(gson.toJson(map)); + return delegateDecrypt; + } + + private String getSchedulerData() { + return "{\"schedulerFuncKey\":\"12345678901234561234567890123456/0-system-faasscheduler/$latest\"," + + "\"schedulerIDList\":[\"2238fb12-0000-4000-8000-00abc0d9cc91\"]}"; + } + + public List generateCallArgs() { + Gson gson = new Gson(); + List args = new ArrayList<>(); + args.add("{\"codeID\":\"\",\"config\":{\"functionID\":{\"cpp\":\"\"," + + "\"python\":\"sn:cn:yrk:12345678901234561234567890123456:function:0-he-he:$latest\"}," + + "\"jodID\":\"96f2fc5e-c9ab-4d83-9aa2-89579a29ff4a\",\"logLevel\":30,\"recycleTime\":2}," + + "\"invokeType\":3,\"objectDescriptor\":{\"className\":\"\",\"functionName\":\"execute\"," + + "\"moduleName\":\"faasexecutor\",\"srcLanguage\":\"python\",\"targetLanguage\":\"python\"}}"); + TestRequestEvent testRequestEvent = new TestRequestEvent("yuanrong", 1); + CallRequest callRequest = new CallRequest(); + callRequest.setBody(testRequestEvent); + callRequest.setHeader(new HashMap(){ + { + put(Constants.CFF_LOG_TYPE, "tail"); + } + }); + args.add(gson.toJson(callRequest)); + return args; + } + + public List generateErrorCallArgs() { + Gson gson = new Gson(); + List args = new ArrayList<>(); + args.add("{\"codeID\":\"\",\"config\":{\"functionID\":{\"cpp\":\"\"," + + "\"python\":\"sn:cn:yrk:12345678901234561234567890123456:function:0-he-he:$latest\"}," + + "\"jodID\":\"96f2fc5e-c9ab-4d83-9aa2-89579a29ff4a\",\"logLevel\":30,\"recycleTime\":2}," + + "\"invokeType\":3,\"objectDescriptor\":{\"className\":\"\",\"functionName\":\"execute\"," + + "\"moduleName\":\"faasexecutor\",\"srcLanguage\":\"python\",\"targetLanguage\":\"python\"}}"); + TestRequestEvent testRequestEvent = new TestRequestEvent("yuanrong", 0); + args.add(gson.toJson(testRequestEvent)); + return args; + } + + @Test + public void testFaaSHandlerInitSuccess() throws Exception { + FaaSHandler faaSHandler = new FaaSHandler(); + List args = generateInitArgs(); + Assert.assertTrue(faaSHandler.faasInitHandler(args).getErrorMessage().contains("processed init handler successfully.")); + } + + @Test + public void testFaaSHandlerInitWithDelegate() throws Exception { + Gson gson = new Gson(); + DelegateDecrypt delegateDecrypt = new DelegateDecrypt(); + Map map = new HashMap<>(); + map.put("spring_start_class", "com.inventory.InventoryApplication"); + delegateDecrypt.setEnvironment(gson.toJson(map)); + Map envMap = new HashMap<>(); + envMap.put("ENV_DELEGATE_DECRYPT", gson.toJson(delegateDecrypt)); + setJavaProcessEnvMap(envMap); + FaaSHandler faaSHandler = new FaaSHandler(); + List args = new ArrayList<>(); + ContextImpl contextImpl = getContext(); + String contextConfig = gson.toJson(contextImpl); + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler.contextHandler\"}"; + args.add(contextConfig); + args.add(udfEntry); + args.add(getSchedulerData()); + Assert.assertTrue(faaSHandler.faasInitHandler(args).getErrorMessage().contains("processed init handler successfully.")); + envMap.put("ENV_DELEGATE_DECRYPT", ""); + setJavaProcessEnvMap(envMap); + } + + @Test + public void testFaaSHandlerInitWithInstanceLabel() throws Exception { + Gson gson = new Gson(); + DelegateDecrypt delegateDecrypt = new DelegateDecrypt(); + Map map = new HashMap<>(); + map.put("spring_start_class", "com.inventory.InventoryApplication"); + delegateDecrypt.setEnvironment(gson.toJson(map)); + Map envMap = new HashMap<>(); + envMap.put("ENV_DELEGATE_DECRYPT", gson.toJson(delegateDecrypt)); + setJavaProcessEnvMap(envMap); + FaaSHandler faaSHandler = new FaaSHandler(); + List args = new ArrayList<>(); + ContextImpl contextImpl = getContext(); + String contextConfig = gson.toJson(contextImpl); + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler.instanceLabelHandler\", \"instanceLabel\":\"aaaaa\"}"; + args.add(contextConfig); + args.add(udfEntry); + args.add(getSchedulerData()); + Assert.assertTrue(faaSHandler.faasInitHandler(args).getErrorMessage().contains("processed init handler successfully.")); + List callArgs = generateCallArgs(); + Assert.assertTrue(faaSHandler.faasCallHandler(callArgs).contains("aaaaa")); + } + + @Test + public void testFaaSHandlerInitFail() throws Exception { + FaaSHandler faaSHandler = new FaaSHandler(); + Gson gson = new Gson(); + List args = new ArrayList<>(); + ErrorInfo initErrorInfo = faaSHandler.faasInitHandler(args); + Assert.assertEquals(FaasErrorCode.FAAS_INIT_ERROR.getCode(), initErrorInfo.getErrorCode().getValue()); + + List wrongArgs = generateErrorInitArgs(FaasErrorCode.INITIALIZE_FUNCTION_ERROR.getErrorMessage()); + initErrorInfo = faaSHandler.faasInitHandler(wrongArgs); + Assert.assertEquals(FaasErrorCode.INITIALIZE_FUNCTION_ERROR.getCode(), initErrorInfo.getErrorCode().getValue()); + + wrongArgs = generateErrorInitArgs("call " + FaasErrorCode.ENTRY_NOT_FOUND.getErrorMessage()); + initErrorInfo = faaSHandler.faasInitHandler(wrongArgs); + Assert.assertEquals(FaasErrorCode.ENTRY_NOT_FOUND.getCode(), initErrorInfo.getErrorCode().getValue()); + + wrongArgs = generateErrorInitArgs("init " + FaasErrorCode.ENTRY_NOT_FOUND.getErrorMessage()); + initErrorInfo = faaSHandler.faasInitHandler(wrongArgs); + Assert.assertEquals(FaasErrorCode.ENTRY_NOT_FOUND.getCode(), initErrorInfo.getErrorCode().getValue()); + + wrongArgs = generateErrorInitArgs("IllegalArgumentException"); + initErrorInfo = faaSHandler.faasInitHandler(wrongArgs); + Assert.assertEquals(FaasErrorCode.INIT_FUNCTION_FAIL.getCode(), initErrorInfo.getErrorCode().getValue()); + + wrongArgs = generateErrorInitArgs("InvocationTargetException"); + initErrorInfo = faaSHandler.faasInitHandler(wrongArgs); + Assert.assertEquals(FaasErrorCode.INIT_FUNCTION_FAIL.getCode(), initErrorInfo.getErrorCode().getValue()); + } + + @Test + public void testFaaSHandlerCallSuccess() throws Exception { + FaaSHandler faaSHandler = new FaaSHandler(); + List initArgs = generateInitArgs(); + faaSHandler.faasInitHandler(initArgs); + List callArgs = generateCallArgs(); + String response = faaSHandler.faasCallHandler(callArgs); + Gson gson = new Gson(); + CallResponse response2 = gson.fromJson(response, CallResponse.class); + Assert.assertEquals("true", response2.getBody().toString()); + } + + @Test + public void testFaaSHandlerCallSuccessWithSerializeBody() throws Exception { + FaaSHandler faaSHandler = new FaaSHandler(); + List initArgs = generateInitArgs(); + faaSHandler.faasInitHandler(initArgs); + Gson gson = new Gson(); + List args = new ArrayList<>(); + args.add("{\"codeID\":\"\",\"config\":{\"functionID\":{\"cpp\":\"\"," + + "\"python\":\"sn:cn:yrk:12345678901234561234567890123456:function:0-he-he:$latest\"}," + + "\"jodID\":\"96f2fc5e-c9ab-4d83-9aa2-89579a29ff4a\",\"logLevel\":30,\"recycleTime\":2}," + + "\"invokeType\":3,\"objectDescriptor\":{\"className\":\"\",\"functionName\":\"execute\"," + + "\"moduleName\":\"faasexecutor\",\"srcLanguage\":\"python\",\"targetLanguage\":\"python\"}}"); + TestRequestEvent testRequestEvent = new TestRequestEvent("{\"id\":\"aaa\"}", 1); + CallRequest callRequest = new CallRequest(); + callRequest.setBody(testRequestEvent); + args.add(gson.toJson(callRequest)); + String response = (String) faaSHandler.faasCallHandler(args); + CallResponse response2 = gson.fromJson(response, CallResponse.class); + Assert.assertEquals("true", response2.getBody().toString()); + } + + @Test + public void testFaaSHandlerCallWithJsonObjectBody() throws Exception { + FaaSHandler faaSHandler = new FaaSHandler(); + Gson gson = new Gson(); + List initArgs = new ArrayList<>(); + String contextConfig = gson.toJson(getContext()); + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::jsonHandler\"}"; + initArgs.add(contextConfig); + initArgs.add(udfEntry); + initArgs.add(getSchedulerData()); + faaSHandler.faasInitHandler(initArgs); + + String arg1 = "{\"codeID\":\"\",\"config\":{\"functionID\":{\"cpp\":\"\"," + + "\"python\":\"sn:cn:yrk:12345678901234561234567890123456:function:0-he-he:$latest\"}," + + "\"jodID\":\"96f2fc5e-c9ab-4d83-9aa2-89579a29ff4a\",\"logLevel\":30,\"recycleTime\":2}," + + "\"invokeType\":3,\"objectDescriptor\":{\"className\":\"\",\"functionName\":\"execute\"," + + "\"moduleName\":\"faasexecutor\",\"srcLanguage\":\"python\",\"targetLanguage\":\"python\"}}"; + List callArgs1 = new ArrayList<>(); + callArgs1.add(arg1); + CallRequest callRequest1 = new CallRequest.Builder().withBody("{}").build(); + callArgs1.add(gson.toJson(callRequest1)); + CallResponse response1 = gson.fromJson(faaSHandler.faasCallHandler(callArgs1), CallResponse.class); + Assert.assertEquals("{}", response1.getBody().toString()); + + List callArgs2 = new ArrayList<>(); + callArgs2.add(arg1); + CallRequest callRequest2 = new CallRequest.Builder().withBody("{\"id\":\"aaa\"}").build(); + callArgs2.add(gson.toJson(callRequest2)); + CallResponse response2 = gson.fromJson(faaSHandler.faasCallHandler(callArgs2), CallResponse.class); + Assert.assertEquals("{\"id\":\"aaa\"}", response2.getBody().toString()); + } + + @Test + public void testFaaSHandlerCallFail() throws Exception { + Gson gson = new Gson(); + FaaSHandler faaSHandler = new FaaSHandler(); + List initArgs = generateCallTimeoutArgs(); + Assert.assertTrue(faaSHandler.faasInitHandler(initArgs).getErrorMessage().contains("processed init handler successfully.")); + + List wrongArgs = new ArrayList<>(); + String resp1 = faaSHandler.faasCallHandler(wrongArgs); + CallResponse response1 = gson.fromJson(resp1, CallResponse.class); + Assert.assertEquals(String.valueOf(FaasErrorCode.FAAS_INIT_ERROR.getCode()), response1.getInnerCode()); + + List callArgs = generateErrorCallArgs(); + UDFManager.getUDFManager().registerInstance("userCallEntry", ""); + String resp3 = faaSHandler.faasCallHandler(callArgs); + CallResponse response3 = gson.fromJson(resp3, CallResponse.class); + Assert.assertEquals("4002", response3.getInnerCode()); + } + + @Test + public void testFaaSHandlerCallUserFail() throws Exception { + FaaSHandler faaSHandler = new FaaSHandler(); + Gson gson = new Gson(); + List initArgs = new ArrayList<>(); + String contextConfig = gson.toJson(getContext()); + String udfEntry = "{\"userInitEntry\":\"com.yuanrong.executor.UserTestHandler.initializer\"," + + "\"userCallEntry\":\"com.yuanrong.executor.UserTestHandler::failedHandler\"}"; + initArgs.add(contextConfig); + initArgs.add(udfEntry); + initArgs.add(getSchedulerData()); + faaSHandler.faasInitHandler(initArgs); + + String arg1 = "{\"codeID\":\"\",\"config\":{\"functionID\":{\"cpp\":\"\"," + + "\"python\":\"sn:cn:yrk:12345678901234561234567890123456:function:0-he-he:$latest\"}," + + "\"jodID\":\"96f2fc5e-c9ab-4d83-9aa2-89579a29ff4a\",\"logLevel\":30,\"recycleTime\":2}," + + "\"invokeType\":3,\"objectDescriptor\":{\"className\":\"\",\"functionName\":\"execute\"," + + "\"moduleName\":\"faasexecutor\",\"srcLanguage\":\"python\",\"targetLanguage\":\"python\"}}"; + List callArgs1 = new ArrayList<>(); + callArgs1.add(arg1); + CallRequest callRequest1 = new CallRequest.Builder().withBody("abc").build(); + callArgs1.add(gson.toJson(callRequest1)); + CallResponse response1 = gson.fromJson(faaSHandler.faasCallHandler(callArgs1), CallResponse.class); + Assert.assertEquals("4002", response1.getInnerCode()); + + } + + @Test + public void testFaaSHandlerResponseExceedSize() throws Exception { + Gson gson = new Gson(); + FaaSHandler faaSHandler = new FaaSHandler(); + List initArgs = generateCallResponseLargeSizeArgs(); + Assert.assertTrue(faaSHandler.faasInitHandler(initArgs).getErrorMessage().contains("processed init handler successfully.")); + + List callArgs = generateCallArgs(); + String response = faaSHandler.faasCallHandler(callArgs); + CallResponse response2 = gson.fromJson(response, CallResponse.class); + Assert.assertEquals("4004", response2.getInnerCode()); + } + + @Test + public void testFaaSException() { + FaaSException ex1 = new FaaSException("this is faas exception1"); + FaaSException ex2 = new FaaSException("this is faas exception2"); + Assert.assertEquals(ex1.equals(ex2), false); + int hash1 = ex1.hashCode(); + int hash2 = ex2.hashCode(); + Assert.assertEquals(hash1 != hash2, true); + Assert.assertTrue(ex1.equals(ex1)); + Assert.assertFalse(ex1.equals(null)); + } + + @Test + public void testFaasCheckPointHandler() { + FaaSHandler faaSHandler = new FaaSHandler(); + faaSHandler.faasCheckPointHandler(null); + } + + @Test + public void testFaasRecoverHandler() { + FaaSHandler faaSHandler = new FaaSHandler(); + faaSHandler.faasRecoverHandler(null); + } + + @Test + public void testFaasShutDownHandlerDoNothing() { + FaaSHandler faaSHandler = new FaaSHandler(); + ErrorInfo response = faaSHandler.faasShutDownHandler(0); + Assert.assertTrue(response.getErrorMessage().contains("run shut down handler successfully.")); + } + + @Test + public void testFaasShutDownHandlerSuccess() { + FaaSHandler faaSHandler = new FaaSHandler(); + List initArgs = generateInitArgsWiThPreStopHandler("Success"); + faaSHandler.faasInitHandler(initArgs); + ErrorInfo response = faaSHandler.faasShutDownHandler(0); + Assert.assertTrue(response.getErrorMessage().contains("run shut down handler successfully.")); + } + + @Test + public void testFaasShutDownHandlerNotExist() { + FaaSHandler faaSHandler = new FaaSHandler(); + List initArgs = generateInitArgsWiThPreStopHandler("NoSuchMethodException"); + faaSHandler.faasInitHandler(initArgs); + ErrorInfo errorInfo = faaSHandler.faasShutDownHandler(0); + Assert.assertEquals(FaasErrorCode.ENTRY_NOT_FOUND.getCode(), errorInfo.getErrorCode().getValue()); + } + + @Test + public void testFaasShutDownHandlerFailed() { + FaaSHandler faaSHandler = new FaaSHandler(); + List initArgs = generateInitArgsWiThPreStopHandler("InvocationTargetException"); + faaSHandler.faasInitHandler(initArgs); + ErrorInfo errorInfo = faaSHandler.faasShutDownHandler(0); + Assert.assertEquals(FaasErrorCode.FUNCTION_RUN_ERROR.getCode(), errorInfo.getErrorCode().getValue()); + } + + @Test + public void testFaasSignalHandler() { + FaaSHandler faaSHandler = new FaaSHandler(); + boolean isException = false; + try { + faaSHandler.faasSignalHandler(null); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testExecute() { + boolean isException = false; + FaaSHandler faaSHandler = new FaaSHandler(); + FunctionMeta meta = FunctionMeta.newBuilder() + .setClassName("MockClass") + .setFunctionName("mockMethod") + .setSignature("()V") + .build(); + ArrayList byteBuffers = new ArrayList<>(); + ByteBuffer buffer = ByteBuffer.allocateDirect(10); + buffer.put((byte)0x01); + byteBuffers.add(buffer); + try { + ReturnType returnType = faaSHandler.execute(meta, Libruntime.InvokeType.InvokeFunction, byteBuffers); + Assert.assertTrue(returnType.getErrorInfo().getErrorMessage().contains("faas get args invalid")); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + + ArrayList byteBuffers2 = new ArrayList<>(); + ByteBuffer buffer2 = ByteBuffer.allocateDirect(10); + buffer2.put((byte)0x01); + byteBuffers2.add(buffer2); + try { + ReturnType returnType = faaSHandler.execute(meta, Libruntime.InvokeType.CreateInstance, byteBuffers2); + Assert.assertTrue(returnType.getErrorInfo().getErrorMessage().contains("faas get args invalid")); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + try { + ReturnType returnType = faaSHandler.execute(meta, Libruntime.InvokeType.InvokeFunctionStateless, + new ArrayList<>()); + Assert.assertEquals("call handler arg list is empty.", + returnType.getErrorInfo().getErrorMessage()); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + ByteBuffer buffer3 = ByteBuffer.allocate(10); + buffer3.put((byte)0x01); + byteBuffers2.add(buffer3); + try { + ReturnType returnType = faaSHandler.execute(meta, Libruntime.InvokeType.GetNamedInstanceMeta, byteBuffers2); + Assert.assertEquals("invalid invoke type", returnType.getErrorInfo().getErrorMessage()); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + } + + @Test + public void testShutdown() { + FaaSHandler faaSHandler = new FaaSHandler(); + ErrorInfo errorInfo = faaSHandler.shutdown(10); + Assert.assertEquals(ErrorCode.ERR_OK, errorInfo.getErrorCode()); + + List initArgs = generateInitArgsWiThPreStopHandler("Success"); + faaSHandler.faasInitHandler(initArgs); + ErrorInfo errorInfo1 = faaSHandler.shutdown(10); + Assert.assertEquals(ErrorCode.ERR_OK, errorInfo1.getErrorCode()); + + List initArgs2 = generateInitArgsWiThPreStopHandler("InterruptedException"); + faaSHandler.faasInitHandler(initArgs2); + ErrorInfo errorInfo2 = faaSHandler.shutdown(1); + Assert.assertEquals(FaasErrorCode.INVOKE_FUNCTION_TIMEOUT.getCode(), errorInfo2.getErrorCode().getValue()); + + Pair pair = null; + boolean isException = false; + try { + pair = faaSHandler.dumpInstance("testID"); + } catch (Exception e) { + isException = true; + } + Assert.assertFalse(isException); + + faaSHandler.loadInstance(null,null); + Assert.assertNotNull(pair); + + Assert.assertEquals(ErrorCode.ERR_OK, faaSHandler.recover().getErrorCode()); + } +} diff --git a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestFunctionHandler.java b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestFunctionHandler.java index 71f70dc..cf5f4cb 100644 --- a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestFunctionHandler.java +++ b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestFunctionHandler.java @@ -20,6 +20,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -33,6 +35,8 @@ import org.powermock.modules.junit4.PowerMockRunnerDelegate; import com.yuanrong.errorcode.ErrorCode; import com.yuanrong.errorcode.ErrorInfo; +import com.yuanrong.errorcode.ModuleCode; +import com.yuanrong.exception.YRException; import com.yuanrong.libruntime.generated.Libruntime; import com.yuanrong.libruntime.generated.Libruntime.FunctionMeta; import com.yuanrong.serialization.Serializer; @@ -56,7 +60,7 @@ public class TestFunctionHandler { * @throws Exception */ @Test - public void testCreateExeceptionWithCause() throws Exception { + public void testCreateExceptionWithCause() throws Exception { FunctionMeta meta = FunctionMeta.newBuilder().setClassName("MockClassName").setSignature("MockSignature") .setFunctionName("MockMethodName").build(); HandlerIntf handler = new FunctionHandler(); @@ -65,6 +69,20 @@ public class TestFunctionHandler { assertTrue(returnType.getErrorInfo().getErrorMessage().contains("failed to create instance due to the cause:")); } + @Test + public void testCreateWithFailedtoFindMethod() throws Exception { + FunctionHandler handler = new FunctionHandler(); + handler.classCache.put("MockClass", MockClass.class); + + FunctionMeta meta = FunctionMeta.newBuilder() + .setClassName("MockClass") + .setFunctionName("mockMethod") + .setSignature("invalidSignature") + .build(); + ReturnType returnType = handler.execute(meta, Libruntime.InvokeType.CreateInstanceStateless, new ArrayList<>()); + assertEquals(returnType.getErrorInfo().getErrorCode(), ErrorCode.ERR_USER_FUNCTION_EXCEPTION); + } + /** * Description: * The message should be in detailed when runtime failed to find Method. @@ -146,6 +164,21 @@ public class TestFunctionHandler { assertEquals("", returnType.getErrorInfo().getErrorMessage()); } + @Test + public void testReturnTypeInvalidInvoke() throws Exception { + FunctionHandler handler = new FunctionHandler(); + handler.classCache.put("MockClass", MockClass.class); + MemberModifier.field(FunctionHandler.class, "instance").set(FunctionHandler.class, new MockClass()); + FunctionMeta meta = FunctionMeta.newBuilder() + .setClassName("MockClass") + .setFunctionName("mockMethod") + .setSignature("()V") + .build(); + ArrayList args = new ArrayList(); + ReturnType returnType = handler.execute(meta, Libruntime.InvokeType.GetNamedInstanceMeta, args); + assertEquals(ErrorCode.ERR_INCORRECT_INVOKE_USAGE, returnType.getErrorInfo().getErrorCode()); + } + @Test public void loadInstanceMismatchedInputException() throws ClassNotFoundException, IOException { String clzName = "MockClass"; @@ -201,6 +234,27 @@ public class TestFunctionHandler { assertTrue(err.getErrorMessage().contains("Failed to invoke instance function")); } + @Test + public void testShutdownNoSuchMethod() throws Exception { + FunctionHandler handler = new FunctionHandler(); + handler.classCache.put("MockNoneClass", MockNoneClass.class); + MemberModifier.field(FunctionHandler.class, "instance").set(FunctionHandler.class, new MockNoneClass()); + MemberModifier.field(FunctionHandler.class, "instanceClassName").set(FunctionHandler.class, "MockNoneClass"); + ErrorInfo err = handler.shutdown(10); + assertEquals(err.getErrorCode(), ErrorCode.ERR_OK); + } + + @Test + public void testShutdownFailedMethod() throws Exception { + FunctionHandler handler = new FunctionHandler(); + handler.classCache.put("MockFailedClass", MockFailedClass.class); + MemberModifier.field(FunctionHandler.class, "instance").set(FunctionHandler.class, new MockFailedClass()); + MemberModifier.field(FunctionHandler.class, "instanceClassName") + .set(FunctionHandler.class, "MockFailedClass"); + ErrorInfo err = handler.shutdown(10); + assertEquals(err.getErrorCode(), ErrorCode.ERR_INNER_SYSTEM_ERROR); + } + @Test public void testRecover() throws Exception { boolean isException = false; @@ -223,4 +277,45 @@ public class TestFunctionHandler { ErrorInfo err = handler.recover(); assertTrue(err.getErrorMessage().contains("Failed to invoke instance function")); } + + @Test + public void testRecoverNoSuchMethod() throws Exception { + FunctionHandler handler = new FunctionHandler(); + handler.classCache.put("MockNoneClass", MockNoneClass.class); + MemberModifier.field(FunctionHandler.class, "instance").set(FunctionHandler.class, new MockNoneClass()); + MemberModifier.field(FunctionHandler.class, "instanceClassName").set(FunctionHandler.class, "MockNoneClass"); + ErrorInfo err = handler.recover(); + assertEquals(err.getErrorCode(), ErrorCode.ERR_OK); + } + + @Test + public void testRecoverFailedMethod() throws Exception { + FunctionHandler handler = new FunctionHandler(); + handler.classCache.put("MockFailedClass", MockFailedClass.class); + MemberModifier.field(FunctionHandler.class, "instance").set(FunctionHandler.class, new MockFailedClass()); + MemberModifier.field(FunctionHandler.class, "instanceClassName") + .set(FunctionHandler.class, "MockFailedClass"); + ErrorInfo err = handler.recover(); + assertEquals(err.getErrorCode(), ErrorCode.ERR_INNER_SYSTEM_ERROR); + } + + @Test + public void testProcessException() throws Exception { + FunctionHandler handler = new FunctionHandler(); + Method method = FunctionHandler.class.getDeclaredMethod("processException", Exception.class, String.class, + String.class); + method.setAccessible(true); + NoSuchMethodException noSuchMethodException = new NoSuchMethodException("testException"); + ReturnType ret1 = (ReturnType) method.invoke(handler, noSuchMethodException, "MockClass", "MockMethod"); + assertEquals(ErrorCode.ERR_INNER_SYSTEM_ERROR, ret1.getCode()); + + ReturnType ret2 = (ReturnType) method.invoke(handler, new InvocationTargetException(noSuchMethodException), + "MockClass", "MockMethod"); + assertEquals(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ret2.getCode()); + + InvocationTargetException invocationException = new InvocationTargetException( + new YRException(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ModuleCode.RUNTIME, "testException")); + ReturnType ret3 = (ReturnType) method.invoke(handler, invocationException, "MockClass", "MockMethod"); + assertEquals(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ret3.getCode()); + } } \ No newline at end of file diff --git a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestReturnType.java b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestReturnType.java index 15bbd54..831c29b 100644 --- a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestReturnType.java +++ b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/TestReturnType.java @@ -28,6 +28,7 @@ import java.util.ArrayList; public class TestReturnType { @Test public void testInitReturnType() { + ReturnType returnType = new ReturnType(ErrorCode.ERR_JOB_USER_CODE_EXCEPTION); ReturnType returnType1 = new ReturnType(ErrorCode.ERR_USER_CODE_LOAD, ByteBuffer.allocate(10)); ReturnType returnType2 = new ReturnType(ErrorCode.ERR_OK, ModuleCode.CORE, "test2"); ReturnType returnType3 = new ReturnType(ErrorCode.ERR_CREATE_RETURN_BUFFER, new byte[] {}); @@ -38,7 +39,7 @@ public class TestReturnType { new ArrayList<>()); returnType6.getBytes(); Assert.assertNotEquals(returnType1, returnType2); - Assert.assertNotEquals(returnType1, returnType3); + Assert.assertNotEquals(returnType, returnType3); Assert.assertNotEquals(returnType4, returnType5); } } diff --git a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/UserTestHandler.java b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/UserTestHandler.java index f7ad934..841f3af 100644 --- a/api/java/yr-runtime/src/test/java/com/yuanrong/executor/UserTestHandler.java +++ b/api/java/yr-runtime/src/test/java/com/yuanrong/executor/UserTestHandler.java @@ -56,6 +56,26 @@ public class UserTestHandler { System.out.println("initialize success"); } + public void illegalArgumentInitializer(Context context) { + throw new IllegalArgumentException("initialize failed"); + } + + public void failedInitializer(Context context) { + throw new RuntimeException("initialize failed"); + } + + public void preStop(Context context) { + System.out.println("preStop success"); + } + + public void failedPreStop(Context context) { + throw new RuntimeException("preStop failed"); + } + + public void timeoutPreStop(Context context) throws InterruptedException { + Thread.sleep(2000); + } + public void timeoutInitializer(Context context) { try { Thread.sleep(3 * 1000 + 500); diff --git a/api/java/yr-runtime/src/test/java/com/yuanrong/runtime/server/TestRuntimeLogger.java b/api/java/yr-runtime/src/test/java/com/yuanrong/runtime/server/TestRuntimeLogger.java index 9f1f9b0..578ef84 100644 --- a/api/java/yr-runtime/src/test/java/com/yuanrong/runtime/server/TestRuntimeLogger.java +++ b/api/java/yr-runtime/src/test/java/com/yuanrong/runtime/server/TestRuntimeLogger.java @@ -95,8 +95,8 @@ public class TestRuntimeLogger { Assert.assertTrue(isException); Mockito.verify(mockLogger, Mockito.times(1)).info("runtime ID {}", runtimeID); - Mockito.verify(mockLogger, Mockito.times(1)).debug("current log path = {}", - String.format("%s%s%s", logDir, File.separator, runtimeID)); + Mockito.verify(mockLogger, Mockito.times(2)).debug("current log path = {}", + String.format("%s", logDir)); Mockito.verify(mockLogger, Mockito.times(2)).debug("current log level = {}", logLevel); Assert.assertNotNull(System.getProperty(ENV_LOG_LEVEL)); diff --git a/api/python/BUILD.bazel b/api/python/BUILD.bazel index 004538e..a0a56d7 100644 --- a/api/python/BUILD.bazel +++ b/api/python/BUILD.bazel @@ -136,10 +136,10 @@ genrule( ln -f -s libssl.so.1.1 $$PYTHON_CODE_DIR/yr/libssl.so ln -f -s libcrypto.so.1.1 $$PYTHON_CODE_DIR/yr/libcrypto.so ln -f -s liblitebus.so.0.0.1 $$PYTHON_CODE_DIR/yr/liblitebus.so - + ln -f -s libspdlog.so.1.12.0 $$PYTHON_CODE_DIR/yr/libspdlog.so.1 ln -f -s libspdlog.so.1 $$PYTHON_CODE_DIR/yr/libspdlog.so - [ -f $$PYTHON_CODE_DIR/yr/libacl_plugin.so ] && chmod +w $$PYTHON_CODE_DIR/yr/libacl_plugin.so + ln -f -s libspdlog.so.1 $$PYTHON_CODE_DIR/yr/libspdlog.so cd $$BASE_DIR echo "$$BASE_DIR" >> $@ diff --git a/api/python/functionsdk.py b/api/python/functionsdk.py new file mode 100644 index 0000000..52bb982 --- /dev/null +++ b/api/python/functionsdk.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""faas function sdk""" + +import yr + + +class Function(): + """Class Function""" + def __init__(self, context, function_name=None, instance_name=None) -> None: + self.__context = context + self.__function_name = function_name + self.__instance_name = instance_name + self.__function = yr.Function(function_name) + + def invoke(self, payload): + """Invoke stateless or stateful function.""" + return self.__function.invoke(payload) diff --git a/api/python/requirements.txt b/api/python/requirements.txt index e6fc58a..a254e5a 100644 --- a/api/python/requirements.txt +++ b/api/python/requirements.txt @@ -1,4 +1,4 @@ -cloudpickle==2.0.0 +cloudpickle==2.2.1 msgpack==1.0.5 protobuf==4.25.5 cython==3.0.10 diff --git a/api/python/requirements_for_py37.txt b/api/python/requirements_for_py37.txt index 3f1b471..9c0b7fa 100644 --- a/api/python/requirements_for_py37.txt +++ b/api/python/requirements_for_py37.txt @@ -1,4 +1,4 @@ -cloudpickle==2.0.0 +cloudpickle==2.2.1 msgpack==1.0.5 protobuf==3.20.0 cython==3.0.10 \ No newline at end of file diff --git a/api/python/yr/__init__.py b/api/python/yr/__init__.py index 9092fbb..7b24736 100644 --- a/api/python/yr/__init__.py +++ b/api/python/yr/__init__.py @@ -17,24 +17,6 @@ """ yr api """ - -__all__ = [ - "init", "finalize", "Config", "UserTLSConfig", - "put", "get", - "wait", "cancel", "invoke", "instance", "method", "InvokeOptions", "exit", - "Context", "GetParam", "GetParams", - "Affinity", "AffinityType", "AffinityKind", "LabelOperator", "OperatorType", - "kv_read", "kv_write", "kv_set", "kv_get", "kv_get_with_param", "kv_del", "kv_m_write_tx", - "ExistenceOpt", "WriteMode", "CacheType", "SetParam", "MSetParam", "CreateParam", "ConsistencyType", - "save_state", "load_state", "get_instance", "is_initialized", - "Gauge", "Alarm", "java_instance_class", "go_instance_class", "create_function_group", - "AlarmSeverity", "AlarmInfo", "UInt64Counter", "DoubleCounter", - "FunctionGroupOptions", "SchedulingAffinityType", "FunctionGroupContext", "ServerInfo", "DeviceInfo", - "get_function_group_context", "create_resource_group", "remove_resource_group", "ResourceGroup", - "FunctionProxy", "InstanceCreator", "InstanceProxy", "MethodProxy", "FunctionGroupHandler", - "FunctionGroupMethodProxy", "get_node_ip_address", "list_named_instances" -] - import os import ctypes @@ -64,18 +46,22 @@ for so_path in [ except OSError: pass + # E402: import not at top of file # We must load so before import datasystem, so the lint is not really useful from yr.apis import ( # noqa: E402 init, finalize, put, get, invoke, instance, wait, cancel, method, exit, + create_stream_producer, create_stream_consumer, delete_stream, kv_read, kv_write, kv_del, kv_set, kv_get, kv_get_with_param, - kv_m_write_tx, kv_write_with_param, get_instance, is_initialized, save_state, load_state, + kv_m_write_tx, kv_write_with_param, get_instance, is_initialized, + query_global_producers_num, query_global_consumers_num, save_state, load_state, cpp_function, java_function, go_function, cpp_instance_class, java_instance_class, go_instance_class, resources, create_resource_group, remove_resource_group, get_node_ip_address, list_named_instances ) + from yr.fcc import ( # noqa: E402 create_function_group, get_function_group_context ) @@ -89,9 +75,32 @@ from yr.config import ( # noqa: E402 FunctionGroupContext, ServerInfo, DeviceInfo, ResourceGroupOptions ) -from yr.affinity import Affinity, AffinityType, AffinityKind, LabelOperator, OperatorType # noqa: E402 +from yr.stream import ProducerConfig, SubscriptionConfig, Element # noqa: E402 +from yr.functionsdk.function import Function # noqa: E402 +from yr.functionsdk.context import Context # noqa: E402 +from yr.affinity import Affinity, AffinityType, AffinityKind, AffinityScope, LabelOperator, OperatorType # noqa: E402 from yr.metrics import Gauge, Alarm, UInt64Counter, DoubleCounter # noqa: E402 + from yr.decorator.function_proxy import FunctionProxy # noqa: E402 from yr.decorator.instance_proxy import ( # noqa: E402 - InstanceCreator, InstanceProxy, MethodProxy, FunctionGroupHandler, FunctionGroupMethodProxy -) + InstanceCreator, InstanceProxy, MethodProxy, FunctionGroupHandler, FunctionGroupMethodProxy) + +__all__ = [ + "init", "finalize", "Config", "UserTLSConfig", + "put", "get", + "wait", "cancel", "invoke", "instance", "method", "InvokeOptions", "exit", + "ProducerConfig", "SubscriptionConfig", "Element", + "create_stream_producer", "create_stream_consumer", "delete_stream", + "Context", "Function", "GetParam", "GetParams", + "Affinity", "AffinityType", "AffinityKind", "AffinityScope", "LabelOperator", "OperatorType", + "kv_read", "kv_write", "kv_set", "kv_get", "kv_get_with_param", "kv_del", "kv_m_write_tx", + "ExistenceOpt", "WriteMode", "CacheType", "SetParam", "MSetParam", "CreateParam", "ConsistencyType", + "save_state", "load_state", "get_instance", "is_initialized", + "query_global_producers_num", "query_global_consumers_num", + "Gauge", "Alarm", "java_instance_class", "go_instance_class", "create_function_group", + "AlarmSeverity", "AlarmInfo", "UInt64Counter", "DoubleCounter", + "FunctionGroupOptions", "SchedulingAffinityType", "FunctionGroupContext", "ServerInfo", "DeviceInfo", + "get_function_group_context", "create_resource_group", "remove_resource_group", "ResourceGroup", + "FunctionProxy", "InstanceCreator", "InstanceProxy", "MethodProxy", "FunctionGroupHandler", + "FunctionGroupMethodProxy", "get_node_ip_address", "list_named_instances" +] diff --git a/api/python/yr/affinity.py b/api/python/yr/affinity.py index a2077f7..f470ca8 100644 --- a/api/python/yr/affinity.py +++ b/api/python/yr/affinity.py @@ -17,7 +17,8 @@ """affinity""" import enum -from typing import List +from dataclasses import dataclass +from typing import List, Optional class AffinityType(enum.Enum): @@ -105,16 +106,42 @@ class LabelOperator: self.values = values if values else [] +class AffinityScope(enum.Enum): + """ + Enum for Affinity scope of instances. + """ + POD = 1 + """ + POD level instance affinity + POD 级别实例亲和 + """ + NODE = 2 + """ + NODE level instance affinity + NODE 级别实例亲和 + """ + + +@dataclass class Affinity: """ Represents an affinity. """ - def __init__(self, affinity_kind: AffinityKind, affinity_type: AffinityType, label_operators: List[LabelOperator]): - """ - affinity_kind (AffinityKind): The kind of affinity. - affinity_type (AffinityType): The type of affinity. - label_operators (List[LabelOperator]): The label operators in the affinity. - """ + #: The kind of affinity. + affinity_kind: AffinityKind + #: The type of affinity. + affinity_type: AffinityType + #: The label operators in the affinity. + label_operators: List[LabelOperator] + #: The affinity scope of instances. + affinity_scope: Optional[AffinityScope] = None + + def __init__(self, + affinity_kind: AffinityKind, + affinity_type: AffinityType, + label_operators: List[LabelOperator], + affinity_scope: Optional[AffinityScope] = None): self.affinity_kind = affinity_kind self.affinity_type = affinity_type self.label_operators = label_operators + self.affinity_scope = affinity_scope diff --git a/api/python/yr/apis.py b/api/python/yr/apis.py index 396b9d6..780437e 100644 --- a/api/python/yr/apis.py +++ b/api/python/yr/apis.py @@ -31,10 +31,11 @@ from yr.config import ClientInfo, Config, InvokeOptions from yr.config_manager import ConfigManager from yr.decorator import function_proxy, instance_proxy from yr.executor.executor import Executor -from yr.fnruntime import auto_get_cluster_access_info +from yr.fnruntime import Consumer, Producer, auto_get_cluster_access_info from yr.object_ref import ObjectRef from yr.resource_group_ref import RgObjectRef from yr.runtime import ExistenceOpt, WriteMode, CacheType, SetParam, MSetParam, CreateParam, GetParams +from yr.stream import ProducerConfig, SubscriptionConfig from yr.decorator.function_proxy import FunctionProxy from yr.decorator.instance_proxy import InstanceCreator, InstanceProxy from yr.common.utils import CrossLanguageInfo @@ -113,10 +114,12 @@ def _auto_get_cluster_access_info(conf): cluster_access_info = auto_get_cluster_access_info({ "serverAddr": conf.server_address, "dsAddr": conf.ds_address, + "inCluster": conf.in_cluster }, args) conf.server_address = cluster_access_info["serverAddr"] conf.ds_address = cluster_access_info["dsAddr"] + conf.in_cluster = cluster_access_info["inCluster"] return conf @@ -256,7 +259,7 @@ def put(obj: object, create_param: CreateParam = CreateParam()) -> ObjectRef: >>> print(yr.get(obj_ref4)) >>> 100 """ - if (isinstance(obj, (bytes, bytearray, memoryview)) and len(obj) == 0): + if obj is None or (isinstance(obj, (bytes, bytearray, memoryview)) and len(obj) == 0): raise ValueError("value is None or has zero length") # Make sure that the value is not an object ref. if isinstance(obj, ObjectRef): @@ -273,6 +276,9 @@ def get(obj_refs: Union["ObjectRef", List, "RgObjectRef"], timeout: int = consta Retrieve the value of an object stored in the backend based on the object's key. The interface call will block until the object's value is obtained or a timeout occurs. + Note: + yr.get() uniformly returns a memoryview pointer for bytes, bytearray, and memoryview types. + Args: obj_refs (ObjectRef, List[ObjectRef]): The object_ref of the object in the data system. timeout (int, optional): The timeout value. A value of -1 means wait indefinitely. Limit: -1, (0, ∞). @@ -672,6 +678,121 @@ def exit() -> None: runtime_holder.global_runtime.get_runtime().exit() +@check_initialized +def create_stream_producer(stream_name: str, config: ProducerConfig) -> Producer: + """ + Create a producer. + + Args: + stream_name (str): The name of the stream. + The length must be less than 256 characters and contain only the following characters + `(a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^&*()\\+\\=;:)` . + config (ProducerConfig): The configuration of the producer. + + Returns: + Producer. + + Raises: + RuntimeError: If creating the Producer fails. + + Examples: + >>> try: + ... producer_config = ProducerConfig( + ... delay_flush_time=5, + ... page_size=1024 * 1024, + ... max_stream_size=1024 * 1024 * 1024, + ... auto_clean_up=True, + ... ) + ... stream_producer = create_stream_producer("streamName", producer_config) + ... except RuntimeError as exp: + ... # 处理异常 + ... pass + """ + return runtime_holder.global_runtime.get_runtime().create_stream_producer(stream_name, config) + + +@check_initialized +def create_stream_consumer(stream_name: str, config: SubscriptionConfig) -> Consumer: + """ + Create a consumer. + + Args: + stream_name (str): The name of the stream. + The length must be less than 256 characters and contain only the following characters + `(a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^&*()\\+\\=;:)` . + config (SubscriptionConfig): The configuration of the consumer. + + Returns: + Consumer. + + Raises: + RuntimeError: If creating the Consumer fails. + + Examples: + >>> try: + ... config = SubscriptionConfig("subName", SubscriptionType.STREAM) + ... consumer = create_stream_consumer("streamName", config) + ... except RuntimeError as exp: + ... pass + """ + return runtime_holder.global_runtime.get_runtime().create_stream_consumer(stream_name, config) + + +@check_initialized +def query_global_producers_num(stream_name: str) -> int: + """ + 查询流生产者数量 + + Args: + stream_name: 流名称 + + Returns: + 数量 + """ + return runtime_holder.global_runtime.get_runtime().query_global_producers_num(stream_name) + + +@check_initialized +def query_global_consumers_num(stream_name: str) -> int: + """ + 查询流消费者数量 + + Args: + stream_name: 流名称 + + Returns: + 数量 + """ + return runtime_holder.global_runtime.get_runtime().query_global_consumers_num(stream_name) + + +@check_initialized +def delete_stream(stream_name: str) -> None: + """ + Delete the data stream. When the global count of producers and consumers is 0, the data stream is no longer in use, + and the metadata related to the data stream on each worker and master is cleaned up. + This function can be called on any Host node. + + Args: + stream_name (str): The name of the stream. + The length must be less than 256 characters and contain only the following characters + `(a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^&*()\\+\\=;:)` . + + Returns: + None. + + Raises: + RuntimeError: If the stream fails to be deleted in the data system, an exception will be thrown. + + Examples: + >>> try: + ... delete_stream("streamName") + ... except RuntimeError as exp: + ... pass + """ + runtime_holder.global_runtime.get_runtime().delete_stream(stream_name) + + @check_initialized def kv_write(key: str, value: bytes, existence: ExistenceOpt = ExistenceOpt.NONE, write_mode: WriteMode = WriteMode.NONE_L2_CACHE, ttl_second: int = constants.DEFAULT_NO_TTL_LIMIT, diff --git a/api/python/yr/cluster_mode_runtime.py b/api/python/yr/cluster_mode_runtime.py index 5546255..e35bc60 100644 --- a/api/python/yr/cluster_mode_runtime.py +++ b/api/python/yr/cluster_mode_runtime.py @@ -24,12 +24,13 @@ from yr.err_type import ErrorCode, ErrorInfo, ModuleCode from yr.common.types import InvokeArg, GroupInfo from yr.config import InvokeOptions from yr.config_manager import ConfigManager -from yr.fnruntime import Fnruntime, SharedBuffer -from yr.libruntime_pb2 import FunctionMeta +from yr.fnruntime import Consumer, Fnruntime, Producer, SharedBuffer +from yr.libruntime_pb2 import ApiType, FunctionMeta from yr.common.utils import GaugeData, UInt64CounterData, DoubleCounterData from yr.object_ref import ObjectRef from yr.runtime import Runtime, AlarmInfo, SetParam, MSetParam, CreateParam, GetParams from yr.serialization import Serialization +from yr.stream import ProducerConfig, SubscriptionConfig from yr.accelerate.shm_broadcast import Handle _logger = logging.getLogger(__name__) @@ -236,6 +237,8 @@ class ClusterModeRuntime(Runtime): :return: None """ self._check_init() + if func_meta.apiType == ApiType.Faas: + return self.libruntime.invoke_by_name(func_meta, self._package_python_args(args, True), opt, return_nums) return self.libruntime.invoke_by_name(func_meta, self._package_python_args(args), opt, return_nums, group_info) def create_instance(self, func_meta: FunctionMeta, args: List[Any], opt: InvokeOptions, @@ -320,6 +323,55 @@ class ClusterModeRuntime(Runtime): self.libruntime.finalize() self.__enable_flag = False + def create_stream_producer(self, stream_name: str, config: ProducerConfig) -> Producer: + """ + create stream producer + :param stream_name: stream name + :param config: ProducerConfig + :return: producer + """ + if config.max_stream_size < 0: + raise RuntimeError(f"Invalid parameter, max_stream_size: {config.max_stream_size}, expect >= 0") + if config.retain_for_num_consumers < 0: + raise RuntimeError( + f"Invalid parameter, retain_for_num_consumers: {config.retain_for_num_consumers}, expect >= 0") + if config.reserve_size < 0: + raise RuntimeError(f"Invalid parameter, reserve_size: {config.reserve_size}, expect >= 0") + return self.libruntime.create_stream_producer(stream_name, config) + + def create_stream_consumer(self, stream_name: str, config: SubscriptionConfig) -> Consumer: + """ + create stream consumer + :param stream_name: stream name + :param config: SubscriptionConfig + :return: consumer + """ + return self.libruntime.create_stream_consumer(stream_name, config) + + def delete_stream(self, stream_name: str) -> None: + """ + delete stream + :param stream_name: stream name + :return: None + """ + self.libruntime.delete_stream(stream_name) + + def query_global_producers_num(self, stream_name: str) -> int: + """ + query global producers num + :param stream_name: stream name + :return: producers num + """ + return self.libruntime.query_global_producers_num(stream_name) + + def query_global_consumers_num(self, stream_name: str) -> int: + """ + query global consumers num + :param stream_name: stream name + :return: consumers num + """ + return self.libruntime.query_global_consumers_num(stream_name) + def get_real_instance_id(self, instance_id: str) -> str: """ get real instance id @@ -503,6 +555,24 @@ class ClusterModeRuntime(Runtime): """ self.libruntime.set_alarm(name, description, info) + def peek_object_ref_stream(self, generator_id, blocking=True, timeout_ms=-1): + """ + peek object reference stream. + Args: + timeout_ms int. + generator_id str. + Returns: + object_id + """ + self._check_init() + result = self.libruntime.peek_object_ref_stream(generator_id, blocking, timeout_ms) + if not isinstance(result, str): + objects = Serialization().deserialize(result) + for obj in objects: + if isinstance(obj, YRInvokeError): + raise obj.origin_error() + return result + def generate_group_name(self) -> str: """ generate group name. @@ -597,17 +667,21 @@ class ClusterModeRuntime(Runtime): """ return self.libruntime.add_return_object(obj_ids) - def _package_python_args(self, args_list): + def _package_python_args(self, args_list, is_faas=False): """package python args""" args_list_new = [] for arg in args_list: if isinstance(arg, ObjectRef): invoke_arg = InvokeArg(buf=bytes(), is_ref=True, obj_id=arg.id, nested_objects=set()) else: - serialized_arg = Serialization().serialize(arg) - invoke_arg = InvokeArg(buf=None, is_ref=False, obj_id="", - nested_objects=set([ref.id for ref in serialized_arg.nested_refs]), - serialized_obj=serialized_arg) + if is_faas: + invoke_arg = InvokeArg(buf=arg, is_ref=False, obj_id="", + nested_objects=set()) + else: + serialized_arg = Serialization().serialize(arg) + invoke_arg = InvokeArg(buf=None, is_ref=False, obj_id="", + nested_objects=set([ref.id for ref in serialized_arg.nested_refs]), + serialized_obj=serialized_arg) args_list_new.append(invoke_arg) return args_list_new diff --git a/api/python/yr/code_manager.py b/api/python/yr/code_manager.py index 7f9c23c..c85afed 100644 --- a/api/python/yr/code_manager.py +++ b/api/python/yr/code_manager.py @@ -18,17 +18,29 @@ import importlib.util import os +import re import sys import threading from typing import Callable, List from yr import log -from yr.common import constants +from yr.common import constants, utils from yr.common.singleton import Singleton -from yr.err_type import ErrorInfo +from yr.err_type import ErrorCode, ErrorInfo, ModuleCode +from yr.functionsdk.error_code import FaasErrorCode from yr.libruntime_pb2 import LanguageType _DEFAULT_ADMIN_FUNC_PATH = "/adminfunc/" +_MAX_FAAS_ENTRY_NUMS = 3 +_MIN_FAAS_ENTRY_NUMS = 2 + + +def _are_faas_entries(code_paths: List[str]) -> bool: + if len(code_paths) > _MAX_FAAS_ENTRY_NUMS or len(code_paths) < _MIN_FAAS_ENTRY_NUMS: + return False + match_re = re.match(constants.PATTERN_FAAS_ENTRY, code_paths[constants.INDEX_SECOND]) is not None + # For example, '/dcache/init.d/bucket-id' is also not a valid faas entry. + return match_re @Singleton @@ -52,21 +64,6 @@ class CodeManager: self.deploy_dir = os.environ.get(constants.ENV_KEY_FUNCTION_LIBRARY_PATH) self.load_code_from_datasystem_func: Callable = None - @staticmethod - def load_functions(code_paths: List[str]) -> ErrorInfo: - """ - Load code paths and return ErrorInfo object. - - Args: - code_paths (List[str]): List of paths to load. - - Returns: - ErrorInfo: Error information for any errors encountered during loading. - """ - for code_path in code_paths: - sys.path.insert(constants.INDEX_FIRST, code_path) - return ErrorInfo() - def clear(self): """clear""" self.code_map.clear() @@ -179,6 +176,77 @@ class CodeManager: code_dir, module_name, code_key, entry_name) return code + def load_functions(self, code_paths: List[str]) -> ErrorInfo: + """ + Load code paths and return ErrorInfo object. + + Args: + code_paths (List[str]): List of paths to load. + + Returns: + ErrorInfo: Error information for any errors encountered during loading. + """ + if _are_faas_entries(code_paths): + # to judge 'code_paths' are FaaS entries. + for code_path, code_key in zip(code_paths, [constants.KEY_USER_INIT_ENTRY, constants.KEY_USER_CALL_ENTRY, + constants.KEY_USER_SHUT_DOWN_ENTRY]): + with self.__lock: + self.entry_map[code_key] = code_path + error_info = self.__load_faas_entry(code_path, code_key) + if error_info.error_code != ErrorCode.ERR_OK: + return error_info + else: + for code_path in code_paths: + sys.path.insert(constants.INDEX_FIRST, code_path) + return ErrorInfo() + + def __load_faas_entry(self, user_entry: str, code_key: str): + """ + load module and the entry code, throw RuntimeError if failed. + """ + user_hook_length = 2 + log.get_logger().debug("Faas load module and entry [%s] from [%s]", user_entry, self.custom_handler) + user_hook_splits = user_entry.rsplit(".", maxsplit=1) if isinstance(user_entry, str) else None + if len(user_hook_splits) != user_hook_length: + if code_key == constants.KEY_USER_INIT_ENTRY: + return ErrorInfo() + msg = convert_response_to_jsonstr( + "User hook not satisfy requirement, expect: xxx.xxx", + FaasErrorCode.INIT_FUNCTION_FAIL) + return ErrorInfo(ErrorCode.ERR_USER_CODE_LOAD, ModuleCode.RUNTIME, msg) + + user_module, user_entry = user_hook_splits[0], user_hook_splits[1] + log.get_logger().debug("User module: %s, entry: %s", user_module, user_entry) + + try: + user_code = self.load_code_from_local(self.custom_handler, user_module, user_entry, code_key) + except ValueError as exp: + log.get_logger().error("Missing user module: %s, exception: %s", user_entry, exp) + msg = convert_response_to_jsonstr(f"Missing user module: {user_entry}, err: {exp}", + FaasErrorCode.ENTRY_EXCEPTION) + return ErrorInfo(ErrorCode.ERR_USER_CODE_LOAD, ModuleCode.RUNTIME, msg) + except ImportError as exp: + log.get_logger().error("Failed to import user module: %s, exception: %s", user_entry, exp) + msg = convert_response_to_jsonstr(f"Failed to import user module: {user_entry}, err: {exp}", + FaasErrorCode.ENTRY_EXCEPTION) + return ErrorInfo(ErrorCode.ERR_USER_CODE_LOAD, ModuleCode.RUNTIME, msg) + except SyntaxError as exp: + log.get_logger().error("Failed to load user code: %s, exception: %s", user_entry, exp) + msg = convert_response_to_jsonstr("Failed to load user code. There is syntax error in user code: " + f"{user_entry}, err: {exp}", FaasErrorCode.ENTRY_EXCEPTION) + return ErrorInfo(ErrorCode.ERR_USER_CODE_LOAD, ModuleCode.RUNTIME, msg) + except Exception as exp: + log.get_logger().error("Failed to load user code: %s, exception: %s", user_entry, exp) + msg = convert_response_to_jsonstr(f"Failed to load user code: {user_entry}, err: {exp}", + FaasErrorCode.ENTRY_EXCEPTION) + return ErrorInfo(ErrorCode.ERR_USER_CODE_LOAD, ModuleCode.RUNTIME, msg) + + if user_code is None: + log.get_logger().error("Missing user entry. %s", user_entry) + msg = convert_response_to_jsonstr(f"Missing user entry. {user_entry}", FaasErrorCode.ENTRY_EXCEPTION) + return ErrorInfo(ErrorCode.ERR_USER_CODE_LOAD, ModuleCode.RUNTIME, msg) + return ErrorInfo() + def __load_module(self, code_dir, module_name): """load module using cache""" if code_dir is None: @@ -220,3 +288,13 @@ class CodeManager: raise self.module_cache[file_path] = module return module + + +def convert_response_to_jsonstr(message: str, status_code: FaasErrorCode) -> str: + """Method transform_call_response_to_str""" + message = "" if message is None else message + result = dict( + message=message, + errorCode=str(status_code.value) + ) + return utils.to_json_string(result) diff --git a/api/python/yr/common/constants.py b/api/python/yr/common/constants.py index c27ead9..0f15d9c 100644 --- a/api/python/yr/common/constants.py +++ b/api/python/yr/common/constants.py @@ -51,6 +51,7 @@ ENV_KEY_ENV_DELEGATE_DOWNLOAD = "ENV_DELEGATE_DOWNLOAD" ENV_KEY_LD_LIBRARY_PATH = "LD_LIBRARY_PATH" ENV_KEY_FUNCTION_LIBRARY_PATH = "YR_FUNCTION_LIB_PATH" +PATTERN_FAAS_ENTRY = r'^[^/]*\.[^/]*$' # conatains only one '.' and without '/' KEY_USER_INIT_ENTRY = "userInitEntry" KEY_USER_CALL_ENTRY = "userCallEntry" KEY_USER_SHUT_DOWN_ENTRY = "userShutDownEntry" @@ -70,3 +71,5 @@ class Metadata(IntEnum): CROSS_LANGUAGE = 1 PYTHON = 2 BYTES = 3 + MEMORYVIEW = 4 + BYTEARRAY = 5 diff --git a/api/python/yr/compiled_dag_ref.py b/api/python/yr/compiled_dag_ref.py new file mode 100644 index 0000000..6893416 --- /dev/null +++ b/api/python/yr/compiled_dag_ref.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Any, List, Optional + +import yr +from yr.exception import ( + GetTimeoutError, + YRChannelError, + YRChannelTimeoutError, + YRInvokeError, +) + + +def _process_return_vals(return_vals: List[Any], return_single_output: bool): + """ + Process return values for return to the DAG caller. Any exceptions found in + return_vals will be raised. If return_single_output=True, it indicates that + the original DAG did not have a MultiOutputNode, so the DAG caller expects + a single return value instead of a list. + """ + # Check for exceptions. + if isinstance(return_vals, Exception): + raise return_vals + + for val in return_vals: + if isinstance(val, YRInvokeError): + raise val + + if return_single_output: + if len(return_vals) != 1: + raise ValueError("The DAG caller expected a single return value, but got multiple.") + return return_vals[0] + + return return_vals + + +class CompiledDAGRef: + """ + A reference to a compiled DAG execution result. + + This is a subclass of ObjectRef and resembles ObjectRef. For example, + similar to ObjectRef, yr.get() can be called on it to retrieve the result. + However, there are several major differences: + 1. yr.get() can only be called once per CompiledDAGRef. + 2. yr.wait() is not supported. + 3. CompiledDAGRef cannot be copied, deep copied, or pickled. + 4. CompiledDAGRef cannot be passed as an argument to another task. + """ + + def __init__( + self, + dag: "yr.dag.CompiledDAG", + execution_index: int, + channel_index: Optional[int] = None, + ): + """ + Args: + dag: The compiled DAG that generated this CompiledDAGRef. + execution_index: The index of the execution for the DAG. + A DAG can be executed multiple times, and execution index + indicates which execution this CompiledDAGRef corresponds to. + actor_execution_loop_refs: The actor execution loop refs that + are used to execute the DAG. This can be used internally to + check the task execution errors in case of exceptions. + channel_index: The index of the DAG's output channel to fetch + the result from. A DAG can have multiple output channels, and + channel index indicates which channel this CompiledDAGRef + corresponds to. If channel index is not provided, this CompiledDAGRef + wraps the results from all output channels. + + """ + self._dag = dag + self._execution_index = execution_index + self._channel_index = channel_index + # Whether yr.get() was called on this CompiledDAGRef. + self._yr_get_called = False + self._dag_output_channels = dag.dag_output_channels + + def __str__(self): + return ( + f"CompiledDAGRef({self._dag.get_id()}, " + f"execution_index={self._execution_index}, " + f"channel_index={self._channel_index})" + ) + + def __del__(self): + # If the dag is already teardown, it should do nothing. + if self._dag.is_teardown: + return + + if self._yr_get_called: + # get() was already called, no further cleanup is needed. + return + + self._dag._delete_execution_results(self._execution_index, self._channel_index) + + def get(self, timeout: Optional[float] = None): + """ + Args: + timeout (float, optional): The timeout value. A value of -1 means wait indefinitely. Limit: -1, (0, ∞). + """ + if self._yr_get_called: + raise ValueError( + "yr.get() can only be called once " + "on a CompiledDAGRef, and it was already called." + ) + + self._yr_get_called = True + try: + self._dag.execute_until( + self._execution_index, self._channel_index, timeout + ) + return_vals = self._dag.get_execution_results( + self._execution_index, self._channel_index + ) + except YRChannelTimeoutError: + raise + except YRChannelError: + # If we get a channel error, we'd like to call yr.get() on + # the actor execution loop refs to check if this is a result + # of task execution error which could not be passed down + # (e.g., when a pure NCCL channel is used, it is only + # able to send tensors, but not the wrapped exceptions). + # In this case, we'd like to raise the task execution error + # (which is the actual cause of the channel error) instead + # of the channel error itself. + # actor task refs have errors. + actor_execution_loop_refs = list(self._dag.worker_task_refs.values()) + try: + yr.get(actor_execution_loop_refs, timeout=10) + except GetTimeoutError as timeout_error: + raise Exception( + "Timed out when getting the actor execution loop exception. " + "This should not happen, please file a GitHub issue." + ) from timeout_error + except Exception as execution_error: + # Use 'from None' to suppress the context of the original + # channel error, which is not useful to the user. + raise execution_error from None + except Exception as e: + raise e + return _process_return_vals(return_vals, True) + + +class CompiledDAGFuture: + """ + A reference to a compiled DAG execution result, when executed with asyncio. + This differs from CompiledDAGRef in that `await` must be called on the + future to get the result, instead of `yr.get()`. + + This resembles async usage of ObjectRefs. For example, similar to + ObjectRef, `await` can be called directly on the CompiledDAGFuture to + retrieve the result. However, there are several major differences: + 1. `await` can only be called once per CompiledDAGFuture. + 2. yr.wait() is not supported. + 3. CompiledDAGFuture cannot be copied, deep copied, or pickled. + 4. CompiledDAGFuture cannot be passed as an argument to another task. + """ + + def __init__( + self, + dag: "yr.dag.CompiledDAG", + execution_index: int, + fut: "asyncio.Future", + channel_index: Optional[int] = None, + ): + self._dag = dag + self._execution_index = execution_index + self._fut = fut + self._channel_index = channel_index + + def __str__(self): + return ( + f"CompiledDAGFuture({self._dag.get_id()}, " + f"execution_index={self._execution_index}, " + f"channel_index={self._channel_index})" + ) + + def __await__(self): + if self._fut is None: + raise ValueError( + "CompiledDAGFuture can only be awaited upon once, and it has " + "already been awaited upon." + ) + + # NOTE: If the object is zero-copy deserialized, then it will + # stay in scope as long as this future is in scope. Therefore, we + # delete self._fut here before we return the result to the user. + fut = self._fut + self._fut = None + + if not self._dag.has_execution_results(self._execution_index): + result = yield from fut.__await__() + self._dag._max_finished_execution_index += 1 + self._dag._cache_execution_results(self._execution_index, result) + + return_vals = self._dag.get_execution_results( + self._execution_index, self._channel_index + ) + return _process_return_vals(return_vals, True) + + def __del__(self): + if self._dag.is_teardown: + return + + if self._fut is None: + # await() was already called, no further cleanup is needed. + return + + self._dag.delete_execution_results(self._execution_index, self._channel_index) diff --git a/api/python/yr/config.py b/api/python/yr/config.py index 6c2dfdd..d103d8f 100644 --- a/api/python/yr/config.py +++ b/api/python/yr/config.py @@ -21,7 +21,7 @@ import dataclasses import json from dataclasses import asdict, dataclass, field from typing import Dict, List, Union, Optional, get_origin, Any -from enum import IntEnum +from enum import Enum, IntEnum from yr.affinity import Affinity _DEFAULT_CONNECTION_NUMS = 100 @@ -30,6 +30,7 @@ _DEFAULT_MAX_TASK_INSTANCE_NUM = -1 _DEFAULT_MAX_CONCURRENCY_CREATE_NUM = 100 _DEFAULT_CONCURRENCY = 1 _DEFAULT_RECYCLE_TIME = 2 +_DEFAULT_HTTP_IOC_THREADS_NUM = 400 _DEFAULT_RPC_TIMOUT = 30 * 60 _MAX_INT = 0x7FFFFFFF _MIN_INT = 0 @@ -56,7 +57,14 @@ class UserTLSConfig: @dataclass class DeploymentConfig: """ - Auto Deployment Configuration Class. + AutoDeploymentConfig + + Attributes: + cpu(str): cpu acquired, the unit is millicpu + mem(str): mem acquiored (MB) + datamem(str): data system mem acquired (MB) + spill_path(str): spill path, when out of memory will flush data to disk + spill_size(str): spill size limit (MB) """ cpu: int = 0 mem: int = 0 @@ -75,6 +83,10 @@ class Config: function_id: str = "" #: Cpp function id which you deploy, get default by env `YR_CPP_FUNCID`. cpp_function_id: str = "" + #: Use default function for cpp. + cpp_auto_function_name: str = "" + #: Function name which need in runtime. + function_name: str = "" #: System cluster address, get default by env `YR_SERVER_ADDRESS`. server_address: str = "" #: DataSystem address, get default by env `YR_DS_ADDRESS`. @@ -133,13 +145,18 @@ class Config: certificate_file_path: str = "" #: Server certificate file path. verify_file_path: str = "" + #: Client private key encryption password. + private_key_paaswd: str = "" + #: HTTP link worker thread. + http_ioc_threads_num: int = _DEFAULT_HTTP_IOC_THREADS_NUM #: Server name, used to identify and connect to a specific server instance. server_name: str = "" #: Namespace, used to organize and isolate configurations or resources. ns: str = "" + tenant_id: str = "" #: Whether to enable metric collection. ``False`` indicates disabled, and ``True`` indicates enabled. - #: The default value is ``False``. This takes effect only when called in the cluster. - enable_metrics: bool = False + #: The default value is ``True``. This takes effect only when called in the cluster. + enable_metrics: bool = True #: Used to set custom environment variables for the runtime. Currently, only `LD_LIBRARY_PATH` is supported. custom_envs: Dict[str, str] = field(default_factory=dict) #: Function master address list. @@ -161,6 +178,13 @@ class Config: runtime_private_key_path: str = "" num_cpus: Optional[int] = None runtime_env: Optional[Dict[str, Any]] = None + #: If ``True``, the output from all of the job processes on all nodes will be directed to the driver, + #: default is ``False``. + log_to_driver: bool = False + #: If ``True``, deduplicates logs that appear redundantly across multiple processes, default True. + #: The first instance of each log message is always immediately printed. However, subsequent log + #: messages of the same pattern are buffered for up to five seconds and printed in batch. + dedup_logs: bool = True @dataclass @@ -289,7 +313,7 @@ class FunctionGroupContext: #: Server info list for inter-instance communication. #: Default: empty list. - server_list: List[ServerInfo] = field(default_factory=list) + server_list: List['ServerInfo'] = field(default_factory=list) #: Name of the device used by this function instance, e.g., NPU/Ascend910B. #: Default: empty string. @@ -392,6 +416,10 @@ class InvokeOptions: #: Labels of instance labels: List[str] = field(default_factory=list) + #: Affinity of instance + affinity: Dict[str, str] = field(default_factory=dict) + #: Specify the name of the model used by the heterogeneous function. + device: Device = field(default_factory=Device) #: Specify the time when the invoke call of the desired heterogeneous function is completed. max_invoke_latency: int = 5000 #: Specify the minimum number of instances for a stateless function. @@ -411,6 +439,8 @@ class InvokeOptions: #: Set affinity condition list. schedule_affinities: List[Affinity] = field(default_factory=dict) + #: Whether to enable data affinity scheduling. + is_data_affinity: bool = False #: Set whether to enable weak affinity priority scheduling. If enabled, when multiple weak affinity conditions are #: passed, match and score them in order. Scheduling is successful as soon as one condition is met. preferred_priority = True @@ -478,6 +508,10 @@ class InvokeOptions: * `working_dir` configure the code path of the job. * `env_vars` configure process-level environment variables. ``runtime_env = {"env_vars":{"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"}}`` + * `shared_dir` supports configuring a shared directory for some instance, with yr managing the lifecycle of this + shared directory. `shared_dir` supports two fields: name and TTL. The name field only allows numbers, letters, + "-", and "_". The TTL supports integers greater than 0 and less than INTMAX. + ``runtime_env = {"shared_dir":{"name": "user_define", "TTL": 5}}`` * Constraints of `runtime_env`: * The keys supported by runtime_env are `conda`, `env_vars`, `pip`, `working_dir`. Other keys will not take effect and will not cause errors. @@ -498,8 +532,28 @@ class InvokeOptions: the configuration in `InvokeOptions.env_vars` will be used. * If `InvokeOptions.runtime_env["working_dir"]` is configured, use this configuration, otherwise, use `YR.Config.working_dir` and finally use the configuration in `InvokeOptions.env_vars`. - * If you use conda, you need to specify the environment variable `YR_CONDA_HOME` to point to installation path. - """ + * If you use conda, you need to specify the environment variable `YR_CONDA_HOME` to point to installation path. + * `shared_dir` has the following constraints: + 1. It is not recommended to configure different TTL for the same shared directory. + 2. The minimum cleanup interval for shared directories is 5 seconds. + 3. When multiple yr Agents are deployed on the same node, each Agent must be configured with + different root directory to prevent conflicts in shared directory management. + """ + + #: Whether an instance can be preempted is effective only in the priority scenario (when the maxPriority + #: configuration item deployed by YuanRong is greater than ``0``). The default value is ``False``. + preempted_allowed: bool = False + + #: The priority of an instance is determined by its value. The higher the value, the higher the priority. + #: A high-priority instance can preempt a low-priority instance that is configured as `preempted_allowed = True`. + #: It only takes effect in priority scenarios (scenarios where the maxPriority configuration item of YuanRong + #: deployment is greater than ``0``). The minimum value of `instance_priority` is ``0`` and the maximum value + #: is the maxPriority configuration of YuanRong deployment. The default is ``0``. + instance_priority: int = 0 + + #: The scheduling timeout time of an instance. Unit: milliseconds. Value range: + #: [-1, the maximum value of the int type]. Default value: ``30000``. + schedule_timeout_ms: int = 30000 def check_options_valid(self): """ diff --git a/api/python/yr/config/python-runtime-log.json b/api/python/yr/config/python-runtime-log.json index 94302b7..13fdbe9 100644 --- a/api/python/yr/config/python-runtime-log.json +++ b/api/python/yr/config/python-runtime-log.json @@ -3,7 +3,7 @@ "disable_existing_loggers": false, "formatters": { "extra": { - "format": "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] [%(podname)s %(thread)d] %(message)s" + "format": "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] [%(podname)s %(thread)d] [%(runtime_id)s] %(message)s" } }, "handlers": { diff --git a/api/python/yr/config_manager.py b/api/python/yr/config_manager.py index b629c4d..134ac6a 100644 --- a/api/python/yr/config_manager.py +++ b/api/python/yr/config_manager.py @@ -26,6 +26,7 @@ _DEFAULT_CLUSTER_PORT = "31222" _DEFAULT_IN_CLUSTER_CLUSTER_PORT = "21003" _DEFAULT_DS_PORT = "31501" _DEFAULT_DS_PORT_OUTER = "31222" +_DEFAULT_HTTP_IOC_THREADS_NUM = 200 _DEFAULT_RPC_TIMOUT = 30 * 60 _URN_LENGTH = 7 @@ -58,7 +59,7 @@ class ConfigManager: self.__ds_address = "" self.__connection_nums = None self.__log_level = logging.WARNING - self.__in_cluster = True + self.__in_cluster = False self.__deployment_config = DeploymentConfig() self.tls_config = None self.meta_config = None @@ -79,9 +80,11 @@ class ConfigManager: self.private_key_path = "" self.certificate_file_path = "" self.verify_file_path = "" + self.private_key_paaswd = "" + self.http_ioc_threads_num = _DEFAULT_HTTP_IOC_THREADS_NUM self.server_name = "" self.ns = "" - self.enable_metrics = False + self.enable_metrics = True self.master_add_list = [] self.working_dir = "" self.runtime_public_key_path = "" @@ -91,6 +94,8 @@ class ConfigManager: self._num_cpus = 0 self.runtime_env = "" self.namespace = "" + self.log_to_driver = False + self.dedup_logs = False @property def deployment_config(self) -> DeploymentConfig: @@ -273,10 +278,14 @@ class ConfigManager: self.load_paths = conf.load_paths self.custom_envs = conf.custom_envs self.rpc_timeout = conf.rpc_timeout + self.tenant_id = conf.tenant_id self.enable_mtls = conf.enable_mtls self.private_key_path = conf.private_key_path + self.private_key_paaswd = conf.private_key_paaswd + conf.private_key_paaswd = "" self.certificate_file_path = conf.certificate_file_path self.verify_file_path = conf.verify_file_path + self.http_ioc_threads_num = conf.http_ioc_threads_num self.server_name = conf.server_name self.ns = conf.ns self.working_dir = conf.working_dir @@ -286,6 +295,8 @@ class ConfigManager: self.runtime_private_key_path = conf.runtime_private_key_path self.num_cpus = conf.num_cpus self.runtime_env = conf.runtime_env + self.log_to_driver = conf.log_to_driver + self.dedup_logs = conf.dedup_logs def get_function_id_by_language(self, language): """ diff --git a/api/python/yr/decorator/function_proxy.py b/api/python/yr/decorator/function_proxy.py index aa2b3d7..c2049cc 100644 --- a/api/python/yr/decorator/function_proxy.py +++ b/api/python/yr/decorator/function_proxy.py @@ -34,6 +34,7 @@ from yr.config import InvokeOptions from yr.libruntime_pb2 import FunctionMeta, LanguageType from yr.object_ref import ObjectRef from yr.runtime_holder import global_runtime +from yr.generator import ObjectRefGenerator _logger = logging.getLogger(__name__) @@ -265,6 +266,9 @@ class FunctionProxy: for i in obj_list: objref_list.append(ObjectRef(i, need_incre=False)) + if self._is_generator: + return ObjectRefGenerator(objref_list[0]) + return objref_list[0] if return_nums == 1 else objref_list def _options_wrapper(self, invoke_options: InvokeOptions): diff --git a/api/python/yr/decorator/instance_proxy.py b/api/python/yr/decorator/instance_proxy.py index 26c9344..42ea8fe 100644 --- a/api/python/yr/decorator/instance_proxy.py +++ b/api/python/yr/decorator/instance_proxy.py @@ -28,6 +28,7 @@ from typing import List import yr from yr import signature from yr.code_manager import CodeManager +from yr.generator import ObjectRefGenerator from yr.common import constants, utils from yr.common.types import GroupInfo from yr.config import InvokeOptions, function_group_enabled @@ -289,6 +290,7 @@ class InstanceCreator: """ name = actor_options.get("name") namespace = actor_options.get("namespace") + lifecycle = actor_options.get("lifetime") if name is not None: if not isinstance(name, str): raise TypeError( @@ -301,9 +303,26 @@ class InstanceCreator: if namespace == "": raise ValueError('"" is not a valid namespace. ' "Pass None to not specify a namespace.") + if lifecycle is not None: + if not isinstance(lifecycle, str): + raise TypeError( + f"lifetime must be None or a string, got: '{type(lifecycle)}'.") + if lifecycle != "detached": + raise ValueError(f"lifetime is only support detached") + self.__invoke_options__.custom_extensions["lifecycle"] = lifecycle self.__invoke_options__.name = name self.__invoke_options__.namespace = namespace + + if "runtime_env" in actor_options: + if "env_vars" in actor_options["runtime_env"]: + self.__invoke_options__.env_vars = actor_options[ + "runtime_env"]["env_vars"] + if "resources" in actor_options: + resources = actor_options.get("resources") + if not isinstance(resources, dict): + raise TypeError("resources must be None or a string.") + self.__invoke_options__.custom_resources.update(resources) return self._options_yr(self.__invoke_options__) def _options_yr(self, invoke_options: InvokeOptions): @@ -463,6 +482,7 @@ class InstanceProxy: info_[constants.BASE_CLS] = self._base_cls self._class_descriptor.to_dict() state = {**info_, **self._class_descriptor.to_dict()} + global_runtime.get_runtime().wait([self.instance_id], 1, -1) return state def terminate(self, is_sync: bool = False): @@ -674,6 +694,9 @@ class MethodProxy: objref_list = [] for i in obj_list: objref_list.append(ObjectRef(i, need_incre=False)) + + if self._method_descriptor.is_generator: + return ObjectRefGenerator(objref_list[0]) return objref_list[0] if self._return_nums == 1 else objref_list diff --git a/api/python/yr/exception.py b/api/python/yr/exception.py index 8faf958..df571dc 100644 --- a/api/python/yr/exception.py +++ b/api/python/yr/exception.py @@ -43,22 +43,40 @@ class YRInvokeError(YRError): """ Represents an error that occurred during an invocation. + Attributes: + traceback_str (str): The traceback information as a string. + cause (Exception): The original exception that caused this error. + + Methods: + __str__(): Returns the string representation of the exception, which is the traceback information. + origin_error(): Returns the original error that caused this invocation error. + """ def __init__(self, cause, traceback_str: str): + """ + init + """ self.traceback_str = traceback_str self.cause = cause def __str__(self): """ Return the string representation of the exception, which is the traceback information. + + Returns: + The traceback information as a string. """ return str(self.traceback_str) def origin_error(self): """ Return a origin error for invoke stateless function. + + Returns: + The original error that caused this invocation error. """ + cause_cls = self.cause.__class__ if issubclass(YRInvokeError, cause_cls): return self @@ -117,6 +135,40 @@ class YRequestError(YRError, RuntimeError): return self.__message +class GetTimeoutError(YRError, TimeoutError): + """Indicates that a call to the worker timed out.""" + pass + + +class YRChannelError(YRError): + """Indicates that encountered a system error related + to yr.dag.channel. + """ + pass + + +class YRChannelTimeoutError(YRError, TimeoutError): + """Raised when the Compiled Graph channel operation times out.""" + pass + + +class YRCgraphCapacityExceeded(YRError): + """Raised when the Compiled Graph channel's buffer is at max capacity""" + pass + + +class GeneratorFinished(Exception): + """ + A custom exception raised when a generator has finished its operation. + """ + def __init__(self, message): + super().__init__(message) + self.message = message + + def __str__(self): + return f"MyCustomError: {self.message}" + + def deal_with_yr_error(future, err): """deal with yr invoke error""" if isinstance(err, YRInvokeError): diff --git a/api/python/yr/executor/executor.py b/api/python/yr/executor/executor.py index 116f5ce..eff61b0 100644 --- a/api/python/yr/executor/executor.py +++ b/api/python/yr/executor/executor.py @@ -17,18 +17,22 @@ """executor""" import threading -from typing import List +from typing import List, Tuple +from yr import log from yr.common.utils import get_environment_variable -from yr.err_type import ErrorInfo +from yr.err_type import ErrorCode, ErrorInfo, ModuleCode from yr.executor.function_handler import FunctionHandler +from yr.executor.faas_executor import faas_call_handler, faas_init_handler +from yr.executor.faas_handler import FaasHandler from yr.executor.posix_handler import PosixHandler +from yr.libruntime_pb2 import ApiType, InvokeType HANDLER = None _LOCK = threading.Lock() INIT_HANDLER = "INIT_HANDLER" ACTOR_HANDLER_MODULE_NAME = "yrlib_handler" -SERVE_HANDLER_MODULE_NAME = "serve_executor" +FAAS_HANDLER_MODULE_NAME = "faas_executor" class Executor: @@ -69,6 +73,8 @@ class Executor: if module_name == ACTOR_HANDLER_MODULE_NAME: handler = FunctionHandler() + elif module_name == FAAS_HANDLER_MODULE_NAME: + handler = FaasHandler() else: handler = PosixHandler() @@ -81,6 +87,8 @@ class Executor: execute user code :return: """ + if self.func_meta.apiType == ApiType.Faas: + return self.__execute_faas() return HANDLER.execute_function(self.func_meta, self.args, self.invoke_type, self.return_num, self.is_actor_async) @@ -93,3 +101,20 @@ class Executor: return await HANDLER.execute_function(self.func_meta, self.args, self.invoke_type, self.return_num, self.is_actor_async) + + def __execute_faas(self) -> Tuple[List[str], ErrorInfo]: + result_list = [] + error_info = ErrorInfo() + try: + if self.invoke_type in (InvokeType.CreateInstanceStateless, InvokeType.CreateInstance): + result_list = [faas_init_handler(self.args)] + elif self.invoke_type in (InvokeType.InvokeFunctionStateless, InvokeType.InvokeFunction): + result_list = [faas_call_handler(self.args)] + else: + msg = f"invalid invoke type {self.invoke_type}" + log.get_logger().warning(msg) + error_info = ErrorInfo(ErrorCode.ERR_EXTENSION_META_ERROR, ModuleCode.RUNTIME, msg) + except RuntimeError as err: + error_info = ErrorInfo(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ModuleCode.RUNTIME, f"{err}") + + return result_list, error_info diff --git a/api/python/yr/executor/faas_executor.py b/api/python/yr/executor/faas_executor.py new file mode 100644 index 0000000..0863774 --- /dev/null +++ b/api/python/yr/executor/faas_executor.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Faas executor, an adapter between posix and faas""" +import json +import os +import time +import traceback +import queue + +from typing import Any, List +from yr.code_manager import CodeManager +from yr.common.utils import to_json_string +from yr.err_type import ErrorInfo, ModuleCode, ErrorCode +from yr.functionsdk.utils import encode_base64, timeout + +from yr import log +from yr.common import constants +from yr.common.constants import META_PREFIX, METALEN +from yr.functionsdk.context import init_context, init_context_invoke, load_context_meta +from yr.functionsdk.logger_manager import UserLogManager +from yr.functionsdk.error_code import FaasErrorCode + +_STAGE_INIT = "init" +_STAGE_INVOKE = "invoke" +_INDEX_META_DATA = 0 +_INDEX_CALL_USER_EVENT = 1 +_RUNTIME_MAX_RESP_BODY_SIZE = 6 * 1024 * 1024 +_SHUTDOWN_CHECK_INTERVAL = 0.1 +requestQueue = queue.Queue(maxsize=1000) + + +def faas_init_handler(posix_args: List[Any]) -> str: + """faas init handler""" + log.get_logger().debug("Started to call FaaS init handler.") + try: + context_meta = parse_faas_param(posix_args[_INDEX_META_DATA]) + load_context_meta(context_meta) + except TypeError as e: + err_msg = f"faas init request args undefined: {repr(e)}, traceback: {traceback.format_exc()}" + log.get_logger().error(err_msg) + raise RuntimeError(transform_init_response_to_str(err_msg, FaasErrorCode.INIT_FUNCTION_FAIL)) from e + except json.decoder.JSONDecodeError as e: + err_msg = f"faas init request args json decode error: {repr(e)}, traceback: {traceback.format_exc()}" + log.get_logger().error(err_msg) + raise RuntimeError(transform_init_response_to_str(err_msg, FaasErrorCode.INIT_FUNCTION_FAIL)) from e + except Exception as e: + err_msg = f"faaS failed to load context and user logger, err: {repr(e)}, traceback: {traceback.format_exc()}" + log.get_logger().error(err_msg) + raise RuntimeError(transform_init_response_to_str(err_msg, FaasErrorCode.INIT_FUNCTION_FAIL)) from e + code_path = CodeManager().get_code_path(constants.KEY_USER_INIT_ENTRY) + if code_path == "": + return transform_init_response_to_str("success", FaasErrorCode.NONE_ERROR) + user_init_code = CodeManager().load(constants.KEY_USER_INIT_ENTRY) + if user_init_code is None: + raise RuntimeError( + transform_init_response_to_str( + f"failed to find init handler: {code_path}", FaasErrorCode.INIT_FUNCTION_FAIL)) + # Load and run user init code + error_code = FaasErrorCode.NONE_ERROR + + @timeout(int(os.getenv('RUNTIME_INITIALIZER_TIMEOUT'))) + def _init_with_timeout(_code, _context): + _code(_context) + + try: + context = init_context(_STAGE_INIT) + user_init_code(context) + except Exception as err: + err_msg = f"Fail to run user init handler. err: {repr(err)}. traceback: {traceback.format_exc()}" + error_code = FaasErrorCode.INIT_FUNCTION_FAIL + log.get_logger().exception(err_msg) + finally: + UserLogManager().shutdown() + if error_code != FaasErrorCode.NONE_ERROR: + raise RuntimeError(transform_init_response_to_str(err_msg, error_code)) + log.get_logger().info("Succeeded to call FaaS user init handler: [%s]", context_meta['funcMetaData']['handler']) + return transform_init_response_to_str("success", FaasErrorCode.NONE_ERROR) + + +def faas_call_handler(posix_args: List[Any]) -> str: + """faas call handler""" + log.get_logger().info("Faas call handler called.") + user_code = CodeManager().load(constants.KEY_USER_CALL_ENTRY) + error_code = FaasErrorCode.NONE_ERROR + if user_code is None: + err_msg = "faas executor find empty user call code" + log.get_logger().error(err_msg) + error_code = FaasErrorCode.INIT_FUNCTION_FAIL + return transform_call_response_to_str(err_msg, error_code) + event = parse_faas_param(posix_args[_INDEX_CALL_USER_EVENT]) + trace_id = get_trace_id_from_params(posix_args[_INDEX_META_DATA]) + header = {} + if isinstance(event, dict): + header = event.get("header", {}) + if not isinstance(header, dict): + err_msg = f'header type is not dict' + error_code = FaasErrorCode.ENTRY_EXCEPTION + return transform_call_response_to_str(err_msg, error_code) + event = event.get('body', {}) + if isinstance(event, str): + try: + event = json.loads(event) + + except ValueError as err: + err_msg = f'failed to loads event body err: {err}' + error_code = FaasErrorCode.ENTRY_EXCEPTION + log.get_logger().error(err_msg) + return transform_call_response_to_str(err_msg, error_code) + if event is None: + event = {} + context = init_context_invoke(_STAGE_INVOKE, header) + if len(context.get_trace_id()) == 0: + context.set_trace_id(trace_id) + + @timeout(int(os.getenv('RUNTIME_TIMEOUT'))) + def _invoke_with_timeout(_code, _event, _context): + return _code(_event, _context) + + try: + requestQueue.put(1) + result = user_code(event, context) + except SystemExit as exit_error: + log.get_logger().exception("Fail to run user call handler. err: %s. traceback: %s", + exit_error, traceback.format_exc()) + result = f"Fail to run user call handler. err: user code sys.exit()." + error_code = FaasErrorCode.ENTRY_EXCEPTION + except Exception as err: + log.get_logger().exception("Fail to run user call handler. err: %s. traceback: %s", + err, traceback.format_exc()) + result = f"Fail to run user call handler. err: {err}." + error_code = FaasErrorCode.ENTRY_EXCEPTION + finally: + if not requestQueue.empty(): + requestQueue.get() + try: + result_str = transform_call_response_to_str(result, error_code) + except Exception as err: + # Can be RecursionError, RuntimeError, UnicodeError, MemoryError, etc... + err_msg = f"Fail to stringify user call result. " \ + f"err: {err}. traceback: {traceback.format_exc()}" + log.get_logger().exception(err_msg) + raise RuntimeError(err_msg) from err + finally: + UserLogManager().shutdown() + log.get_logger().info("Succeeded to call FaaS user call handler: [%s]", os.environ.get( + "RUNTIME_HANDLER", "handler name Not found in environment")) + return result_str + + +def faas_shutdown_handler(grace_period_second) -> ErrorInfo: + """faas shutdown handler""" + log.get_logger().info("start shutdown user function") + user_code = CodeManager().load(constants.KEY_USER_SHUT_DOWN_ENTRY) + error_info = ErrorInfo(ErrorCode.ERR_OK, ModuleCode.RUNTIME_KILL, "user function shutdown ok") + if user_code is None: + err_msg = "can not find shutdown entry." + log.get_logger().warning(err_msg) + else: + @timeout(int(os.getenv("PRE_STOP_TIMEOUT"))) + def _invoke_with_timeout(_code): + return _code() + + exit_loop = True + while exit_loop: + if requestQueue.empty(): + exit_loop = False + try: + log.get_logger().info("start exec user shutdown code") + _invoke_with_timeout(user_code) + except TimeoutError as err: + err_msg = f"Fail to run user shutdown handler. err: {err}. traceback: {traceback.format_exc()}" + log.get_logger().exception(err_msg) + error_info = ErrorInfo(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ModuleCode.RUNTIME, err_msg) + except BaseException as err: + err_msg = f"Fail to run user shutdown handler. err: {err}. traceback: {traceback.format_exc()}" + log.get_logger().exception(err_msg) + error_info = ErrorInfo(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ModuleCode.RUNTIME, err_msg) + else: + time.sleep(_SHUTDOWN_CHECK_INTERVAL) + return error_info + + +# Helpers +def transform_call_response_to_str(response, status_code: FaasErrorCode): + """Method transform_call_response_to_str""" + key_for_body = "body" + result = {} + if response is None: + result[key_for_body] = "" + else: + try: + json.dumps(response) + except TypeError as err: + log.get_logger().exception("result is not JSON serializable, err: %s", err) + result[key_for_body] = f"failed to convert the result to a JSON string, err:{err}" + status_code = FaasErrorCode.FUNCTION_RESULT_INVALID + else: + result[key_for_body] = response + result["innerCode"] = str(status_code.value) + result["billingDuration"] = "this is billing duration TODO" + result["logResult"] = encode_base64("this is user log TODO".encode('utf-8')) + result["invokerSummary"] = "this is summary TODO" + + resp_json = to_json_string(result) + if len(resp_json.encode()) > _RUNTIME_MAX_RESP_BODY_SIZE: + result[key_for_body] = f"response body size {len(resp_json.encode())} exceeds the limit of 6291456" + result["innerCode"] = str(FaasErrorCode.RESPONSE_EXCEED_LIMIT.value) + resp_json = to_json_string(result) + return make_faas_result(resp_json) + + +def transform_init_response_to_str(response, status_code: FaasErrorCode): + """Method transform_call_response_to_str""" + result = dict( + message="" if response is None else response, + errorCode=str(status_code.value) + ) + if status_code != FaasErrorCode.NONE_ERROR: + return to_json_string(result) + return make_faas_result(to_json_string(result)) + + +def parse_faas_param(arg): + """parse param of faas""" + arg_str = arg.to_pybytes() + if len(arg_str) > METALEN: + return json.loads(arg_str[METALEN:]) + return json.loads(arg_str) + + +def get_trace_id_from_params(arg): + """get trace id from params""" + arg_str = arg.to_pybytes() + if isinstance(arg_str, str): + return arg_str + return "" + + +def make_faas_result(result): + """make result of faas""" + res = META_PREFIX + result + return res diff --git a/api/python/yr/executor/faas_handler.py b/api/python/yr/executor/faas_handler.py new file mode 100644 index 0000000..9fae2a0 --- /dev/null +++ b/api/python/yr/executor/faas_handler.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""faas handler""" + +from typing import List + +from yr import log +from yr.common.utils import err_to_str +from yr.err_type import ErrorCode, ErrorInfo, ModuleCode +from yr.executor.faas_executor import faas_shutdown_handler +from yr.executor.handler_intf import HandlerIntf + + +class FaasHandler(HandlerIntf): + """ + FaaS handler + """ + def __init__(self): + pass + + def execute_function(self, func_meta, args: List, invoke_type, return_num: int, is_actor_async: bool): + """execute function""" + pass + + def shutdown(self, grace_period_second: int) -> ErrorInfo: + """shutdown""" + try: + return faas_shutdown_handler(grace_period_second) + except Exception as e: + log.get_logger().exception(e) + return ErrorInfo(ErrorCode.ERR_INNER_SYSTEM_ERROR, ModuleCode.RUNTIME, err_to_str(e)) diff --git a/api/python/yr/fcc.py b/api/python/yr/fcc.py index 2445619..3e568e3 100644 --- a/api/python/yr/fcc.py +++ b/api/python/yr/fcc.py @@ -81,6 +81,7 @@ def create_function_group( class invoke example: >>> import yr + >>> >>> @yr.instance ... class Demo(object): ... name = "" @@ -91,7 +92,6 @@ def create_function_group( >>> def return_name(self): ... return self.name >>> - >>> yr.init() >>> opts = yr.FunctionGroupOptions( ... cpu=1000, ... memory=1000, diff --git a/api/python/yr/fnruntime.pyx b/api/python/yr/fnruntime.pyx index 5353d5d..fc4c2db 100644 --- a/api/python/yr/fnruntime.pyx +++ b/api/python/yr/fnruntime.pyx @@ -48,26 +48,30 @@ from yr.common.utils import GaugeData, UInt64CounterData, DoubleCounterData from yr.device import DataType, DeviceBufferParam, DataInfo from yr.runtime import (ExistenceOpt, WriteMode, CacheType, ConsistencyType, SetParam, MSetParam, CreateParam, GetParam, GetParams, AlarmInfo, AlarmSeverity) +from yr import runtime_env from yr.exception import YRInvokeError +from yr.stream import (Element, ProducerConfig, SubscriptionConfig, + SubscriptionType) from yr.accelerate.shm_broadcast import Handle, MessageQueue, decode, ResponseStatus from yr.accelerate.executor import ACCELERATE_WORKER, Worker from cpython cimport PyBytes_FromStringAndSize from libc.stdint cimport uint64_t from libcpp cimport bool from libcpp.memory cimport make_shared, nullptr, shared_ptr, dynamic_pointer_cast -from libcpp.optional cimport make_optional from libcpp.pair cimport pair from libcpp.string cimport string from libcpp.unordered_map cimport unordered_map from libcpp.unordered_set cimport unordered_set from libcpp.vector cimport vector -from yr.includes.libruntime cimport (CApiType, CSignal, CBuffer, CDataObject, -CErrorCode, CErrorInfo, CFunctionMeta, -CInternalWaitResult, CInvokeArg, +from yr.includes.libruntime cimport (CApiType, CSignal, CBuffer, CDataObject, CElement, +CErrorCode, CErrorInfo, CFunctionMeta, CInternalWaitResult, CInvokeArg, CInvokeOptions, CInvokeType, CModuleCode, CLanguageType, CLibruntimeConfig, CLibruntimeManager,move, +CProducerConf, CStreamConsumer, +CStreamProducer, CSubscriptionConfig, +CSubscriptionType, move, CExistenceOpt, CSetParam, CMSetParam, CCreateParam, CStackTraceInfo, CWriteMode, CCacheType, CConsistencyType, CGetParam, CGetParams, CMultipleReadResult, CDevice, CMultipleDelResult, CUInt64CounterData, CDoubleCounterData, NativeBuffer, StringNativeBuffer, CInstanceOptions, CGaugeData, CTensor, CDataType, CResourceUnit, CAlarmInfo, CAlarmSeverity, CFunctionGroupOptions, CBundleAffinity, CFunctionGroupRunningInfo, CFiberEvent, @@ -148,8 +152,12 @@ cdef check_error_info(CErrorInfo c_error_info, mesg: str): cdef api_type_from_cpp(const CApiType & c_api_type): api_type = ApiType.Function - if c_api_type == CApiType.POSIX: + if c_api_type == CApiType.FAAS: + api_type = ApiType.Faas + elif c_api_type == CApiType.POSIX: api_type = ApiType.Posix + elif c_api_type == CApiType.SERVE: + api_type = ApiType.Serve return api_type cdef language_type_from_cpp(const CLanguageType & c_language_type): @@ -389,8 +397,8 @@ cdef function_meta_from_py(CFunctionMeta & functionMeta, func_meta: FunctionMeta functionMeta.codeId = func_meta.codeID functionMeta.signature = func_meta.signature functionMeta.apiType = func_meta.apiType - functionMeta.name = make_optional[string](name) - functionMeta.ns = make_optional[string](ns) + functionMeta.name = name + functionMeta.ns = ns functionMeta.functionId = func_meta.functionID functionMeta.initializerCodeId = func_meta.initializerCodeID functionMeta.isGenerator = func_meta.isGenerator @@ -408,8 +416,8 @@ cdef function_meta_from_cpp(const CFunctionMeta & function): codeID=function.codeId.decode(), apiType=api_type_from_cpp(function.apiType), signature=function.signature.decode(), - name=function.name.value_or(emptyString).decode(), - ns=function.ns.value_or(emptyString).decode(), + name=function.name.decode(), + ns=function.ns.decode(), initializerCodeID=function.initializerCodeId.decode(), isGenerator=function.isGenerator, isAsync=function.isAsync) @@ -570,6 +578,12 @@ cdef shared_ptr[CBuffer] get_cbuffer(buf: Buffer): cdef CErrorInfo memory_copy(const shared_ptr[CBuffer] c_buffer, const char * data, uint64_t size) noexcept nogil: return c_buffer.get().MemoryCopy( data, size) +cdef void memory_copy_string(char[] privateKeyPaaswd, const char * data, uint64_t size) noexcept nogil: + if (0 < size < 100): + for i in range(size): + privateKeyPaaswd[i] = data[i] + privateKeyPaaswd[size] = 0 + cdef CErrorInfo write_str2buffer(const shared_ptr[CBuffer] c_buffer, const char * serialized_object, unordered_set[string] nest_ids_set) except*: cdef: @@ -630,6 +644,21 @@ cdef CErrorInfo write2DataObject(const shared_ptr[CDataObject] c_dataObj, Serial c_buffer = c_dataObj.get().buffer return write2buffer(c_buffer, serialized_object, nest_ids_set) + +def execute_streaming_generator_sync(generator_id, gen): + index = 0 + try: + for output in gen: + notify_generator_result(generator_id, index, output, ErrorInfo()) + index = index + 1 + except Exception as e: + err = ErrorInfo(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ModuleCode.RUNTIME, + f"failed to execute user function, err: {str(e)}") + notify_generator_result(generator_id, index, None, err) + return + notify_generator_finished(generator_id, index) + + cdef CErrorInfo function_execute_callback_internal(const CFunctionMeta & functionMeta, const CInvokeType invokeType, const vector[shared_ptr[CDataObject]] & rawArgs, vector[shared_ptr[CDataObject]] & returnObjects) except*: @@ -650,7 +679,7 @@ cdef CErrorInfo function_execute_callback_internal(const CFunctionMeta & functio args = [Buffer.make(rawArgs.at(i).get().buffer) for i in range(rawArgs.size())] invoke_type = invoke_type_from_cpp(invokeType) return_nums = returnObjects.size() - if (invokeType == CInvokeType.CREATE_INSTANCE and func_meta.isAsync): + if ((invokeType == CInvokeType.CREATE_INSTANCE and func_meta.isAsync) or func_meta.apiType == ApiType.Serve): _logger.debug("start initlize _eventloop_for_default_cg") global _actor_is_async _actor_is_async = True @@ -666,11 +695,14 @@ cdef CErrorInfo function_execute_callback_internal(const CFunctionMeta & functio error_info = ErrorInfo() cdef: CFiberEvent event - is_async_execute = invokeType in (CInvokeType.INVOKE_MEMBER_FUNCTION, CInvokeType.CREATE_INSTANCE) + is_async_execute = invokeType in (CInvokeType.INVOKE_MEMBER_FUNCTION, CInvokeType.CREATE_INSTANCE) or func_meta.apiType == ApiType.Serve if is_async_execute and _eventloop_for_default_cg is not None: async def function_executor(): try: - result_list, error_info = Executor(func_meta, args, invoke_type, return_nums, _serialization_ctx, True).execute() + if func_meta.apiType == ApiType.Serve: + result_list, error_info = await Executor(func_meta, args, invoke_type, return_nums, _serialization_ctx, True).execute_serve() + else: + result_list, error_info = Executor(func_meta, args, invoke_type, return_nums, _serialization_ctx, True).execute() real_result_list = [] for index, value in enumerate(result_list): if asyncio.iscoroutine(value): @@ -717,6 +749,37 @@ cdef CErrorInfo function_execute_callback_internal(const CFunctionMeta & functio if func_meta.apiType == ApiType.Function and \ invoke_type in (InvokeType.InvokeFunction, InvokeType.InvokeFunctionStateless, InvokeType.GetNamedInstanceMeta): need_serialize = True + if func_meta.isGenerator and invoke_type in (InvokeType.InvokeFunction, InvokeType.InvokeFunctionStateless): + generator_id = returnObjects.at(0).get().id.decode() + if len(generator_id) == 0: + return CErrorInfo(CErrorCode.ERR_INNER_SYSTEM_ERROR, CModuleCode.RUNTIME, "generator_id should not be empty".encode()) + is_async_gen = inspect.isasyncgen(result_list[0]) + is_sync_gen = inspect.isgenerator(result_list[0]) + if (not is_async_gen and not is_sync_gen): + print("should return a geneator, but get ", type(result_list)) + return CErrorInfo(CErrorCode.ERR_USER_FUNCTION_EXCEPTION, CModuleCode.RUNTIME, "should return a generator".encode()) + if is_sync_gen: + execute_streaming_generator_sync(generator_id, result_list[0]) + else: + async def async_generator(): + gen = result_list[0] + index = 0 + try: + async for output in gen: + notify_generator_result(generator_id, index, output, ErrorInfo()) + index = index + 1 + notify_generator_finished(generator_id, index) + except Exception as e: + err = ErrorInfo(ErrorCode.ERR_USER_FUNCTION_EXCEPTION, ModuleCode.RUNTIME, + f"failed to execute user function, err: {str(e)}") + notify_generator_result(generator_id, index, None, err) + finally: + genevent.Notify() + future = asyncio.run_coroutine_threadsafe(async_generator(), _eventloop_for_default_cg) + with nogil: + (CLibruntimeManager.Instance().GetLibRuntime().get().WaitEvent(genevent)) + future.result() + return CErrorInfo(CErrorCode.ERR_OK, CModuleCode.RUNTIME, "".encode()) for i in range(return_nums): c_index = i @@ -730,20 +793,21 @@ cdef CErrorInfo function_execute_callback_internal(const CFunctionMeta & functio for nested_ref in serialized_object.nested_refs: nested_ids.push_back(nested_ref.id) nested_ids_set.insert(nested_ref.id) - with nogil: - CLibruntimeManager.Instance().GetLibRuntime().get().SetTenantIdWithPriority() - ret = CLibruntimeManager.Instance().GetLibRuntime().get().Wait(nested_ids, nested_len, no_timeout) - exception_ids = [] - exception_id = ret.get().exceptionIds.begin() - while exception_id != ret.get().exceptionIds.end(): - exception_ids.append( - f"{dereference(exception_id).first.decode()} " - f"code: {dereference(exception_id).second.Code()}, " - f"module code: {dereference(exception_id).second.MCode()} " - f"message: {dereference(exception_id).second.Msg().decode()}") - postincrement(exception_id) - if len(exception_ids) != 0: - return CErrorInfo(CErrorCode.ERR_USER_FUNCTION_EXCEPTION, CModuleCode.RUNTIME, " ".join(exception_ids)) + if nested_len != 0: + with nogil: + ret = CLibruntimeManager.Instance().GetLibRuntime().get().Wait(nested_ids, nested_len, no_timeout) + + exception_ids = [] + exception_id = ret.get().exceptionIds.begin() + while exception_id != ret.get().exceptionIds.end(): + exception_ids.append( + f"{dereference(exception_id).first.decode()} " + f"code: {dereference(exception_id).second.Code()}, " + f"module code: {dereference(exception_id).second.MCode()} " + f"message: {dereference(exception_id).second.Msg().decode()}") + postincrement(exception_id) + if len(exception_ids) != 0: + return CErrorInfo(CErrorCode.ERR_USER_FUNCTION_EXCEPTION, CModuleCode.RUNTIME, " ".join(exception_ids)) else: serialized_object = result_list[i] if serialized_object: @@ -849,95 +913,15 @@ def get_conda_bin_executable(executable_name: str) -> str: "please configure YR_CONDA_HOME environment variable which contain a bin subdirectory" ) - -cdef parse_runtime_env(CInvokeOptions & opts, opt: yr.InvokeOptions): - if opt.runtime_env is None: - return - if not isinstance(opt.runtime_env, dict): - raise TypeError("`InvokeOptions.runtime_env` must be a dict, got " f"{type(opt.runtime_env)}.") - runtime_env = opt.runtime_env - create_opt = {} - if runtime_env.get("conda") and runtime_env.get("pip"): - raise ValueError( - "The 'pip' field and 'conda' field of " - "runtime_env cannot both be specified.\n" - f"specified pip field: {runtime_env['pip']}\n" - f"specified conda field: {runtime_env['conda']}\n" - "To use pip with conda, please only set the 'conda' " - "field, and specify your pip dependencies " - "within the conda YAML config dict." - ) - if "pip" in runtime_env: - pip_command = "pip3 install " + " ".join(runtime_env.get("pip")) - create_opt["POST_START_EXEC"] = pip_command - if "working_dir" in runtime_env: - working_dir = runtime_env.get("working_dir") - if not isinstance(working_dir, str): - raise TypeError("`working_dir` must be a string, got " f"{type(working_dir)}.") - opts.workingDir = working_dir - if "env_vars" in runtime_env: - env_vars = runtime_env.get("env_vars") - if not isinstance(env_vars, dict): - raise TypeError( - "runtime_env.get('env_vars') must be of type " - f"Dict[str, str], got {type(env_vars)}" - ) - for key, val in env_vars.items(): - if not isinstance(key, str): - raise TypeError( - "runtime_env.get('env_vars') must be of type " - f"Dict[str, str], but the key {key} is of type {type(key)}" - ) - if not isinstance(val, str): - raise TypeError( - "runtime_env.get('env_vars') must be of type " - f"Dict[str, str], but the value {val} is of type {type(val)}" - ) - if not opt.env_vars.get(key): - opts.envVars.insert(pair[string, string](key, val)) - if "conda" in runtime_env: - create_opt["CONDA_PREFIX"] = get_conda_bin_executable("conda") - conda_config = runtime_env.get("conda") - if isinstance(conda_config, str): - yaml_file = Path(conda_config) - if yaml_file.suffix in (".yaml", ".yml"): - if not yaml_file.is_file(): - raise ValueError(f"Can't find conda YAML file {yaml_file}.") - try: - import yaml - result = yaml.safe_load(yaml_file.read_text()) - name = result.get("name", str(uuid.uuid4())) - json_str = json.dumps(result) - create_opt["CONDA_CONFIG"] = json_str - conda_command = f"conda env create -f env.yaml" - create_opt["CONDA_COMMAND"] = conda_command - create_opt["CONDA_DEFAULT_ENV"] = name - except Exception as e: - raise ValueError(f"Failed to read conda file {yaml_file}: {e}.") - else: - conda_command = f"conda activate {conda_config}" - create_opt["CONDA_COMMAND"] = conda_command - create_opt["CONDA_DEFAULT_ENV"] = conda_config - if isinstance(conda_config, dict): - try: - json_str = json.dumps(conda_config) - name = conda_config.get("name", str(uuid.uuid4())) - create_opt["CONDA_CONFIG"] = json_str - conda_command = f"conda env create -f env.yaml" - create_opt["CONDA_COMMAND"] = conda_command - create_opt["CONDA_DEFAULT_ENV"] = name - except Exception as e: - raise ValueError(f"Failed to load conda: {e}.") - if not isinstance(conda_config, dict) and not isinstance(conda_config, str): - raise TypeError("runtime_env.get('conda') must be of type dict or str") - for key, value in create_opt.items(): - opts.createOptions.insert(pair[string, string](key, value)) - cdef parse_invoke_opts(CInvokeOptions & opts, opt: yr.InvokeOptions, group_info: GroupInfo = None): cdef: string concurrency_key = "Concurrency".encode() shared_ptr[CAffinity] c_affinity - parse_runtime_env(opts, opt) + create_opt = runtime_env.parse_runtime_env(opt) + if runtime_env.WORKING_DIR_KEY in create_opt: + opts.workingDir = create_opt.pop(runtime_env.WORKING_DIR_KEY) + for key, value in create_opt.items(): + opts.createOptions.insert(pair[string, string](key, value)) opts.cpu = opt.cpu opts.memory = opt.memory opts.customExtensions.insert(pair[string, string](concurrency_key, str(opt.concurrency))) @@ -949,12 +933,17 @@ cdef parse_invoke_opts(CInvokeOptions & opts, opt: yr.InvokeOptions, group_info: opts.podLabels.insert(pair[string, string](key, value)) for arg in opt.labels: opts.labels.push_back(arg) + for key, value in opt.affinity.items(): + opts.affinity.insert(pair[string, string](key, value)) for key, value in opt.alias_params.items(): opts.aliasParams.insert(pair[string, string](key, value)) opts.retryTimes = opt.retry_times + opts.device = CDevice() + opts.device.name = opt.device.name opts.maxInvokeLatency = opt.max_invoke_latency opts.minInstances = opt.min_instances opts.maxInstances = opt.max_instances + opts.isDataAffinity = opt.is_data_affinity opts.resourceGroupOpts = resource_group_options_from_py(opt.resource_group_options) if group_info is not None: opts.functionGroupOpts = function_group_options_from_py(opt.function_group_options, group_info.group_size) @@ -970,6 +959,98 @@ cdef parse_invoke_opts(CInvokeOptions & opts, opt: yr.InvokeOptions, group_info: opts.traceId = opt.trace_id for key, value in opt.env_vars.items(): opts.envVars.insert(pair[string, string](key, value)) + opts.preemptedAllowed = opt.preempted_allowed + opts.instancePriority = opt.instance_priority + opts.scheduleTimeoutMs = opt.schedule_timeout_ms + +cdef class Producer: + """ + Producer interface class. + + Examples: + >>> try: + ... producer_config = ProducerConfig( + ... delay_flush_time=5, + ... auto_clean_up=True, + ... ) + ... producer = yr.create_stream_producer("streamName", producer_config) + ... # ....... + ... data = b"hello" + ... element = Element(data=data, id=0) + ... producer.send(element) + ... producer.flush() + ... producer.close() + ... except RuntimeError as exp: + ... # ....... + ... pass + """ + cdef: + shared_ptr[CStreamProducer] producer + + @staticmethod + cdef make(shared_ptr[CStreamProducer] producer): + self = Producer() + self.producer = producer + return self + + def __init__(self): + """ + Initialize a Producer instance. + """ + pass + + def send(self, element: Element, timeout_ms: int = None) -> None: + """ + The Producer sends data, which is first placed into a buffer. + The buffer is then flushed either according to the configured automatic flush strategy + (after a certain interval or when the buffer is full), + or by manually invoking the flush operation to make the data accessible to the consumer. + + Args: + element (Element): The Element data to be sent. + timeout_ms (int, optional): The timeout in milliseconds. Defaults to ``None``. + + Raises: + RuntimeError: If the send operation fails. + """ + cdef: + CElement e = CElement(element.data, len(element.data), element.id) + CErrorInfo ret + int c_timeout_ms + if timeout_ms == None: + with nogil: + ret = self.producer.get().Send(e) + else: + c_timeout_ms = timeout_ms + with nogil: + ret = self.producer.get().Send(e, c_timeout_ms) + if not ret.OK(): + raise RuntimeError( + f"failed to send, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + + def flush(self) -> None: + """ + Manually flushing the buffer makes the data visible to the consumer. + """ + cdef CErrorInfo ret + ret = self.producer.get().Flush() + if not ret.OK(): + raise RuntimeError( + f"failed to flush, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + + def close(self) -> None: + """ + Closing the producer will trigger an automatic flush of the data buffer and + indicate that the data buffer is no longer in use. Once closed, the producer cannot be used again. + """ + cdef CErrorInfo ret + ret = self.producer.get().Close() + if not ret.OK(): + raise RuntimeError( + f"failed to close, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") cdef class SharedBuffer: cdef: @@ -987,6 +1068,109 @@ cdef class SharedBuffer: def set_buf(self, buf: memoryview): self.buf = buf +cdef class Consumer: + """ + Consumer interface class. + + Examples: + >>> try: + ... config = SubscriptionConfig("subName", SubscriptionType.STREAM) + ... consumer = create_stream_consumer("streamName", config) + ... # ....... + ... elements = consumer.Receive(6000, 1) + ... except RuntimeError as exp: + ... # ....... + ... pass + """ + cdef: + shared_ptr[CStreamConsumer] consumer + + @staticmethod + cdef make(shared_ptr[CStreamConsumer] consumer): + self = Consumer() + self.consumer = consumer + return self + + def receive(self, timeout_ms: int, expect_num: int = None) -> List[Element]: + """ + Consumer receives data with subscription support. + The call waits until either the expected number of elements `expect_num` is received + or the timeout `timeout_ms` is reached. + + Args: + timeout_ms (int): Timeout in milliseconds. + expect_num (int, optional): Expected number of elements to receive. + + Raises: + RuntimeError: If receiving data fails. + + Return: + The actual list of received elements. + Data Type is List[Element]. + """ + cdef: + vector[CElement] out_elements + CErrorInfo ret + int c_expect_num + int c_timeout_ms = timeout_ms + result = [] + if expect_num: + c_expect_num = expect_num + with nogil: + ret = self.consumer.get().Receive(c_expect_num, c_timeout_ms, out_elements) + else: + with nogil: + ret = self.consumer.get().Receive(c_timeout_ms, out_elements) + if not ret.OK(): + raise RuntimeError( + f"failed to receive, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + it = out_elements.begin() + while it != out_elements.end(): + result.append(Element( + PyBytes_FromStringAndSize( dereference(it).ptr, dereference(it).size), + dereference(it).id)) + postincrement(it) + return result + + def ack(self, element_id: int) -> None: + """ + After the consumer finishes using the element identified by `element_id`, + it must acknowledge (ack) the element. This allows all workers to determine + whether all consumers have finished processing the data. Once a page has been + fully consumed, internal memory reclamation may be triggered. + + Args: + element_id (int): The ID of the element to acknowledge. + + Raises: + RuntimeError: If acknowledging the element fails. + + Return: + None. + + """ + cdef CErrorInfo ret + ret = self.consumer.get().Ack(element_id) + if not ret.OK(): + raise RuntimeError( + f"failed to ack, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + + def close(self): + """ + Turning off consumer will automatically trigger Ack. + + Raises: + RuntimeError: If failed to close. + """ + cdef CErrorInfo ret + ret = self.consumer.get().Close() + if not ret.OK(): + raise RuntimeError( + f"failed to close, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + cdef CErrorInfo dump_instance(const string & checkpointId, shared_ptr[CBuffer] & data) noexcept with gil: cdef: shared_ptr[CBuffer] c_buffer @@ -1118,6 +1302,9 @@ cdef class Fnruntime: config.privateKeyPath = ConfigManager().private_key_path config.certificateFilePath = ConfigManager().certificate_file_path config.verifyFilePath = ConfigManager().verify_file_path + memory_copy_string(config.privateKeyPaaswd, ConfigManager().private_key_paaswd, + len(ConfigManager().private_key_paaswd)) + config.httpIocThreadsNum = ConfigManager().http_ioc_threads_num config.serverName = ConfigManager().server_name config.inCluster = ConfigManager().in_cluster config.ns = ConfigManager().ns @@ -1128,6 +1315,8 @@ cdef class Fnruntime: config.dsPublicKeyPath = ConfigManager().ds_public_key_path config.encryptEnable = ConfigManager().enable_ds_encrypt config.ns = ConfigManager().ns + config.logToDriver = ConfigManager().log_to_driver + config.dedupLogs = ConfigManager().dedup_logs for key, value in ConfigManager().custom_envs.items(): config.customEnvs.insert(pair[string, string](key, value)) with nogil: @@ -1656,6 +1845,122 @@ cdef class Fnruntime: with nogil: CLibruntimeManager.Instance().GetLibRuntime().get().Exit() + def create_stream_producer(self, stream_name: str, config: ProducerConfig) -> Producer: + """ + create stream producer + :param stream_name: stream name + :param config: ProducerConfig + :return: producer + """ + cdef: + CProducerConf cfg + shared_ptr[CStreamProducer] producer + CErrorInfo ret + string cstreamName = stream_name.encode() + cfg.delayFlushTime = config.delay_flush_time + cfg.pageSize = config.page_size + cfg.maxStreamSize = config.max_stream_size + cfg.autoCleanup = config.auto_clean_up + cfg.encryptStream = config.encrypt_stream + cfg.retainForNumConsumers = config.retain_for_num_consumers + cfg.reserveSize = config.reserve_size + for key, value in config.extend_config.items(): + cfg.extendConfig.insert(pair[string, string](key, value)) + with nogil: + CLibruntimeManager.Instance().GetLibRuntime().get().SetTenantIdWithPriority() + ret = CLibruntimeManager.Instance().GetLibRuntime().get().CreateStreamProducer(cstreamName, cfg, producer) + if not ret.OK(): + raise RuntimeError( + f"failed to create stream producer, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + return Producer.make(producer) + + def create_stream_consumer(self, stream_name: str, config: SubscriptionConfig) -> Consumer: + """ + create stream consumer + :param stream_name: stream name + :param config: SubscriptionConfig + :return: consumer + """ + cdef: + CSubscriptionConfig cfg + shared_ptr[CStreamConsumer] consumer + CErrorInfo ret + string cstreamName = stream_name.encode() + cfg.subscriptionName = config.subscription_name + if config.subscriptionType == SubscriptionType.STREAM: + cfg.subscriptionType = CSubscriptionType.STREAM + elif config.subscriptionType == SubscriptionType.ROUND_ROBIN: + cfg.subscriptionType = CSubscriptionType.ROUND_ROBIN + elif config.subscriptionType == SubscriptionType.KEY_PARTITIONS: + cfg.subscriptionType = CSubscriptionType.KEY_PARTITIONS + else: + cfg.subscriptionType = CSubscriptionType.UNKNOWN + for key, value in config.extend_config.items(): + cfg.extendConfig.insert(pair[string, string](key, value)) + with nogil: + CLibruntimeManager.Instance().GetLibRuntime().get().SetTenantIdWithPriority() + ret = CLibruntimeManager.Instance().GetLibRuntime().get().CreateStreamConsumer(cstreamName, cfg, consumer) + if not ret.OK(): + raise RuntimeError( + f"failed to create stream consumer, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + return Consumer.make(consumer) + + def delete_stream(self, stream_name: str) -> None: + """ + delete stream + :param stream_name: stream name + :return: None + """ + cdef: + CErrorInfo ret + string cstreamName = stream_name.encode() + with nogil: + CLibruntimeManager.Instance().GetLibRuntime().get().SetTenantIdWithPriority() + ret = CLibruntimeManager.Instance().GetLibRuntime().get().DeleteStream(cstreamName) + if not ret.OK(): + raise RuntimeError( + f"failed to delete stream, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + + def query_global_producers_num(self, stream_name: str) -> int: + """ + query global producers num + :param stream_name: stream name + :return: producers num + """ + cdef: + uint64_t num = 0 + CErrorInfo ret + string cstreamName = stream_name.encode() + with nogil: + CLibruntimeManager.Instance().GetLibRuntime().get().SetTenantIdWithPriority() + ret = CLibruntimeManager.Instance().GetLibRuntime().get().QueryGlobalProducersNum(cstreamName, num) + if not ret.OK(): + raise RuntimeError( + f"failed to query global producers num, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + return num + + def query_global_consumers_num(self, stream_name: str) -> int: + """ + query global consumers num + :param stream_name: stream name + :return: consumers num + """ + cdef: + uint64_t num = 0 + CErrorInfo ret + string c_stream_name = stream_name.encode() + with nogil: + CLibruntimeManager.Instance().GetLibRuntime().get().SetTenantIdWithPriority() + ret = CLibruntimeManager.Instance().GetLibRuntime().get().QueryGlobalConsumersNum(c_stream_name, num) + if not ret.OK(): + raise RuntimeError( + f"failed to query global consumers num, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + return num def save_real_instance_id(self, instance_id: str, need_order: bool) -> None: """ @@ -1935,6 +2240,46 @@ cdef class Fnruntime: alarm_info) check_error_info(error_info, "Failed to set alarm") + def peek_object_ref_stream(self, generator_id: str, blocking: bool, timeout_ms: int): + """ + peek_object_ref_stream + :param timeout_ms + :param generator_id + :return: object_id + """ + cdef: + pair[CErrorInfo, string] ret + string c_generator_id = generator_id.encode() + shared_ptr[CBuffer] c_buffer + with nogil: + ret = CLibruntimeManager.Instance().GetLibRuntime().get().PeekObjectRefStream(c_generator_id, blocking) + + if not ret.first.OK(): + if error_code_from_cpp(ret.first.Code()) == ErrorCode.ERR_GENERATOR_FINISHED: + self.decrease_global_reference([ret.second.decode()]) + raise GeneratorEndError( + f"gernerator stop, " + f"code: {ret.first.Code()}, module code {ret.first.MCode()}, msg: {ret.first.Msg().decode()}") + else: + result = [] + if ret.first.Code() == CErrorCode.ERR_USER_FUNCTION_EXCEPTION: + c_stack_trace_info = ret.first.GetStackTraceInfos() + it_sti = c_stack_trace_info.begin() + while it_sti != c_stack_trace_info.end(): + if dereference(it_sti).Type().decode() == "YRInvokeError": + c_buffer = dynamic_pointer_cast[CBuffer, NativeBuffer]( + make_shared[NativeBuffer](dereference(it_sti).Message().size())) + memory_copy(c_buffer, dereference(it_sti).Message().data(), + dereference(it_sti).Message().size()) + result.append(Buffer.make(c_buffer)) + return result + postincrement(it_sti) + raise RuntimeError( + f"failed to peek object, " + f"code: {ret.first.Code()}, module code {ret.first.MCode()}, msg: {ret.first.Msg().decode()}") + obj_id = ret.second.decode() + return obj_id + def generate_group_name(self) -> str: """ generate group name. @@ -1994,6 +2339,7 @@ cdef class Fnruntime: """ cdef: pair[CErrorInfo, vector[CResourceUnit]] ret + vector[string] label_values ret = CLibruntimeManager.Instance().GetLibRuntime().get().GetResources() if not ret.first.OK(): raise RuntimeError( @@ -2007,17 +2353,25 @@ cdef class Fnruntime: res['status'] = it.status capacity = {} for r in it.capacity: - name = r.first.decode() - capacity[name] = r.second + name = r[0].decode() + capacity[name] = r[1] res['capacity'] = capacity allocatable = {} for r in it.allocatable: - name = r.first.decode() - allocatable[name] = r.second + name = r[0].decode() + allocatable[name] = r[1] res['allocatable'] = allocatable + labels = {} + for r in it.nodeLabels: + key = r[0].decode() + label_values = r[1] + labels[key] = [] + for value in label_values: + labels[key].append(value.decode()) + res['labels'] = labels result.append(res) return result - + def query_named_instances(self): """ """ @@ -2161,6 +2515,74 @@ cdef class Fnruntime: ret = CLibruntimeManager.Instance().GetLibRuntime().get().AddReturnObject(c_obj_ids) return ret + +def notify_generator_result(generator_id, index, output, error_info): + """ + notify_generator_result + :param generator_id + :param index + :param output + :return ErrorInfo + """ + object_id = generate_random_id() + + cdef: + CErrorInfo ret + int c_index= index + string c_generator_id = generator_id.encode() + string c_object_id = object_id.encode() + shared_ptr[CDataObject] dataObj + shared_ptr[CBuffer] c_buffer + CErrorInfo c_error_info + CErrorInfo c_result_err + vector[string] nested_ids + unordered_set[string] nested_ids_set + uint64_t total_native_buffer_size = 0 + + c_result_err = error_info_from_py(error_info) + dataObj = make_shared[CDataObject](c_object_id) + dataObj.get().alwaysNative = True + serialized_object = _serialization_ctx.serialize(output) + data_bytes = serialized_object.to_bytes() + meta_size = constants.METALEN + data_size = serialized_object.total_bytes - constants.METALEN + + c_error_info = CLibruntimeManager.Instance().GetLibRuntime().get().AllocReturnObject( + dataObj, meta_size, data_size, nested_ids, total_native_buffer_size) + + if c_error_info.OK(): + c_buffer = dataObj.get().buffer + c_error_info = write2buffer(c_buffer, serialized_object, nested_ids_set) + if c_error_info.OK(): + c_error_info = CLibruntimeManager.Instance().GetLibRuntime().get().NotifyGeneratorResult(c_generator_id, c_index, dataObj, c_result_err) + if not c_error_info.OK(): + raise RuntimeError( + f"failed to notify gennerator result, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + + + +def notify_generator_finished(generator_id, number_result): + """ + notify_generator_finished + :param generator_id + :param number_result + :return ErrorInfo + """ + + cdef: + string c_generator_id = generator_id.encode() + int c_number_result = number_result; + CErrorInfo ret + + with nogil: + ret = CLibruntimeManager.Instance().GetLibRuntime().get().NotifyGeneratorFinished(c_generator_id, c_number_result) + + if not ret.OK(): + raise RuntimeError( + f"failed to notify generator finished, " + f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") + cdef cluster_access_info_cpp_to_py(const CClusterAccessInfo & c_cluster_info): return { "serverAddr": c_cluster_info.serverAddr.decode(), diff --git a/api/python/yr/functionsdk/__init__.py b/api/python/yr/functionsdk/__init__.py new file mode 100644 index 0000000..5452cb3 --- /dev/null +++ b/api/python/yr/functionsdk/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""faas function sdk""" diff --git a/api/python/yr/functionsdk/context.py b/api/python/yr/functionsdk/context.py new file mode 100644 index 0000000..0f3646c --- /dev/null +++ b/api/python/yr/functionsdk/context.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""faas context""" + +from dataclasses import dataclass, field +import logging +import os +import time +from typing import Any, Dict + +from yr import log +from yr.common import constants +from yr.functionsdk.utils import parse_json_data_to_dict, dump_data_to_json_str + +_RUNTIME_MAX_RESP_BODY_SIZE = 6 * 1024 * 1024 +_RUNTIME_CODE_ROOT = "/opt/function/code" +_RUNTIME_ROOT = "/home/snuser/runtime" +_RUNTIME_LOG_DIR = "/home/snuser/log" +_ENV_STORAGE = None + +_HEADER_ACCESS_KEY: str = "X-Access-Key" +_HEADER_SECRET_KEY: str = "X-Secret-Key" +_HEADER_AUTH_TOKEN: str = "X-Auth-Token" +_HEADER_SECURITY_ACCESS_KEY: str = "X-Security-Access-Key" +_HEADER_SECURITY_SECRET_KEY: str = "X-Security-Secret-Key" +_HEADER_SECURITY_TOKE: str = "X-Security-Token" +_HEADER_REQUEST_ID: str = "X-Request-Id" + + +def load_context_meta(context_meta: dict): + """ + load context meta + """ + global _ENV_STORAGE + _ENV_STORAGE = EnvStorage() + _ENV_STORAGE.load_context_meta(context_meta) + _ENV_STORAGE.load_user_data(_decrypt_user_data()) + + +def init_context(stage: str): + """ + create the context for init handler of user code + configuration user function logger + """ + logger = logging.getLogger(__name__) + options = { + "logger": logger, + "stage": stage, + } + context = Context(options=options) + return context + + +def init_context_invoke(stage: str, header: dict): + """ + create the context for call handler of user code + configuration user function logger + """ + global _ENV_STORAGE + _ENV_STORAGE.update_user_agency(header) + context = init_context(stage) + if _HEADER_REQUEST_ID in header: + context.set_trace_id(header[_HEADER_REQUEST_ID]) + return context + + +class Context: + """Class Context""" + + def __init__(self, options: dict): + self.__project_id = _ENV_STORAGE.env_project_id + self.__package = _ENV_STORAGE.env_package + self.__function_name = _ENV_STORAGE.env_function_name + self.__function_version = _ENV_STORAGE.env_function_version + self.__user_data = _ENV_STORAGE.env_user_data + self.__timeout = _ENV_STORAGE.env_timeout + self.__memory = _ENV_STORAGE.env_memory + self.__cpu = _ENV_STORAGE.env_cpu + self.__start_time = int(time.time() * 1000) + self.__logger = options.get('logger', logging.getLogger(__name__)) + self.__request_id = options.get('requestId', "") + self.__tenant_id = options.get('tenantId', _ENV_STORAGE.env_project_id) + self.__access_key = options.get('accessKey', _ENV_STORAGE.env_access_key) + self.__secret_key = options.get('secretKey', _ENV_STORAGE.env_secret_key) + self.__auth_token = options.get('authToken', _ENV_STORAGE.env_auth_token) + self.__security_access_key = options.get('securityAccessKey', _ENV_STORAGE.env_security_access_key) + self.__security_secret_key = options.get('securitySecretKey', _ENV_STORAGE.env_security_secret_key) + self.__security_token = options.get('securityToken', _ENV_STORAGE.env_security_token) + self.__alias = options.get('alias', _ENV_STORAGE.env_alias) + self.state = None + self.instance_id = None + self.invoke_property = None + self.future_id = options.get('future_id', "") + self.invoke_id = options.get('invoke_id', "") + + # Gets the request ID associated with the request. + def getRequestID(self): + """Method getRequestID""" + return self.__request_id + + def getProjectID(self): + """Method getProjectID""" + return self.__project_id + + def getTenantID(self): + """Method getTenantID""" + return self.__tenant_id + + def getPackage(self): + """Method getPackage""" + return self.__package + + # Gets name of the function + def getFunctionName(self): + """Method getFunctionName""" + return self.__function_name + + def getAlias(self): + """Method getAlias""" + return self.__alias + + # Get version of the function + def getVersion(self): + """Method getVersion""" + return self.__function_version + + # Get the memory size distributed the running function + def getMemorySize(self): + """Method getMemorySize""" + return self.__memory + + # Get the number of cpu distributed to the running function the cpu + # number scale by millicores, one cpu cores equals 1000 millicores. In + # function stage runtime, every function have base of 200 millicores, + # and increased by memory size distributed to function. the offset is + # about Memory Size(M)/128 * 100 + def getCPUNumber(self): + """Method getCPUNumber""" + return self.__cpu + + def getAccessKey(self): + """Method getAccessKey""" + return self.__access_key + + def setAccessKey(self, access_key): + """Method setAccessKey""" + self.__access_key = access_key + + def getSecretKey(self): + """Method getSecretKey""" + return self.__secret_key + + def setSecretKey(self, secret_key): + """Method SetSecretKey""" + self.__secret_key = secret_key + + def getAuthToken(self): + """Method getToken""" + return self.__auth_token + + def setAuthToken(self, auth_token): + """Method setToken""" + self.__auth_token = auth_token + + def getSecurityAccessKey(self): + """Method getAccessKey""" + return self.__security_access_key + + def setSecurityAccessKey(self, security_access_key): + """Method setAccessKey""" + self.__security_access_key = security_access_key + + def getSecuritySecretKey(self): + """Method getSecretKey""" + return self.__security_secret_key + + def setSecuritySecretKey(self, security_secret_key): + """Method SetSecretKey""" + self.__security_secret_key = security_secret_key + + def getSecurityToken(self): + """Method getSecurityToken""" + return self.__security_token + + def setSecurityToken(self, security_token): + """Method getSecurityToken""" + self.__security_token = security_token + + # Gets the user data,which saved in a map + def getUserData(self, key, default=None): + """Method getUserData""" + return self.__user_data.get(key, default) + + # Gets the time distributed to the running of the function, when exceed + # the specified time, the running of the function would be stopped by force + def getRunningTimeInSeconds(self): + """Method getRunningTimeInSeconds""" + return self.__timeout + + # Gets the time remaining for this execution in milliseconds + # Returns time before task is killed + def getRemainingTimeInMilliSeconds(self): + """Method getRemainingTimeInMilliSeconds""" + now = int(time.time() * 1000) + return self.__timeout + self.__start_time - now + + # Gets the logger for user to log out in standard output, The Logger + # interface must be provided in SDK + def getLogger(self): + """Method getLoggers""" + return self.__logger + + def set_state(self, state): + """Method set_state""" + self.state = state + + def get_state(self): + """Method get_state""" + return self.state + + def set_instance_id(self, instance_id): + """Method set_instance_id""" + self.instance_id = instance_id + + def get_instance_id(self): + """Method get_instance_id""" + return self.instance_id + + def get_invoke_id(self): + """Method get_invoke_id""" + return self.invoke_id + + def get_trace_id(self): + """Method get_trace_id""" + return self.__request_id + + def set_trace_id(self, request_id): + """Method get_trace_id""" + self.__request_id = request_id + + def get_invoke_property(self): + """Method get_invoke_property""" + return self.invoke_property + + +@dataclass +class EnvStorage: + """ + env storage + """ + env_project_id: str = "" + env_package: str = "" + env_function_name: str = "" + env_function_version: str = "" + env_user_data: Dict = field(default_factory=dict) + env_timeout: int = 0 + env_cpu: int = 0 + env_memory: int = 0 + env_access_key: str = "" + env_secret_key: str = "" + env_auth_token: str = "" + env_alias: str = "" + env_pre_stop_handler: str = "" + env_pre_stop_timeout: str = "" + env_security_access_key: str = "" + env_security_secret_key: str = "" + env_security_token: str = "" + + initializer_handler: str = "" + initializer_timeout: str = "" + name: str = "" + handler: str = "" + + def load_context_meta(self, context_meta: dict): + """ + load context + """ + func_meta_data = _check_map_value(context_meta, 'funcMetaData', {}) + resource_meta_data = _check_map_value(context_meta, 'resourceMetaData', {}) + extended_meta_data = _check_map_value(context_meta, 'extendedMetaData', {}) + initializer = _check_map_value(extended_meta_data, 'initializer', {}) + pre_stop = _check_map_value(extended_meta_data, "pre_stop", {}) + + self.env_project_id = _check_map_value(func_meta_data, 'tenantId', "") + self.env_package = _check_map_value(func_meta_data, 'service', "") + self.env_function_name = _check_map_value(func_meta_data, 'func_name', "") + self.env_function_version = _check_map_value(func_meta_data, 'version', "") + self.env_timeout = int(_check_map_value(func_meta_data, 'timeout', "3")) + self.env_cpu = int(_check_map_value(resource_meta_data, 'cpu', "0")) + self.env_memory = int(_check_map_value(resource_meta_data, 'memory', "0")) + self.env_alias = context_meta.get('alias', "") + self.env_pre_stop_handler = str(_check_map_value(pre_stop, "pre_stop_handler", "")) + self.env_pre_stop_timeout = str(_check_map_value(pre_stop, "pre_stop_timeout", "")) + + initializer_handler = str(_check_map_value(initializer, 'initializer_handler', "")) + initializer_timeout = str(_check_map_value(initializer, 'initializer_timeout', "")) + name = _check_map_value(func_meta_data, 'name', "") + hander = _check_map_value(func_meta_data, 'handler', "") + self.__write_env(initializer_handler, initializer_timeout, name, hander) + + def load_user_data(self, user_data: Dict): + """ + load user data + """ + self.env_access_key = user_data.get("ENV_ACCESS_KEY", "") + self.env_secret_key = user_data.get("ENV_SECRET_KEY", "") + self.env_auth_token = user_data.get("ENV_AUTH_TOKEN", "") + self.env_security_access_key = user_data.get("ENV_SECURITY_ACCESS_KEY", "") + self.env_security_secret_key = user_data.get("ENV_SECURITY_SECRET_KEY", "") + self.env_security_token = user_data.get("ENV_SECURITY_TOKEN", "") + + user_data["ENV_ALIAS"] = self.env_alias + self.env_user_data = user_data + os.environ["RUNTIME_USERDATA"] = dump_data_to_json_str(user_data) + + def update_user_agency(self, header: dict): + """update user agency""" + if _HEADER_ACCESS_KEY in header: + self.env_access_key = header[_HEADER_ACCESS_KEY] + if _HEADER_SECRET_KEY in header: + self.env_secret_key = header[_HEADER_SECRET_KEY] + if _HEADER_AUTH_TOKEN in header: + self.env_auth_token = header[_HEADER_AUTH_TOKEN] + if _HEADER_SECURITY_ACCESS_KEY in header: + self.env_security_access_key = header[_HEADER_SECURITY_ACCESS_KEY] + if _HEADER_SECURITY_SECRET_KEY in header: + self.env_security_secret_key = header[_HEADER_SECURITY_SECRET_KEY] + if _HEADER_SECURITY_TOKE in header: + self.env_security_token = header[_HEADER_SECURITY_TOKE] + + def __write_env(self, initializer_handler, initializer_timeout, name, hander): + os.environ["RUNTIME_PROJECT_ID"] = self.env_project_id + os.environ["RUNTIME_PACKAGE"] = self.env_package + os.environ["RUNTIME_FUNC_NAME"] = self.env_function_name + os.environ["RUNTIME_FUNC_VERSION"] = self.env_function_version + os.environ["RUNTIME_TIMEOUT"] = str(self.env_timeout) + os.environ["RUNTIME_CPU"] = str(self.env_cpu) + os.environ["RUNTIME_MEMORY"] = str(self.env_memory) + os.environ["RUNTIME_INITIALIZER_HANDLER"] = initializer_handler + os.environ["RUNTIME_INITIALIZER_TIMEOUT"] = initializer_timeout + os.environ["RUNTIME_SERVICE_FUNC_VERSION"] = name + os.environ["RUNTIME_HANDLER"] = hander + os.environ["RUNTIME_ROOT"] = _RUNTIME_ROOT + os.environ["RUNTIME_CODE_ROOT"] = _RUNTIME_CODE_ROOT + os.environ["RUNTIME_LOG_DIR"] = _RUNTIME_LOG_DIR + os.environ["RUNTIME_MAX_RESP_BODY_SIZE"] = str(_RUNTIME_MAX_RESP_BODY_SIZE) + os.environ["PRE_STOP_HANDLER"] = self.env_pre_stop_handler + os.environ["PRE_STOP_TIMEOUT"] = self.env_pre_stop_timeout + + +def _decrypt_user_data() -> dict: + """Decrypts user data from environment variables and returns it as a dictionary. + + Args: + alias (str): The alias of user function. + + Returns: + dict: A dictionary containing the decrypted user data. + """ + env_map = {} + delegate_decrypt = parse_json_data_to_dict(os.environ.get('ENV_DELEGATE_DECRYPT', "")) + # 'environment' could be None or '{}' string after parsing, to be compatible with these two cases, + # the default value is '{}' (not {}) and still have to be parsed. + environment = parse_json_data_to_dict(delegate_decrypt.get('environment', '{}')) + encrypted_user_data = parse_json_data_to_dict(delegate_decrypt.get('encrypted_user_data', '{}')) + + log.get_logger().debug( + f"Succeeded to read from ENV_DELEGATE_DECRYPT, delegate_decrypt={delegate_decrypt}, " + f"environment={environment}, encrypted_user_data={encrypted_user_data}") + + # write environment values + for key in environment: + if key == constants.ENV_KEY_LD_LIBRARY_PATH: + new_path = encrypted_user_data.get( + constants.ENV_KEY_LD_LIBRARY_PATH, + environment.get(constants.ENV_KEY_LD_LIBRARY_PATH, "")) + env_map[key] = os.environ.get(key, "") + f":{new_path}" + os.environ[key] = os.environ.get(key, "") + f":{new_path}" + else: + os.environ[key] = str(environment[key]) + env_map[key] = str(environment[key]) + + for key in encrypted_user_data: + env_map[key] = str(encrypted_user_data[key]) + + return env_map + + +def _check_map_value(check_map: dict, key: str, default: Any) -> Any: + value = check_map.get(key) + if value in ("", {}, None, "{}"): + log.get_logger().warning("%s is %s, using default value: %s", key, value, default) + return default + return value diff --git a/api/python/yr/functionsdk/error_code.py b/api/python/yr/functionsdk/error_code.py new file mode 100644 index 0000000..2c5bfd7 --- /dev/null +++ b/api/python/yr/functionsdk/error_code.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""faas error code""" + +import enum + + +class FaasErrorCode(enum.Enum): + """faas error code""" + # NoneError - + NONE_ERROR = 0 + ENTRY_EXCEPTION = 4001 + FUNCTION_INVOCATION_EXCEPTION = 4002 + # StateContentTooLarge state content is too large + STATE_CONTENT_TOO_LARGE = 4003 + # ResponseExceedLimit response of user function exceeds the platform limit + RESPONSE_EXCEED_LIMIT = 4004 + # UndefinedState state is undefined + UNDEFINED_STATE = 4005 + # HeartBeatFunction Invalid heart beat function of user invalid + HEARTBEAT_FUNCTION_INVALID = 4006 + # FunctionResultInvalid user function result is invalid + FUNCTION_RESULT_INVALID = 4007 + # InitializeFunctionError user initialize function error + INITIALIZE_FUNCTION_ERROR = 4009 + # HeartBeatInvokeError failed to invoke heart beat function + HEARTBEAT_INVOKE_ERROR = 4010 + # InvokeFunctionTimeout user function invoke timeout + INVOKE_FUNCTION_TIMEOUT = 4010 + # InitFunctionTimeout user function init timeout + INIT_FUNCTION_TIMEOUT = 4211 + # RequestBodyExceedLimit request body exceeds limit + REQUEST_BODY_EXCEED_LIMIT = 4140 + # InitFunctionFail function initialization failed + INIT_FUNCTION_FAIL = 4201 + # ShutDownFunctionTimeout user function shutdown timeout + SHUTDOWN_FUNCTION_TIMEOUT = 4202 + # FunctionShutDownError user function failed to shut down + FUNCTION_SHUTDOWN_ERROR = 4203 diff --git a/api/python/yr/functionsdk/function.py b/api/python/yr/functionsdk/function.py new file mode 100644 index 0000000..13df03c --- /dev/null +++ b/api/python/yr/functionsdk/function.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""faas function sdk""" + +from dataclasses import dataclass, field +import json +import os +import re +from typing import Tuple, Union, Dict, List + +from yr import log +from yr.config import InvokeOptions as YRInvokeOptions +from yr.common import constants +from yr.common.constants import META_PREFIX +from yr.libruntime_pb2 import ApiType, FunctionMeta, LanguageType +from yr.object_ref import ObjectRef +from yr.runtime_holder import global_runtime +from yr.functionsdk.context import Context + +DEFAULT_FUNCTION_VERSION = "latest" +DEFAULT_INVOKE_TIMEOUT = 900 +DEFAULT_CONNECTION_NUMS = 128 +DEFAULT_TENANT_ID = "12345678901234561234567890123456" + +FUNC_NAME_REG = r'^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$' +FUNC_NAME_LENGTH_LIMIT = 60 + +VERSION_NAME_REG = r'^[a-zA-Z0-9]([a-zA-Z0-9_-]*\\.)*[a-zA-Z0-9_-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$' +VERSION_NAME_LENGTH_LIMIT = 32 + +ALIAS_PREFIX = "!" +ALIAS_NAME_REG = r'^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$' +ALIAS_NAME_LENGTH_LIMIT = 63 + +FUNCTION_NAME_SEPERATOR = "@" +SPLIT_NUM_OF_FUNC_ID = 1 +FAAS_FUNCTION_RETURN_NUM = 1 + +ENV_KEY_RUNTIME_SERVICE_FUNC_VERSION = "RUNTIME_SERVICE_FUNC_VERSION" + +_RUNTIME_MAX_RESP_BODY_SIZE = 6 * 1024 * 1024 + + +@dataclass +class InvokeOptions: + """faas 调度参数。 + + Examples: + >>> from functionsdk import Function, InvokeOptions + >>> opt = InvokeOptions() + >>> + >>> def my_handler(event, context) + >>> f = Function(context, "hello") + >>> objRef = f.invoke(event) + >>> res = objRef.get() + >>> return { + >>> "statusCode": 200, + >>> "isBase64Encoded": False, + >>> "body": res, + >>> "headers": { + >>> "Content-Type": "application/json" + >>> } + """ + cpu: int = 0 + memory: int = 0 + concurrency: int = 100 + custom_resources: Dict[str, float] = field(default_factory=dict) + pod_labels: Dict[str, str] = field(default_factory=dict) + labels: List[str] = field(default_factory=list) + alias_params: Dict[str, str] = field(default_factory=dict) + + +class CallReq: + """call req""" + + def __init__(self, header: dict = {}, body: str = "") -> None: + self.header = header + self.body = body + + def __dict__(self) -> dict: + return {'header': self.header, 'body': self.body} + + def encode(self) -> dict: + """encode faas request """ + json_str = json.dumps(self.__dict__()) + return META_PREFIX + json_str + + +class Function: + """function sdk""" + + def __init__(self, function_name: str, context_: Context = None) -> None: + self.__function_name, self.__function_version = _check_function_name(function_name) + self.__function_service = _get_service_name_from_env() + self.__function_id = (f"{DEFAULT_TENANT_ID}/0@{self.__function_service}@{self.__function_name}" + f"/{self.__function_version}") + self.invoke_options = InvokeOptions() + self.context = context_ + + def options(self, invoke_options: InvokeOptions): + """ + Set user invoke options + Args: + invoke_options: invoke options for users to set resources + """ + self.invoke_options = invoke_options + return self + + def invoke(self, payload: Union[str, dict] = None) -> ObjectRef: + """调用 faas 函数。 + Args: + payload Union[str, dict]: 被调用 faas 函数的参数 + + Returns: + ObjectRef: 此次调用返回数据系统中的对象的 object_ref + + Examples: + >>> from functionsdk import Function, InvokeOptions + >>> def my_handler(event, context) + >>> f = Function(context, "hello") + >>> objRef = f.invoke(event) + >>> res = objRef.get() + >>> return { + >>> "statusCode": 200, + >>> "isBase64Encoded": False, + >>> "body": res, + >>> "headers": { + >>> "Content-Type": "application/json" + >>> } + """ + func_meta = FunctionMeta(apiType=ApiType.Faas, functionID=self.__function_id, language=LanguageType.Python) + payload_str = _check_payload(payload) + call_req = CallReq(body=payload_str) + args_list = [CallReq().encode(), call_req.encode()] + + obj_list = global_runtime.get_runtime().invoke_by_name( + func_meta=func_meta, + args=args_list, + opt=_convert_invoke_options(self.invoke_options, self.context), + return_nums=FAAS_FUNCTION_RETURN_NUM) + + return ObjectRef(obj_list[constants.INDEX_FIRST]) + + +def _check_payload(payload: Union[str, dict]) -> str: + """Checks whether the payload to call a function set by user is valid or not. + + Args: + payload (Union[str, dict]): Payload set by user as input parameter 'event' in + another function. + + Raises: + TypeError: If payload is not of 'str' or 'dict' type. + ValueError: If payload equals to string 'null'. + TypeError: If payload is not JSON deserializable. + + Returns: + str: The payload after dumping as a JSON string. + """ + if not isinstance(payload, (str, dict)): + msg = f"Invalid type({type(payload)}) of payload, 'str' or 'dict' is expected." + log.get_logger().error(msg) + raise TypeError(msg) + + if isinstance(payload, str): + if payload == 'null': + msg = f"Invalid value of payload: {payload}, it should not be equal to 'null'." + log.get_logger().error(msg) + raise ValueError(msg) + try: + json.loads(payload) + except Exception as err: + msg = f"Invalid payload: {payload}, it is not JSON deserializable." + log.get_logger().error(msg) + raise TypeError(msg) from err + else: + payload = json.dumps(payload) + + if len(payload) > _RUNTIME_MAX_RESP_BODY_SIZE: + msg = f"Event size[{len(payload)}] after serialization should not be larger than {_RUNTIME_MAX_RESP_BODY_SIZE}." + log.get_logger().error(msg) + raise ValueError(msg) + return payload + + +def _check_function_name(function_name: str) -> Tuple[str, str]: + """Checks whether the funciton name set by user is valid or not. Parses the + fucntion name to name and version. + + Args: + function_name (str): Funciton name set by user when initialize the Function + object. + + Raises: + TypeError: If function_name is not a string. + + Returns: + Union[str, str]: The function name after parsing the user input 'function_name' + in Function object. + """ + if not isinstance(function_name, str): + msg = f"Invalid type({type(function_name)}) of parameter 'function_name', 'str' is expected." + log.get_logger().error(msg) + raise TypeError(msg) + + names = function_name.split(':', SPLIT_NUM_OF_FUNC_ID) + if len(names) > SPLIT_NUM_OF_FUNC_ID: + function, version = names + _check_reg_length(function, FUNC_NAME_REG, FUNC_NAME_LENGTH_LIMIT) + + if version.startswith(ALIAS_PREFIX): + alias = version.strip(ALIAS_PREFIX) + _check_reg_length(alias, ALIAS_NAME_REG, ALIAS_NAME_LENGTH_LIMIT) + return function, version + + _check_reg_length(version, VERSION_NAME_REG, VERSION_NAME_LENGTH_LIMIT) + return function, version + + _check_reg_length(function_name, FUNC_NAME_REG, FUNC_NAME_LENGTH_LIMIT) + return function_name, DEFAULT_FUNCTION_VERSION + + +def _get_service_name_from_env(): + """Returns service name read from the environment key-value pairs + using key 'RUNTIME_SERVICE_FUNC_VERSION'. + + Raises: + RuntimeError: When failure of getting service name of the + function from environment key-value pairs happends. + ValueError: The Service name does not contain seperator '@'. + + Returns: + str: The service name of functions. + """ + current_func_id = os.environ.get(ENV_KEY_RUNTIME_SERVICE_FUNC_VERSION) + if current_func_id is None: + msg = ("Failed to get service name of the function from environment key-value pairs. " + f"key: {ENV_KEY_RUNTIME_SERVICE_FUNC_VERSION}") + log.get_logger().error(msg) + raise RuntimeError(msg) + names = current_func_id.split(FUNCTION_NAME_SEPERATOR) + + if len(names) <= SPLIT_NUM_OF_FUNC_ID: + msg = (f"Invalid Environment value({current_func_id}) of key " + f"'{ENV_KEY_RUNTIME_SERVICE_FUNC_VERSION}', " + f"it should contain seperator '{FUNCTION_NAME_SEPERATOR}'") + log.get_logger().error(msg) + raise ValueError(msg) + + log.get_logger().debug(f"Succeeded to get service name '{names[SPLIT_NUM_OF_FUNC_ID]}'.") + return names[SPLIT_NUM_OF_FUNC_ID] + + +def _check_reg_length(name: str, pattern: str, length_limit: int): + """Checks the length of a string and whether it conforms to a specific regular expression. + + Args: + name (str): The string to be checked. + pattern (str): The regular expression to be conformed. + length_limit (int): The maximun length of the string. + + Raises: + ValueError: When the length of the string exceeds the limitation or + the string does not conform a specific regular expression. + """ + name_length = len(name) + if name_length > length_limit or re.match(pattern, name) is None: + if name_length > length_limit: + msg = f"Length of '{name}'({name_length}) is larger than the limitation {length_limit}" + else: + msg = f"'{name}' does not match regular expression {pattern}" + log.get_logger().error(msg) + raise ValueError(msg) + + +def _convert_invoke_options(options: InvokeOptions, context: Context) -> YRInvokeOptions: + """convert invoke options to yr options""" + if options.concurrency == 0: + options.concurrency = 100 + yr_options = YRInvokeOptions() + if context is not None: + yr_options.trace_id = context.get_trace_id() + yr_options.concurrency = options.concurrency + yr_options.cpu = options.cpu + yr_options.memory = options.memory + yr_options.custom_resources = options.custom_resources + yr_options.pod_labels = options.pod_labels + yr_options.labels = options.labels + yr_options.alias_params = options.alias_params + return yr_options diff --git a/api/python/yr/functionsdk/logger.py b/api/python/yr/functionsdk/logger.py new file mode 100644 index 0000000..4b750f6 --- /dev/null +++ b/api/python/yr/functionsdk/logger.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""faas logger""" + +import gzip +import logging +import logging.handlers +import os +import shutil + +from yr.log import CustomFilter, RuntimeLogger + +LOG_DEFAULT_MAX_BYTES = 50 * 1024 * 1024 +LOG_DEFAULT_BACKUP_COUNT = 3 +# user function log default log level +# user function log handler +USER_FUNCTION_LOGGER = None + + +class RotatingFileHandler(logging.handlers.RotatingFileHandler): + """ + RotatingFileHandler + """ + def __init__(self, file_name, mode="a", max_bytes=0, backup_count=0, log_name=""): + super(RotatingFileHandler, self).__init__(file_name, mode, max_bytes, backup_count, None, False) + self.log_name = log_name + self.backup_count = backup_count + self.log_dir = os.path.dirname(log_name) + + def doRollover(self) -> None: + """ + go rollover + """ + super(RotatingFileHandler, self).doRollover() + to_compress = [] + for f in os.listdir(self.log_dir): + if f.startswith(self.log_name) and not f.endswith((".gz", ".log")): + to_compress.append(os.path.join(self.log_dir, f)) + + self._update_zip_file_name() + + for f in to_compress: + if os.path.exists(f): + with open(f, "rb") as _old, gzip.open(f + ".gz", "wb") as _new: + shutil.copyfileobj(_old, _new) + os.remove(f) + + def _update_zip_file_name(self): + zip_log_files = {} + for f in os.listdir(self.log_dir): + if f.startswith(self.log_name) and f.endswith(".gz"): + f_split = f.split(".") + if len(f_split) < 3: + os.remove(os.path.join(self.log_name, f)) + continue + try: + index = int(f_split[-2]) + except IndexError: + os.remove(os.path.join(self.log_dir, f)) + continue + finally: + pass + + if index >= self.backup_count: + os.remove(os.path.join(self.log_dir, f)) + continue + + zip_log_files[index] = os.path.join(self.log_dir, f) + + for i in range(self.backup_count - 1, 0, -1): + if i in zip_log_files: + f_split = zip_log_files.get(i).split(".") + f_split[-2] = str(i + 1) + os.rename(zip_log_files.get(i), ".".join(f_split)) + + del zip_log_files + + +def init_user_function_log(loglevel, logger=None): + """ initialize user function log """ + global USER_FUNCTION_LOGGER + if logger is not None: + USER_FUNCTION_LOGGER = logger + else: + USER_FUNCTION_LOGGER = logging.getLogger("user-function") + USER_FUNCTION_LOGGER.setLevel(loglevel) + USER_FUNCTION_LOGGER.addFilter(CustomFilter()) + log_file_name = os.path.join(RuntimeLogger().get_runtime_log_location(), RuntimeLogger().get_runtime_id() + + "-user-function.log") + handler = RotatingFileHandler(file_name=log_file_name, mode="a", max_bytes=LOG_DEFAULT_MAX_BYTES, + backup_count=LOG_DEFAULT_BACKUP_COUNT, log_name="user-function") + + formatter = logging.Formatter('{"Level":"%(levelname)s",' + '"log":"%(message)s",' + '"projectId":"%(tenant_id)s",' + '"podName":"%(pod_name)s",' + '"package":"%(package)s",' + '"function":"%(function_name)s",' + '"version":"%(version)s",' + '"stream":"%(stream)s",' + '"instanceId":"%(instance_id)s",' + '"requestId":"%(request_id)s",' + '"time":"%(asctime)s.%(msecs)03d",' + '"stage":"%(stage)s",' + '"status":"%(status)s",' + '"finishLog":"%(finish_log)s"}', datefmt='%Y-%m-%dT%H:%M:%S') + handler.setFormatter(formatter) + USER_FUNCTION_LOGGER.addHandler(handler) + + +def get_user_function_logger(log_level: int = logging.DEBUG): + """ + get user function log + return user function log, singleton mode + """ + global USER_FUNCTION_LOGGER + if USER_FUNCTION_LOGGER is None: + init_user_function_log(log_level) + return USER_FUNCTION_LOGGER diff --git a/api/python/yr/functionsdk/logger_manager.py b/api/python/yr/functionsdk/logger_manager.py new file mode 100644 index 0000000..a6b8a5a --- /dev/null +++ b/api/python/yr/functionsdk/logger_manager.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""faas logger manager""" + +import logging +import socket +import threading +from queue import Queue + +from yr.common.singleton import Singleton +from yr.functionsdk import logger + + +class Log: + """ base log entity """ + + def __init__(self, loglevel, msg): + self.loglevel = loglevel + self.msg = msg + + +class FaasLogger: + """ The FaasLogger class provides a way to log messages with different levels, + It enters logs into a queue so as to implement asynchronous printing. 'UserLogManager' + will use global 'USER_FUNCTION_LOG' to write the log got from this queue. + """ + + def __init__(self): + """ + log_level: should be the same as defined in logging. + debug, info, warning, error, critical = logging.DEBUG, logging.INFO, ... + """ + self.queue = None + + def set_queue(self, queue: Queue): + """ set a queue """ + self.queue = queue + + def log(self, level, msg): + """receive log message and put into the queue""" + log = Log(level, msg) + self.queue.put(log) + + def debug(self, msg: str): + """log with level debug""" + self.log(logging.DEBUG, msg) + + def info(self, msg: str): + """log with level debug""" + self.log(logging.INFO, msg) + + def warning(self, msg: str): + """log with level debug""" + self.log(logging.WARNING, msg) + + def error(self, msg: str): + """log with level debug""" + self.log(logging.ERROR, msg) + + +@Singleton +class UserLogManager: + """ + UserLogManager provide user a logger + It needs a queue and a user function logger + It has a loop, receiving user log, running till shutdown() called + """ + + def __init__(self): + self._user_function_log = None + self.log_level = logging.INFO + self.tenant_id = "" + self.function_name = "" + self.version = "" + self.package = "" + self.stream = "" + self.instance_id = "" + self.request_id = "" + self.stage = "" + self.finish_log = "" + self.status = "" + self.queue = None + self.__running = True + + def start_user_log(self): + """ start user function log with a new thread """ + self.__running = True + t = threading.Thread(target=self.run, name="user_log_thread", daemon=True) + t.start() + + def run(self) -> None: + """ a loop runs till user function finished """ + self.insert_start_log() + while self.__running: + ret_log = self.queue.get() + if ret_log is None: + break + self._write_log(ret_log.loglevel, ret_log.msg) + while not self.queue.empty(): + ret_log = self.queue.get() + if ret_log is None: + break + self._write_log(ret_log.loglevel, ret_log.msg) + self.insert_end_log() + + def set_log(self, log): + """ set a user function log """ + self._user_function_log = log + + def load_logger_config(self, cfg): + """ load logger config when function initializing """ + self.log_level = cfg['log_level'] + self.tenant_id = cfg['tenant_id'] + self.function_name = cfg['function_name'] + self.version = cfg['version'] + self.package = cfg['package'] + self.stream = cfg['stream'] + self.instance_id = cfg['instance_id'] + self._user_function_log = logger.get_user_function_logger(cfg['log_level']) + + def set_stage(self, stage: str): + """ + set stage + """ + self.stage = stage + + def insert_start_log(self): + """ insert a first log """ + self.status = "success" + self.finish_log = "false" + self._write_log(self.log_level, f"@@Start {self.stage} Reqeust") + + def insert_end_log(self, status="success"): + """ insert a last log """ + self.status = status + self.finish_log = "true" + self._write_log(self.log_level, f"@@End {self.stage} Reqeust") + + def shutdown(self): + """ stop the loop when user function finished """ + self.__running = False + if self.queue: + self.queue.put(None) + + def register_logger(self, log: FaasLogger): + """ + register logger + """ + self.queue = log.queue + + def _write_log(self, loglevel, message): + extra_info = { + "tenant_id": self.tenant_id, + "pod_name": {'podname': socket.gethostname()}, + "package": self.package, + "function_name": self.function_name, + "version": self.version, + "stream": self.stream, + "instance_id": self.instance_id, + "request_id": self.request_id, + "stage": self.stage, + "status": self.status, + "finish_log": self.finish_log, + } + self._user_function_log.log(loglevel, message, extra=extra_info) diff --git a/api/python/yr/functionsdk/utils.py b/api/python/yr/functionsdk/utils.py new file mode 100644 index 0000000..21045df --- /dev/null +++ b/api/python/yr/functionsdk/utils.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Faas Runtime utils +""" +import base64 +import ctypes +import functools +import json +import socket +import threading +import uuid +from typing import Union + +SUPPORTED_RESOURCE = { + 128: 300, + 256: 400, + 512: 600, + 768: 800, + 1024: 1000, + 1280: 1200, + 1536: 1400, + 1792: 1600, + 2048: 1800, + 2560: 2200, + 3072: 2600, + 3584: 3000, + 4096: 3400, + 8192: 6600, + 10240: 8200 +} + +HOST_NAME = socket.gethostname() +POD_NAME = {'podname': HOST_NAME} +G_TENANT_ID = "" + + +def encode_base64(data: bytes): + """Method encode string to base64""" + base64_result = base64.b64encode(data).decode('utf-8') + return base64_result + + +def convert_obj_to_json(obj): + """Method convert_obj_to_json""" + return obj.__dict__ + + +def is_instance_type(obj): + """Method is_instance_type""" + return hasattr(obj, '__dict__') + + +def to_json_string(obj, indent=None, sort_keys=False): + """Method to_json_string""" + if isinstance(obj, dict): + return json.dumps(obj, indent=indent, sort_keys=sort_keys) + return json.dumps(obj, indent=indent, default=convert_obj_to_json, sort_keys=sort_keys) + + +def generate_request_id() -> str: + """ + Format: task-188ca8cc-35ea-429c-8a07-68163b47c914 + """ + random_id = str(uuid.uuid4()) + return f"task-{random_id}" + + +def generate_trace_id() -> str: + """ + Format: trace-188ca8cc-35ea-429c-8a07-68163b47c914 + """ + random_id = str(uuid.uuid4()) + return f"trace-{random_id}" + + +def set_trace_id(trace_id: str) -> None: + """ + Set the trace id, this is unique in one processing thread + """ + thread_local = threading.local() + thread_local.trace_id = trace_id + + +def get_trace_id() -> str: + """ + Return trace id in this thread, if not set, generate one before return + """ + thread_local = threading.local() + if hasattr(thread_local, "trace_id"): + return thread_local.trace_id + return generate_trace_id() + + +def set_tenant_id(tenant_id) -> None: + """ + Set the tenant id of this instance, unique in this process + """ + global G_TENANT_ID + G_TENANT_ID = tenant_id + + +def get_tenant_id() -> str: + """ + Return the process unique tenant id + """ + return G_TENANT_ID + + +def parse_json_data_to_dict(user_data: Union[str, bytes, bytearray, dict]) -> dict: + '''This function parses user data in JSON format and returns a dictionary. + + Args: + user_data (Union[str, bytes, bytearray]): The user data to be parsed, + which can be a string, bytes, or bytearray. + + Raises: + RuntimeError: If there is an error during JSON parsing. + + Returns: + dict: The parsed user data in dictionary format. + ''' + if isinstance(user_data, dict): + return user_data + result = {} + if user_data in ("", {}, None): + return result + try: + result = json.loads(user_data) + except Exception as e: + raise RuntimeError(f"parse user_data error, err: {e}") from e + return result + + +def dump_data_to_json_str(user_data: object): + """This function converts an object to a JSON string. + + Args: + user_data (object): The object to be converted. + + Raises: + RuntimeError: If there is an error converting the object to a JSON string. + """ + try: + result = json.dumps(user_data) + except Exception as e: + raise RuntimeError(f"dump user_data error, err: {e}") from e + return result + + +def timeout(sec, check_sec=1): + """ + timeout decorator + :param sec: raise TimeoutError after %s seconds + :param check_sec: retry kill thread per %s seconds + default: 1 second + """ + + def decorator(func): + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + + result, exception = [], [] + + def run_user_func(): + try: + response = func(*args, **kwargs) + except TimeoutError: + pass + except SystemExit as e: + exception.append(e) + except BaseException as err: + exception.append(err) + else: + if response: + result.append(response) + else: + result.append("") + + thread = TerminableThread(target=run_user_func, daemon=True) + thread.start() + thread.join(timeout=sec) + + if thread.is_alive(): + exc = type('TimeoutError', TimeoutError.__bases__, dict(TimeoutError.__dict__)) + thread.terminated(exception_cls=exc, check_sec=check_sec) + err_msg = f'invoke timed out after {sec} seconds' + raise TimeoutError(err_msg) + if exception: + raise exception[0] + return result[0] + + return wrapped_func + + return decorator + + +class UserThreadKiller(threading.Thread): + """separate thread to kill TerminableThread""" + + def __init__(self, target_thread, exception_cls, check_sec=2.0): + super().__init__() + self.user_thread = target_thread + self.exception_cls = exception_cls + self.check_sec = check_sec + self.daemon = True + + def run(self): + """loop util user code finished or raise exception""" + while self.user_thread.is_alive(): + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(self.user_thread.ident), + ctypes.py_object(self.exception_cls)) + self.user_thread.join(self.check_sec) + + +class TerminableThread(threading.Thread): + """a thread that can be stopped by forcing an exception in the execution context""" + + def terminated(self, exception_cls, check_sec=2.0): + """Method to terminated user thread""" + if self.is_alive(): + killer = UserThreadKiller(self, exception_cls, check_sec=check_sec) + killer.start() diff --git a/api/python/yr/generator.py b/api/python/yr/generator.py new file mode 100644 index 0000000..0ca6a32 --- /dev/null +++ b/api/python/yr/generator.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ObjectRefGenerator""" +import asyncio +import logging +from typing import Optional +from yr.object_ref import ObjectRef +from yr.runtime_holder import global_runtime +from yr.fnruntime import GeneratorEndError +from yr.exception import GeneratorFinished + +_logger = logging.getLogger(__name__) + + +class ObjectRefGenerator: + """ ObjectRefGenerator streamming return""" + + def __init__(self, object_ref: ObjectRef): + self._obj_ref = object_ref + self._generator_id = object_ref.id + self._generator_task_exception = None + self._generator_ref = None + self._runtime = global_runtime.get_runtime() + self._stop = False + + def __iter__(self) -> "ObjectRefGenerator": + return self + + def __next__(self) -> ObjectRef: + return self._next_sync() + + def __aiter__(self) -> "ObjectRefGenerator": + return self + + async def __anext__(self): + return await self._next_async() + + def get_generator_id(self): + """get generator id """ + return self._generator_id + + def _next_sync( + self, + timeout_s: Optional[float] = None + ) -> ObjectRef: + """Waits for timeout_s and returns the object ref if available. + + If an object is not available within the given timeout, it + returns a nil object reference. + + If -1 timeout is provided, it means it waits infinitely. + + Waiting is implemented as busy waiting. + + Raises StopIteration if there's no more objects + to generate. + + The object ref will contain an exception if the task fails. + When the generator task returns N objects, it can return + up to N + 1 objects (if there's a system failure, the + last object will contain a system level exception). + + Args: + timeout_s: If the next object is not ready within + this timeout, it returns the nil object ref. + """ + if self._stop: + raise StopIteration + + try: + object_id = self._runtime.peek_object_ref_stream(self._generator_id, True) + except GeneratorEndError as e: + # generator stop error + self._stop = True + raise StopIteration from e + + except Exception as e: + self._generator_task_exception = e + self._stop = True + self._generator_ref = ObjectRef(object_id="", need_incre=False, exception=e) + return self._generator_ref + + if self._stop: + raise StopIteration + self._generator_ref = ObjectRef(object_id, need_incre=False) + return self._generator_ref + + async def _suppress_exceptions(self, ref: ObjectRef) -> None: + # Wrap a streamed ref to avoid asyncio warnings about not retrieving + # the exception when we are just waiting for the ref to become ready. + # The exception will get returned (or warned) to the user once they + # actually await the ref. + try: + data = await ref + self._generator_ref.set_data(data) + except GeneratorFinished as e: + self._stop = True + self._generator_task_exception = e + _logger.debug("generation finished, objectRef %s ", ref) + except Exception as e: + self._stop = True + self._generator_task_exception = e + _logger.debug("failed to await objectRef %s : %s", ref, str(e)) + + async def _next_async( + self, + timeout_s: Optional[float] = None + ): + """Same API as _next_sync, but it is for async context.""" + if self._stop: + raise StopAsyncIteration + + try: + object_id = self._runtime.peek_object_ref_stream(self._generator_id, False) + self._generator_ref = ObjectRef(object_id=object_id, need_incre=False) + _, unready = await asyncio.wait( + [asyncio.create_task(self._suppress_exceptions(self._generator_ref))], + timeout=timeout_s + ) + self._generator_ref.set_exception(self._generator_task_exception) + if len(unready) > 0: + return ObjectRef.nil() + except GeneratorEndError as e: + # generator stop error + self._stop = True + raise StopAsyncIteration from e + + except Exception as e: + self._generator_task_exception = e + self._stop = True + self._generator_ref = ObjectRef(object_id="", need_incre=False, exception=e) + return self._generator_ref + + if self._stop: + raise StopAsyncIteration + return self._generator_ref diff --git a/api/python/yr/includes/affinity.pxd b/api/python/yr/includes/affinity.pxd index 989cedc..f92bd0b 100644 --- a/api/python/yr/includes/affinity.pxd +++ b/api/python/yr/includes/affinity.pxd @@ -22,6 +22,7 @@ from libcpp.vector cimport vector cdef extern from "src/libruntime/fsclient/protobuf/common.pb.h" nogil: cdef cppclass PBAffinity "::common::Affinity" + cdef cppclass PBInstanceAffinity "::common::InstanceAffinity" cdef cppclass CLabelMatchExpression "::common::LabelMatchExpression" cdef extern from "src/dto/affinity.h" nogil: @@ -45,6 +46,9 @@ cdef extern from "src/dto/affinity.h" nogil: void SetRequiredPriority(bool requiredPriority) void SetPreferredAntiOtherLabels(bool preferredAntiOtherLabels) bool GetPreferredAntiOtherLabels() + string GetAffinityScope() const + void SetAffinityScope(const string &affinityScope) + void UpdateAffinityScope(PBInstanceAffinity *pbInstanceAffinity) void UpdatePbAffinity(PBAffinity *pbAffinity) size_t GetAffinityHash() vector[CLabelMatchExpression] GetLabels() diff --git a/api/python/yr/includes/affinity.pxi b/api/python/yr/includes/affinity.pxi index 9fb3273..dde92fe 100644 --- a/api/python/yr/includes/affinity.pxi +++ b/api/python/yr/includes/affinity.pxi @@ -23,6 +23,7 @@ from yr.affinity import ( AffinityType, LabelOperator, OperatorType, + AffinityScope, ) from yr.includes.affinity cimport ( CAffinity, @@ -108,6 +109,11 @@ cdef shared_ptr[CAffinity] affinity_from_py_to_cpp(affinity: Affinity, bool pref raise ValueError("Failed to convert LabelOperator to cpp LabelOperator.") c_operators.push_back(c_operator) + if affinity.affinity_scope == AffinityScope.POD: + c_affinity.get().SetAffinityScope(AffinityScope.POD.name.encode()) + elif affinity.affinity_scope == AffinityScope.NODE: + c_affinity.get().SetAffinityScope(AffinityScope.NODE.name.encode()) + c_affinity.get().SetLabelOperators(c_operators); c_affinity.get().SetPreferredPriority(preferredPriority); c_affinity.get().SetRequiredPriority(requiredPriority); diff --git a/api/python/yr/includes/libruntime.pxd b/api/python/yr/includes/libruntime.pxd index 59a093e..2f229de 100644 --- a/api/python/yr/includes/libruntime.pxd +++ b/api/python/yr/includes/libruntime.pxd @@ -235,6 +235,8 @@ cdef extern from "src/libruntime/libruntime_config.h" nogil: string logDir uint32_t logFileSizeMax uint32_t logFileNumMax + bool logToDriver + bool dedupLogs int logFlushInterval CLibruntimeOptions libruntimeOptions string metaConfig @@ -250,6 +252,9 @@ cdef extern from "src/libruntime/libruntime_config.h" nogil: string privateKeyPath string certificateFilePath string verifyFilePath + char privateKeyPaaswd[MAX_PASSWD_LENGTH] + shared_ptr[void] tlsContext + uint32_t httpIocThreadsNum string serverName string clientId bool inCluster @@ -293,7 +298,9 @@ cdef extern from "src/proto/libruntime.pb.h" nogil: cdef enum CApiType "libruntime::ApiType": ACTOR "libruntime::ApiType::Function", + FAAS "libruntime::ApiType::Faas" POSIX "libruntime::ApiType::Posix" + SERVE "libruntime::ApiType::Serve" cdef enum CSignal "libruntime::Signal": DEFAULTSIGNAL "libruntime::Signal::DefaultSignal", @@ -337,8 +344,8 @@ cdef extern from "src/dto/invoke_options.h" nogil: string signature string functionId CApiType apiType - optional[string] name - optional[string] ns + string name + string ns string initializerCodeId bool isGenerator bool isAsync @@ -377,6 +384,7 @@ cdef extern from "src/dto/invoke_options.h" nogil: int maxInvokeLatency int minInstances int maxInstances + bool isDataAffinity list[shared_ptr[CAffinity]] scheduleAffinities bool needOrder int recoverRetryTimes @@ -388,6 +396,9 @@ cdef extern from "src/dto/invoke_options.h" nogil: bool isGetInstance string traceId string workingDir + bool preemptedAllowed + int instancePriority + int64_t scheduleTimeoutMs cdef cppclass CMetaConfig "YR::Libruntime::MetaConfig": string jobID @@ -523,12 +534,38 @@ cdef extern from "src/libruntime/statestore/state_store.h" nogil: ctypedef pair[vector[shared_ptr[CBuffer]], CErrorInfo] CMultipleReadResult "YR::Libruntime::MultipleReadResult" +cdef extern from "src/dto/stream_conf.h" nogil: + cdef cppclass CElement "YR::Libruntime::Element": + CElement(uint8_t *ptr, uint64_t size, uint64_t id) + CElement() + uint8_t *ptr + uint64_t size + uint64_t id + + cdef cppclass CProducerConf "YR::Libruntime::ProducerConf": + int64_t delayFlushTime + int64_t pageSize + uint64_t maxStreamSize + bool autoCleanup + bool encryptStream + uint64_t retainForNumConsumers + uint64_t reserveSize + unordered_map[string, string] extendConfig + + cdef cppclass CSubscriptionConfig "YR::Libruntime::SubscriptionConfig": + string subscriptionName + CSubscriptionType subscriptionType + unordered_map[string, string] extendConfig + SubscriptionConfig(string subName, const CSubscriptionType subType) + SubscriptionConfig() + cdef extern from "src/dto/resource_unit.h" nogil: cdef cppclass CResourceUnit "YR::Libruntime::ResourceUnit": string id uint32_t status unordered_map[string, float] capacity unordered_map[string, float] allocatable + unordered_map[string, vector[string]] nodeLabels cdef extern from "src/dto/resource_unit.h" nogil: cdef cppclass CScalar "YR::Libruntime::Resource::Scalar": @@ -583,6 +620,19 @@ cdef extern from "src/dto/resource_unit.h" nogil: cdef cppclass CResourceGroupUnit "YR::Libruntime::ResourceGroupUnit": unordered_map[string, CRgInfo] resourceGroups; +cdef extern from "src/libruntime/streamstore/stream_producer_consumer.h" nogil: + cdef cppclass CStreamProducer "YR::Libruntime::StreamProducer": + CErrorInfo Send(const CElement & element) + CErrorInfo Send(const CElement & element, int64_t timeoutMs) + CErrorInfo Flush() + CErrorInfo Close() + + cdef cppclass CStreamConsumer "YR::Libruntime::StreamConsumer": + CErrorInfo Receive(uint32_t expectNum, uint32_t timeoutMs, vector[CElement] & outElements) + CErrorInfo Receive(uint32_t timeoutMs, vector[CElement] & outElements) + CErrorInfo Ack(uint64_t elementId) + CErrorInfo Close() + cdef extern from "src/libruntime/libruntime.h" nogil: cdef cppclass CLibruntime "YR::Libruntime::Libruntime": CLibruntime(shared_ptr[CLibruntimeConfig] config) @@ -630,6 +680,14 @@ cdef extern from "src/libruntime/libruntime.h" nogil: CErrorInfo KVDel(const string & key) CMultipleDelResult KVDel(const vector[string] & keys) + CErrorInfo CreateStreamProducer(const string & streamName, CProducerConf producerConf, + shared_ptr[CStreamProducer] producer) + CErrorInfo CreateStreamConsumer(const string & streamName, const CSubscriptionConfig & config, + shared_ptr[CStreamConsumer] consumer) + CErrorInfo DeleteStream(const string & streamName) + CErrorInfo QueryGlobalProducersNum(const string & streamName, uint64_t & gProducerNum) + CErrorInfo QueryGlobalConsumersNum(const string & streamName, uint64_t & gConsumerNum) + void SetTenantIdWithPriority() string GetTenantId() @@ -653,6 +711,9 @@ cdef extern from "src/libruntime/libruntime.h" nogil: string GenerateGroupName(); + pair[CErrorInfo, string] PeekObjectRefStream(const string & generatorId, bool blocking); + CErrorInfo NotifyGeneratorResult(const string & generatorId, int index, shared_ptr[CDataObject] & resultObj, CErrorInfo & resultErr); + CErrorInfo NotifyGeneratorFinished(const string & generatorId, int numResults); void WaitAsync(const string & objectId, CWaitAsyncCallback callback, void *userData); void GetAsync(const string & objectId, CGetAsyncCallback callback, void *userData); diff --git a/api/python/yr/includes/serialization.pxi b/api/python/yr/includes/serialization.pxi index f029bf2..aeae253 100644 --- a/api/python/yr/includes/serialization.pxi +++ b/api/python/yr/includes/serialization.pxi @@ -31,6 +31,11 @@ DEF METADATA_HEADER_OFFSET = 8 cdef extern from "src/utility/memory.h" namespace "YR::utility" nogil: void CopyInParallel(uint8_t *dst, const uint8_t *src, int64_t totalBytes, size_t blockSize) +cdef extern from "google/protobuf/repeated_field.h" nogil: + cdef cppclass RepeatedField[Element]: + const Element* data() const + + cdef int64_t padded_length(int64_t offsets, int64_t alignment): return ((offsets + alignment - 1) // alignment) * alignment diff --git a/api/python/yr/local_mode/local_mode_runtime.py b/api/python/yr/local_mode/local_mode_runtime.py index 1182813..b639831 100644 --- a/api/python/yr/local_mode/local_mode_runtime.py +++ b/api/python/yr/local_mode/local_mode_runtime.py @@ -25,9 +25,11 @@ from yr.accelerate.shm_broadcast import Handle from yr.common.types import GroupInfo from yr.config import InvokeOptions from yr.exception import YRInvokeError +from yr.stream import ProducerConfig, SubscriptionConfig from yr.common.utils import ( generate_random_id, generate_task_id, GaugeData, UInt64CounterData, DoubleCounterData ) +from yr.fnruntime import Producer, Consumer from yr.local_mode.local_client import LocalClient from yr.local_mode.local_object_store import LocalObjectStore from yr.local_mode.task_manager import TaskManager @@ -339,6 +341,48 @@ class LocalModeRuntime(Runtime, ABC): self.__enable_flag = False self.__local_store.clear() + def create_stream_producer(self, stream_name: str, config: ProducerConfig) -> Producer: + """ + create stream producer + :param stream_name: stream name + :param config: ProducerConfig + :return: producer + """ + raise RuntimeError("not support in local mode") + + def create_stream_consumer(self, stream_name: str, config: SubscriptionConfig) -> Consumer: + """ + create stream consumer + :param stream_name: stream name + :param config: SubscriptionConfig + :return: consumer + """ + raise RuntimeError("not support in local mode") + + def delete_stream(self, stream_name: str) -> None: + """ + delete stream + :param stream_name: stream name + :return: None + """ + raise RuntimeError("not support in local mode") + + def query_global_producers_num(self, stream_name: str) -> int: + """ + query global producers num + :param stream_name: stream name + :return: producers num + """ + raise RuntimeError("not support in local mode") + + def query_global_consumers_num(self, stream_name: str) -> int: + """ + query global consumers num + :param stream_name: stream name + :return: consumers num + """ + raise RuntimeError("not support in local mode") + def get_real_instance_id(self, instance_id: str) -> str: """ get real instance id diff --git a/api/python/yr/log.py b/api/python/yr/log.py index f79f299..0838cac 100644 --- a/api/python/yr/log.py +++ b/api/python/yr/log.py @@ -31,7 +31,7 @@ from yr.common.singleton import Singleton _MAX_ROW_SIZE = 1024 * 1024 # python runtime log location _BASE_LOG_NAME = "yr" -_LOG_SUFFIX = ".log" +_LOG_SUFFIX = "_runtime.log" class CustomFilter(logging.Filterer): @@ -97,7 +97,8 @@ class RuntimeLogger: self.__logger = logging.getLogger("FileLogger") self.__logger.addFilter(CustomFilter()) - self.__logger = logging.LoggerAdapter(self.__logger, {'podname': socket.gethostname()}) + self.__logger = logging.LoggerAdapter(self.__logger, {'podname': socket.gethostname(), + 'runtime_id': self.__runtime_id}) def __init_stream_logger(self, log_level: str) -> None: self.__logger = logging.getLogger(_BASE_LOG_NAME) @@ -117,15 +118,18 @@ class RuntimeLogger: log_file_name = os.getenv("GLOG_log_dir") os.environ["DATASYSTEM_CLIENT_LOG_DIR"] = log_file_name self.__runtime_log_location = log_file_name - - log_file_name = os.path.join(log_file_name, self.__runtime_id + _LOG_SUFFIX) + log_id = os.environ.get("YR_LOG_PREFIX", "") + if len(log_id) != 0: + log_file_name = os.path.join(log_file_name, log_id + _LOG_SUFFIX) + else: + log_file_name = os.path.join(log_file_name, self.__runtime_id + _LOG_SUFFIX) config["handlers"]["file"]["filename"] = log_file_name return log_file_name -def init_logger(is_driver: bool, runtime_id: str = "", log_level: str = "DEBUG") -> None: +def init_logger(on_cloud: bool, runtime_id: str = "", log_level: str = "DEBUG") -> None: """init log handler""" - RuntimeLogger().init(is_driver, runtime_id, log_level) + RuntimeLogger().init(on_cloud, runtime_id, log_level) def get_logger() -> Logger: diff --git a/api/python/yr/main/yr_runtime_main.py b/api/python/yr/main/yr_runtime_main.py index 6b5288e..df5d743 100644 --- a/api/python/yr/main/yr_runtime_main.py +++ b/api/python/yr/main/yr_runtime_main.py @@ -84,6 +84,7 @@ def configure(): config.log_dir = log_dir else: config.log_dir = DEFAULT_LOG_DIR + config.in_cluster = True return config @@ -102,9 +103,9 @@ def insert_sys_path(): def main(): """main""" # If args are invalid, the script automatically exits when calling 'parser.parse_args()'. + insert_sys_path() init(configure()) try_install_uvloop() - insert_sys_path() receive_request_loop() diff --git a/api/python/yr/object_ref.py b/api/python/yr/object_ref.py index 010a89b..83212de 100644 --- a/api/python/yr/object_ref.py +++ b/api/python/yr/object_ref.py @@ -21,8 +21,8 @@ import json from concurrent.futures import Future from typing import Any, Union +from yr.exception import YRInvokeError, YRError, GeneratorFinished from yr.err_type import ErrorInfo, ErrorCode -from yr.exception import YRInvokeError, YRError import yr from yr import log @@ -30,21 +30,26 @@ from yr.common import constants def _set_future_helper( - result: Any, - *, - f: Union[asyncio.Future, Future], + result: Any, + *, + f: Union[asyncio.Future, Future], ): if f.done(): return if isinstance(result, ErrorInfo): + if result.error_code == ErrorCode.ERR_GENERATOR_FINISHED.value: + f.set_exception(GeneratorFinished("")) + return if result.error_code != ErrorCode.ERR_OK.value: f.set_exception(RuntimeError( - f"code: {result.error_code}, module code {result.module_code}, msg: {result.msg}")) + f"code: {result.error_code}, module code {result.module_code}, msg: {result.msg}")) elif isinstance(result, YRInvokeError): f.set_exception(result.origin_error()) elif isinstance(result, YRError): f.set_exception(result) + elif isinstance(result, RuntimeError): + f.set_exception(result) else: f.set_result(result) @@ -128,7 +133,7 @@ class ObjectRef: """ f = Future() if self._exception is not None: - f.set_exception(RuntimeError(str(self._exception))) + _set_future_helper(self._exception, f=f) return f if self._data is not None: f.set_result(self._data) diff --git a/api/python/yr/runtime.py b/api/python/yr/runtime.py index 628545e..6489726 100644 --- a/api/python/yr/runtime.py +++ b/api/python/yr/runtime.py @@ -122,6 +122,11 @@ class MSetParam: class CreateParam: """Create param.""" + def __init__(self): + """ + Initialize a CreateParam instance. + """ + pass #: Configure the reliability of the data. #: When the server is configured to support a secondary cache for ensuring reliability, @@ -411,6 +416,48 @@ class Runtime(metaclass=ABCMeta): :return: None """ + @abstractmethod + def create_stream_producer(self, stream_name: str, config): + """ + create stream producer + :param stream_name: stream name + :param config: ProducerConfig + :return: producer + """ + + @abstractmethod + def create_stream_consumer(self, stream_name: str, config): + """ + create stream consumer + :param stream_name: stream name + :param config: SubscriptionConfig + :return: consumer + """ + + @abstractmethod + def delete_stream(self, stream_name: str) -> None: + """ + delete stream + :param stream_name: stream name + :return: None + """ + + @abstractmethod + def query_global_producers_num(self, stream_name: str) -> int: + """ + query global producers num + :param stream_name: stream name + :return: producers num + """ + + @abstractmethod + def query_global_consumers_num(self, stream_name: str) -> int: + """ + query global consumers num + :param stream_name: stream name + :return: consumers num + """ + @abstractmethod def is_object_existing_in_local(self, object_id: str) -> bool: """ diff --git a/api/python/yr/runtime_env.py b/api/python/yr/runtime_env.py new file mode 100644 index 0000000..7b27d94 --- /dev/null +++ b/api/python/yr/runtime_env.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import uuid +from pathlib import Path +from typing import Dict + +from yr.config import InvokeOptions + +WORKING_DIR_KEY = "WORKING_DIR" +CONDA_PREFIX = "CONDA_PREFIX" + + +def _get_conda_bin_executable(executable_name: str) -> str: + conda_home = os.environ.get("YR_CONDA_HOME") + if conda_home: + return conda_home + if CONDA_PREFIX in os.environ: + return os.environ.get(CONDA_PREFIX) + raise ValueError( + "please configure YR_CONDA_HOME environment variable which contain a bin subdirectory" + ) + + +def _check_pip_and_conda(runtime_env: Dict): + if runtime_env.get("conda") and runtime_env.get("pip"): + raise ValueError( + "The 'pip' field and 'conda' field of " + "runtime_env cannot both be specified.\n" + f"specified pip field: {runtime_env['pip']}\n" + f"specified conda field: {runtime_env['conda']}\n" + "To use pip with conda, please only set the 'conda' " + "field, and specify your pip dependencies " + "within the conda YAML config dict" + ) + + +def _process_pip(opt: InvokeOptions, create_opt: Dict): + _check_pip_and_conda(opt.runtime_env) + pip_command = "pip3 install " + " ".join(opt.runtime_env.get("pip")) + create_opt["POST_START_EXEC"] = pip_command + + +def _process_conda(opt: InvokeOptions, create_opt: Dict): + _check_pip_and_conda(opt.runtime_env) + create_opt[CONDA_PREFIX] = _get_conda_bin_executable("conda") + conda_config = opt.runtime_env.get("conda") + if isinstance(conda_config, str): + yaml_file = Path(conda_config) + if yaml_file.suffix in (".yaml", ".yml"): + if not yaml_file.is_file(): + raise ValueError(f"Can't find conda YAML file {yaml_file}.") + try: + import yaml + result = yaml.safe_load(yaml_file.read_text()) + name = result.get("name", str(uuid.uuid4())) + json_str = json.dumps(result) + create_opt["CONDA_CONFIG"] = json_str + conda_command = "conda env create -f env.yaml" + create_opt["CONDA_COMMAND"] = conda_command + create_opt["CONDA_DEFAULT_ENV"] = name + except Exception as e: + raise ValueError(f"Failed to read conda file {yaml_file}") from e + else: + conda_command = f"conda activate {conda_config}" + create_opt["CONDA_COMMAND"] = conda_command + create_opt["CONDA_DEFAULT_ENV"] = conda_config + if isinstance(conda_config, dict): + try: + json_str = json.dumps(conda_config) + name = conda_config.get("name", str(uuid.uuid4())) + create_opt["CONDA_CONFIG"] = json_str + conda_command = "conda env create -f env.yaml" + create_opt["CONDA_COMMAND"] = conda_command + create_opt["CONDA_DEFAULT_ENV"] = name + except Exception as e: + raise ValueError(f"Failed to load conda {conda_config}") from e + if not isinstance(conda_config, dict) and not isinstance(conda_config, str): + raise TypeError("runtime_env.get('conda') must be of type dict or str") + + +def _process_working_dir(opt: InvokeOptions, create_opt: Dict): + working_dir = opt.runtime_env.get("working_dir") + if not isinstance(working_dir, str): + raise TypeError("`working_dir` must be a string, got " f"{type(working_dir)}.") + create_opt[WORKING_DIR_KEY] = working_dir + + +def _process_env_vars(opt: InvokeOptions, create_opt: Dict): + env_vars = opt.runtime_env.get("env_vars") + if not isinstance(env_vars, dict): + raise TypeError( + "runtime_env.get('env_vars') must be of type " + f"Dict[str, str], got {type(env_vars)}" + ) + for key, val in env_vars.items(): + if not isinstance(key, str): + raise TypeError( + "runtime_env.get('env_vars') must be of type " + f"Dict[str, str], but the key {key} is of type {type(key)}" + ) + if not isinstance(val, str): + raise TypeError( + "runtime_env.get('env_vars') must be of type " + f"Dict[str, str], but the value {val} is of type {type(val)}" + ) + if not opt.env_vars.get(key): + opt.env_vars[key] = val + + +def _process_shared_dir(opt: InvokeOptions, create_opt: Dict): + shared_dir = opt.runtime_env.get("shared_dir") + if not isinstance(shared_dir, dict): + raise TypeError( + "runtime_env.get('shared_dir') must be of type " + f"Dict[str, str], got {type(shared_dir)}") + if "name" not in shared_dir: + raise ValueError("runtime_env.get('shared_dir') contain of 'name'") + name = shared_dir["name"] + if not isinstance(name, str): + raise TypeError( + "runtime_env['shared_dir']['name'] must be of type str" + f"but the value {name} is of type {type(name)}") + if "TTL" in shared_dir: + ttl = shared_dir["TTL"] + else: + ttl = 0 + if not isinstance(ttl, int): + raise TypeError( + "runtime_env['shared_dir']['TTL'] must be of type int" + f"but the value {ttl} is of type {type(ttl)}") + create_opt["DELEGATE_SHARED_DIRECTORY"] = name + create_opt["DELEGATE_SHARED_DIRECTORY_TTL"] = f"{ttl}" + + +_runtime_env_processors = { + "pip": _process_pip, + "conda": _process_conda, + "working_dir": _process_working_dir, + "env_vars": _process_env_vars, + "shared_dir": _process_shared_dir +} + + +def parse_runtime_env(opt: InvokeOptions) -> Dict: + """ + parse runtime env to create options + """ + create_opt = {} + if opt.runtime_env is None: + return create_opt + if not isinstance(opt.runtime_env, dict): + raise TypeError("`InvokeOptions.runtime_env` must be a dict, got " f"{type(opt.runtime_env)}.") + + for key in opt.runtime_env.keys(): + if key not in _runtime_env_processors: + raise ValueError(f"runtime_env.get('{key}') is not supported.") + _runtime_env_processors[key](opt, create_opt) + + return create_opt diff --git a/api/python/yr/serialization/__init__.py b/api/python/yr/serialization/__init__.py index 40fefe7..fd9f265 100644 --- a/api/python/yr/serialization/__init__.py +++ b/api/python/yr/serialization/__init__.py @@ -15,7 +15,6 @@ # limitations under the License. """serialization""" +from yr.serialization.serialization import Serialization, register_pack_hook, register_unpack_hook __all__ = ["Serialization", "register_pack_hook", "register_unpack_hook"] - -from yr.serialization.serialization import Serialization, register_pack_hook, register_unpack_hook diff --git a/api/python/yr/serialization/serialization.py b/api/python/yr/serialization/serialization.py index d55cb26..8d9ba67 100644 --- a/api/python/yr/serialization/serialization.py +++ b/api/python/yr/serialization/serialization.py @@ -37,13 +37,24 @@ class Serialization: self._protocol = 5 @staticmethod - def serialize(value: Any) -> SerializedObject: + def normalize_input(value: Any): + """ + Normalize bytes, memoryview, bytearray to memoryview to avoid copying and attach meta-information + """ + if isinstance(value, bytes): + return constants.Metadata.BYTES, value + if isinstance(value, memoryview): + return constants.Metadata.MEMORYVIEW, value + if isinstance(value, bytearray): + return constants.Metadata.BYTEARRAY, value + raise TypeError(f"Unsupported input type: {type(value)}") + + def serialize(self, value: Any) -> SerializedObject: """serialize""" metadata = constants.Metadata.CROSS_LANGUAGE py_serialized_object = None - if isinstance(value, bytes): - metadata = constants.Metadata.BYTES - msgpack_data = value + if isinstance(value, (bytes, memoryview, bytearray)): + metadata, msgpack_data = self.normalize_input(value) else: msgpack_serialized_object = MessagePackSerializer.serialize(value) msgpack_data = msgpack_serialized_object.msgpack_data @@ -51,6 +62,8 @@ class Serialization: if py_objects: metadata = constants.Metadata.PYTHON py_serialized_object = PySerializer.serialize(py_objects) + if not isinstance(msgpack_data, bytes): + msgpack_data = bytes(msgpack_data) return SerializedObject( metadata=metadata.value, msgpack_data=msgpack_data, @@ -63,7 +76,7 @@ class Serialization: structure. Otherwise, cloudpickles is directly used for deserialization. Args: - values (Union[bytes, List[bytes]]): Bytes data to be deserialized. + buffers (Union[None, Buffer, memoryview, List[Buffer]]): Bytes data to be deserialized. Returns: Union[Any, List]: The deserialized result. @@ -72,7 +85,6 @@ class Serialization: if is_buffer: buffers = [buffers] - pop_local_object_refs() result = [] # Deserializing user code may execute 'import' operation, which is not thread safe. @@ -83,21 +95,21 @@ class Serialization: result.append(None) continue metadata, msgpack_data, py_serialized_data = split_buffer(buf) - if constants.Metadata(metadata) == constants.Metadata.BYTES: - result.append(bytes(msgpack_data)) + if constants.Metadata(metadata) in [ + constants.Metadata.BYTES, + constants.Metadata.BYTEARRAY, + constants.Metadata.MEMORYVIEW + ]: + result.append(memoryview(msgpack_data)) continue - python_objects = [] if constants.Metadata(metadata) == constants.Metadata.PYTHON: python_objects = PySerializer.deserialize(py_serialized_data) - result.append(MessagePackSerializer.deserialize(msgpack_data, python_objects)) - object_refs = pop_local_object_refs() if len(object_refs) != 0: object_ref_ids = [ref.id for ref in object_refs] yr.runtime_holder.global_runtime.get_runtime().increase_global_reference(object_ref_ids) - return result[0] if is_buffer else result def register_pack_hook(self, hook): diff --git a/api/python/yr/stream.py b/api/python/yr/stream.py new file mode 100644 index 0000000..023a7d4 --- /dev/null +++ b/api/python/yr/stream.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""stream""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Union + + +@dataclass +class ProducerConfig: + """ + The configuration class created by the producer. + """ + + #: After Send, Flush will be triggered after a delay up to the specified duration. + #: < 0: do not auto flush, + #: = 0: flush immediately, + #: > 0: delay duration in milliseconds before flushing. + #: Default value is 5. + delay_flush_time: int = 5 + #: Represents the buffer page size for the producer, in bytes (B). A flush is triggered when a page is full. + #: The value must be greater than 0 and a multiple of 4 KB. + #: Default is ``1`` MB (``1024 * 1024``). + page_size: int = 1024 * 1024 + #: Specifies the maximum shared memory size that a stream can use on a worker, in bytes (B). + #: The default is ``1`` GB ( ``1024 * 1024 * 1024``), + #: and the valid range is [64 KB, size of the worker's shared memory]. + max_stream_size: int = 1024 * 1024 * 1024 + #: Specifies whether the stream enables the auto-cleanup feature. Default is ``false``. + auto_clean_up: bool = False + #: Specifies whether content encryption is enabled for the stream. Default is ``false`` (disabled). + encrypt_stream: bool = False + #: The data sent by the producer will be retained until received by the Nth consumer. + #: The default value is ``0``, meaning that if there are no consumers when the producer sends data, + #: the data will not be retained and may be missed when consumers are created later. + retain_for_num_consumers: int = 0 + #: Represents the reserved memory size in bytes (B). + #: When creating a producer, an attempt will be made to reserve ``reserve_size`` bytes of memory. + #: If the reservation fails, creating the producer will raise an exception. + #: ``reserve_size`` must be an integer multiple of ``page_size`` and within the range ``[0, max_stream_size]``. + #: If ``reserve_size`` is ``0``, it will be set to `page_size`. + #: Default value is ``0``. + reserve_size: int = 0 + #: Extended configuration stored as a dictionary, allowing users to customize configuration items. + #: Default value is an empty dictionary. + extend_config: Dict[str, str] = field(default_factory=dict) + + +class SubscriptionType(Enum): + """ + SubscriptionType + + Attributes: + STREAM: default mode. + ROUND_ROBIN: not support + KEY_PARTITIONS: not support. + """ + STREAM = 0 + ROUND_ROBIN = 1 + KEY_PARTITIONS = 2 + + +@dataclass +class SubscriptionConfig: + """ + The configuration class subscribed by consumers. + """ + #: Subscription name, used to identify subscriptions in the producer configuration. + #: The value of this attribute is a string. + subscription_name: str + #: Subscription type, including ``STREAM``, ``ROUND_ROBIN``, and ``KEY_PARTITIONS``. + #: ``STREAM`` means single consumer consumption within a subscription group, + #: ``ROUND_ROBIN`` means multiple consumers in a subscription group share load in a round-robin manner, + #: ``KEY_PARTITIONS`` means multiple consumers in a subscription group share load by key partitioning. + #: Currently, only ``STREAM`` type is supported; other types are not supported. + #: The default subscription type is ``STREAM``. + subscriptionType: SubscriptionType = SubscriptionType.STREAM + #: Extended configuration. + #: stored in dictionary form, allows users to customize configuration items. + #: The default value is an empty dictionary. the dictionary generated through ``field(default_factory=dict)``. + extend_config: Dict[str, str] = field(default_factory=dict) + + +class Element: + """ + Element class containing an element ID and data buffer. + + Args: + value (Union[bytes, memoryview]): data to send. + ele_id (int, optional): element id. Default to ``0``. + """ + + def __init__(self, value: Union[bytes, memoryview], ele_id: int = 0) -> None: + self.data = value + self.id = ele_id diff --git a/api/python/yr/tests/BUILD.bazel b/api/python/yr/tests/BUILD.bazel index 6d7c638..0d6ce24 100644 --- a/api/python/yr/tests/BUILD.bazel +++ b/api/python/yr/tests/BUILD.bazel @@ -74,6 +74,16 @@ py_test( env = SAN_ENV, ) +py_test( + name = "test_generator", + size = "small", + srcs = ["test_generator.py"], + tags = ["smoke"], + imports = ["../../"], + deps = ["//api/python:yr_lib"], + env = SAN_ENV, +) + py_test( name = "test_instance_manager", size = "small", @@ -84,6 +94,15 @@ py_test( env = SAN_ENV, ) +py_test( + name = "test_faas_handler", + size = "small", + srcs = ["test_faas_handler.py"], + tags = ["smoke"], + imports = ["../../"], + deps = ["//api/python:yr_lib"], + env = SAN_ENV, +) py_test( name = "test_function_handler", size = "small", diff --git a/api/python/yr/tests/test_apis.py b/api/python/yr/tests/test_apis.py index ec212a2..6b0a23c 100644 --- a/api/python/yr/tests/test_apis.py +++ b/api/python/yr/tests/test_apis.py @@ -54,6 +54,7 @@ class TestApi(unittest.TestCase): conf = yr.Config() conf.function_id = "sn:cn:yrk:12345678901234561234567890123456:function:0-yr-test-config-init:$latest" conf.server_address = "127.0.0.1:11111" + conf.in_cluster = False with self.assertRaises(ValueError): yr.init(conf) @@ -91,6 +92,7 @@ class TestApi(unittest.TestCase): def test_yr_init_failed_when_input_invaild_function_id(self): conf = yr.Config() conf.function_id = "111" + conf.in_cluster = False with pytest.raises(ValueError): yr.init(conf) @@ -108,6 +110,10 @@ class TestApi(unittest.TestCase): assert affinity_type == affinity.affinity_type assert label_operators == affinity.label_operators + affinity_scope = yr.AffinityScope.NODE + affinity2 = yr.Affinity(affinity_kind, affinity_type, label_operators, affinity_scope) + assert affinity_scope == affinity2.affinity_scope + @patch("yr.apis.is_initialized") def test_cancel_with_invalid_value(self, is_initialized): is_initialized.return_value = True @@ -294,12 +300,26 @@ class TestApi(unittest.TestCase): @patch("yr.apis.is_initialized") def test_stream(self, is_initialized, get_runtime): mock_runtime = Mock() + mock_runtime.create_stream_producer.return_value = "producer" + mock_runtime.create_stream_consumer.return_value = "consummer" + mock_runtime.query_global_producers_num.return_value = 10 + mock_runtime.query_global_consumers_num.return_value = 10 + mock_runtime.delete_stream.side_effect = RuntimeError("mock exception") mock_runtime.kv_write.side_effect = RuntimeError("mock exception") mock_runtime.kv_read.return_value = ["value"] mock_runtime.kv_del.side_effect = RuntimeError("mock exception") get_runtime.return_value = mock_runtime is_initialized.return_value = True + self.assertEqual(yr.create_stream_producer("", None), "producer") + + self.assertEqual(yr.create_stream_consumer("", None), "consummer") + self.assertEqual(yr.query_global_producers_num(""), 10) + self.assertEqual(yr.query_global_consumers_num(""), 10) + + with self.assertRaises(RuntimeError): + yr.delete_stream("") + with self.assertRaises(RuntimeError): yr.kv_write("key", b"abc") @@ -484,6 +504,11 @@ class TestException(unittest.TestCase): self.assertEqual(err.message, "some exception") self.assertTrue("request1" in str(err)) + def test_generator_finished(self): + e = exception.GeneratorFinished(RuntimeError("some exception")) + with self.assertRaises(Exception): + raise e + def test_deal_with_yr_error(self): origin_err = exception.YRInvokeError("origin_err", "some origin exception") err = exception.YRInvokeError(origin_err, "some new exception") diff --git a/api/python/yr/tests/test_apis_get.py b/api/python/yr/tests/test_apis_get.py new file mode 100644 index 0000000..03a61b7 --- /dev/null +++ b/api/python/yr/tests/test_apis_get.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import yr +from unittest.mock import patch, Mock +from yr.object_ref import ObjectRef + +class TestGet(unittest.TestCase): + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_get_invalid_timeout(self, mock_get_runtime): + obj_ref = yr.object_ref.ObjectRef(123) + time = -10 + mock_get_runtime =Mock() + mock_get_runtime.get.return_value="Parameter 'timeout' should be greater than 0 or equal to -1 (no timeout)" + mock_get_runtime.return_value=mock_get_runtime + yr.apis.set_initialized() + with self.assertRaises(ValueError): + yr.apis.get(obj_ref,time) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/api/python/yr/tests/test_apis_put.py b/api/python/yr/tests/test_apis_put.py new file mode 100644 index 0000000..6be6cca --- /dev/null +++ b/api/python/yr/tests/test_apis_put.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import yr +from yr.object_ref import ObjectRef +from unittest.mock import patch, Mock + + +class TestPut(unittest.TestCase): + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_with_memoryview(self, mock_get_runtime): + obj_refs = memoryview(bytearray(10 * 1024 * 1024)) + mock_runtime = Mock() + mock_runtime.put.return_value = 1 + mock_get_runtime.return_value = mock_runtime + yr.apis.set_initialized() + result = yr.apis.put(obj_refs) + self.assertIsInstance(result, ObjectRef) + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_with_bytes(self, mock_get_runtime): + obj_refs = bytes(10 * 1024 * 1024) + mock_runtime = Mock() + mock_runtime.put.return_value = 10 + mock_get_runtime.return_value = mock_runtime + yr.apis.set_initialized() + result = yr.apis.put(obj_refs) + self.assertIsInstance(result, ObjectRef) + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_with_bytearray(self, mock_get_runtime): + obj_refs = bytearray(10 * 1024 * 1024) + mock_runtime = Mock() + mock_runtime.put.return_value = 1 + mock_get_runtime.return_value = mock_runtime + yr.apis.set_initialized() + result = yr.apis.put(obj_refs) + self.assertIsInstance(result, ObjectRef) + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_with_object_ref(self, mock_rt): + mock_rt.return_value.put.side_effect = TypeError + obj = ObjectRef(1) + yr.apis.set_initialized() + with self.assertRaises(TypeError): + yr.put(obj) + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_fail(self, mock_rt): + mock_rt.return_value.put.side_effect = RuntimeError("mock error") + with self.assertRaises(RuntimeError): + yr.put(1) + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_null_ptr(self,mock_rt): + mock_rt.return_value.put.side_effect = ValueError("value is None or has zero length") + obj = None + yr.apis.set_initialized() + with self.assertRaises(ValueError): + yr.put(obj) + + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_len_zero_bytes(self,mock_rt): + mock_rt.return_value.put.side_effect = ValueError + obj = bytes(0) + yr.apis.set_initialized() + with self.assertRaises(ValueError): + yr.put(obj) + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_len_zero_bytearray(self,mock_rt): + mock_rt.return_value.put.side_effect = ValueError + obj = bytearray(0) + yr.apis.set_initialized() + with self.assertRaises(ValueError): + yr.put(obj) + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_put_len_zero_memoryview(self, mock_rt): + mock_rt.return_value.put.side_effect = ValueError + o = bytes(0) + obj = memoryview(o) + yr.apis.set_initialized() + with self.assertRaises(ValueError): + yr.put(obj) + + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/api/python/yr/tests/test_cluster_mode_runtime.py b/api/python/yr/tests/test_cluster_mode_runtime.py index 9901fc0..ce6c63d 100644 --- a/api/python/yr/tests/test_cluster_mode_runtime.py +++ b/api/python/yr/tests/test_cluster_mode_runtime.py @@ -62,6 +62,11 @@ class TestClusterModeRuntime(unittest.TestCase): mock_fnruntime.terminate_group.side_effect = RuntimeError("mock exception") mock_fnruntime.exit.side_effect = RuntimeError("mock exception") mock_fnruntime.receive_request_loop.side_effect = RuntimeError("mock exception") + mock_fnruntime.create_stream_producer.return_value = "producer" + mock_fnruntime.create_stream_consumer.return_value = "consumer" + mock_fnruntime.query_global_producers_num.return_value = 10 + mock_fnruntime.query_global_consumers_num.return_value = 10 + mock_fnruntime.delete_stream.side_effect = RuntimeError("mock exception") mock_fnruntime.get_real_instance_id.return_value = "get_real_instance_id" mock_fnruntime.save_real_instance_id.side_effect = RuntimeError("mock exception") @@ -78,6 +83,8 @@ class TestClusterModeRuntime(unittest.TestCase): mock_fnruntime.get_value_double_counter.return_value = 0.1 mock_fnruntime.get_value_uint64_counter.return_value = 1 + mock_fnruntime.peek_object_ref_stream.return_value = "peek_object_ref_stream" + mock_fnruntime.generate_group_name.return_value = "generate_group_name" mock_fnruntime.get_instances.return_value = ["instance1"] mock_fnruntime.resources.return_value = "resources" @@ -119,6 +126,29 @@ class TestClusterModeRuntime(unittest.TestCase): self.runtime.libruntime.save_state.assert_called_once_with(10 * 1000) self.runtime.libruntime.load_state.assert_called_once_with(10 * 1000) + def test_stream(self): + with self.assertRaises(RuntimeError): + cfg = yr.ProducerConfig() + cfg.max_stream_size = -1 + self.runtime.create_stream_producer("stream", cfg) + with self.assertRaises(RuntimeError): + cfg = yr.ProducerConfig() + cfg.retain_for_num_consumers = -1 + self.runtime.create_stream_producer("stream", cfg) + with self.assertRaises(RuntimeError): + cfg = yr.ProducerConfig() + cfg.reserve_size = -1 + self.runtime.create_stream_producer("stream", cfg) + + self.assertEqual(self.runtime.create_stream_producer("", yr.ProducerConfig()), "producer") + self.assertEqual(self.runtime.create_stream_consumer("", yr.ProducerConfig()), "consumer") + + self.assertEqual(self.runtime.query_global_producers_num(""), 10) + self.assertEqual(self.runtime.query_global_consumers_num(""), 10) + + with self.assertRaises(RuntimeError): + self.runtime.delete_stream("") + def test_instance(self): self.assertEqual(self.runtime.get_real_instance_id(""), "get_real_instance_id") self.assertEqual(self.runtime.get_instances("", "")[0], "instance1") @@ -160,6 +190,25 @@ class TestClusterModeRuntime(unittest.TestCase): self.runtime.increase_global_reference([]) self.runtime.decrease_global_reference([]) + def test_invoke(self): + with self.assertRaises(RuntimeError): + meta = FunctionMeta() + meta.apiType = ApiType.Faas + args = [ObjectRef(""), "abc"] + self.runtime.invoke_by_name(meta, args, None, 1, None) + + self.assertEqual(self.runtime.create_instance(None, [], None), "instance") + self.assertEqual(self.runtime.invoke_instance(None, "", [], None, 1)[0], "invoke_instance") + + with self.assertRaises(RuntimeError): + self.runtime.cancel([], False, False) + self.assertTrue(self.runtime.is_object_existing_in_local("")) + + self.assertEqual(self.runtime.resources(), "resources") + self.assertEqual(self.runtime.get_function_group_context(), "get_function_group_context") + self.assertEqual(self.runtime.get_instance_by_name( + "", "", 1), "get_instance_by_name") + def test_metrics(self): with self.assertRaises(RuntimeError): self.runtime.set_uint64_counter(None) @@ -188,6 +237,10 @@ class TestClusterModeRuntime(unittest.TestCase): self.assertEqual(self.runtime.get_value_uint64_counter(None), 1) self.assertEqual(self.runtime.get_value_double_counter(None), 0.1) + def test_generator(self): + self.assertEqual(self.runtime.peek_object_ref_stream(""), "peek_object_ref_stream") + self.assertEqual(self.runtime.generate_group_name(), "generate_group_name") + def test_private(self): with self.assertRaises(RuntimeError): self.runtime.finalize() diff --git a/api/python/yr/tests/test_code_manager.py b/api/python/yr/tests/test_code_manager.py index f1ac22d..82b8e03 100644 --- a/api/python/yr/tests/test_code_manager.py +++ b/api/python/yr/tests/test_code_manager.py @@ -41,13 +41,20 @@ class TestCodeManager(TestCase): self.cm.load_functions([path]) mock_sys_path.insert.assert_called_once_with(0, path) + @mock.patch("yr.log.get_logger") + def test_load_functions_when_input_invalid_faas_entry(self, mock_logger): + mock_logger.return_value = logger + self.cm.custom_handler = "/tmp" + err = self.cm.load_functions(["test.init", "test.handler"]) + assert err.error_code == ErrorCode.ERR_USER_CODE_LOAD + @mock.patch.object(CodeManager(), 'load_code_from_local') @mock.patch("yr.log.get_logger") def test_load_functions_when_user_code_syntax_err(self, mock_logger, mock_load_code_from_local): mock_logger.return_value = logger mock_load_code_from_local.side_effect = SyntaxError("a syntax error in user code") err = CodeManager().load_functions(["test.init", "test.handler"]) - self.assertTrue(err.error_code == ErrorCode.ERR_OK) + assert err.error_code == ErrorCode.ERR_USER_CODE_LOAD @mock.patch("yr.log.get_logger") def test_entry_load(self, mock_logger): diff --git a/api/python/yr/tests/test_executor.py b/api/python/yr/tests/test_executor.py index 04d82e7..3fcfa4c 100644 --- a/api/python/yr/tests/test_executor.py +++ b/api/python/yr/tests/test_executor.py @@ -21,8 +21,10 @@ from unittest.mock import Mock, patch import yr from yr.executor.posix_handler import PosixHandler +from yr.executor.faas_handler import FaasHandler from yr.executor.function_handler import FunctionHandler from yr.executor.executor import INIT_HANDLER, Executor +import yr.executor.faas_executor as faas from yr.libruntime_pb2 import FunctionMeta, LanguageType, InvokeType, ApiType logger = logging.getLogger(__name__) @@ -42,7 +44,7 @@ class TestExecutor(TestCase): functionName = "add", language=LanguageType.Python, codeID="123456", - apiType=ApiType.Function, + apiType=ApiType.Faas, signature="", name="", ns="", @@ -59,11 +61,33 @@ class TestExecutor(TestCase): from yr.executor.executor import HANDLER self.assertTrue(isinstance(HANDLER, FunctionHandler), f"Failed to load executor, HANDLER type is {type(HANDLER)}") + os.environ[INIT_HANDLER] = "faas_executor.init" + Executor.load_handler() + from yr.executor.executor import HANDLER + self.assertTrue(isinstance(HANDLER, FaasHandler), f"Failed to load executor, HANDLER type is {type(HANDLER)}") + os.environ[INIT_HANDLER] = "posix.init" Executor.load_handler() from yr.executor.executor import HANDLER self.assertTrue(isinstance(HANDLER, PosixHandler), f"Failed to load executor, HANDLER type is {type(HANDLER)}") + @patch("yr.log.get_logger") + def test_executor(self, mock_logger): + mock_logger.return_value = logger + e = Executor(self.function_meta, [], InvokeType.CreateInstanceStateless, 1, None, False) + _, err = e.execute() + + self.assertTrue("faaS failed" in err.msg) + + e = Executor(self.function_meta, [], InvokeType.InvokeFunctionStateless, 1, None, False) + res, _ = e.execute() + self.assertTrue("faas executor find empty user call code" in res[0]) + + e = Executor(self.function_meta, [], 10, 1, None, False) + _, err = e.execute() + self.assertTrue("invalid invoke type" in err.msg) + + if __name__ == "__main__": main() diff --git a/api/python/yr/tests/test_faas_handler.py b/api/python/yr/tests/test_faas_handler.py new file mode 100644 index 0000000..aa58e8a --- /dev/null +++ b/api/python/yr/tests/test_faas_handler.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import json +from unittest import TestCase, main +from unittest.mock import Mock, patch +import yr.executor.faas_executor as faas +from yr.code_manager import CodeManager + +from yr.libruntime_pb2 import FunctionMeta, LanguageType, InvokeType, ApiType +from yr.functionsdk.context import load_context_meta +from yr.err_type import ErrorCode + +logger = logging.getLogger(__name__) + + +class TestFaasExecutor(TestCase): + + @patch("yr.log.get_logger") + @patch("yr.executor.faas_executor.parse_faas_param") + @patch.object(CodeManager(), 'get_code_path') + @patch.object(CodeManager(), 'load') + def test_parse_faas_param(self, mock_load, mock_get_code_path, parse_faas_param, mock_logger): + mock_logger.return_value = logger + mock_get_code_path.return_value = "code_path" + + def user_init_func(context): + return {"body": "hello world"} + + mock_load.return_value = user_init_func + parse_faas_param.return_value = {"funcMetaData": {"handler": "user_init_func"}, + "extendedMetaData": {"initializer": {"initializer_timeout": "10"}}} + + # case1 init success + + res = faas.faas_init_handler(["arg0"]) + self.assertTrue("success" in res) + + # case2 user function exception + def err_user_init_func(context): + raise RuntimeError("Something is wrong with the world") + + mock_load.return_value = err_user_init_func + try: + res = faas.faas_init_handler(["arg0"]) + except Exception as e: + self.assertTrue("Fail to run user init handler" in str(e), str(e)) + + # case3 args exception + parse_faas_param.side_effect = TypeError("mock type error") + try: + res = faas.faas_init_handler(["arg0"]) + except Exception as e: + self.assertTrue("faas init request args undefined" in str(e), str(e)) + + parse_faas_param.side_effect = json.decoder.JSONDecodeError("mock json error", "", 0) + try: + res = faas.faas_init_handler(["arg0"]) + except Exception as e: + self.assertTrue("faas init request args json decode error" in str(e), str(e)) + parse_faas_param.reset_mock() + + @patch("yr.log.get_logger") + @patch("yr.executor.faas_executor.parse_faas_param") + @patch("yr.executor.faas_executor.get_trace_id_from_params") + @patch.object(CodeManager(), 'load') + def test_faas_call_handler(self, mock_load, get_trace_id_from_params, mock_parse_faas_param, mock_logger): + mock_logger.return_value = logger + mock_parse_faas_param.side_effect = None + get_trace_id_from_params.side_effect = None + + def user_call_func(event, context): + return {"body": "hello world"} + + mock_load.return_value = user_call_func + + mock_parse_faas_param.return_value = {"header": "error_header"} + get_trace_id_from_params.return_value = "traceid" + res = faas.faas_call_handler(["arg0", "arg1"]) + self.assertTrue("header type is not dict" in res) + + mock_parse_faas_param.return_value = {"header": {}, "body": "error_body"} + res = faas.faas_call_handler(["arg0", "arg1"]) + self.assertTrue("failed to loads event body err" in res, res) + + load_context_meta({"funcMetaData": {"timeout": "3"}}) + body = {"name": "world"} + mock_parse_faas_param.return_value = {"header": {}, "body": json.dumps(body)} + res = faas.faas_call_handler(["arg0", "arg1"]) + self.assertTrue("hello world" in res, res) + + def user_call_err_func(event, context): + raise RuntimeError("Something is wrong with the world") + + mock_load.return_value = user_call_err_func + res = faas.faas_call_handler(["arg0", "arg1"]) + self.assertTrue("Fail to run user call handler" in res, res) + + def user_call_err_return_func(event, context): + return {1, 2, 3} + + mock_load.return_value = user_call_err_return_func + res = faas.faas_call_handler(["arg0", "arg1"]) + self.assertTrue("failed to convert the result to a JSON string" in res, res) + + @patch("yr.log.get_logger") + @patch.object(CodeManager(), 'load') + def test_faas_shutdown_handler(self, mock_load, mock_logger): + mock_logger.return_value = logger + load_context_meta({"extendedMetaData": {"pre_stop": {"pre_stop_timeout": 10}}}) + + def user_shutdown_func(): + return + + mock_load.return_value = user_shutdown_func + err = faas.faas_shutdown_handler(0) + self.assertTrue(err.error_code == ErrorCode.ERR_OK, err.error_code) + + def user_shutdown_err_func(): + raise RuntimeError("shutdown exception") + + mock_load.return_value = user_shutdown_err_func + err = faas.faas_shutdown_handler(0) + self.assertTrue(err.error_code == ErrorCode.ERR_USER_FUNCTION_EXCEPTION, err.error_code) + + +if __name__ == "__main__": + main() diff --git a/api/python/yr/tests/test_functionsdk.py b/api/python/yr/tests/test_functionsdk.py new file mode 100644 index 0000000..0514894 --- /dev/null +++ b/api/python/yr/tests/test_functionsdk.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import logging +from logging import handlers +from yr.functionsdk import context +from yr.functionsdk import function +from yr.functionsdk import utils +from yr.functionsdk import logger as sdklogger +from yr.functionsdk import logger_manager +from unittest import TestCase, main +from unittest.mock import Mock, patch +import os +import logging +import queue + + + +logger = logging.getLogger(__name__) + + +class TestFunctionSdk(TestCase): + + @patch("yr.log.get_logger") + def test_context(self, mock_logger): + mock_logger.return_value = logger + + delegate = {"environment": {"LD_LIBRARY_PATH": "default_path", "TEST_ENV_KEY": "test_env_value"}, + "encrypted_user_data": {"LD_LIBRARY_PATH": "user_path", "TEST_USER_KEY": "test_user_value"}} + os.environ["ENV_DELEGATE_DECRYPT"] = json.dumps(delegate) + context_meta = {"funcMetaData": {"timeout": "3"}, "extendedMetaData": {"pre_stop": {"pre_stop_timeout": 10}}} + context.load_context_meta(context_meta) + env_user_data = os.environ.get("RUNTIME_USERDATA") + user_data = json.loads(env_user_data) + path = user_data.get("LD_LIBRARY_PATH", "") + self.assertTrue("user_path" in path, path) + self.assertTrue("test_env_value" in user_data.get("TEST_ENV_KEY", ""), user_data) + self.assertTrue("test_user_value" in user_data.get("TEST_USER_KEY", ""), user_data) + + header = {"X-Request-Id": "12345"} + invoke_context = context.init_context_invoke("invoke", header) + self.assertEqual(invoke_context.get_trace_id(), "12345") + self.assertEqual(invoke_context.getUserData("TEST_ENV_KEY"), "test_env_value") + + @patch("yr.log.get_logger") + @patch("yr.runtime_holder.global_runtime.get_runtime") + def test_invoke(self, get_runtime, mock_logger): + mock_logger.return_value = logger + + mock_runtime = Mock() + mock_runtime.invoke_by_name.return_value = ["obj_abcd"] + get_runtime.return_value = mock_runtime + + os.environ["RUNTIME_SERVICE_FUNC_VERSION"] = "0@faashello@hello" + + f = function.Function("hello") + opt = function.InvokeOptions(cpu=200, memory=200) + args = {"body": "test"} + obj = f.options(opt).invoke(args) + self.assertEqual(obj.id, "obj_abcd", obj.id) + + args = [1, 2, 3] + with self.assertRaises(TypeError): + f.options(opt).invoke(args) + + args = 'null' + with self.assertRaises(ValueError): + f.options(opt).invoke(args) + + args = 'abcd' + with self.assertRaises(TypeError): + f.options(opt).invoke(args) + + large_body = "a" * (1024 * 1024 * 7) + args = {"body": large_body} + args_str = json.dumps(args) + with self.assertRaises(ValueError): + f.options(opt).invoke(args_str) + + func_name = {} + with self.assertRaises(TypeError): + function.Function(func_name) + + @patch("yr.log.get_logger") + @patch("yr.runtime_holder.global_runtime.get_runtime") + def test_invoke_alias(self, get_runtime, mock_logger): + mock_logger.return_value = logger + + mock_runtime = Mock() + mock_runtime.invoke_by_name.return_value = ["obj_abcd"] + get_runtime.return_value = mock_runtime + + os.environ["RUNTIME_SERVICE_FUNC_VERSION"] = "0@faashello@hello" + + f = function.Function("hello:alias") + opt = function.InvokeOptions(cpu=200, memory=200, alias_params={"a": "b"}) + args = {"body": "test"} + obj = f.options(opt).invoke(args) + self.assertEqual(obj.id, "obj_abcd", obj.id) + + args = [1, 2, 3] + with self.assertRaises(TypeError): + f.options(opt).invoke(args) + + args = 'null' + with self.assertRaises(ValueError): + f.options(opt).invoke(args) + + args = 'abcd' + with self.assertRaises(TypeError): + f.options(opt).invoke(args) + + large_body = "a" * (1024 * 1024 * 7) + args = {"body": large_body} + args_str = json.dumps(args) + with self.assertRaises(ValueError): + f.options(opt).invoke(args_str) + + func_name = {} + with self.assertRaises(TypeError): + function.Function(func_name) + + @patch("yr.log.get_logger") + def test_private_function(self, mock_logger): + mock_logger.return_value = logger + + func_name = "hello:latest" + name, version = function._check_function_name(func_name) + self.assertEqual(name, "hello", name) + self.assertEqual(version, "latest", version) + + func_name = "hello:!alise" + name, version = function._check_function_name(func_name) + self.assertEqual(name, "hello", name) + self.assertEqual(version, "!alise", version) + + large_body = "1" * (1024) + func_name = f"hello:!{large_body}" + with self.assertRaises(ValueError): + function._check_function_name(func_name) + + del os.environ["RUNTIME_SERVICE_FUNC_VERSION"] + with self.assertRaises(RuntimeError): + function._get_service_name_from_env() + + os.environ["RUNTIME_SERVICE_FUNC_VERSION"] = "faashello" + with self.assertRaises(ValueError): + function._get_service_name_from_env() + + os.environ["RUNTIME_SERVICE_FUNC_VERSION"] = "0@faashello@hello" + service_name = function._get_service_name_from_env() + self.assertEqual(service_name, "faashello", service_name) + + reg = r'^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$' + with self.assertRaises(ValueError): + function._check_reg_length("aaa", reg, 1) + + with self.assertRaises(ValueError): + function._check_reg_length("_*aaa", reg, 10) + + def test_utils(self): + class SampleClass: + def __init__(self, value): + self.value = value + + obj = SampleClass(10) + expected = {"value": 10} + self.assertEqual(utils.convert_obj_to_json(obj), expected) + + obj = {"key": "value"} + expected = json.dumps(obj) + self.assertEqual(utils.to_json_string(obj), expected) + + obj = {"b": 2, "a": 1} + expected = json.dumps(obj, indent=4, sort_keys=True) + self.assertEqual(utils.to_json_string(obj, indent=4, sort_keys=True), expected) + + request_id = utils.generate_request_id() + self.assertTrue(request_id.startswith("task-")) + self.assertEqual(len(request_id), len("task-") + 36) + + trace_id = utils.generate_trace_id() + self.assertTrue(trace_id.startswith("trace-")) + self.assertEqual(len(trace_id), len("trace-") + 36) + + trace_id = "test-id" + utils.set_trace_id(trace_id) + self.assertTrue("trace" in utils.get_trace_id()) + + json_str = '{"key": "value"}' + expected = {"key": "value"} + self.assertEqual(utils.parse_json_data_to_dict(json_str), expected) + + json_bytes = b'{"key": "value"}' + self.assertEqual(utils.parse_json_data_to_dict(json_bytes), expected) + + json_bytearray = bytearray(b'{"key": "value"}') + self.assertEqual(utils.parse_json_data_to_dict(json_bytearray), expected) + + self.assertEqual(utils.parse_json_data_to_dict(""), {}) + self.assertEqual(utils.parse_json_data_to_dict({}), {}) + self.assertEqual(utils.parse_json_data_to_dict(None), {}) + + with self.assertRaises(RuntimeError): + utils.parse_json_data_to_dict("invalid json") + + obj = {"value": 10} + expected = json.dumps(obj) + self.assertEqual(utils.dump_data_to_json_str(obj), expected) + + expected = json.dumps({"value": 10}) + self.assertEqual(utils.dump_data_to_json_str(obj), expected) + + with self.assertRaises(RuntimeError): + utils.dump_data_to_json_str(object()) + + def test_utils_timeout(self): + import time + + @utils.timeout(1, 1 ) + def test_timeout(t): + time.sleep(t) + print("finished sleep", t) + return "hello" + + res = test_timeout(0) + self.assertEqual(res, "hello", res) + with self.assertRaises(TimeoutError): + test_timeout(2) + + def test_sdklogger(self): + print("start test get_user_function_logger 1") + log = sdklogger.get_user_function_logger(logging.DEBUG) + self.assertEqual(log.name, "user-function", log.name) + + def test_sdklogger_mgr(self): + print("start test faaslog 1") + log1 = logger_manager.Log(logging.DEBUG, "test") + self.assertEqual(log1.msg, "test", log1.msg) + + print("start test faaslog") + log2 = logger_manager.FaasLogger() + q = queue.Queue(maxsize=10) + log2.set_queue(q) + log2.log(logging.DEBUG, "testlog") + get_msg = q.get(block=False).msg + self.assertEqual(get_msg, "testlog", get_msg) + + log2.debug("debugtest") + get_msg = q.get(block=False).msg + self.assertEqual(get_msg, "debugtest", get_msg) + + log2.error("errlog") + get_msg = q.get(block=False).msg + self.assertEqual(get_msg, "errlog", get_msg) + + log2.warning("warning") + get_msg = q.get(block=False).msg + self.assertEqual(get_msg, "warning", get_msg) + + log2.info("infolog") + get_msg = q.get(block=False).msg + self.assertEqual(get_msg, "infolog", get_msg) + + def test_UserLogManager(self): + user_logger = logger_manager.UserLogManager() + cfg = {"log_level":logging.DEBUG, "tenant_id":"test_tenant", "function_name":"", "version":"", "package":"", "stream":"", "instance_id":""} + user_logger.load_logger_config(cfg) + user_logger.set_stage("test_stage") + #user_logger.start_user_log() + user_logger.insert_start_log() + user_logger.insert_end_log() + self.assertEqual(user_logger.tenant_id, "test_tenant", user_logger.tenant_id) + user_logger.shutdown() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/api/python/yr/tests/test_generator.py b/api/python/yr/tests/test_generator.py new file mode 100644 index 0000000..ffd7273 --- /dev/null +++ b/api/python/yr/tests/test_generator.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import unittest +import logging +from unittest.mock import Mock, patch +from concurrent.futures import Future + +import yr +from yr.fnruntime import GeneratorEndError +from yr.generator import ObjectRefGenerator +from yr.object_ref import ObjectRef + + +logger = logging.getLogger(__name__) + +class testObjectRef(unittest.TestCase): + + def test_obj(self): + obj = ObjectRef(object_id="object_1234", task_id="task_abcd") + self.assertEqual(obj.task_id, "task_abcd") + self.assertEqual(obj.id, "object_1234") + with patch.object(obj, 'on_complete') as mock_on_complete: + mock_on_complete.return_value = "on_complete" + f = obj.get_future() + self.assertTrue(isinstance(f, Future)) + + with patch.object(obj, 'get_future') as mock_get_future: + f = Future() + f.set_result("result") + mock_get_future.return_value = f + obj.wait() + self.assertTrue(obj.done()) + + with patch.object(obj, 'get_future') as mock_get_future: + f = Future() + f.set_exception(RuntimeError("mock exception")) + mock_get_future.return_value = f + res = obj.is_exception() + self.assertTrue(res, res) + + with patch.object(obj, 'get_future') as mock_get_future: + f = Future() + mock_get_future.return_value = f + obj.cancel() + with self.assertRaises(Exception): + f.result() + + + def test_exception(self): + obj = ObjectRef("object_1234") + obj.set_exception(RuntimeError("mock exception")) + with self.assertRaises(RuntimeError): + obj.exception() + + with self.assertRaises(RuntimeError): + obj.get() + + f = obj.get_future() + self.assertTrue(f.exception() is not None) + + def test_data(self): + obj = ObjectRef("object_1234") + obj.set_data("data") + self.assertEqual(obj.get_future().result(), "data") + + @patch('yr.runtime_holder.global_runtime.get_runtime') + @patch("yr.log.get_logger") + def test_get(self, mock_logger, get_runtime): + mock_logger.return_value = logger + mock_runtime = Mock() + mock_runtime.get.return_value = ["data1"] + get_runtime.return_value = mock_runtime + + obj = ObjectRef("object_1234") + with self.assertRaises(ValueError): + obj.get(-2) + + self.assertEqual(obj.get(), "data1") + + +class TestObjectRefGenerator(unittest.TestCase): + + def setUp(self): + pass + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_iter(self, get_runtime): + mock_runtime = Mock() + mock_runtime.peek_object_ref_stream.return_value = 'test_object_id' + get_runtime.return_value = mock_runtime + generator_id = 'test_id' + generator = ObjectRefGenerator( + ObjectRef(generator_id, need_incre=False)) + self.assertEqual(iter(generator), generator) + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_next_sync(self, get_runtime): + mock_runtime = Mock() + mock_runtime.peek_object_ref_stream.return_value = 'test_object_id' + get_runtime.return_value = mock_runtime + generator_id = 'test_id' + generator = ObjectRefGenerator( + ObjectRef(generator_id, need_incre=False)) + self.assertEqual(generator._next_sync().id, 'test_object_id') + print("finished test_next_sync") + + @patch('yr.runtime_holder.global_runtime.get_runtime') + def test_next_sync_with_exception(self, get_runtime): + mock_runtime = Mock() + mock_runtime.peek_object_ref_stream.side_effect = GeneratorEndError("failed") + get_runtime.return_value = mock_runtime + generator_id = 'test_id' + generator = ObjectRefGenerator( + ObjectRef(generator_id, need_incre=False)) + self.assertRaises(StopIteration, generator._next_sync) + print("finished test_next_sync") + + +if __name__ == '__main__': + unittest.main() diff --git a/api/python/yr/tests/test_local_mode.py b/api/python/yr/tests/test_local_mode.py index 2dafc66..f18713d 100644 --- a/api/python/yr/tests/test_local_mode.py +++ b/api/python/yr/tests/test_local_mode.py @@ -117,6 +117,21 @@ class TestApi(TestCase): lr.finalize() + with self.assertRaises(RuntimeError): + lr.create_stream_producer("stream", None) + + with self.assertRaises(RuntimeError): + lr.create_stream_consumer("stream", None) + + with self.assertRaises(RuntimeError): + lr.delete_stream("stream") + + with self.assertRaises(RuntimeError): + lr.query_global_producers_num("stream") + + with self.assertRaises(RuntimeError): + lr.query_global_consumers_num("stream") + with self.assertRaises(RuntimeError): lr.save_state(1) diff --git a/api/python/yr/tests/test_runtime_env.py b/api/python/yr/tests/test_runtime_env.py new file mode 100644 index 0000000..7636370 --- /dev/null +++ b/api/python/yr/tests/test_runtime_env.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# coding=UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import unittest +from unittest.mock import patch, Mock + +from yr import runtime_env, InvokeOptions + + +class TestPut(unittest.TestCase): + + def tearDown(self): + if "YR_CONDA_HOME" in os.environ: + os.environ.pop("YR_CONDA_HOME") + + def test_runtime_env_pip_succeed(self): + opt = InvokeOptions() + opt.runtime_env["pip"] = ["numpy==1.24", "scipy==1.11"] + create_opt = runtime_env.parse_runtime_env(opt) + self.assertEqual(create_opt["POST_START_EXEC"], "pip3 install numpy==1.24 scipy==1.11") + + def test_runtime_env_pip_failed(self): + opt = InvokeOptions() + opt.runtime_env["pip"] = ["numpy==1.24"] + opt.runtime_env["conda"] = "test" + with self.assertRaises(ValueError) as context: + _ = runtime_env.parse_runtime_env(opt) + + def test_runtime_env_conda_succeed(self): + os.environ["YR_CONDA_HOME"] = "/tmp" + opt = InvokeOptions() + opt.runtime_env["conda"] = "test" + create_opt = runtime_env.parse_runtime_env(opt) + self.assertEqual(create_opt["CONDA_COMMAND"], "conda activate test") + self.assertEqual(create_opt["CONDA_DEFAULT_ENV"], "test") + self.assertEqual(create_opt[runtime_env.CONDA_PREFIX], "/tmp") + + with tempfile.NamedTemporaryFile(mode='w+t', delete=True, suffix='.yaml') as temp_file: + temp_file.write("name: test\n") + temp_file.seek(0) + opt.runtime_env["conda"] = temp_file.name + create_opt = runtime_env.parse_runtime_env(opt) + self.assertEqual(json.loads(create_opt["CONDA_CONFIG"]), {"name": "test"}) + self.assertEqual(create_opt["CONDA_COMMAND"], "conda env create -f env.yaml") + self.assertEqual(create_opt["CONDA_DEFAULT_ENV"], "test") + + opt.runtime_env["conda"] = {"name": "test"} + create_opt = runtime_env.parse_runtime_env(opt) + self.assertEqual(json.loads(create_opt["CONDA_CONFIG"]), {"name": "test"}) + self.assertEqual(create_opt["CONDA_COMMAND"], "conda env create -f env.yaml") + self.assertEqual(create_opt["CONDA_DEFAULT_ENV"], "test") + + def test_runtime_env_conda_failed(self): + opt = InvokeOptions() + opt.runtime_env["conda"] = "test" + with self.assertRaises(ValueError) as context: + _ = runtime_env.parse_runtime_env(opt) + + os.environ["YR_CONDA_HOME"] = "/tmp" + + opt.runtime_env["conda"] = [] + with self.assertRaises(TypeError) as context: + _ = runtime_env.parse_runtime_env(opt) + + opt.runtime_env["conda"] = "/tmp/aaa.yaml" + with self.assertRaises(ValueError) as context: + _ = runtime_env.parse_runtime_env(opt) + + opt.runtime_env["conda"] = {"aaa": 1+2j} + with self.assertRaises(ValueError) as context: + _ = runtime_env.parse_runtime_env(opt) + + with tempfile.NamedTemporaryFile(mode='w+t', delete=True, suffix='.yaml') as temp_file: + temp_file.write("name:x\nbb") + temp_file.seek(0) + opt.runtime_env["conda"] = temp_file.name + with self.assertRaises(ValueError) as context: + _ = runtime_env.parse_runtime_env(opt) + + def test_runtime_env_working_dir_succeed(self): + opt = InvokeOptions() + opt.runtime_env["working_dir"] = "aaa" + create_opt = runtime_env.parse_runtime_env(opt) + self.assertEqual(create_opt[runtime_env.WORKING_DIR_KEY], "aaa") + + def test_runtime_env_working_dir_failed(self): + opt = InvokeOptions() + opt.runtime_env["working_dir"] = 1 + with self.assertRaises(TypeError) as context: + _ = runtime_env.parse_runtime_env(opt) + + def test_runtime_env_env_vars_succeed(self): + opt = InvokeOptions() + opt.runtime_env["env_vars"] = {"key": "value"} + _ = runtime_env.parse_runtime_env(opt) + self.assertEqual(opt.env_vars, {"key": "value"}) + + def test_runtime_env_env_vars_failed(self): + opt = InvokeOptions() + opt.runtime_env["env_vars"] = 1 + with self.assertRaises(TypeError) as context: + _ = runtime_env.parse_runtime_env(opt) + + opt.runtime_env["env_vars"] = {"key": 1} + with self.assertRaises(TypeError) as context: + _ = runtime_env.parse_runtime_env(opt) + + opt.runtime_env["env_vars"] = {1: "value"} + with self.assertRaises(TypeError) as context: + _ = runtime_env.parse_runtime_env(opt) + + def test_runtime_env_shared_dir_succeed(self): + opt = InvokeOptions() + opt.runtime_env["shared_dir"] = {"name": "abc", "TTL": 1} + create_opt = runtime_env.parse_runtime_env(opt) + self.assertEqual(create_opt["DELEGATE_SHARED_DIRECTORY"], "abc") + self.assertEqual(create_opt["DELEGATE_SHARED_DIRECTORY_TTL"], "1") + + opt.runtime_env["shared_dir"] = {"name": "abc"} + create_opt = runtime_env.parse_runtime_env(opt) + self.assertEqual(create_opt["DELEGATE_SHARED_DIRECTORY"], "abc") + self.assertEqual(create_opt["DELEGATE_SHARED_DIRECTORY_TTL"], "0") + + def test_runtime_env_shared_dir_failed(self): + opt = InvokeOptions() + opt.runtime_env["shared_dir"] = 1 + with self.assertRaises(TypeError) as context: + _ = runtime_env.parse_runtime_env(opt) + + opt.runtime_env["shared_dir"] = {} + with self.assertRaises(ValueError) as context: + _ = runtime_env.parse_runtime_env(opt) + + opt.runtime_env["shared_dir"] = {"name": 1} + with self.assertRaises(TypeError) as context: + _ = runtime_env.parse_runtime_env(opt) + + opt.runtime_env["shared_dir"] = {"TTL": "aaa"} + with self.assertRaises(ValueError) as context: + _ = runtime_env.parse_runtime_env(opt) + + +if __name__ == '__main__': + unittest.main() diff --git a/api/python/yr/tests/test_serialization.py b/api/python/yr/tests/test_serialization.py index e9140d5..68d66d1 100644 --- a/api/python/yr/tests/test_serialization.py +++ b/api/python/yr/tests/test_serialization.py @@ -17,6 +17,7 @@ import unittest import pickle from unittest.mock import patch, Mock +from yr.common import constants from yr.serialization import Serialization from yr.fnruntime import write_to_cbuffer import numpy as np @@ -70,5 +71,17 @@ class TestApi(unittest.TestCase): assert ret.sum() == value.sum() pickle.HIGHEST_PROTOCOL = default_highest_protocol + def test_normalize_input(self): + input_one = b"hello, world" + res_one, res_two = Serialization().normalize_input(input_one) + assert res_one == constants.Metadata.BYTES + + input_two = memoryview(input_one) + res_three, res_four = Serialization().normalize_input(input_two) + assert res_three == constants.Metadata.MEMORYVIEW + + input_three = bytearray(input_one) + res_five, res_six = Serialization().normalize_input(input_three) + assert res_five == constants.Metadata.BYTEARRAY if __name__ == "__main__": unittest.main() diff --git a/bazel/local_patched_repository.bzl b/bazel/local_patched_repository.bzl index fc0061e..f826dd6 100644 --- a/bazel/local_patched_repository.bzl +++ b/bazel/local_patched_repository.bzl @@ -11,12 +11,14 @@ def _impl(repository_ctx): result = repository_ctx.execute(["patch", "-N", "-p0", "-i", patch_file]) if result.return_code != 0: fail("Failed to patch (%s): %s, %s" % (result.return_code, result.stderr, result.stdout)) - + if repository_ctx.attr.build_file: + repository_ctx.symlink(repository_ctx.attr.build_file, repository_ctx.path("BUILD")) local_patched_repository = repository_rule( implementation=_impl, attrs={ "path": attr.string(mandatory=True), - "patch_files": attr.label_list(allow_files=True) + "patch_files": attr.label_list(allow_files=True), + "build_file": attr.label(allow_single_file=True), }, local = True) \ No newline at end of file diff --git a/bazel/metrics_sdk.bzl b/bazel/metrics_sdk.bzl index f973dc8..ef0ac3f 100644 --- a/bazel/metrics_sdk.bzl +++ b/bazel/metrics_sdk.bzl @@ -8,7 +8,6 @@ cc_library( "lib/liblitebus.so.0.0.1", "lib/libyrlogs.so", "lib/libspdlog.so.1.*", - "lib/libspdlog.so.1.*.0", ]), hdrs = glob(["include/metrics/**/*.h"]), strip_include_prefix = "include", diff --git a/bazel/openssl.bazel b/bazel/openssl.bazel index d465dbc..128cfc5 100644 --- a/bazel/openssl.bazel +++ b/bazel/openssl.bazel @@ -25,13 +25,13 @@ MAKE_TARGETS = [ alias( name = "ssl", - actual = "openssl", + actual = "boringssl_sdk", visibility = ["//visibility:public"], ) alias( name = "crypto", - actual = "openssl", + actual = "boringssl_sdk", visibility = ["//visibility:public"], ) @@ -57,13 +57,6 @@ filegroup( visibility = ["//visibility:public"], ) -filter_files_with_suffix( - name = "shared", - srcs = [":openssl"], - suffix = ".so", - visibility = ["//visibility:public"], -) - cc_library( name = "boringssl_sdk", hdrs = glob(["install/include/**/*.h"]), @@ -75,3 +68,10 @@ cc_library( visibility = ["//visibility:public"], alwayslink = True, ) + +filter_files_with_suffix( + name = "shared", + srcs = glob(["install/lib/lib*.so*"]), + suffix = ".so", + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/bazel/preload_opentelemetry.bzl b/bazel/preload_opentelemetry.bzl new file mode 100644 index 0000000..106e05e --- /dev/null +++ b/bazel/preload_opentelemetry.bzl @@ -0,0 +1,27 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def preload_opentelemetry(): + http_archive( + name = "com_github_opentelemetry_proto", + sha256 = "df491a05f3fcbf86cc5ba5c9de81f6a624d74d4773d7009d573e37d6e2b6af64", + strip_prefix = "opentelemetry-proto-1.1.0", + urls = ["https://github.com/open-telemetry/opentelemetry-proto/archive/refs/tags/v1.1.0.tar.gz"], + build_file = "@opentelemetry_cpp//bazel:opentelemetry_proto.BUILD", + ) + + http_archive( + name = "com_github_jupp0r_prometheus_cpp", + sha256 = "8104d3b216aae60a1d0bca04adea4ba9ac1748eb1ed8646e123cf8e1591d99a3", + strip_prefix = "prometheus-cpp-1.1.0", + urls = ["https://github.com/jupp0r/prometheus-cpp/archive/refs/tags/v1.1.0.zip"], + ) + + http_archive( + name = "github_nlohmann_json", + build_file = "@opentelemetry_cpp//bazel:nlohmann_json.BUILD", + sha256 = "0deac294b2c96c593d0b7c0fb2385a2f4594e8053a36c52b11445ef4b9defebb", + strip_prefix = "nlohmann-json-v3.11.3", + urls = ["https://gitee.com/mirrors/nlohmann-json/repository/archive/v3.11.3.zip"], + ) + + diff --git a/bazel/yr_go.bzl b/bazel/yr_go.bzl new file mode 100644 index 0000000..8cef7d7 --- /dev/null +++ b/bazel/yr_go.bzl @@ -0,0 +1,51 @@ +go_packages = [ + "runtime_go", +] + +def yr_go_test(name, sanitizer): + asan_options="" + cgo_ldflags="" + coverage_suffix = "_coverage.out" + test_result_out = "test_result.out" + if sanitizer != "off": + cgo_ldflags = "-fsanitize={}".format(sanitizer) + coverage_suffix = "_coverage_{}.out".format(sanitizer) + test_result_out = "test_result_{}.out".format(sanitizer) + if sanitizer == "address": + asan_options= "detect_odr_violation=0" + + native.genrule( + name = name, + srcs = [ + "//api/go/libruntime/cpplibruntime:libcpplibruntime.so", + "@datasystem_sdk//:shared", + "@metrics_sdk//:shared", + ":go_sources", + ], + outs = [package + coverage_suffix for package in go_packages] + [test_result_out], + cmd = r""" + BASE_DIR="$$(pwd)" && + TEST_RESULT_OUT=$$BASE_DIR/$(location {test_result_out}) && + CGO_LINKDIR=$$(dirname $$BASE_DIR/$(location //api/go/libruntime/cpplibruntime:libcpplibruntime.so)) && + DATASYSTEM_DIR=$$(dirname $$BASE_DIR/$(locations @datasystem_sdk//:shared) | head -1) && + METRICS_DIR=$$(dirname $$BASE_DIR/$(locations @metrics_sdk//:shared) | head -1) && + export LD_LIBRARY_PATH=$$LD_LIBRARY_PATH:$$DATASYSTEM_DIR:$$METRICS_DIR:$$CGO_LINKDIR && + export CGO_LDFLAGS="{cgo_ldflags} -L$$CGO_LINKDIR -lcpplibruntime" && + export ASAN_OPTIONS={asan_options} && + OUT_DIR=$$(dirname $$TEST_RESULT_OUT) && + for package in {all_packages}; do + cd $$(realpath $$BASE_DIR/api/go) && + go mod tidy && + go test ./libruntime/... ./faassdk/... ./posixsdk/... ./yr/... -covermode=set -coverprofile=$$OUT_DIR/"$$package"{suffix} -cover -gcflags=all=-l -vet=off || exit 2 + done && + echo "$$OUT_DIR" > $$TEST_RESULT_OUT + """.format( + cgo_ldflags = cgo_ldflags, + asan_options = asan_options, + test_result_out = test_result_out, + all_packages = " ".join(go_packages), + suffix = coverage_suffix, + ), + local = True, + visibility = ["//visibility:public"], + ) \ No newline at end of file diff --git a/build.sh b/build.sh index 527a544..ad9f459 100644 --- a/build.sh +++ b/build.sh @@ -63,9 +63,9 @@ OUTPUT_DIR="${BASE_DIR}/output" OUTPUT_BASE="${BASE_DIR}/build/output" BAZEL_COMMAND="build" BUILD_VERSION="v0.0.1" -BAZEL_OPTIONS="--experimental_cc_shared_library=true --verbose_failures --strategy=CcStrip=standalone" +BAZEL_OPTIONS="--experimental_cc_shared_library=true --verbose_failures --strategy=CcStrip=standalone --@opentelemetry_cpp//api:with_abseil=true" BAZEL_OPTIONS_CONFIG=" --config=release " -BAZEL_TARGETS="//api/cpp:yr_cpp_pkg //api/java:yr_java_pkg //api/python:yr_python_pkg" +BAZEL_TARGETS="//api/cpp:yr_cpp_pkg //api/java:yr_java_pkg //api/python:yr_python_pkg //api/go:yr_go_pkg" BAZEL_PRE_OPTIONS="--output_user_root=${BUILD_BASE} --output_base=${OUTPUT_BASE}" THIRD_PARTY_DIR="$(dirname "$BASE_DIR")/thirdparty" PYTHON3_BIN_PATH="python3.9" @@ -77,6 +77,7 @@ BAZEL_OPTIONS_ENV="" SECBRELLA_CCE="OFF" PACKAGE_ALL="false" LD_LIBRARY_PATH=/opt/buildtools/python3.7/lib:/opt/buildtools/python3.9/lib:/opt/buildtools/python3.11/lib:${LD_LIBRARY_PATH} +BOOST_VERSION="1.87.0" export BUILD_ALL="false" if [ ! -d "${THIRD_PARTY_DIR}" ]; then mkdir -p "${THIRD_PARTY_DIR}" @@ -102,11 +103,40 @@ log_fatal() { exit 1 } +MODULE_LIST=(\ +"runtime_go" +) + PYTHON_VERSION_LIST=(\ "python3.11" \ "python3.10" ) +function go_module_coverage_report() { + COVERAGE_SUFFIX="_coverage.out" + [ "${SANITIZER}" != "off" ] && COVERAGE_SUFFIX="_coverage_${SANITIZER}.out" + MODULE=$1 + MODULE_SOURCE=$(echo "$MODULE" | cut -d '-' -f 1) + pushd ${GO_SRC_BASE} + sed -i "/clibruntime.go/d" ${BASE_DIR}/bazel-bin/go/${MODULE_SOURCE}${COVERAGE_SUFFIX} + gocov convert ${BASE_DIR}/bazel-bin/go/${MODULE_SOURCE}${COVERAGE_SUFFIX} > ${BASE_DIR}/bazel-bin/go/${MODULE_SOURCE}_coverage.json + gocov report ${BASE_DIR}/bazel-bin/go/${MODULE_SOURCE}_coverage.json > ${BASE_DIR}/bazel-bin/go/${MODULE_SOURCE}_coverage.txt + gocov-html ${BASE_DIR}/bazel-bin/go/${MODULE_SOURCE}_coverage.json > ${BASE_DIR}/bazel-bin/go/${MODULE_SOURCE}_coverage.html + popd +} + +function run_go_coverage_report() { + for i in "${!MODULE_LIST[@]}" + do + MODULE_SOURCE=$(echo "${MODULE_LIST[$i]}" | cut -d '-' -f 1) + go_module_coverage_report ${MODULE_LIST[$i]} + coverage_info=$(tail -1 ${BASE_DIR}/bazel-bin/go/${MODULE_SOURCE}_coverage.txt) + ((cov_go_total+=$(echo $coverage_info | awk -F '[()/]' '{print $(NF-1)}'))) + ((cov_go+=$(echo $coverage_info | awk -F '[()/]' '{print $(NF-2)}'))) + done + go_coverage=$(echo "scale=4; $cov_go / $cov_go_total * 100" | bc) +} + function run_java_coverage_report() { rm -rf bazel-testlogs/api/java/liblib_yr_api_sdk/ rm -rf bazel-testlogs/api/java/libfunction_common/ @@ -228,7 +258,7 @@ while getopts 'athr:v:S:DcCgPET:p:bm:j:g' opt; do ;; t) BAZEL_COMMAND="test" - BAZEL_TARGETS="//test/... //api/python/yr/tests/... //api/java:java_tests" + BAZEL_TARGETS="//api/go:yr_go_test //test/... //api/python/yr/tests/... //api/java:java_tests" install_python_requirements ;; T) @@ -261,8 +291,8 @@ while getopts 'athr:v:S:DcCgPET:p:bm:j:g' opt; do ;; c) BAZEL_COMMAND="coverage" - BAZEL_TARGETS="//test/... //api/python/yr/tests/... //api/java:java_tests" - BAZEL_OPTIONS="$BAZEL_OPTIONS --combined_report=lcov --nocache_test_results --instrumentation_filter=^//.*[/:]" + BAZEL_TARGETS="//api/go:yr_go_test //test/... //api/python/yr/tests/..." + BAZEL_OPTIONS="$BAZEL_OPTIONS --combined_report=lcov --nocache_test_results --instrumentation_filter=^//.*[/:] --test_tag_filters=-cgo" install_python_requirements ;; C) @@ -319,7 +349,7 @@ sed -i "s/1.0.0<\/version>/${BUILD_VERSION}<\/version>/" $API_ build_multi_python_version -BAZEL_OPTIONS_ENV="${BAZEL_OPTIONS_ENV} --action_env=BUILD_VERSION=${BUILD_VERSION} --action_env=PYTHON3_BIN_PATH=$PYTHON3_BIN_PATH" +BAZEL_OPTIONS_ENV="${BAZEL_OPTIONS_ENV} --action_env=BOOST_VERSION=$BOOST_VERSION --action_env=GOPATH=$(go env GOPATH) --action_env=GOEXPERIMENT=$(go env GOEXPERIMENT) --action_env=GOCACHE=$(go env GOCACHE) --action_env=BUILD_VERSION=${BUILD_VERSION} --action_env=PYTHON3_BIN_PATH=$PYTHON3_BIN_PATH" BAZEL_OPTIONS="${BAZEL_OPTIONS} ${BAZEL_OPTIONS_CONFIG} ${BAZEL_OPTIONS_ENV}" cd $BASE_DIR @@ -334,11 +364,13 @@ if [ "$BAZEL_COMMAND" == "coverage" ]; then lcov -q -r ${BASE_DIR}/bazel-out/_coverage/_coverage_report.dat '*python*' '*.pb.*' '*test*' -o ${BASE_DIR}/bazel-out/_coverage/_coverage_report.info genhtml -q --ignore-errors source --output genhtml ${BASE_DIR}/bazel-out/_coverage/_coverage_report.info cpp_coverage=$(grep headerCovTableEntryMed genhtml/index.html | head -n 1 | awk -F '>' '{print $2}'| awk -F '<' '{print $1}') + run_go_coverage_report run_java_coverage_report run_python_coverage_report echo "cpp_covearge: $cpp_coverage" >> genhtml/coverage.txt echo "python_covearge: $python_coverage" >> genhtml/coverage.txt echo "java_coverage: $java_coverage%" >> genhtml/coverage.txt + echo "go_coverage: $go_coverage%" >> genhtml/coverage.txt cat genhtml/coverage.txt fi @@ -354,6 +386,6 @@ if [ "$BAZEL_COMMAND" == "build" ]; then fi if [ "$PACKAGE_ALL" == "true" ]; then - bash ${BASE_DIR}/scripts/package.sh -t ${BUILD_VERSION} + bash ${BASE_DIR}/scripts/package.sh -t ${BUILD_VERSION} --python_bin_path ${PYTHON3_BIN_PATH} fi cd - \ No newline at end of file diff --git a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/FunctionHandler-Options.rst b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/FunctionHandler-Options.rst index 7807c12..9def0f2 100644 --- a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/FunctionHandler-Options.rst +++ b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/FunctionHandler-Options.rst @@ -43,7 +43,7 @@ FunctionHandler::Options auto r1 = YR::Function(AddOne).Options(opts).Invoke(5); YR::Get(r1); - opts.retryTime = 1; + opts.retryTimes = 1; opts.retryChecker = RetryChecker; auto r2 = YR::Function(ThrowRuntimeError).Options(opts).Invoke(); YR::Get(r2); diff --git a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/InstanceFunctionHandler-Options.rst b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/InstanceFunctionHandler-Options.rst index 6b183c1..22727a5 100644 --- a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/InstanceFunctionHandler-Options.rst +++ b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/InstanceFunctionHandler-Options.rst @@ -16,7 +16,7 @@ InstanceFunctionHandler::Options YR::Init(conf); YR::InvokeOptions opts; - opts.retryTime = 5; + opts.retryTimes = 5; auto ins = YR::Instance(SimpleCaculator::Constructor).Invoke(); auto r3 = ins.Function(&SimpleCaculator::Plus).Options(opts).Invoke(1, 1); diff --git a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/struct-InvokeOptions.rst b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/struct-InvokeOptions.rst index 8cb651c..eb007bc 100644 --- a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/struct-InvokeOptions.rst +++ b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/struct-InvokeOptions.rst @@ -77,7 +77,7 @@ InvokeOptions 参数 +--------------------------+-----------------------------------------------+-------------------------------------------------------------------------------------------------+------------------------+ | retryTime | size_t | 无状态函数的重试次数。 | ``0`` | +--------------------------+-----------------------------------------------+-------------------------------------------------------------------------------------------------+------------------------+ -| retryChecker | bool (*)(const YR::Exception &e) noexcept | 无状态函数的重试判断钩子,默认空(当 `retryTime = 0` 时无效)。 | - | +| retryChecker | bool (*)(const YR::Exception &e) noexcept | 无状态函数的重试判断钩子,默认空(当 `retryTimes = 0` 时无效)。 | - | +--------------------------+-----------------------------------------------+-------------------------------------------------------------------------------------------------+------------------------+ | priority | size_t | 定义无状态函数的优先级。 | ``0`` | +--------------------------+-----------------------------------------------+-------------------------------------------------------------------------------------------------+------------------------+ diff --git a/docs/multi_language_function_programming_interface/development_guide/data_object/KV.md b/docs/multi_language_function_programming_interface/development_guide/data_object/KV.md index 156c164..3a72821 100644 --- a/docs/multi_language_function_programming_interface/development_guide/data_object/KV.md +++ b/docs/multi_language_function_programming_interface/development_guide/data_object/KV.md @@ -81,7 +81,7 @@ obs: # iamHostName, identityProvider, projectId, regionId at least. # In addition, obsEndpoint and obsBucket need to be specified. enable: false - # Domain name of the IAM token to be obtained. Example: iam.cn-north-7.myhuaweicloud.com. + # Domain name of the IAM token to be obtained. iamHostName: "" # Provider that provides permissions for the ds-worker. Example: csms-datasystem. identityProvider: "" diff --git a/patch/spdlog-change-namespace-and-library-name-with-yr.patch b/patch/spdlog-change-namespace-and-library-name-with-yr.patch new file mode 100644 index 0000000..0a3db0b --- /dev/null +++ b/patch/spdlog-change-namespace-and-library-name-with-yr.patch @@ -0,0 +1,6696 @@ +diff -ruN bench/async_bench.cpp bench/async_bench.cpp +--- bench/async_bench.cpp 2025-07-02 15:16:11.164618390 +0800 ++++ bench/async_bench.cpp 2025-07-02 15:16:19.630618059 +0800 +@@ -27,11 +27,11 @@ + + using namespace std; + using namespace std::chrono; +-using namespace spdlog; +-using namespace spdlog::sinks; ++using namespace yr_spdlog; ++using namespace yr_spdlog::sinks; + using namespace utils; + +-void bench_mt(int howmany, std::shared_ptr log, int thread_count); ++void bench_mt(int howmany, std::shared_ptr log, int thread_count); + + #ifdef _MSC_VER + # pragma warning(push) +@@ -55,14 +55,14 @@ + + void verify_file(const char *filename, int expected_count) + { +- spdlog::info("Verifying {} to contain {} line..", filename, expected_count); ++ yr_spdlog::info("Verifying {} to contain {} line..", filename, expected_count); + auto count = count_lines(filename); + if (count != expected_count) + { +- spdlog::error("Test failed. {} has {} lines instead of {}", filename, count, expected_count); ++ yr_spdlog::error("Test failed. {} has {} lines instead of {}", filename, count, expected_count); + exit(1); + } +- spdlog::info("Line count OK ({})\n", count); ++ yr_spdlog::info("Line count OK ({})\n", count); + } + + #ifdef _MSC_VER +@@ -79,10 +79,10 @@ + + try + { +- spdlog::set_pattern("[%^%l%$] %v"); ++ yr_spdlog::set_pattern("[%^%l%$] %v"); + if (argc == 1) + { +- spdlog::info("Usage: {} ", argv[0]); ++ yr_spdlog::info("Usage: {} ", argv[0]); + return 0; + } + +@@ -95,7 +95,7 @@ + queue_size = atoi(argv[3]); + if (queue_size > 500000) + { +- spdlog::error("Max queue size allowed: 500,000"); ++ yr_spdlog::error("Max queue size allowed: 500,000"); + exit(1); + } + } +@@ -103,44 +103,44 @@ + if (argc > 4) + iters = atoi(argv[4]); + +- auto slot_size = sizeof(spdlog::details::async_msg); +- spdlog::info("-------------------------------------------------"); +- spdlog::info("Messages : {:L}", howmany); +- spdlog::info("Threads : {:L}", threads); +- spdlog::info("Queue : {:L} slots", queue_size); +- spdlog::info("Queue memory : {:L} x {:L} = {:L} KB ", queue_size, slot_size, (queue_size * slot_size) / 1024); +- spdlog::info("Total iters : {:L}", iters); +- spdlog::info("-------------------------------------------------"); ++ auto slot_size = sizeof(yr_spdlog::details::async_msg); ++ yr_spdlog::info("-------------------------------------------------"); ++ yr_spdlog::info("Messages : {:L}", howmany); ++ yr_spdlog::info("Threads : {:L}", threads); ++ yr_spdlog::info("Queue : {:L} slots", queue_size); ++ yr_spdlog::info("Queue memory : {:L} x {:L} = {:L} KB ", queue_size, slot_size, (queue_size * slot_size) / 1024); ++ yr_spdlog::info("Total iters : {:L}", iters); ++ yr_spdlog::info("-------------------------------------------------"); + + const char *filename = "logs/basic_async.log"; +- spdlog::info(""); +- spdlog::info("*********************************"); +- spdlog::info("Queue Overflow Policy: block"); +- spdlog::info("*********************************"); ++ yr_spdlog::info(""); ++ yr_spdlog::info("*********************************"); ++ yr_spdlog::info("Queue Overflow Policy: block"); ++ yr_spdlog::info("*********************************"); + for (int i = 0; i < iters; i++) + { + auto tp = std::make_shared(queue_size, 1); +- auto file_sink = std::make_shared(filename, true); ++ auto file_sink = std::make_shared(filename, true); + auto logger = std::make_shared("async_logger", std::move(file_sink), std::move(tp), async_overflow_policy::block); + bench_mt(howmany, std::move(logger), threads); + // verify_file(filename, howmany); + } + +- spdlog::info(""); +- spdlog::info("*********************************"); +- spdlog::info("Queue Overflow Policy: overrun"); +- spdlog::info("*********************************"); ++ yr_spdlog::info(""); ++ yr_spdlog::info("*********************************"); ++ yr_spdlog::info("Queue Overflow Policy: overrun"); ++ yr_spdlog::info("*********************************"); + // do same test but discard oldest if queue is full instead of blocking + filename = "logs/basic_async-overrun.log"; + for (int i = 0; i < iters; i++) + { + auto tp = std::make_shared(queue_size, 1); +- auto file_sink = std::make_shared(filename, true); ++ auto file_sink = std::make_shared(filename, true); + auto logger = + std::make_shared("async_logger", std::move(file_sink), std::move(tp), async_overflow_policy::overrun_oldest); + bench_mt(howmany, std::move(logger), threads); + } +- spdlog::shutdown(); ++ yr_spdlog::shutdown(); + } + catch (std::exception &ex) + { +@@ -151,7 +151,7 @@ + return 0; + } + +-void thread_fun(std::shared_ptr logger, int howmany) ++void thread_fun(std::shared_ptr logger, int howmany) + { + for (int i = 0; i < howmany; i++) + { +@@ -159,7 +159,7 @@ + } + } + +-void bench_mt(int howmany, std::shared_ptr logger, int thread_count) ++void bench_mt(int howmany, std::shared_ptr logger, int thread_count) + { + using std::chrono::high_resolution_clock; + vector threads; +@@ -182,5 +182,5 @@ + + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); +- spdlog::info("Elapsed: {} secs\t {:L}/sec", delta_d, int(howmany / delta_d)); ++ yr_spdlog::info("Elapsed: {} secs\t {:L}/sec", delta_d, int(howmany / delta_d)); + } +diff -ruN bench/bench.cpp bench/bench.cpp +--- bench/bench.cpp 2025-07-02 15:16:11.164618390 +0800 ++++ bench/bench.cpp 2025-07-02 15:16:19.630618059 +0800 +@@ -27,11 +27,11 @@ + #include + #include + +-void bench(int howmany, std::shared_ptr log); +-void bench_mt(int howmany, std::shared_ptr log, size_t thread_count); ++void bench(int howmany, std::shared_ptr log); ++void bench_mt(int howmany, std::shared_ptr log, size_t thread_count); + +-// void bench_default_api(int howmany, std::shared_ptr log); +-// void bench_c_string(int howmany, std::shared_ptr log); ++// void bench_default_api(int howmany, std::shared_ptr log); ++// void bench_c_string(int howmany, std::shared_ptr log); + + static const size_t file_size = 30 * 1024 * 1024; + static const size_t rotating_files = 5; +@@ -39,81 +39,81 @@ + + void bench_threaded_logging(size_t threads, int iters) + { +- spdlog::info("**************************************************************"); +- spdlog::info(spdlog::fmt_lib::format(std::locale("en_US.UTF-8"), "Multi threaded: {:L} threads, {:L} messages", threads, iters)); +- spdlog::info("**************************************************************"); ++ yr_spdlog::info("**************************************************************"); ++ yr_spdlog::info(yr_spdlog::fmt_lib::format(std::locale("en_US.UTF-8"), "Multi threaded: {:L} threads, {:L} messages", threads, iters)); ++ yr_spdlog::info("**************************************************************"); + +- auto basic_mt = spdlog::basic_logger_mt("basic_mt", "logs/basic_mt.log", true); ++ auto basic_mt = yr_spdlog::basic_logger_mt("basic_mt", "logs/basic_mt.log", true); + bench_mt(iters, std::move(basic_mt), threads); +- auto basic_mt_tracing = spdlog::basic_logger_mt("basic_mt/backtrace-on", "logs/basic_mt.log", true); ++ auto basic_mt_tracing = yr_spdlog::basic_logger_mt("basic_mt/backtrace-on", "logs/basic_mt.log", true); + basic_mt_tracing->enable_backtrace(32); + bench_mt(iters, std::move(basic_mt_tracing), threads); + +- spdlog::info(""); +- auto rotating_mt = spdlog::rotating_logger_mt("rotating_mt", "logs/rotating_mt.log", file_size, rotating_files); ++ yr_spdlog::info(""); ++ auto rotating_mt = yr_spdlog::rotating_logger_mt("rotating_mt", "logs/rotating_mt.log", file_size, rotating_files); + bench_mt(iters, std::move(rotating_mt), threads); +- auto rotating_mt_tracing = spdlog::rotating_logger_mt("rotating_mt/backtrace-on", "logs/rotating_mt.log", file_size, rotating_files); ++ auto rotating_mt_tracing = yr_spdlog::rotating_logger_mt("rotating_mt/backtrace-on", "logs/rotating_mt.log", file_size, rotating_files); + rotating_mt_tracing->enable_backtrace(32); + bench_mt(iters, std::move(rotating_mt_tracing), threads); + +- spdlog::info(""); +- auto daily_mt = spdlog::daily_logger_mt("daily_mt", "logs/daily_mt.log"); ++ yr_spdlog::info(""); ++ auto daily_mt = yr_spdlog::daily_logger_mt("daily_mt", "logs/daily_mt.log"); + bench_mt(iters, std::move(daily_mt), threads); +- auto daily_mt_tracing = spdlog::daily_logger_mt("daily_mt/backtrace-on", "logs/daily_mt.log"); ++ auto daily_mt_tracing = yr_spdlog::daily_logger_mt("daily_mt/backtrace-on", "logs/daily_mt.log"); + daily_mt_tracing->enable_backtrace(32); + bench_mt(iters, std::move(daily_mt_tracing), threads); + +- spdlog::info(""); +- auto empty_logger = std::make_shared("level-off"); +- empty_logger->set_level(spdlog::level::off); ++ yr_spdlog::info(""); ++ auto empty_logger = std::make_shared("level-off"); ++ empty_logger->set_level(yr_spdlog::level::off); + bench(iters, empty_logger); +- auto empty_logger_tracing = std::make_shared("level-off/backtrace-on"); +- empty_logger_tracing->set_level(spdlog::level::off); ++ auto empty_logger_tracing = std::make_shared("level-off/backtrace-on"); ++ empty_logger_tracing->set_level(yr_spdlog::level::off); + empty_logger_tracing->enable_backtrace(32); + bench(iters, empty_logger_tracing); + } + + void bench_single_threaded(int iters) + { +- spdlog::info("**************************************************************"); +- spdlog::info(spdlog::fmt_lib::format(std::locale("en_US.UTF-8"), "Single threaded: {} messages", iters)); +- spdlog::info("**************************************************************"); ++ yr_spdlog::info("**************************************************************"); ++ yr_spdlog::info(yr_spdlog::fmt_lib::format(std::locale("en_US.UTF-8"), "Single threaded: {} messages", iters)); ++ yr_spdlog::info("**************************************************************"); + +- auto basic_st = spdlog::basic_logger_st("basic_st", "logs/basic_st.log", true); ++ auto basic_st = yr_spdlog::basic_logger_st("basic_st", "logs/basic_st.log", true); + bench(iters, std::move(basic_st)); + +- auto basic_st_tracing = spdlog::basic_logger_st("basic_st/backtrace-on", "logs/basic_st.log", true); ++ auto basic_st_tracing = yr_spdlog::basic_logger_st("basic_st/backtrace-on", "logs/basic_st.log", true); + bench(iters, std::move(basic_st_tracing)); + +- spdlog::info(""); +- auto rotating_st = spdlog::rotating_logger_st("rotating_st", "logs/rotating_st.log", file_size, rotating_files); ++ yr_spdlog::info(""); ++ auto rotating_st = yr_spdlog::rotating_logger_st("rotating_st", "logs/rotating_st.log", file_size, rotating_files); + bench(iters, std::move(rotating_st)); +- auto rotating_st_tracing = spdlog::rotating_logger_st("rotating_st/backtrace-on", "logs/rotating_st.log", file_size, rotating_files); ++ auto rotating_st_tracing = yr_spdlog::rotating_logger_st("rotating_st/backtrace-on", "logs/rotating_st.log", file_size, rotating_files); + rotating_st_tracing->enable_backtrace(32); + bench(iters, std::move(rotating_st_tracing)); + +- spdlog::info(""); +- auto daily_st = spdlog::daily_logger_st("daily_st", "logs/daily_st.log"); ++ yr_spdlog::info(""); ++ auto daily_st = yr_spdlog::daily_logger_st("daily_st", "logs/daily_st.log"); + bench(iters, std::move(daily_st)); +- auto daily_st_tracing = spdlog::daily_logger_st("daily_st/backtrace-on", "logs/daily_st.log"); ++ auto daily_st_tracing = yr_spdlog::daily_logger_st("daily_st/backtrace-on", "logs/daily_st.log"); + daily_st_tracing->enable_backtrace(32); + bench(iters, std::move(daily_st_tracing)); + +- spdlog::info(""); +- auto empty_logger = std::make_shared("level-off"); +- empty_logger->set_level(spdlog::level::off); ++ yr_spdlog::info(""); ++ auto empty_logger = std::make_shared("level-off"); ++ empty_logger->set_level(yr_spdlog::level::off); + bench(iters, empty_logger); + +- auto empty_logger_tracing = std::make_shared("level-off/backtrace-on"); +- empty_logger_tracing->set_level(spdlog::level::off); ++ auto empty_logger_tracing = std::make_shared("level-off/backtrace-on"); ++ empty_logger_tracing->set_level(yr_spdlog::level::off); + empty_logger_tracing->enable_backtrace(32); + bench(iters, empty_logger_tracing); + } + + int main(int argc, char *argv[]) + { +- spdlog::set_automatic_registration(false); +- spdlog::default_logger()->set_pattern("[%^%l%$] %v"); ++ yr_spdlog::set_automatic_registration(false); ++ yr_spdlog::default_logger()->set_pattern("[%^%l%$] %v"); + int iters = 250000; + size_t threads = 4; + try +@@ -130,7 +130,7 @@ + + if (threads > max_threads) + { +- throw std::runtime_error(spdlog::fmt_lib::format("Number of threads exceeds maximum({})", max_threads)); ++ throw std::runtime_error(yr_spdlog::fmt_lib::format("Number of threads exceeds maximum({})", max_threads)); + } + + bench_single_threaded(iters); +@@ -139,13 +139,13 @@ + } + catch (std::exception &ex) + { +- spdlog::error(ex.what()); ++ yr_spdlog::error(ex.what()); + return EXIT_FAILURE; + } + return EXIT_SUCCESS; + } + +-void bench(int howmany, std::shared_ptr log) ++void bench(int howmany, std::shared_ptr log) + { + using std::chrono::duration; + using std::chrono::duration_cast; +@@ -160,12 +160,12 @@ + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); + +- spdlog::info(spdlog::fmt_lib::format( ++ yr_spdlog::info(yr_spdlog::fmt_lib::format( + std::locale("en_US.UTF-8"), "{:<30} Elapsed: {:0.2f} secs {:>16L}/sec", log->name(), delta_d, int(howmany / delta_d))); +- spdlog::drop(log->name()); ++ yr_spdlog::drop(log->name()); + } + +-void bench_mt(int howmany, std::shared_ptr log, size_t thread_count) ++void bench_mt(int howmany, std::shared_ptr log, size_t thread_count) + { + using std::chrono::duration; + using std::chrono::duration_cast; +@@ -191,34 +191,34 @@ + + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); +- spdlog::info(spdlog::fmt_lib::format( ++ yr_spdlog::info(yr_spdlog::fmt_lib::format( + std::locale("en_US.UTF-8"), "{:<30} Elapsed: {:0.2f} secs {:>16L}/sec", log->name(), delta_d, int(howmany / delta_d))); +- spdlog::drop(log->name()); ++ yr_spdlog::drop(log->name()); + } + + /* +-void bench_default_api(int howmany, std::shared_ptr log) ++void bench_default_api(int howmany, std::shared_ptr log) + { + using std::chrono::high_resolution_clock; + using std::chrono::duration; + using std::chrono::duration_cast; + +- auto orig_default = spdlog::default_logger(); +- spdlog::set_default_logger(log); ++ auto orig_default = yr_spdlog::default_logger(); ++ yr_spdlog::set_default_logger(log); + auto start = high_resolution_clock::now(); + for (auto i = 0; i < howmany; ++i) + { +- spdlog::info("Hello logger: msg number {}", i); ++ yr_spdlog::info("Hello logger: msg number {}", i); + } + + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); +- spdlog::drop(log->name()); +- spdlog::set_default_logger(std::move(orig_default)); +- spdlog::info("{:<30} Elapsed: {:0.2f} secs {:>16}/sec", log->name(), delta_d, int(howmany / delta_d)); ++ yr_spdlog::drop(log->name()); ++ yr_spdlog::set_default_logger(std::move(orig_default)); ++ yr_spdlog::info("{:<30} Elapsed: {:0.2f} secs {:>16}/sec", log->name(), delta_d, int(howmany / delta_d)); + } + +-void bench_c_string(int howmany, std::shared_ptr log) ++void bench_c_string(int howmany, std::shared_ptr log) + { + using std::chrono::high_resolution_clock; + using std::chrono::duration; +@@ -230,19 +230,19 @@ + "augue pretium, nec scelerisque est maximus. Nullam convallis, sem nec blandit maximus, nisi turpis ornare " + "nisl, sit amet volutpat neque massa eu odio. Maecenas malesuada quam ex, posuere congue nibh turpis duis."; + +- auto orig_default = spdlog::default_logger(); +- spdlog::set_default_logger(log); ++ auto orig_default = yr_spdlog::default_logger(); ++ yr_spdlog::set_default_logger(log); + auto start = high_resolution_clock::now(); + for (auto i = 0; i < howmany; ++i) + { +- spdlog::log(spdlog::level::info, msg); ++ yr_spdlog::log(yr_spdlog::level::info, msg); + } + + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); +- spdlog::drop(log->name()); +- spdlog::set_default_logger(std::move(orig_default)); +- spdlog::info("{:<30} Elapsed: {:0.2f} secs {:>16}/sec", log->name(), delta_d, int(howmany / delta_d)); ++ yr_spdlog::drop(log->name()); ++ yr_spdlog::set_default_logger(std::move(orig_default)); ++ yr_spdlog::info("{:<30} Elapsed: {:0.2f} secs {:>16}/sec", log->name(), delta_d, int(howmany / delta_d)); + } + + */ +\ No newline at end of file +diff -ruN bench/CMakeLists.txt bench/CMakeLists.txt +--- bench/CMakeLists.txt 2025-07-02 15:16:11.164618390 +0800 ++++ bench/CMakeLists.txt 2025-07-02 15:16:19.630618059 +0800 +@@ -29,13 +29,13 @@ + + add_executable(bench bench.cpp) + spdlog_enable_warnings(bench) +-target_link_libraries(bench PRIVATE spdlog::spdlog) ++target_link_libraries(bench PRIVATE yr_spdlog::spdlog) + + add_executable(async_bench async_bench.cpp) +-target_link_libraries(async_bench PRIVATE spdlog::spdlog) ++target_link_libraries(async_bench PRIVATE yr_spdlog::spdlog) + + add_executable(latency latency.cpp) +-target_link_libraries(latency PRIVATE benchmark::benchmark spdlog::spdlog) ++target_link_libraries(latency PRIVATE benchmark::benchmark yr_spdlog::spdlog) + + add_executable(formatter-bench formatter-bench.cpp) +-target_link_libraries(formatter-bench PRIVATE benchmark::benchmark spdlog::spdlog) ++target_link_libraries(formatter-bench PRIVATE benchmark::benchmark yr_spdlog::spdlog) +diff -ruN bench/formatter-bench.cpp bench/formatter-bench.cpp +--- bench/formatter-bench.cpp 2025-07-02 15:16:11.164618390 +0800 ++++ bench/formatter-bench.cpp 2025-07-02 15:16:19.630618059 +0800 +@@ -10,13 +10,13 @@ + + void bench_formatter(benchmark::State &state, std::string pattern) + { +- auto formatter = spdlog::details::make_unique(pattern); +- spdlog::memory_buf_t dest; ++ auto formatter = yr_spdlog::details::make_unique(pattern); ++ yr_spdlog::memory_buf_t dest; + std::string logger_name = "logger-name"; + const char *text = "Hello. This is some message with length of 80 "; + +- spdlog::source_loc source_loc{"a/b/c/d/myfile.cpp", 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, logger_name, spdlog::level::info, text); ++ yr_spdlog::source_loc source_loc{"a/b/c/d/myfile.cpp", 123, "some_func()"}; ++ yr_spdlog::details::log_msg msg(source_loc, logger_name, yr_spdlog::level::info, text); + + for (auto _ : state) + { +@@ -59,10 +59,10 @@ + int main(int argc, char *argv[]) + { + +- spdlog::set_pattern("[%^%l%$] %v"); ++ yr_spdlog::set_pattern("[%^%l%$] %v"); + if (argc != 2) + { +- spdlog::error("Usage: {} (or \"all\" to bench all)", argv[0]); ++ yr_spdlog::error("Usage: {} (or \"all\" to bench all)", argv[0]); + exit(1); + } + +diff -ruN bench/latency.cpp bench/latency.cpp +--- bench/latency.cpp 2025-07-02 15:16:11.164618390 +0800 ++++ bench/latency.cpp 2025-07-02 15:16:19.630618059 +0800 +@@ -16,7 +16,7 @@ + #include "spdlog/sinks/null_sink.h" + #include "spdlog/sinks/rotating_file_sink.h" + +-void bench_c_string(benchmark::State &state, std::shared_ptr logger) ++void bench_c_string(benchmark::State &state, std::shared_ptr logger) + { + const char *msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Vestibulum pharetra metus cursus " + "lacus placerat congue. Nulla egestas, mauris a tincidunt tempus, enim lectus volutpat mi, eu consequat sem " +@@ -30,7 +30,7 @@ + } + } + +-void bench_logger(benchmark::State &state, std::shared_ptr logger) ++void bench_logger(benchmark::State &state, std::shared_ptr logger) + { + int i = 0; + for (auto _ : state) +@@ -38,17 +38,17 @@ + logger->info("Hello logger: msg number {}...............", ++i); + } + } +-void bench_global_logger(benchmark::State &state, std::shared_ptr logger) ++void bench_global_logger(benchmark::State &state, std::shared_ptr logger) + { +- spdlog::set_default_logger(std::move(logger)); ++ yr_spdlog::set_default_logger(std::move(logger)); + int i = 0; + for (auto _ : state) + { +- spdlog::info("Hello logger: msg number {}...............", ++i); ++ yr_spdlog::info("Hello logger: msg number {}...............", ++i); + } + } + +-void bench_disabled_macro(benchmark::State &state, std::shared_ptr logger) ++void bench_disabled_macro(benchmark::State &state, std::shared_ptr logger) + { + int i = 0; + benchmark::DoNotOptimize(i); // prevent unused warnings +@@ -59,9 +59,9 @@ + } + } + +-void bench_disabled_macro_global_logger(benchmark::State &state, std::shared_ptr logger) ++void bench_disabled_macro_global_logger(benchmark::State &state, std::shared_ptr logger) + { +- spdlog::set_default_logger(std::move(logger)); ++ yr_spdlog::set_default_logger(std::move(logger)); + int i = 0; + benchmark::DoNotOptimize(i); // prevent unused warnings + benchmark::DoNotOptimize(logger); // prevent unused warnings +@@ -74,20 +74,20 @@ + #ifdef __linux__ + void bench_dev_null() + { +- auto dev_null_st = spdlog::basic_logger_st("/dev/null_st", "/dev/null"); ++ auto dev_null_st = yr_spdlog::basic_logger_st("/dev/null_st", "/dev/null"); + benchmark::RegisterBenchmark("/dev/null_st", bench_logger, std::move(dev_null_st))->UseRealTime(); +- spdlog::drop("/dev/null_st"); ++ yr_spdlog::drop("/dev/null_st"); + +- auto dev_null_mt = spdlog::basic_logger_mt("/dev/null_mt", "/dev/null"); ++ auto dev_null_mt = yr_spdlog::basic_logger_mt("/dev/null_mt", "/dev/null"); + benchmark::RegisterBenchmark("/dev/null_mt", bench_logger, std::move(dev_null_mt))->UseRealTime(); +- spdlog::drop("/dev/null_mt"); ++ yr_spdlog::drop("/dev/null_mt"); + } + #endif // __linux__ + + int main(int argc, char *argv[]) + { +- using spdlog::sinks::null_sink_mt; +- using spdlog::sinks::null_sink_st; ++ using yr_spdlog::sinks::null_sink_mt; ++ using yr_spdlog::sinks::null_sink_st; + + size_t file_size = 30 * 1024 * 1024; + size_t rotating_files = 5; +@@ -96,23 +96,23 @@ + auto full_bench = argc > 1 && std::string(argv[1]) == "full"; + + // disabled loggers +- auto disabled_logger = std::make_shared("bench", std::make_shared()); +- disabled_logger->set_level(spdlog::level::off); ++ auto disabled_logger = std::make_shared("bench", std::make_shared()); ++ disabled_logger->set_level(yr_spdlog::level::off); + benchmark::RegisterBenchmark("disabled-at-compile-time", bench_disabled_macro, disabled_logger); + benchmark::RegisterBenchmark("disabled-at-compile-time (global logger)", bench_disabled_macro_global_logger, disabled_logger); + benchmark::RegisterBenchmark("disabled-at-runtime", bench_logger, disabled_logger); + benchmark::RegisterBenchmark("disabled-at-runtime (global logger)", bench_global_logger, disabled_logger); + // with backtrace of 64 +- auto tracing_disabled_logger = std::make_shared("bench", std::make_shared()); ++ auto tracing_disabled_logger = std::make_shared("bench", std::make_shared()); + tracing_disabled_logger->enable_backtrace(64); + benchmark::RegisterBenchmark("disabled-at-runtime/backtrace", bench_logger, tracing_disabled_logger); + +- auto null_logger_st = std::make_shared("bench", std::make_shared()); ++ auto null_logger_st = std::make_shared("bench", std::make_shared()); + benchmark::RegisterBenchmark("null_sink_st (500_bytes c_str)", bench_c_string, std::move(null_logger_st)); + benchmark::RegisterBenchmark("null_sink_st", bench_logger, null_logger_st); + benchmark::RegisterBenchmark("null_sink_st (global logger)", bench_global_logger, null_logger_st); + // with backtrace of 64 +- auto tracing_null_logger_st = std::make_shared("bench", std::make_shared()); ++ auto tracing_null_logger_st = std::make_shared("bench", std::make_shared()); + tracing_null_logger_st->enable_backtrace(64); + benchmark::RegisterBenchmark("null_sink_st/backtrace", bench_logger, tracing_null_logger_st); + +@@ -123,64 +123,64 @@ + if (full_bench) + { + // basic_st +- auto basic_st = spdlog::basic_logger_st("basic_st", "latency_logs/basic_st.log", true); ++ auto basic_st = yr_spdlog::basic_logger_st("basic_st", "latency_logs/basic_st.log", true); + benchmark::RegisterBenchmark("basic_st", bench_logger, std::move(basic_st))->UseRealTime(); +- spdlog::drop("basic_st"); ++ yr_spdlog::drop("basic_st"); + // with backtrace of 64 +- auto tracing_basic_st = spdlog::basic_logger_st("tracing_basic_st", "latency_logs/tracing_basic_st.log", true); ++ auto tracing_basic_st = yr_spdlog::basic_logger_st("tracing_basic_st", "latency_logs/tracing_basic_st.log", true); + tracing_basic_st->enable_backtrace(64); + benchmark::RegisterBenchmark("basic_st/backtrace", bench_logger, std::move(tracing_basic_st))->UseRealTime(); +- spdlog::drop("tracing_basic_st"); ++ yr_spdlog::drop("tracing_basic_st"); + + // rotating st +- auto rotating_st = spdlog::rotating_logger_st("rotating_st", "latency_logs/rotating_st.log", file_size, rotating_files); ++ auto rotating_st = yr_spdlog::rotating_logger_st("rotating_st", "latency_logs/rotating_st.log", file_size, rotating_files); + benchmark::RegisterBenchmark("rotating_st", bench_logger, std::move(rotating_st))->UseRealTime(); +- spdlog::drop("rotating_st"); ++ yr_spdlog::drop("rotating_st"); + // with backtrace of 64 + auto tracing_rotating_st = +- spdlog::rotating_logger_st("tracing_rotating_st", "latency_logs/tracing_rotating_st.log", file_size, rotating_files); ++ yr_spdlog::rotating_logger_st("tracing_rotating_st", "latency_logs/tracing_rotating_st.log", file_size, rotating_files); + benchmark::RegisterBenchmark("rotating_st/backtrace", bench_logger, std::move(tracing_rotating_st))->UseRealTime(); +- spdlog::drop("tracing_rotating_st"); ++ yr_spdlog::drop("tracing_rotating_st"); + + // daily st +- auto daily_st = spdlog::daily_logger_mt("daily_st", "latency_logs/daily_st.log"); ++ auto daily_st = yr_spdlog::daily_logger_mt("daily_st", "latency_logs/daily_st.log"); + benchmark::RegisterBenchmark("daily_st", bench_logger, std::move(daily_st))->UseRealTime(); +- spdlog::drop("daily_st"); +- auto tracing_daily_st = spdlog::daily_logger_mt("tracing_daily_st", "latency_logs/daily_st.log"); ++ yr_spdlog::drop("daily_st"); ++ auto tracing_daily_st = yr_spdlog::daily_logger_mt("tracing_daily_st", "latency_logs/daily_st.log"); + benchmark::RegisterBenchmark("daily_st/backtrace", bench_logger, std::move(tracing_daily_st))->UseRealTime(); +- spdlog::drop("tracing_daily_st"); ++ yr_spdlog::drop("tracing_daily_st"); + + // + // Multi threaded bench, 10 loggers using same logger concurrently + // +- auto null_logger_mt = std::make_shared("bench", std::make_shared()); ++ auto null_logger_mt = std::make_shared("bench", std::make_shared()); + benchmark::RegisterBenchmark("null_sink_mt", bench_logger, null_logger_mt)->Threads(n_threads)->UseRealTime(); + + // basic_mt +- auto basic_mt = spdlog::basic_logger_mt("basic_mt", "latency_logs/basic_mt.log", true); ++ auto basic_mt = yr_spdlog::basic_logger_mt("basic_mt", "latency_logs/basic_mt.log", true); + benchmark::RegisterBenchmark("basic_mt", bench_logger, std::move(basic_mt))->Threads(n_threads)->UseRealTime(); +- spdlog::drop("basic_mt"); ++ yr_spdlog::drop("basic_mt"); + + // rotating mt +- auto rotating_mt = spdlog::rotating_logger_mt("rotating_mt", "latency_logs/rotating_mt.log", file_size, rotating_files); ++ auto rotating_mt = yr_spdlog::rotating_logger_mt("rotating_mt", "latency_logs/rotating_mt.log", file_size, rotating_files); + benchmark::RegisterBenchmark("rotating_mt", bench_logger, std::move(rotating_mt))->Threads(n_threads)->UseRealTime(); +- spdlog::drop("rotating_mt"); ++ yr_spdlog::drop("rotating_mt"); + + // daily mt +- auto daily_mt = spdlog::daily_logger_mt("daily_mt", "latency_logs/daily_mt.log"); ++ auto daily_mt = yr_spdlog::daily_logger_mt("daily_mt", "latency_logs/daily_mt.log"); + benchmark::RegisterBenchmark("daily_mt", bench_logger, std::move(daily_mt))->Threads(n_threads)->UseRealTime(); +- spdlog::drop("daily_mt"); ++ yr_spdlog::drop("daily_mt"); + } + + // async + auto queue_size = 1024 * 1024 * 3; +- auto tp = std::make_shared(queue_size, 1); +- auto async_logger = std::make_shared( +- "async_logger", std::make_shared(), std::move(tp), spdlog::async_overflow_policy::overrun_oldest); ++ auto tp = std::make_shared(queue_size, 1); ++ auto async_logger = std::make_shared( ++ "async_logger", std::make_shared(), std::move(tp), yr_spdlog::async_overflow_policy::overrun_oldest); + benchmark::RegisterBenchmark("async_logger", bench_logger, async_logger)->Threads(n_threads)->UseRealTime(); + +- auto async_logger_tracing = std::make_shared( +- "async_logger_tracing", std::make_shared(), std::move(tp), spdlog::async_overflow_policy::overrun_oldest); ++ auto async_logger_tracing = std::make_shared( ++ "async_logger_tracing", std::make_shared(), std::move(tp), yr_spdlog::async_overflow_policy::overrun_oldest); + async_logger_tracing->enable_backtrace(32); + benchmark::RegisterBenchmark("async_logger/tracing", bench_logger, async_logger_tracing)->Threads(n_threads)->UseRealTime(); + +diff -ruN CMakeLists.txt CMakeLists.txt +--- CMakeLists.txt 2025-07-02 15:16:11.164618390 +0800 ++++ CMakeLists.txt 2025-07-02 15:16:19.630618059 +0800 +@@ -344,8 +344,8 @@ + # --------------------------------------------------------------------------------------- + # Install CMake config files + # --------------------------------------------------------------------------------------- +- export(TARGETS spdlog NAMESPACE spdlog:: FILE "${CMAKE_CURRENT_BINARY_DIR}/${config_targets_file}") +- install(EXPORT spdlog DESTINATION ${export_dest_dir} NAMESPACE spdlog:: FILE ${config_targets_file}) ++ export(TARGETS spdlog NAMESPACE yr_spdlog:: FILE "${CMAKE_CURRENT_BINARY_DIR}/${config_targets_file}") ++ install(EXPORT spdlog DESTINATION ${export_dest_dir} NAMESPACE yr_spdlog:: FILE ${config_targets_file}) + + include(CMakePackageConfigHelpers) + configure_package_config_file("${project_config_in}" "${project_config_out}" INSTALL_DESTINATION ${export_dest_dir}) +diff -ruN example/example.cpp example/example.cpp +--- example/example.cpp 2025-07-02 15:16:11.164618390 +0800 ++++ example/example.cpp 2025-07-02 15:16:19.630618059 +0800 +@@ -37,36 +37,36 @@ + // Log levels can be loaded from argv/env using "SPDLOG_LEVEL" + load_levels_example(); + +- spdlog::info("Welcome to spdlog version {}.{}.{} !", SPDLOG_VER_MAJOR, SPDLOG_VER_MINOR, SPDLOG_VER_PATCH); ++ yr_spdlog::info("Welcome to spdlog version {}.{}.{} !", SPDLOG_VER_MAJOR, SPDLOG_VER_MINOR, SPDLOG_VER_PATCH); + +- spdlog::warn("Easy padding in numbers like {:08d}", 12); +- spdlog::critical("Support for int: {0:d}; hex: {0:x}; oct: {0:o}; bin: {0:b}", 42); +- spdlog::info("Support for floats {:03.2f}", 1.23456); +- spdlog::info("Positional args are {1} {0}..", "too", "supported"); +- spdlog::info("{:>8} aligned, {:<8} aligned", "right", "left"); ++ yr_spdlog::warn("Easy padding in numbers like {:08d}", 12); ++ yr_spdlog::critical("Support for int: {0:d}; hex: {0:x}; oct: {0:o}; bin: {0:b}", 42); ++ yr_spdlog::info("Support for floats {:03.2f}", 1.23456); ++ yr_spdlog::info("Positional args are {1} {0}..", "too", "supported"); ++ yr_spdlog::info("{:>8} aligned, {:<8} aligned", "right", "left"); + + // Runtime log levels +- spdlog::set_level(spdlog::level::info); // Set global log level to info +- spdlog::debug("This message should not be displayed!"); +- spdlog::set_level(spdlog::level::trace); // Set specific logger's log level +- spdlog::debug("This message should be displayed.."); ++ yr_spdlog::set_level(yr_spdlog::level::info); // Set global log level to info ++ yr_spdlog::debug("This message should not be displayed!"); ++ yr_spdlog::set_level(yr_spdlog::level::trace); // Set specific logger's log level ++ yr_spdlog::debug("This message should be displayed.."); + + // Customize msg format for all loggers +- spdlog::set_pattern("[%H:%M:%S %z] [%^%L%$] [thread %t] %v"); +- spdlog::info("This an info message with custom format"); +- spdlog::set_pattern("%+"); // back to default format +- spdlog::set_level(spdlog::level::info); ++ yr_spdlog::set_pattern("[%H:%M:%S %z] [%^%L%$] [thread %t] %v"); ++ yr_spdlog::info("This an info message with custom format"); ++ yr_spdlog::set_pattern("%+"); // back to default format ++ yr_spdlog::set_level(yr_spdlog::level::info); + + // Backtrace support + // Loggers can store in a ring buffer all messages (including debug/trace) for later inspection. + // When needed, call dump_backtrace() to see what happened: +- spdlog::enable_backtrace(10); // create ring buffer with capacity of 10 messages ++ yr_spdlog::enable_backtrace(10); // create ring buffer with capacity of 10 messages + for (int i = 0; i < 100; i++) + { +- spdlog::debug("Backtrace message {}", i); // not logged.. ++ yr_spdlog::debug("Backtrace message {}", i); // not logged.. + } + // e.g. if some error happened: +- spdlog::dump_backtrace(); // log them now! ++ yr_spdlog::dump_backtrace(); // log them now! + + try + { +@@ -90,18 +90,18 @@ + + // Flush all *registered* loggers using a worker thread every 3 seconds. + // note: registered loggers *must* be thread safe for this to work correctly! +- spdlog::flush_every(std::chrono::seconds(3)); ++ yr_spdlog::flush_every(std::chrono::seconds(3)); + + // Apply some function on all registered loggers +- spdlog::apply_all([&](std::shared_ptr l) { l->info("End of example."); }); ++ yr_spdlog::apply_all([&](std::shared_ptr l) { l->info("End of example."); }); + + // Release all spdlog resources, and drop all loggers in the registry. + // This is optional (only mandatory if using windows + async log). +- spdlog::shutdown(); ++ yr_spdlog::shutdown(); + } + + // Exceptions will only be thrown upon failed logger or sink construction (not during logging). +- catch (const spdlog::spdlog_ex &ex) ++ catch (const yr_spdlog::spdlog_ex &ex) + { + std::printf("Log initialization failed: %s\n", ex.what()); + return 1; +@@ -113,37 +113,37 @@ + void stdout_logger_example() + { + // Create color multi threaded logger. +- auto console = spdlog::stdout_color_mt("console"); ++ auto console = yr_spdlog::stdout_color_mt("console"); + // or for stderr: +- // auto console = spdlog::stderr_color_mt("error-logger"); ++ // auto console = yr_spdlog::stderr_color_mt("error-logger"); + } + + #include "spdlog/sinks/basic_file_sink.h" + void basic_example() + { + // Create basic file logger (not rotated). +- auto my_logger = spdlog::basic_logger_mt("file_logger", "logs/basic-log.txt", true); ++ auto my_logger = yr_spdlog::basic_logger_mt("file_logger", "logs/basic-log.txt", true); + } + + #include "spdlog/sinks/rotating_file_sink.h" + void rotating_example() + { + // Create a file rotating logger with 5mb size max and 3 rotated files. +- auto rotating_logger = spdlog::rotating_logger_mt("some_logger_name", "logs/rotating.txt", 1048576 * 5, 3); ++ auto rotating_logger = yr_spdlog::rotating_logger_mt("some_logger_name", "logs/rotating.txt", 1048576 * 5, 3); + } + + #include "spdlog/sinks/daily_file_sink.h" + void daily_example() + { + // Create a daily logger - a new file is created every day on 2:30am. +- auto daily_logger = spdlog::daily_logger_mt("daily_logger", "logs/daily.txt", 2, 30); ++ auto daily_logger = yr_spdlog::daily_logger_mt("daily_logger", "logs/daily.txt", 2, 30); + } + + #include "spdlog/sinks/callback_sink.h" + void callback_example() + { + // Create the logger +- auto logger = spdlog::callback_logger_mt("custom_callback_logger", [](const spdlog::details::log_msg & /*msg*/) { ++ auto logger = yr_spdlog::callback_logger_mt("custom_callback_logger", [](const yr_spdlog::details::log_msg & /*msg*/) { + // do what you need to do with msg + }); + } +@@ -153,21 +153,21 @@ + { + // Set the log level to "info" and mylogger to "trace": + // SPDLOG_LEVEL=info,mylogger=trace && ./example +- spdlog::cfg::load_env_levels(); ++ yr_spdlog::cfg::load_env_levels(); + // or from command line: + // ./example SPDLOG_LEVEL=info,mylogger=trace + // #include "spdlog/cfg/argv.h" // for loading levels from argv +- // spdlog::cfg::load_argv_levels(args, argv); ++ // yr_spdlog::cfg::load_argv_levels(args, argv); + } + + #include "spdlog/async.h" + void async_example() + { + // Default thread pool settings can be modified *before* creating the async logger: +- // spdlog::init_thread_pool(32768, 1); // queue with max 32k items 1 backing thread. +- auto async_file = spdlog::basic_logger_mt("async_file_logger", "logs/async_log.txt"); ++ // yr_spdlog::init_thread_pool(32768, 1); // queue with max 32k items 1 backing thread. ++ auto async_file = yr_spdlog::basic_logger_mt("async_file_logger", "logs/async_log.txt"); + // alternatively: +- // auto async_file = spdlog::create_async("async_file_logger", "logs/async_log.txt"); ++ // auto async_file = yr_spdlog::create_async("async_file_logger", "logs/async_log.txt"); + + for (int i = 1; i < 101; ++i) + { +@@ -193,14 +193,14 @@ + { + buf.push_back(static_cast(i & 0xff)); + } +- spdlog::info("Binary example: {}", spdlog::to_hex(buf)); +- spdlog::info("Another binary example:{:n}", spdlog::to_hex(std::begin(buf), std::begin(buf) + 10)); ++ yr_spdlog::info("Binary example: {}", yr_spdlog::to_hex(buf)); ++ yr_spdlog::info("Another binary example:{:n}", yr_spdlog::to_hex(std::begin(buf), std::begin(buf) + 10)); + // more examples: +- // logger->info("uppercase: {:X}", spdlog::to_hex(buf)); +- // logger->info("uppercase, no delimiters: {:Xs}", spdlog::to_hex(buf)); +- // logger->info("uppercase, no delimiters, no position info: {:Xsp}", spdlog::to_hex(buf)); +- // logger->info("hexdump style: {:a}", spdlog::to_hex(buf)); +- // logger->info("hexdump style, 20 chars per line {:a}", spdlog::to_hex(buf, 20)); ++ // logger->info("uppercase: {:X}", yr_spdlog::to_hex(buf)); ++ // logger->info("uppercase, no delimiters: {:Xs}", yr_spdlog::to_hex(buf)); ++ // logger->info("uppercase, no delimiters, no position info: {:Xsp}", yr_spdlog::to_hex(buf)); ++ // logger->info("hexdump style: {:a}", yr_spdlog::to_hex(buf)); ++ // logger->info("hexdump style, 20 chars per line {:a}", yr_spdlog::to_hex(buf, 20)); + } + #else + void binary_example() { +@@ -214,7 +214,7 @@ + void vector_example() + { + std::vector vec = {1, 2, 3}; +- spdlog::info("Vector example: {}", vec); ++ yr_spdlog::info("Vector example: {}", vec); + } + + #else +@@ -233,7 +233,7 @@ + SPDLOG_DEBUG("Some debug message.. {} ,{}", 1, 3.23); + + // trace from logger object +- auto logger = spdlog::get("file_logger"); ++ auto logger = yr_spdlog::get("file_logger"); + SPDLOG_LOGGER_TRACE(logger, "another trace message"); + } + +@@ -242,32 +242,32 @@ + #include + void stopwatch_example() + { +- spdlog::stopwatch sw; ++ yr_spdlog::stopwatch sw; + std::this_thread::sleep_for(std::chrono::milliseconds(123)); +- spdlog::info("Stopwatch: {} seconds", sw); ++ yr_spdlog::info("Stopwatch: {} seconds", sw); + } + + #include "spdlog/sinks/udp_sink.h" + void udp_example() + { +- spdlog::sinks::udp_sink_config cfg("127.0.0.1", 11091); +- auto my_logger = spdlog::udp_logger_mt("udplog", cfg); +- my_logger->set_level(spdlog::level::debug); ++ yr_spdlog::sinks::udp_sink_config cfg("127.0.0.1", 11091); ++ auto my_logger = yr_spdlog::udp_logger_mt("udplog", cfg); ++ my_logger->set_level(yr_spdlog::level::debug); + my_logger->info("hello world"); + } + + // A logger with multiple sinks (stdout and file) - each with a different format and log level. + void multi_sink_example() + { +- auto console_sink = std::make_shared(); +- console_sink->set_level(spdlog::level::warn); ++ auto console_sink = std::make_shared(); ++ console_sink->set_level(yr_spdlog::level::warn); + console_sink->set_pattern("[multi_sink_example] [%^%l%$] %v"); + +- auto file_sink = std::make_shared("logs/multisink.txt", true); +- file_sink->set_level(spdlog::level::trace); ++ auto file_sink = std::make_shared("logs/multisink.txt", true); ++ file_sink->set_level(yr_spdlog::level::trace); + +- spdlog::logger logger("multi_sink", {console_sink, file_sink}); +- logger.set_level(spdlog::level::debug); ++ yr_spdlog::logger logger("multi_sink", {console_sink, file_sink}); ++ logger.set_level(yr_spdlog::level::debug); + logger.warn("this should appear in both console and file"); + logger.info("this message should not appear in the console, only in the file"); + } +@@ -303,14 +303,14 @@ + + void user_defined_example() + { +- spdlog::info("user defined type: {}", my_type(14)); ++ yr_spdlog::info("user defined type: {}", my_type(14)); + } + + // Custom error handler. Will be triggered on log failure. + void err_handler_example() + { + // can be set globally or per logger(logger->set_error_handler(..)) +- spdlog::set_error_handler([](const std::string &msg) { printf("*** Custom log error handler: %s ***\n", msg.c_str()); }); ++ yr_spdlog::set_error_handler([](const std::string &msg) { printf("*** Custom log error handler: %s ***\n", msg.c_str()); }); + } + + // syslog example (linux/osx/freebsd) +@@ -319,7 +319,7 @@ + void syslog_example() + { + std::string ident = "spdlog-example"; +- auto syslog_logger = spdlog::syslog_logger_mt("syslog", ident, LOG_PID); ++ auto syslog_logger = yr_spdlog::syslog_logger_mt("syslog", ident, LOG_PID); + syslog_logger->warn("This is warning that will end up in syslog."); + } + #endif +@@ -330,7 +330,7 @@ + void android_example() + { + std::string tag = "spdlog-android"; +- auto android_logger = spdlog::android_logger_mt("android", tag); ++ auto android_logger = yr_spdlog::android_logger_mt("android", tag); + android_logger->critical("Use \"adb shell logcat\" to view this message."); + } + #endif +@@ -338,10 +338,10 @@ + // Log patterns can contain custom flags. + // this will add custom flag '%*' which will be bound to a instance + #include "spdlog/pattern_formatter.h" +-class my_formatter_flag : public spdlog::custom_flag_formatter ++class my_formatter_flag : public yr_spdlog::custom_flag_formatter + { + public: +- void format(const spdlog::details::log_msg &, const std::tm &, spdlog::memory_buf_t &dest) override ++ void format(const yr_spdlog::details::log_msg &, const std::tm &, yr_spdlog::memory_buf_t &dest) override + { + std::string some_txt = "custom-flag"; + dest.append(some_txt.data(), some_txt.data() + some_txt.size()); +@@ -349,50 +349,50 @@ + + std::unique_ptr clone() const override + { +- return spdlog::details::make_unique(); ++ return yr_spdlog::details::make_unique(); + } + }; + + void custom_flags_example() + { + +- using spdlog::details::make_unique; // for pre c++14 +- auto formatter = make_unique(); ++ using yr_spdlog::details::make_unique; // for pre c++14 ++ auto formatter = make_unique(); + formatter->add_flag('*').set_pattern("[%n] [%*] [%^%l%$] %v"); +- // set the new formatter using spdlog::set_formatter(formatter) or logger->set_formatter(formatter) +- // spdlog::set_formatter(std::move(formatter)); ++ // set the new formatter using yr_spdlog::set_formatter(formatter) or logger->set_formatter(formatter) ++ // yr_spdlog::set_formatter(std::move(formatter)); + } + + void file_events_example() + { +- // pass the spdlog::file_event_handlers to file sinks for open/close log file notifications +- spdlog::file_event_handlers handlers; +- handlers.before_open = [](spdlog::filename_t filename) { spdlog::info("Before opening {}", filename); }; +- handlers.after_open = [](spdlog::filename_t filename, std::FILE *fstream) { +- spdlog::info("After opening {}", filename); ++ // pass the yr_spdlog::file_event_handlers to file sinks for open/close log file notifications ++ yr_spdlog::file_event_handlers handlers; ++ handlers.before_open = [](yr_spdlog::filename_t filename) { yr_spdlog::info("Before opening {}", filename); }; ++ handlers.after_open = [](yr_spdlog::filename_t filename, std::FILE *fstream) { ++ yr_spdlog::info("After opening {}", filename); + fputs("After opening\n", fstream); + }; +- handlers.before_close = [](spdlog::filename_t filename, std::FILE *fstream) { +- spdlog::info("Before closing {}", filename); ++ handlers.before_close = [](yr_spdlog::filename_t filename, std::FILE *fstream) { ++ yr_spdlog::info("Before closing {}", filename); + fputs("Before closing\n", fstream); + }; +- handlers.after_close = [](spdlog::filename_t filename) { spdlog::info("After closing {}", filename); }; +- auto file_sink = std::make_shared("logs/events-sample.txt", true, handlers); +- spdlog::logger my_logger("some_logger", file_sink); ++ handlers.after_close = [](yr_spdlog::filename_t filename) { yr_spdlog::info("After closing {}", filename); }; ++ auto file_sink = std::make_shared("logs/events-sample.txt", true, handlers); ++ yr_spdlog::logger my_logger("some_logger", file_sink); + my_logger.info("Some log line"); + } + + void replace_default_logger_example() + { + // store the old logger so we don't break other examples. +- auto old_logger = spdlog::default_logger(); ++ auto old_logger = yr_spdlog::default_logger(); + +- auto new_logger = spdlog::basic_logger_mt("new_default_logger", "logs/new-default-log.txt", true); +- spdlog::set_default_logger(new_logger); +- spdlog::set_level(spdlog::level::info); +- spdlog::debug("This message should not be displayed!"); +- spdlog::set_level(spdlog::level::trace); +- spdlog::debug("This message should be displayed.."); ++ auto new_logger = yr_spdlog::basic_logger_mt("new_default_logger", "logs/new-default-log.txt", true); ++ yr_spdlog::set_default_logger(new_logger); ++ yr_spdlog::set_level(yr_spdlog::level::info); ++ yr_spdlog::debug("This message should not be displayed!"); ++ yr_spdlog::set_level(yr_spdlog::level::trace); ++ yr_spdlog::debug("This message should be displayed.."); + +- spdlog::set_default_logger(old_logger); ++ yr_spdlog::set_default_logger(old_logger); + } +diff -ruN include/spdlog/async.h include/spdlog/async.h +--- include/spdlog/async.h 2025-07-02 15:16:11.164618390 +0800 ++++ include/spdlog/async.h 2025-07-02 15:16:19.630618059 +0800 +@@ -22,7 +22,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + + namespace details { + static const size_t default_async_q_size = 8192; +@@ -61,13 +61,13 @@ + using async_factory_nonblock = async_factory_impl; + + template +-inline std::shared_ptr create_async(std::string logger_name, SinkArgs &&...sink_args) ++inline std::shared_ptr create_async(std::string logger_name, SinkArgs &&...sink_args) + { + return async_factory::create(std::move(logger_name), std::forward(sink_args)...); + } + + template +-inline std::shared_ptr create_async_nb(std::string logger_name, SinkArgs &&...sink_args) ++inline std::shared_ptr create_async_nb(std::string logger_name, SinkArgs &&...sink_args) + { + return async_factory_nonblock::create(std::move(logger_name), std::forward(sink_args)...); + } +@@ -92,8 +92,8 @@ + } + + // get the global thread pool. +-inline std::shared_ptr thread_pool() ++inline std::shared_ptr thread_pool() + { + return details::registry::instance().get_tp(); + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/async_logger.h include/spdlog/async_logger.h +--- include/spdlog/async_logger.h 2025-07-02 15:16:11.164618390 +0800 ++++ include/spdlog/async_logger.h 2025-07-02 15:16:19.630618059 +0800 +@@ -16,7 +16,7 @@ + + #include + +-namespace spdlog { ++namespace yr_spdlog { + + // Async overflow policy - block by default. + enum class async_overflow_policy +@@ -61,7 +61,7 @@ + std::weak_ptr thread_pool_; + async_overflow_policy overflow_policy_; + }; +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "async_logger-inl.h" +diff -ruN include/spdlog/async_logger-inl.h include/spdlog/async_logger-inl.h +--- include/spdlog/async_logger-inl.h 2025-07-02 15:16:11.164618390 +0800 ++++ include/spdlog/async_logger-inl.h 2025-07-02 15:16:19.630618059 +0800 +@@ -13,18 +13,18 @@ + #include + #include + +-SPDLOG_INLINE spdlog::async_logger::async_logger( ++SPDLOG_INLINE yr_spdlog::async_logger::async_logger( + std::string logger_name, sinks_init_list sinks_list, std::weak_ptr tp, async_overflow_policy overflow_policy) + : async_logger(std::move(logger_name), sinks_list.begin(), sinks_list.end(), std::move(tp), overflow_policy) + {} + +-SPDLOG_INLINE spdlog::async_logger::async_logger( ++SPDLOG_INLINE yr_spdlog::async_logger::async_logger( + std::string logger_name, sink_ptr single_sink, std::weak_ptr tp, async_overflow_policy overflow_policy) + : async_logger(std::move(logger_name), {std::move(single_sink)}, std::move(tp), overflow_policy) + {} + + // send the log message to the thread pool +-SPDLOG_INLINE void spdlog::async_logger::sink_it_(const details::log_msg &msg){ ++SPDLOG_INLINE void yr_spdlog::async_logger::sink_it_(const details::log_msg &msg){ + SPDLOG_TRY{if (auto pool_ptr = thread_pool_.lock()){pool_ptr->post_log(shared_from_this(), msg, overflow_policy_); + } + else +@@ -36,7 +36,7 @@ + } + + // send flush request to the thread pool +-SPDLOG_INLINE void spdlog::async_logger::flush_(){ ++SPDLOG_INLINE void yr_spdlog::async_logger::flush_(){ + SPDLOG_TRY{if (auto pool_ptr = thread_pool_.lock()){pool_ptr->post_flush(shared_from_this(), overflow_policy_); + } + else +@@ -50,7 +50,7 @@ + // + // backend functions - called from the thread pool to do the actual job + // +-SPDLOG_INLINE void spdlog::async_logger::backend_sink_it_(const details::log_msg &msg) ++SPDLOG_INLINE void yr_spdlog::async_logger::backend_sink_it_(const details::log_msg &msg) + { + for (auto &sink : sinks_) + { +@@ -70,7 +70,7 @@ + } + } + +-SPDLOG_INLINE void spdlog::async_logger::backend_flush_() ++SPDLOG_INLINE void yr_spdlog::async_logger::backend_flush_() + { + for (auto &sink : sinks_) + { +@@ -82,9 +82,9 @@ + } + } + +-SPDLOG_INLINE std::shared_ptr spdlog::async_logger::clone(std::string new_name) ++SPDLOG_INLINE std::shared_ptr yr_spdlog::async_logger::clone(std::string new_name) + { +- auto cloned = std::make_shared(*this); ++ auto cloned = std::make_shared(*this); + cloned->name_ = std::move(new_name); + return cloned; + } +diff -ruN include/spdlog/cfg/argv.h include/spdlog/cfg/argv.h +--- include/spdlog/cfg/argv.h 2025-07-02 15:16:11.164618390 +0800 ++++ include/spdlog/cfg/argv.h 2025-07-02 15:16:19.630618059 +0800 +@@ -17,7 +17,7 @@ + // turn off all logging except for logger1 and logger2: + // example.exe "SPDLOG_LEVEL=off,logger1=debug,logger2=info" + +-namespace spdlog { ++namespace yr_spdlog { + namespace cfg { + + // search for SPDLOG_LEVEL= in the args and use it to init the levels +@@ -41,4 +41,4 @@ + } + + } // namespace cfg +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/cfg/env.h include/spdlog/cfg/env.h +--- include/spdlog/cfg/env.h 2025-07-02 15:16:11.164618390 +0800 ++++ include/spdlog/cfg/env.h 2025-07-02 15:16:19.630618059 +0800 +@@ -23,7 +23,7 @@ + // turn off all logging except for logger1 and logger2: + // export SPDLOG_LEVEL="off,logger1=debug,logger2=info" + +-namespace spdlog { ++namespace yr_spdlog { + namespace cfg { + inline void load_env_levels() + { +@@ -35,4 +35,4 @@ + } + + } // namespace cfg +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/cfg/helpers.h include/spdlog/cfg/helpers.h +--- include/spdlog/cfg/helpers.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/cfg/helpers.h 2025-07-02 15:16:19.630618059 +0800 +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace cfg { + namespace helpers { + // +@@ -22,7 +22,7 @@ + } // namespace helpers + + } // namespace cfg +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "helpers-inl.h" +diff -ruN include/spdlog/cfg/helpers-inl.h include/spdlog/cfg/helpers-inl.h +--- include/spdlog/cfg/helpers-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/cfg/helpers-inl.h 2025-07-02 15:16:19.630618059 +0800 +@@ -16,7 +16,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace cfg { + namespace helpers { + +@@ -117,4 +117,4 @@ + + } // namespace helpers + } // namespace cfg +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/common.h include/spdlog/common.h +--- include/spdlog/common.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/common.h 2025-07-02 15:16:19.630618059 +0800 +@@ -111,7 +111,7 @@ + {} + #endif + +-namespace spdlog { ++namespace yr_spdlog { + + class formatter; + +@@ -242,13 +242,13 @@ + n_levels + }; + +-#define SPDLOG_LEVEL_NAME_TRACE spdlog::string_view_t("trace", 5) +-#define SPDLOG_LEVEL_NAME_DEBUG spdlog::string_view_t("debug", 5) +-#define SPDLOG_LEVEL_NAME_INFO spdlog::string_view_t("info", 4) +-#define SPDLOG_LEVEL_NAME_WARNING spdlog::string_view_t("warning", 7) +-#define SPDLOG_LEVEL_NAME_ERROR spdlog::string_view_t("error", 5) +-#define SPDLOG_LEVEL_NAME_CRITICAL spdlog::string_view_t("critical", 8) +-#define SPDLOG_LEVEL_NAME_OFF spdlog::string_view_t("off", 3) ++#define SPDLOG_LEVEL_NAME_TRACE yr_spdlog::string_view_t("trace", 5) ++#define SPDLOG_LEVEL_NAME_DEBUG yr_spdlog::string_view_t("debug", 5) ++#define SPDLOG_LEVEL_NAME_INFO yr_spdlog::string_view_t("info", 4) ++#define SPDLOG_LEVEL_NAME_WARNING yr_spdlog::string_view_t("warning", 7) ++#define SPDLOG_LEVEL_NAME_ERROR yr_spdlog::string_view_t("error", 5) ++#define SPDLOG_LEVEL_NAME_CRITICAL yr_spdlog::string_view_t("critical", 8) ++#define SPDLOG_LEVEL_NAME_OFF yr_spdlog::string_view_t("off", 3) + + #if !defined(SPDLOG_LEVEL_NAMES) + # define SPDLOG_LEVEL_NAMES \ +@@ -266,9 +266,9 @@ + } + #endif + +-SPDLOG_API const string_view_t &to_string_view(spdlog::level::level_enum l) SPDLOG_NOEXCEPT; +-SPDLOG_API const char *to_short_c_str(spdlog::level::level_enum l) SPDLOG_NOEXCEPT; +-SPDLOG_API spdlog::level::level_enum from_str(const std::string &name) SPDLOG_NOEXCEPT; ++SPDLOG_API const string_view_t &to_string_view(yr_spdlog::level::level_enum l) SPDLOG_NOEXCEPT; ++SPDLOG_API const char *to_short_c_str(yr_spdlog::level::level_enum l) SPDLOG_NOEXCEPT; ++SPDLOG_API yr_spdlog::level::level_enum from_str(const std::string &name) SPDLOG_NOEXCEPT; + + } // namespace level + +@@ -346,23 +346,23 @@ + + // to_string_view + +-SPDLOG_CONSTEXPR_FUNC spdlog::string_view_t to_string_view(const memory_buf_t &buf) SPDLOG_NOEXCEPT ++SPDLOG_CONSTEXPR_FUNC yr_spdlog::string_view_t to_string_view(const memory_buf_t &buf) SPDLOG_NOEXCEPT + { +- return spdlog::string_view_t{buf.data(), buf.size()}; ++ return yr_spdlog::string_view_t{buf.data(), buf.size()}; + } + +-SPDLOG_CONSTEXPR_FUNC spdlog::string_view_t to_string_view(spdlog::string_view_t str) SPDLOG_NOEXCEPT ++SPDLOG_CONSTEXPR_FUNC yr_spdlog::string_view_t to_string_view(yr_spdlog::string_view_t str) SPDLOG_NOEXCEPT + { + return str; + } + + #if defined(SPDLOG_WCHAR_FILENAMES) || defined(SPDLOG_WCHAR_TO_UTF8_SUPPORT) +-SPDLOG_CONSTEXPR_FUNC spdlog::wstring_view_t to_string_view(const wmemory_buf_t &buf) SPDLOG_NOEXCEPT ++SPDLOG_CONSTEXPR_FUNC yr_spdlog::wstring_view_t to_string_view(const wmemory_buf_t &buf) SPDLOG_NOEXCEPT + { +- return spdlog::wstring_view_t{buf.data(), buf.size()}; ++ return yr_spdlog::wstring_view_t{buf.data(), buf.size()}; + } + +-SPDLOG_CONSTEXPR_FUNC spdlog::wstring_view_t to_string_view(spdlog::wstring_view_t str) SPDLOG_NOEXCEPT ++SPDLOG_CONSTEXPR_FUNC yr_spdlog::wstring_view_t to_string_view(yr_spdlog::wstring_view_t str) SPDLOG_NOEXCEPT + { + return str; + } +@@ -413,7 +413,7 @@ + } + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "common-inl.h" +diff -ruN include/spdlog/common-inl.h include/spdlog/common-inl.h +--- include/spdlog/common-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/common-inl.h 2025-07-02 15:16:19.630618059 +0800 +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace level { + + #if __cplusplus >= 201703L +@@ -20,17 +20,17 @@ + + static const char *short_level_names[] SPDLOG_SHORT_LEVEL_NAMES; + +-SPDLOG_INLINE const string_view_t &to_string_view(spdlog::level::level_enum l) SPDLOG_NOEXCEPT ++SPDLOG_INLINE const string_view_t &to_string_view(yr_spdlog::level::level_enum l) SPDLOG_NOEXCEPT + { + return level_string_views[l]; + } + +-SPDLOG_INLINE const char *to_short_c_str(spdlog::level::level_enum l) SPDLOG_NOEXCEPT ++SPDLOG_INLINE const char *to_short_c_str(yr_spdlog::level::level_enum l) SPDLOG_NOEXCEPT + { + return short_level_names[l]; + } + +-SPDLOG_INLINE spdlog::level::level_enum from_str(const std::string &name) SPDLOG_NOEXCEPT ++SPDLOG_INLINE yr_spdlog::level::level_enum from_str(const std::string &name) SPDLOG_NOEXCEPT + { + auto it = std::find(std::begin(level_string_views), std::end(level_string_views), name); + if (it != std::end(level_string_views)) +@@ -79,4 +79,4 @@ + SPDLOG_THROW(spdlog_ex(std::move(msg))); + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/backtracer.h include/spdlog/details/backtracer.h +--- include/spdlog/details/backtracer.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/backtracer.h 2025-07-02 15:16:19.631618059 +0800 +@@ -13,7 +13,7 @@ + // Store log messages in circular buffer. + // Useful for storing debug data in case of error/warning happens. + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + class SPDLOG_API backtracer + { +@@ -39,7 +39,7 @@ + }; + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "backtracer-inl.h" +diff -ruN include/spdlog/details/backtracer-inl.h include/spdlog/details/backtracer-inl.h +--- include/spdlog/details/backtracer-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/backtracer-inl.h 2025-07-02 15:16:19.631618059 +0800 +@@ -6,7 +6,7 @@ + #ifndef SPDLOG_HEADER_ONLY + # include + #endif +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + SPDLOG_INLINE backtracer::backtracer(const backtracer &other) + { +@@ -72,4 +72,4 @@ + } + } + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/circular_q.h include/spdlog/details/circular_q.h +--- include/spdlog/details/circular_q.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/circular_q.h 2025-07-02 15:16:19.631618059 +0800 +@@ -7,7 +7,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + template + class circular_q +@@ -143,4 +143,4 @@ + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/console_globals.h include/spdlog/details/console_globals.h +--- include/spdlog/details/console_globals.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/console_globals.h 2025-07-02 15:16:19.631618059 +0800 +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + struct console_mutex +@@ -29,4 +29,4 @@ + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/file_helper.h include/spdlog/details/file_helper.h +--- include/spdlog/details/file_helper.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/file_helper.h 2025-07-02 15:16:19.631618059 +0800 +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + // Helper class for file sinks. +@@ -55,7 +55,7 @@ + file_event_handlers event_handlers_; + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "file_helper-inl.h" +diff -ruN include/spdlog/details/file_helper-inl.h include/spdlog/details/file_helper-inl.h +--- include/spdlog/details/file_helper-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/file_helper-inl.h 2025-07-02 15:16:19.631618059 +0800 +@@ -17,7 +17,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + SPDLOG_INLINE file_helper::file_helper(const file_event_handlers &event_handlers) +@@ -177,4 +177,4 @@ + } + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/fmt_helper.h include/spdlog/details/fmt_helper.h +--- include/spdlog/details/fmt_helper.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/fmt_helper.h 2025-07-02 15:16:19.631618059 +0800 +@@ -14,11 +14,11 @@ + #endif + + // Some fmt helpers to efficiently format and pad ints and strings +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + namespace fmt_helper { + +-inline void append_string_view(spdlog::string_view_t view, memory_buf_t &dest) ++inline void append_string_view(yr_spdlog::string_view_t view, memory_buf_t &dest) + { + auto *buf_ptr = view.data(); + dest.append(buf_ptr, buf_ptr + view.size()); +@@ -161,4 +161,4 @@ + + } // namespace fmt_helper + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/log_msg_buffer.h include/spdlog/details/log_msg_buffer.h +--- include/spdlog/details/log_msg_buffer.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/log_msg_buffer.h 2025-07-02 15:16:19.631618059 +0800 +@@ -5,7 +5,7 @@ + + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + // Extend log_msg with internal buffer to store its payload. +@@ -26,7 +26,7 @@ + }; + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "log_msg_buffer-inl.h" +diff -ruN include/spdlog/details/log_msg_buffer-inl.h include/spdlog/details/log_msg_buffer-inl.h +--- include/spdlog/details/log_msg_buffer-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/log_msg_buffer-inl.h 2025-07-02 15:16:19.631618059 +0800 +@@ -7,7 +7,7 @@ + # include + #endif + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + SPDLOG_INLINE log_msg_buffer::log_msg_buffer(const log_msg &orig_msg) +@@ -55,4 +55,4 @@ + } + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/log_msg.h include/spdlog/details/log_msg.h +--- include/spdlog/details/log_msg.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/log_msg.h 2025-07-02 15:16:19.631618059 +0800 +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + struct SPDLOG_API log_msg + { +@@ -30,7 +30,7 @@ + string_view_t payload; + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "log_msg-inl.h" +diff -ruN include/spdlog/details/log_msg-inl.h include/spdlog/details/log_msg-inl.h +--- include/spdlog/details/log_msg-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/log_msg-inl.h 2025-07-02 15:16:19.631618059 +0800 +@@ -9,11 +9,11 @@ + + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + +-SPDLOG_INLINE log_msg::log_msg(spdlog::log_clock::time_point log_time, spdlog::source_loc loc, string_view_t a_logger_name, +- spdlog::level::level_enum lvl, spdlog::string_view_t msg) ++SPDLOG_INLINE log_msg::log_msg(yr_spdlog::log_clock::time_point log_time, yr_spdlog::source_loc loc, string_view_t a_logger_name, ++ yr_spdlog::level::level_enum lvl, yr_spdlog::string_view_t msg) + : logger_name(a_logger_name) + , level(lvl) + , time(log_time) +@@ -25,13 +25,13 @@ + {} + + SPDLOG_INLINE log_msg::log_msg( +- spdlog::source_loc loc, string_view_t a_logger_name, spdlog::level::level_enum lvl, spdlog::string_view_t msg) ++ yr_spdlog::source_loc loc, string_view_t a_logger_name, yr_spdlog::level::level_enum lvl, yr_spdlog::string_view_t msg) + : log_msg(os::now(), loc, a_logger_name, lvl, msg) + {} + +-SPDLOG_INLINE log_msg::log_msg(string_view_t a_logger_name, spdlog::level::level_enum lvl, spdlog::string_view_t msg) ++SPDLOG_INLINE log_msg::log_msg(string_view_t a_logger_name, yr_spdlog::level::level_enum lvl, yr_spdlog::string_view_t msg) + : log_msg(os::now(), source_loc{}, a_logger_name, lvl, msg) + {} + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/mpmc_blocking_q.h include/spdlog/details/mpmc_blocking_q.h +--- include/spdlog/details/mpmc_blocking_q.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/mpmc_blocking_q.h 2025-07-02 15:16:19.631618059 +0800 +@@ -15,7 +15,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + template +@@ -148,7 +148,7 @@ + std::mutex queue_mutex_; + std::condition_variable push_cv_; + std::condition_variable pop_cv_; +- spdlog::details::circular_q q_; ++ yr_spdlog::details::circular_q q_; + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/null_mutex.h include/spdlog/details/null_mutex.h +--- include/spdlog/details/null_mutex.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/null_mutex.h 2025-07-02 15:16:19.631618059 +0800 +@@ -7,7 +7,7 @@ + #include + // null, no cost dummy "mutex" and dummy "atomic" int + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + struct null_mutex + { +@@ -42,4 +42,4 @@ + }; + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/os.h include/spdlog/details/os.h +--- include/spdlog/details/os.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/os.h 2025-07-02 15:16:19.631618059 +0800 +@@ -6,11 +6,11 @@ + #include + #include // std::time_t + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + namespace os { + +-SPDLOG_API spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT; ++SPDLOG_API yr_spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT; + + SPDLOG_API std::tm localtime(const std::time_t &time_tt) SPDLOG_NOEXCEPT; + +@@ -115,7 +115,7 @@ + + } // namespace os + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "os-inl.h" +diff -ruN include/spdlog/details/os-inl.h include/spdlog/details/os-inl.h +--- include/spdlog/details/os-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/os-inl.h 2025-07-02 15:16:19.631618059 +0800 +@@ -70,11 +70,11 @@ + # define __has_feature(x) 0 // Compatibility with non-clang compilers. + #endif + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + namespace os { + +-SPDLOG_INLINE spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT ++SPDLOG_INLINE yr_spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT + { + + #if defined __linux__ && defined SPDLOG_CLOCK_COARSE +@@ -632,4 +632,4 @@ + + } // namespace os + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/periodic_worker.h include/spdlog/details/periodic_worker.h +--- include/spdlog/details/periodic_worker.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/periodic_worker.h 2025-07-02 15:16:19.631618059 +0800 +@@ -14,7 +14,7 @@ + #include + #include + #include +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + class SPDLOG_API periodic_worker +@@ -53,7 +53,7 @@ + std::condition_variable cv_; + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "periodic_worker-inl.h" +diff -ruN include/spdlog/details/periodic_worker-inl.h include/spdlog/details/periodic_worker-inl.h +--- include/spdlog/details/periodic_worker-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/periodic_worker-inl.h 2025-07-02 15:16:19.631618059 +0800 +@@ -7,7 +7,7 @@ + # include + #endif + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + // stop the worker thread and join it +@@ -25,4 +25,4 @@ + } + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/registry.h include/spdlog/details/registry.h +--- include/spdlog/details/registry.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/registry.h 2025-07-02 15:16:19.631618059 +0800 +@@ -18,7 +18,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + class logger; + + namespace details { +@@ -37,9 +37,9 @@ + std::shared_ptr default_logger(); + + // Return raw ptr to the default logger. +- // To be used directly by the spdlog default api (e.g. spdlog::info) ++ // To be used directly by the spdlog default api (e.g. yr_spdlog::info) + // This make the default API faster, but cannot be used concurrently with set_default_logger(). +- // e.g do not call set_default_logger() from one thread while calling spdlog::info() from another. ++ // e.g do not call set_default_logger() from one thread while calling yr_spdlog::info() from another. + logger *get_default_raw(); + + // set default logger. +@@ -105,7 +105,7 @@ + std::unordered_map> loggers_; + log_levels log_levels_; + std::unique_ptr formatter_; +- spdlog::level::level_enum global_log_level_ = level::info; ++ yr_spdlog::level::level_enum global_log_level_ = level::info; + level::level_enum flush_level_ = level::off; + err_handler err_handler_; + std::shared_ptr tp_; +@@ -116,7 +116,7 @@ + }; + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "registry-inl.h" +diff -ruN include/spdlog/details/registry-inl.h include/spdlog/details/registry-inl.h +--- include/spdlog/details/registry-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/registry-inl.h 2025-07-02 15:16:19.631618059 +0800 +@@ -27,7 +27,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + SPDLOG_INLINE registry::registry() +@@ -43,7 +43,7 @@ + # endif + + const char *default_logger_name = ""; +- default_logger_ = std::make_shared(default_logger_name, std::move(color_sink)); ++ default_logger_ = std::make_shared(default_logger_name, std::move(color_sink)); + loggers_[default_logger_name] = default_logger_; + + #endif // SPDLOG_DISABLE_DEFAULT_LOGGER +@@ -99,9 +99,9 @@ + } + + // Return raw ptr to the default logger. +-// To be used directly by the spdlog default api (e.g. spdlog::info) ++// To be used directly by the spdlog default api (e.g. yr_spdlog::info) + // This make the default API faster, but cannot be used concurrently with set_default_logger(). +-// e.g do not call set_default_logger() from one thread while calling spdlog::info() from another. ++// e.g do not call set_default_logger() from one thread while calling yr_spdlog::info() from another. + SPDLOG_INLINE logger *registry::get_default_raw() + { + return default_logger_.get(); +@@ -312,4 +312,4 @@ + } + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/synchronous_factory.h include/spdlog/details/synchronous_factory.h +--- include/spdlog/details/synchronous_factory.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/synchronous_factory.h 2025-07-02 15:16:19.631618059 +0800 +@@ -5,7 +5,7 @@ + + #include "registry.h" + +-namespace spdlog { ++namespace yr_spdlog { + + // Default logger factory- creates synchronous loggers + class logger; +@@ -13,12 +13,12 @@ + struct synchronous_factory + { + template +- static std::shared_ptr create(std::string logger_name, SinkArgs &&...args) ++ static std::shared_ptr create(std::string logger_name, SinkArgs &&...args) + { + auto sink = std::make_shared(std::forward(args)...); +- auto new_logger = std::make_shared(std::move(logger_name), std::move(sink)); ++ auto new_logger = std::make_shared(std::move(logger_name), std::move(sink)); + details::registry::instance().initialize_logger(new_logger); + return new_logger; + } + }; +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/tcp_client.h include/spdlog/details/tcp_client.h +--- include/spdlog/details/tcp_client.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/tcp_client.h 2025-07-02 15:16:19.631618059 +0800 +@@ -20,7 +20,7 @@ + + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + class tcp_client + { +@@ -143,4 +143,4 @@ + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/tcp_client-windows.h include/spdlog/details/tcp_client-windows.h +--- include/spdlog/details/tcp_client-windows.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/tcp_client-windows.h 2025-07-02 15:16:19.631618059 +0800 +@@ -19,7 +19,7 @@ + #pragma comment(lib, "Mswsock.lib") + #pragma comment(lib, "AdvApi32.lib") + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + class tcp_client + { +@@ -157,4 +157,4 @@ + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/thread_pool.h include/spdlog/details/thread_pool.h +--- include/spdlog/details/thread_pool.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/thread_pool.h 2025-07-02 15:16:19.631618059 +0800 +@@ -13,12 +13,12 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + class async_logger; + + namespace details { + +-using async_logger_ptr = std::shared_ptr; ++using async_logger_ptr = std::shared_ptr; + + enum class async_msg_type + { +@@ -115,7 +115,7 @@ + }; + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "thread_pool-inl.h" +diff -ruN include/spdlog/details/thread_pool-inl.h include/spdlog/details/thread_pool-inl.h +--- include/spdlog/details/thread_pool-inl.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/thread_pool-inl.h 2025-07-02 15:16:19.631618059 +0800 +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + SPDLOG_INLINE thread_pool::thread_pool( +@@ -19,7 +19,7 @@ + { + if (threads_n == 0 || threads_n > 1000) + { +- throw_spdlog_ex("spdlog::thread_pool(): invalid threads_n param (valid " ++ throw_spdlog_ex("yr_spdlog::thread_pool(): invalid threads_n param (valid " + "range is 1-1000)"); + } + for (size_t i = 0; i < threads_n; i++) +@@ -134,4 +134,4 @@ + } + + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/udp_client.h include/spdlog/details/udp_client.h +--- include/spdlog/details/udp_client.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/udp_client.h 2025-07-02 15:16:19.631618059 +0800 +@@ -22,7 +22,7 @@ + + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + class udp_client +@@ -91,4 +91,4 @@ + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/details/udp_client-windows.h include/spdlog/details/udp_client-windows.h +--- include/spdlog/details/udp_client-windows.h 2025-07-02 15:16:11.165618390 +0800 ++++ include/spdlog/details/udp_client-windows.h 2025-07-02 15:16:19.631618059 +0800 +@@ -21,7 +21,7 @@ + # pragma comment(lib, "AdvApi32.lib") + #endif + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + class udp_client + { +@@ -110,4 +110,4 @@ + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/fmt/bin_to_hex.h include/spdlog/fmt/bin_to_hex.h +--- include/spdlog/fmt/bin_to_hex.h 2025-07-02 15:16:11.166618390 +0800 ++++ include/spdlog/fmt/bin_to_hex.h 2025-07-02 15:16:19.631618059 +0800 +@@ -31,12 +31,12 @@ + // Examples: + // + // std::vector v(200, 0x0b); +-// logger->info("Some buffer {}", spdlog::to_hex(v)); ++// logger->info("Some buffer {}", yr_spdlog::to_hex(v)); + // char buf[128]; +-// logger->info("Some buffer {:X}", spdlog::to_hex(std::begin(buf), std::end(buf))); +-// logger->info("Some buffer {:X}", spdlog::to_hex(std::begin(buf), std::end(buf), 16)); ++// logger->info("Some buffer {:X}", yr_spdlog::to_hex(std::begin(buf), std::end(buf))); ++// logger->info("Some buffer {:X}", yr_spdlog::to_hex(std::begin(buf), std::end(buf), 16)); + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + template +@@ -99,7 +99,7 @@ + return details::dump_info(range_begin, range_end, size_per_line); + } + +-} // namespace spdlog ++} // namespace yr_spdlog + + namespace + #ifdef SPDLOG_USE_STD_FORMAT +@@ -110,7 +110,7 @@ + { + + template +-struct formatter, char> ++struct formatter, char> + { + const char delimiter = ' '; + bool put_newlines = true; +@@ -156,7 +156,7 @@ + + // format the given bytes range as hex + template +- auto format(const spdlog::details::dump_info &the_range, FormatContext &ctx) const -> decltype(ctx.out()) ++ auto format(const yr_spdlog::details::dump_info &the_range, FormatContext &ctx) const -> decltype(ctx.out()) + { + SPDLOG_CONSTEXPR const char *hex_upper = "0123456789ABCDEF"; + SPDLOG_CONSTEXPR const char *hex_lower = "0123456789abcdef"; +@@ -241,7 +241,7 @@ + + if (put_positions) + { +- spdlog::fmt_lib::format_to(inserter, SPDLOG_FMT_STRING("{:04X}: "), pos); ++ yr_spdlog::fmt_lib::format_to(inserter, SPDLOG_FMT_STRING("{:04X}: "), pos); + } + } + }; +diff -ruN include/spdlog/formatter.h include/spdlog/formatter.h +--- include/spdlog/formatter.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/formatter.h 2025-07-02 15:16:19.633618059 +0800 +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + + class formatter + { +@@ -15,4 +15,4 @@ + virtual void format(const details::log_msg &msg, memory_buf_t &dest) = 0; + virtual std::unique_ptr clone() const = 0; + }; +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/fwd.h include/spdlog/fwd.h +--- include/spdlog/fwd.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/fwd.h 2025-07-02 15:16:19.633618059 +0800 +@@ -3,7 +3,7 @@ + + #pragma once + +-namespace spdlog { ++namespace yr_spdlog { + class logger; + class formatter; + +@@ -15,4 +15,4 @@ + enum level_enum : int; + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/logger.h include/spdlog/logger.h +--- include/spdlog/logger.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/logger.h 2025-07-02 15:16:19.633618059 +0800 +@@ -49,7 +49,7 @@ + # define SPDLOG_LOGGER_CATCH(location) + #endif + +-namespace spdlog { ++namespace yr_spdlog { + + class SPDLOG_API logger + { +@@ -82,7 +82,7 @@ + logger(const logger &other); + logger(logger &&other) SPDLOG_NOEXCEPT; + logger &operator=(logger other) SPDLOG_NOEXCEPT; +- void swap(spdlog::logger &other) SPDLOG_NOEXCEPT; ++ void swap(yr_spdlog::logger &other) SPDLOG_NOEXCEPT; + + template + void log(source_loc loc, level::level_enum lvl, format_string_t fmt, Args &&...args) +@@ -350,8 +350,8 @@ + protected: + std::string name_; + std::vector sinks_; +- spdlog::level_t level_{level::info}; +- spdlog::level_t flush_level_{level::off}; ++ yr_spdlog::level_t level_{level::info}; ++ yr_spdlog::level_t flush_level_{level::off}; + err_handler custom_err_handler_{nullptr}; + details::backtracer tracer_; + +@@ -420,7 +420,7 @@ + + void swap(logger &a, logger &b); + +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "logger-inl.h" +diff -ruN include/spdlog/logger-inl.h include/spdlog/logger-inl.h +--- include/spdlog/logger-inl.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/logger-inl.h 2025-07-02 15:16:19.633618059 +0800 +@@ -13,7 +13,7 @@ + + #include + +-namespace spdlog { ++namespace yr_spdlog { + + // public methods + SPDLOG_INLINE logger::logger(const logger &other) +@@ -40,7 +40,7 @@ + return *this; + } + +-SPDLOG_INLINE void logger::swap(spdlog::logger &other) SPDLOG_NOEXCEPT ++SPDLOG_INLINE void logger::swap(yr_spdlog::logger &other) SPDLOG_NOEXCEPT + { + name_.swap(other.name_); + sinks_.swap(other.sinks_); +@@ -163,7 +163,7 @@ + } + + // protected methods +-SPDLOG_INLINE void logger::log_it_(const spdlog::details::log_msg &log_msg, bool log_enabled, bool traceback_enabled) ++SPDLOG_INLINE void logger::log_it_(const yr_spdlog::details::log_msg &log_msg, bool log_enabled, bool traceback_enabled) + { + if (log_enabled) + { +@@ -254,4 +254,4 @@ + #endif + } + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/pattern_formatter.h include/spdlog/pattern_formatter.h +--- include/spdlog/pattern_formatter.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/pattern_formatter.h 2025-07-02 15:16:19.633618059 +0800 +@@ -16,7 +16,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + // padding information. +@@ -80,10 +80,10 @@ + using custom_flags = std::unordered_map>; + + explicit pattern_formatter(std::string pattern, pattern_time_type time_type = pattern_time_type::local, +- std::string eol = spdlog::details::os::default_eol, custom_flags custom_user_flags = custom_flags()); ++ std::string eol = yr_spdlog::details::os::default_eol, custom_flags custom_user_flags = custom_flags()); + + // use default pattern is not given +- explicit pattern_formatter(pattern_time_type time_type = pattern_time_type::local, std::string eol = spdlog::details::os::default_eol); ++ explicit pattern_formatter(pattern_time_type time_type = pattern_time_type::local, std::string eol = yr_spdlog::details::os::default_eol); + + pattern_formatter(const pattern_formatter &other) = delete; + pattern_formatter &operator=(const pattern_formatter &other) = delete; +@@ -121,7 +121,7 @@ + + void compile_pattern_(const std::string &pattern); + }; +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "pattern_formatter-inl.h" +diff -ruN include/spdlog/pattern_formatter-inl.h include/spdlog/pattern_formatter-inl.h +--- include/spdlog/pattern_formatter-inl.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/pattern_formatter-inl.h 2025-07-02 15:16:19.633618059 +0800 +@@ -27,7 +27,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace details { + + /////////////////////////////////////////////////////////////////////// +@@ -1318,8 +1318,8 @@ + formatters_.push_back((std::move(unknown_flag))); + } + // fix issue #1617 (prev char was '!' and should have been treated as funcname flag instead of truncating flag) +- // spdlog::set_pattern("[%10!] %v") => "[ main] some message" +- // spdlog::set_pattern("[%3!!] %v") => "[mai] some message" ++ // yr_spdlog::set_pattern("[%10!] %v") => "[ main] some message" ++ // yr_spdlog::set_pattern("[%3!!] %v") => "[mai] some message" + else + { + padding.truncate_ = false; +@@ -1433,4 +1433,4 @@ + formatters_.push_back(std::move(user_chars)); + } + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/android_sink.h include/spdlog/sinks/android_sink.h +--- include/spdlog/sinks/android_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/android_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -22,7 +22,7 @@ + # define SPDLOG_ANDROID_RETRIES 2 + # endif + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + /* +@@ -92,21 +92,21 @@ + return __android_log_buf_write(ID, prio, tag, text); + } + +- static android_LogPriority convert_to_android_(spdlog::level::level_enum level) ++ static android_LogPriority convert_to_android_(yr_spdlog::level::level_enum level) + { + switch (level) + { +- case spdlog::level::trace: ++ case yr_spdlog::level::trace: + return ANDROID_LOG_VERBOSE; +- case spdlog::level::debug: ++ case yr_spdlog::level::debug: + return ANDROID_LOG_DEBUG; +- case spdlog::level::info: ++ case yr_spdlog::level::info: + return ANDROID_LOG_INFO; +- case spdlog::level::warn: ++ case yr_spdlog::level::warn: + return ANDROID_LOG_WARN; +- case spdlog::level::err: ++ case yr_spdlog::level::err: + return ANDROID_LOG_ERROR; +- case spdlog::level::critical: ++ case yr_spdlog::level::critical: + return ANDROID_LOG_FATAL; + default: + return ANDROID_LOG_DEFAULT; +@@ -129,18 +129,18 @@ + + // Create and register android syslog logger + +-template ++template + inline std::shared_ptr android_logger_mt(const std::string &logger_name, const std::string &tag = "spdlog") + { + return Factory::template create(logger_name, tag); + } + +-template ++template + inline std::shared_ptr android_logger_st(const std::string &logger_name, const std::string &tag = "spdlog") + { + return Factory::template create(logger_name, tag); + } + +-} // namespace spdlog ++} // namespace yr_spdlog + + #endif // __ANDROID__ +diff -ruN include/spdlog/sinks/ansicolor_sink.h include/spdlog/sinks/ansicolor_sink.h +--- include/spdlog/sinks/ansicolor_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/ansicolor_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -11,7 +11,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + /** +@@ -42,7 +42,7 @@ + void log(const details::log_msg &msg) override; + void flush() override; + void set_pattern(const std::string &pattern) final; +- void set_formatter(std::unique_ptr sink_formatter) override; ++ void set_formatter(std::unique_ptr sink_formatter) override; + + // Formatting codes + const string_view_t reset = "\033[m"; +@@ -83,7 +83,7 @@ + FILE *target_file_; + mutex_t &mutex_; + bool should_do_colors_; +- std::unique_ptr formatter_; ++ std::unique_ptr formatter_; + std::array colors_; + void print_ccode_(const string_view_t &color_code); + void print_range_(const memory_buf_t &formatted, size_t start, size_t end); +@@ -111,7 +111,7 @@ + using ansicolor_stderr_sink_st = ansicolor_stderr_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "ansicolor_sink-inl.h" +diff -ruN include/spdlog/sinks/ansicolor_sink-inl.h include/spdlog/sinks/ansicolor_sink-inl.h +--- include/spdlog/sinks/ansicolor_sink-inl.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/ansicolor_sink-inl.h 2025-07-02 15:16:19.633618059 +0800 +@@ -10,14 +10,14 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + template + SPDLOG_INLINE ansicolor_sink::ansicolor_sink(FILE *target_file, color_mode mode) + : target_file_(target_file) + , mutex_(ConsoleMutex::mutex()) +- , formatter_(details::make_unique()) ++ , formatter_(details::make_unique()) + + { + set_color_mode(mode); +@@ -76,11 +76,11 @@ + SPDLOG_INLINE void ansicolor_sink::set_pattern(const std::string &pattern) + { + std::lock_guard lock(mutex_); +- formatter_ = std::unique_ptr(new pattern_formatter(pattern)); ++ formatter_ = std::unique_ptr(new pattern_formatter(pattern)); + } + + template +-SPDLOG_INLINE void ansicolor_sink::set_formatter(std::unique_ptr sink_formatter) ++SPDLOG_INLINE void ansicolor_sink::set_formatter(std::unique_ptr sink_formatter) + { + std::lock_guard lock(mutex_); + formatter_ = std::move(sink_formatter); +@@ -142,4 +142,4 @@ + {} + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/base_sink.h include/spdlog/sinks/base_sink.h +--- include/spdlog/sinks/base_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/base_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -13,14 +13,14 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + template + class SPDLOG_API base_sink : public sink + { + public: + base_sink(); +- explicit base_sink(std::unique_ptr formatter); ++ explicit base_sink(std::unique_ptr formatter); + ~base_sink() override = default; + + base_sink(const base_sink &) = delete; +@@ -32,20 +32,20 @@ + void log(const details::log_msg &msg) final; + void flush() final; + void set_pattern(const std::string &pattern) final; +- void set_formatter(std::unique_ptr sink_formatter) final; ++ void set_formatter(std::unique_ptr sink_formatter) final; + + protected: + // sink formatter +- std::unique_ptr formatter_; ++ std::unique_ptr formatter_; + Mutex mutex_; + + virtual void sink_it_(const details::log_msg &msg) = 0; + virtual void flush_() = 0; + virtual void set_pattern_(const std::string &pattern); +- virtual void set_formatter_(std::unique_ptr sink_formatter); ++ virtual void set_formatter_(std::unique_ptr sink_formatter); + }; + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "base_sink-inl.h" +diff -ruN include/spdlog/sinks/base_sink-inl.h include/spdlog/sinks/base_sink-inl.h +--- include/spdlog/sinks/base_sink-inl.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/base_sink-inl.h 2025-07-02 15:16:19.633618059 +0800 +@@ -13,51 +13,51 @@ + #include + + template +-SPDLOG_INLINE spdlog::sinks::base_sink::base_sink() +- : formatter_{details::make_unique()} ++SPDLOG_INLINE yr_spdlog::sinks::base_sink::base_sink() ++ : formatter_{details::make_unique()} + {} + + template +-SPDLOG_INLINE spdlog::sinks::base_sink::base_sink(std::unique_ptr formatter) ++SPDLOG_INLINE yr_spdlog::sinks::base_sink::base_sink(std::unique_ptr formatter) + : formatter_{std::move(formatter)} + {} + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::log(const details::log_msg &msg) ++void SPDLOG_INLINE yr_spdlog::sinks::base_sink::log(const details::log_msg &msg) + { + std::lock_guard lock(mutex_); + sink_it_(msg); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::flush() ++void SPDLOG_INLINE yr_spdlog::sinks::base_sink::flush() + { + std::lock_guard lock(mutex_); + flush_(); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::set_pattern(const std::string &pattern) ++void SPDLOG_INLINE yr_spdlog::sinks::base_sink::set_pattern(const std::string &pattern) + { + std::lock_guard lock(mutex_); + set_pattern_(pattern); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::set_formatter(std::unique_ptr sink_formatter) ++void SPDLOG_INLINE yr_spdlog::sinks::base_sink::set_formatter(std::unique_ptr sink_formatter) + { + std::lock_guard lock(mutex_); + set_formatter_(std::move(sink_formatter)); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::set_pattern_(const std::string &pattern) ++void SPDLOG_INLINE yr_spdlog::sinks::base_sink::set_pattern_(const std::string &pattern) + { +- set_formatter_(details::make_unique(pattern)); ++ set_formatter_(details::make_unique(pattern)); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::set_formatter_(std::unique_ptr sink_formatter) ++void SPDLOG_INLINE yr_spdlog::sinks::base_sink::set_formatter_(std::unique_ptr sink_formatter) + { + formatter_ = std::move(sink_formatter); + } +diff -ruN include/spdlog/sinks/basic_file_sink.h include/spdlog/sinks/basic_file_sink.h +--- include/spdlog/sinks/basic_file_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/basic_file_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -11,7 +11,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + /* + * Trivial file sink with single file as target +@@ -39,21 +39,21 @@ + // + // factory functions + // +-template ++template + inline std::shared_ptr basic_logger_mt( + const std::string &logger_name, const filename_t &filename, bool truncate = false, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, truncate, event_handlers); + } + +-template ++template + inline std::shared_ptr basic_logger_st( + const std::string &logger_name, const filename_t &filename, bool truncate = false, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, truncate, event_handlers); + } + +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "basic_file_sink-inl.h" +diff -ruN include/spdlog/sinks/basic_file_sink-inl.h include/spdlog/sinks/basic_file_sink-inl.h +--- include/spdlog/sinks/basic_file_sink-inl.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/basic_file_sink-inl.h 2025-07-02 15:16:19.633618059 +0800 +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + template +@@ -41,4 +41,4 @@ + } + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/callback_sink.h include/spdlog/sinks/callback_sink.h +--- include/spdlog/sinks/callback_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/callback_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + + // callbacks type + typedef std::function custom_log_callback; +@@ -46,16 +46,16 @@ + // + // factory functions + // +-template ++template + inline std::shared_ptr callback_logger_mt(const std::string &logger_name, const custom_log_callback &callback) + { + return Factory::template create(logger_name, callback); + } + +-template ++template + inline std::shared_ptr callback_logger_st(const std::string &logger_name, const custom_log_callback &callback) + { + return Factory::template create(logger_name, callback); + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/daily_file_sink.h include/spdlog/sinks/daily_file_sink.h +--- include/spdlog/sinks/daily_file_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/daily_file_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -20,7 +20,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + /* +@@ -41,8 +41,8 @@ + /* + * Generator of daily log file names with strftime format. + * Usages: +- * auto sink = std::make_shared("myapp-%Y-%m-%d:%H:%M:%S.log", hour, minute);" +- * auto logger = spdlog::daily_logger_format_mt("loggername, "myapp-%Y-%m-%d:%X.log", hour, minute)" ++ * auto sink = std::make_shared("myapp-%Y-%m-%d:%H:%M:%S.log", hour, minute);" ++ * auto logger = yr_spdlog::daily_logger_format_mt("loggername, "myapp-%Y-%m-%d:%X.log", hour, minute)" + * + */ + struct daily_filename_format_calculator +@@ -155,7 +155,7 @@ + tm now_tm(log_clock::time_point tp) + { + time_t tnow = log_clock::to_time_t(tp); +- return spdlog::details::os::localtime(tnow); ++ return yr_spdlog::details::os::localtime(tnow); + } + + log_clock::time_point next_rotation_tp_() +@@ -215,14 +215,14 @@ + // + // factory functions + // +-template ++template + inline std::shared_ptr daily_logger_mt(const std::string &logger_name, const filename_t &filename, int hour = 0, int minute = 0, + bool truncate = false, uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, hour, minute, truncate, max_files, event_handlers); + } + +-template ++template + inline std::shared_ptr daily_logger_format_mt(const std::string &logger_name, const filename_t &filename, int hour = 0, + int minute = 0, bool truncate = false, uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { +@@ -230,18 +230,18 @@ + logger_name, filename, hour, minute, truncate, max_files, event_handlers); + } + +-template ++template + inline std::shared_ptr daily_logger_st(const std::string &logger_name, const filename_t &filename, int hour = 0, int minute = 0, + bool truncate = false, uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, hour, minute, truncate, max_files, event_handlers); + } + +-template ++template + inline std::shared_ptr daily_logger_format_st(const std::string &logger_name, const filename_t &filename, int hour = 0, + int minute = 0, bool truncate = false, uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create( + logger_name, filename, hour, minute, truncate, max_files, event_handlers); + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/dist_sink.h include/spdlog/sinks/dist_sink.h +--- include/spdlog/sinks/dist_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/dist_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -16,7 +16,7 @@ + // Distribution sink (mux). Stores a vector of sinks which get called when log + // is called + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + template +@@ -76,10 +76,10 @@ + + void set_pattern_(const std::string &pattern) override + { +- set_formatter_(details::make_unique(pattern)); ++ set_formatter_(details::make_unique(pattern)); + } + +- void set_formatter_(std::unique_ptr sink_formatter) override ++ void set_formatter_(std::unique_ptr sink_formatter) override + { + base_sink::formatter_ = std::move(sink_formatter); + for (auto &sub_sink : sinks_) +@@ -94,4 +94,4 @@ + using dist_sink_st = dist_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/dup_filter_sink.h include/spdlog/sinks/dup_filter_sink.h +--- include/spdlog/sinks/dup_filter_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/dup_filter_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -22,7 +22,7 @@ + // int main() { + // auto dup_filter = std::make_shared(std::chrono::seconds(5), level::info); + // dup_filter->add_sink(std::make_shared()); +-// spdlog::logger l("logger", dup_filter); ++// yr_spdlog::logger l("logger", dup_filter); + // l.info("Hello"); + // l.info("Hello"); + // l.info("Hello"); +@@ -34,7 +34,7 @@ + // [2019-06-25 17:50:56.512] [logger] [info] Skipped 3 duplicate messages.. + // [2019-06-25 17:50:56.512] [logger] [info] Different Hello + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + template + class dup_filter_sink : public dist_sink +@@ -93,4 +93,4 @@ + using dup_filter_sink_st = dup_filter_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/hourly_file_sink.h include/spdlog/sinks/hourly_file_sink.h +--- include/spdlog/sinks/hourly_file_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/hourly_file_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -18,7 +18,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + /* +@@ -132,7 +132,7 @@ + tm now_tm(log_clock::time_point tp) + { + time_t tnow = log_clock::to_time_t(tp); +- return spdlog::details::os::localtime(tnow); ++ return yr_spdlog::details::os::localtime(tnow); + } + + log_clock::time_point next_rotation_tp_() +@@ -188,17 +188,17 @@ + // + // factory functions + // +-template ++template + inline std::shared_ptr hourly_logger_mt(const std::string &logger_name, const filename_t &filename, bool truncate = false, + uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, truncate, max_files, event_handlers); + } + +-template ++template + inline std::shared_ptr hourly_logger_st(const std::string &logger_name, const filename_t &filename, bool truncate = false, + uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, truncate, max_files, event_handlers); + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/kafka_sink.h include/spdlog/sinks/kafka_sink.h +--- include/spdlog/sinks/kafka_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/kafka_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -21,7 +21,7 @@ + // kafka header + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + struct kafka_sink_config +@@ -102,32 +102,32 @@ + }; + + using kafka_sink_mt = kafka_sink; +-using kafka_sink_st = kafka_sink; ++using kafka_sink_st = kafka_sink; + + } // namespace sinks + +-template +-inline std::shared_ptr kafka_logger_mt(const std::string &logger_name, spdlog::sinks::kafka_sink_config config) ++template ++inline std::shared_ptr kafka_logger_mt(const std::string &logger_name, yr_spdlog::sinks::kafka_sink_config config) + { + return Factory::template create(logger_name, config); + } + +-template +-inline std::shared_ptr kafka_logger_st(const std::string &logger_name, spdlog::sinks::kafka_sink_config config) ++template ++inline std::shared_ptr kafka_logger_st(const std::string &logger_name, yr_spdlog::sinks::kafka_sink_config config) + { + return Factory::template create(logger_name, config); + } + +-template +-inline std::shared_ptr kafka_logger_async_mt(std::string logger_name, spdlog::sinks::kafka_sink_config config) ++template ++inline std::shared_ptr kafka_logger_async_mt(std::string logger_name, yr_spdlog::sinks::kafka_sink_config config) + { + return Factory::template create(logger_name, config); + } + +-template +-inline std::shared_ptr kafka_logger_async_st(std::string logger_name, spdlog::sinks::kafka_sink_config config) ++template ++inline std::shared_ptr kafka_logger_async_st(std::string logger_name, yr_spdlog::sinks::kafka_sink_config config) + { + return Factory::template create(logger_name, config); + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/mongo_sink.h include/spdlog/sinks/mongo_sink.h +--- include/spdlog/sinks/mongo_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/mongo_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -23,7 +23,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + template + class mongo_sink : public base_sink +@@ -45,7 +45,7 @@ + { + try + { +- client_ = spdlog::details::make_unique(mongocxx::uri{uri}); ++ client_ = yr_spdlog::details::make_unique(mongocxx::uri{uri}); + } + catch (const std::exception &e) + { +@@ -86,22 +86,22 @@ + #include "spdlog/details/null_mutex.h" + #include + using mongo_sink_mt = mongo_sink; +-using mongo_sink_st = mongo_sink; ++using mongo_sink_st = mongo_sink; + + } // namespace sinks + +-template ++template + inline std::shared_ptr mongo_logger_mt(const std::string &logger_name, const std::string &db_name, + const std::string &collection_name, const std::string &uri = "mongodb://localhost:27017") + { + return Factory::template create(logger_name, db_name, collection_name, uri); + } + +-template ++template + inline std::shared_ptr mongo_logger_st(const std::string &logger_name, const std::string &db_name, + const std::string &collection_name, const std::string &uri = "mongodb://localhost:27017") + { + return Factory::template create(logger_name, db_name, collection_name, uri); + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/msvc_sink.h include/spdlog/sinks/msvc_sink.h +--- include/spdlog/sinks/msvc_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/msvc_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -22,7 +22,7 @@ + # endif + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + /* + * MSVC sink (logging using OutputDebugStringA) +@@ -66,6 +66,6 @@ + using windebug_sink_st = msvc_sink_st; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog + + #endif +diff -ruN include/spdlog/sinks/null_sink.h include/spdlog/sinks/null_sink.h +--- include/spdlog/sinks/null_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/null_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -9,7 +9,7 @@ + + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + template +@@ -25,7 +25,7 @@ + + } // namespace sinks + +-template ++template + inline std::shared_ptr null_logger_mt(const std::string &logger_name) + { + auto null_logger = Factory::template create(logger_name); +@@ -33,7 +33,7 @@ + return null_logger; + } + +-template ++template + inline std::shared_ptr null_logger_st(const std::string &logger_name) + { + auto null_logger = Factory::template create(logger_name); +@@ -41,4 +41,4 @@ + return null_logger; + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/ostream_sink.h include/spdlog/sinks/ostream_sink.h +--- include/spdlog/sinks/ostream_sink.h 2025-07-02 15:16:11.167618390 +0800 ++++ include/spdlog/sinks/ostream_sink.h 2025-07-02 15:16:19.633618059 +0800 +@@ -9,7 +9,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + template + class ostream_sink final : public base_sink +@@ -47,4 +47,4 @@ + using ostream_sink_st = ostream_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/qt_sinks.h include/spdlog/sinks/qt_sinks.h +--- include/spdlog/sinks/qt_sinks.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/qt_sinks.h 2025-07-02 15:16:19.634618058 +0800 +@@ -24,7 +24,7 @@ + // + // qt_sink class + // +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + template + class qt_sink : public base_sink +@@ -237,56 +237,56 @@ + // + + // log to QTextEdit +-template ++template + inline std::shared_ptr qt_logger_mt(const std::string &logger_name, QTextEdit *qt_object, const std::string &meta_method = "append") + { + return Factory::template create(logger_name, qt_object, meta_method); + } + +-template ++template + inline std::shared_ptr qt_logger_st(const std::string &logger_name, QTextEdit *qt_object, const std::string &meta_method = "append") + { + return Factory::template create(logger_name, qt_object, meta_method); + } + + // log to QPlainTextEdit +-template ++template + inline std::shared_ptr qt_logger_mt( + const std::string &logger_name, QPlainTextEdit *qt_object, const std::string &meta_method = "appendPlainText") + { + return Factory::template create(logger_name, qt_object, meta_method); + } + +-template ++template + inline std::shared_ptr qt_logger_st( + const std::string &logger_name, QPlainTextEdit *qt_object, const std::string &meta_method = "appendPlainText") + { + return Factory::template create(logger_name, qt_object, meta_method); + } + // log to QObject +-template ++template + inline std::shared_ptr qt_logger_mt(const std::string &logger_name, QObject *qt_object, const std::string &meta_method) + { + return Factory::template create(logger_name, qt_object, meta_method); + } + +-template ++template + inline std::shared_ptr qt_logger_st(const std::string &logger_name, QObject *qt_object, const std::string &meta_method) + { + return Factory::template create(logger_name, qt_object, meta_method); + } + + // log to QTextEdit with colorize output +-template ++template + inline std::shared_ptr qt_color_logger_mt(const std::string &logger_name, QTextEdit *qt_text_edit, int max_lines) + { + return Factory::template create(logger_name, qt_text_edit, max_lines); + } + +-template ++template + inline std::shared_ptr qt_color_logger_st(const std::string &logger_name, QTextEdit *qt_text_edit, int max_lines) + { + return Factory::template create(logger_name, qt_text_edit, max_lines); + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/ringbuffer_sink.h include/spdlog/sinks/ringbuffer_sink.h +--- include/spdlog/sinks/ringbuffer_sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/ringbuffer_sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -12,7 +12,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + /* + * Ring buffer sink +@@ -71,4 +71,4 @@ + + } // namespace sinks + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/rotating_file_sink.h include/spdlog/sinks/rotating_file_sink.h +--- include/spdlog/sinks/rotating_file_sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/rotating_file_sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -12,7 +12,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + // +@@ -59,7 +59,7 @@ + // factory functions + // + +-template ++template + inline std::shared_ptr rotating_logger_mt(const std::string &logger_name, const filename_t &filename, size_t max_file_size, + size_t max_files, bool rotate_on_open = false, const file_event_handlers &event_handlers = {}) + { +@@ -67,14 +67,14 @@ + logger_name, filename, max_file_size, max_files, rotate_on_open, event_handlers); + } + +-template ++template + inline std::shared_ptr rotating_logger_st(const std::string &logger_name, const filename_t &filename, size_t max_file_size, + size_t max_files, bool rotate_on_open = false, const file_event_handlers &event_handlers = {}) + { + return Factory::template create( + logger_name, filename, max_file_size, max_files, rotate_on_open, event_handlers); + } +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "rotating_file_sink-inl.h" +diff -ruN include/spdlog/sinks/rotating_file_sink-inl.h include/spdlog/sinks/rotating_file_sink-inl.h +--- include/spdlog/sinks/rotating_file_sink-inl.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/rotating_file_sink-inl.h 2025-07-02 15:16:19.634618058 +0800 +@@ -20,7 +20,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + template +@@ -149,4 +149,4 @@ + } + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/sink.h include/spdlog/sinks/sink.h +--- include/spdlog/sinks/sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + + namespace sinks { + class SPDLOG_API sink +@@ -16,7 +16,7 @@ + virtual void log(const details::log_msg &msg) = 0; + virtual void flush() = 0; + virtual void set_pattern(const std::string &pattern) = 0; +- virtual void set_formatter(std::unique_ptr sink_formatter) = 0; ++ virtual void set_formatter(std::unique_ptr sink_formatter) = 0; + + void set_level(level::level_enum log_level); + level::level_enum level() const; +@@ -28,7 +28,7 @@ + }; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "sink-inl.h" +diff -ruN include/spdlog/sinks/sink-inl.h include/spdlog/sinks/sink-inl.h +--- include/spdlog/sinks/sink-inl.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/sink-inl.h 2025-07-02 15:16:19.634618058 +0800 +@@ -9,17 +9,17 @@ + + #include + +-SPDLOG_INLINE bool spdlog::sinks::sink::should_log(spdlog::level::level_enum msg_level) const ++SPDLOG_INLINE bool yr_spdlog::sinks::sink::should_log(yr_spdlog::level::level_enum msg_level) const + { + return msg_level >= level_.load(std::memory_order_relaxed); + } + +-SPDLOG_INLINE void spdlog::sinks::sink::set_level(level::level_enum log_level) ++SPDLOG_INLINE void yr_spdlog::sinks::sink::set_level(level::level_enum log_level) + { + level_.store(log_level, std::memory_order_relaxed); + } + +-SPDLOG_INLINE spdlog::level::level_enum spdlog::sinks::sink::level() const ++SPDLOG_INLINE yr_spdlog::level::level_enum yr_spdlog::sinks::sink::level() const + { +- return static_cast(level_.load(std::memory_order_relaxed)); ++ return static_cast(level_.load(std::memory_order_relaxed)); + } +diff -ruN include/spdlog/sinks/stdout_color_sinks.h include/spdlog/sinks/stdout_color_sinks.h +--- include/spdlog/sinks/stdout_color_sinks.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/stdout_color_sinks.h 2025-07-02 15:16:19.634618058 +0800 +@@ -11,7 +11,7 @@ + + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + #ifdef _WIN32 + using stdout_color_sink_mt = wincolor_stdout_sink_mt; +@@ -26,19 +26,19 @@ + #endif + } // namespace sinks + +-template ++template + std::shared_ptr stdout_color_mt(const std::string &logger_name, color_mode mode = color_mode::automatic); + +-template ++template + std::shared_ptr stdout_color_st(const std::string &logger_name, color_mode mode = color_mode::automatic); + +-template ++template + std::shared_ptr stderr_color_mt(const std::string &logger_name, color_mode mode = color_mode::automatic); + +-template ++template + std::shared_ptr stderr_color_st(const std::string &logger_name, color_mode mode = color_mode::automatic); + +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "stdout_color_sinks-inl.h" +diff -ruN include/spdlog/sinks/stdout_color_sinks-inl.h include/spdlog/sinks/stdout_color_sinks-inl.h +--- include/spdlog/sinks/stdout_color_sinks-inl.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/stdout_color_sinks-inl.h 2025-07-02 15:16:19.634618058 +0800 +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + + template + SPDLOG_INLINE std::shared_ptr stdout_color_mt(const std::string &logger_name, color_mode mode) +@@ -35,4 +35,4 @@ + { + return Factory::template create(logger_name, mode); + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/stdout_sinks.h include/spdlog/sinks/stdout_sinks.h +--- include/spdlog/sinks/stdout_sinks.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/stdout_sinks.h 2025-07-02 15:16:19.634618058 +0800 +@@ -12,7 +12,7 @@ + # include + #endif + +-namespace spdlog { ++namespace yr_spdlog { + + namespace sinks { + +@@ -34,12 +34,12 @@ + void flush() override; + void set_pattern(const std::string &pattern) override; + +- void set_formatter(std::unique_ptr sink_formatter) override; ++ void set_formatter(std::unique_ptr sink_formatter) override; + + protected: + mutex_t &mutex_; + FILE *file_; +- std::unique_ptr formatter_; ++ std::unique_ptr formatter_; + #ifdef _WIN32 + HANDLE handle_; + #endif // WIN32 +@@ -68,19 +68,19 @@ + } // namespace sinks + + // factory methods +-template ++template + std::shared_ptr stdout_logger_mt(const std::string &logger_name); + +-template ++template + std::shared_ptr stdout_logger_st(const std::string &logger_name); + +-template ++template + std::shared_ptr stderr_logger_mt(const std::string &logger_name); + +-template ++template + std::shared_ptr stderr_logger_st(const std::string &logger_name); + +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "stdout_sinks-inl.h" +diff -ruN include/spdlog/sinks/stdout_sinks-inl.h include/spdlog/sinks/stdout_sinks-inl.h +--- include/spdlog/sinks/stdout_sinks-inl.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/stdout_sinks-inl.h 2025-07-02 15:16:19.634618058 +0800 +@@ -24,7 +24,7 @@ + # include // _fileno(..) + #endif // WIN32 + +-namespace spdlog { ++namespace yr_spdlog { + + namespace sinks { + +@@ -32,7 +32,7 @@ + SPDLOG_INLINE stdout_sink_base::stdout_sink_base(FILE *file) + : mutex_(ConsoleMutex::mutex()) + , file_(file) +- , formatter_(details::make_unique()) ++ , formatter_(details::make_unique()) + { + #ifdef _WIN32 + // get windows handle from the FILE* object +@@ -44,7 +44,7 @@ + // throw only if non stdout/stderr target is requested (probably regular file and not console). + if (handle_ == INVALID_HANDLE_VALUE && file != stdout && file != stderr) + { +- throw_spdlog_ex("spdlog::stdout_sink_base: _get_osfhandle() failed", errno); ++ throw_spdlog_ex("yr_spdlog::stdout_sink_base: _get_osfhandle() failed", errno); + } + #endif // WIN32 + } +@@ -87,11 +87,11 @@ + SPDLOG_INLINE void stdout_sink_base::set_pattern(const std::string &pattern) + { + std::lock_guard lock(mutex_); +- formatter_ = std::unique_ptr(new pattern_formatter(pattern)); ++ formatter_ = std::unique_ptr(new pattern_formatter(pattern)); + } + + template +-SPDLOG_INLINE void stdout_sink_base::set_formatter(std::unique_ptr sink_formatter) ++SPDLOG_INLINE void stdout_sink_base::set_formatter(std::unique_ptr sink_formatter) + { + std::lock_guard lock(mutex_); + formatter_ = std::move(sink_formatter); +@@ -135,4 +135,4 @@ + { + return Factory::template create(logger_name); + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/syslog_sink.h include/spdlog/sinks/syslog_sink.h +--- include/spdlog/sinks/syslog_sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/syslog_sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -11,7 +11,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + /** + * Sink that write to syslog using the `syscall()` library call. +@@ -23,13 +23,13 @@ + public: + syslog_sink(std::string ident, int syslog_option, int syslog_facility, bool enable_formatting) + : enable_formatting_{enable_formatting} +- , syslog_levels_{{/* spdlog::level::trace */ LOG_DEBUG, +- /* spdlog::level::debug */ LOG_DEBUG, +- /* spdlog::level::info */ LOG_INFO, +- /* spdlog::level::warn */ LOG_WARNING, +- /* spdlog::level::err */ LOG_ERR, +- /* spdlog::level::critical */ LOG_CRIT, +- /* spdlog::level::off */ LOG_INFO}} ++ , syslog_levels_{{/* yr_spdlog::level::trace */ LOG_DEBUG, ++ /* yr_spdlog::level::debug */ LOG_DEBUG, ++ /* yr_spdlog::level::info */ LOG_INFO, ++ /* yr_spdlog::level::warn */ LOG_WARNING, ++ /* yr_spdlog::level::err */ LOG_ERR, ++ /* yr_spdlog::level::critical */ LOG_CRIT, ++ /* yr_spdlog::level::off */ LOG_INFO}} + , ident_{std::move(ident)} + { + // set ident to be program name if empty +@@ -93,17 +93,17 @@ + } // namespace sinks + + // Create and register a syslog logger +-template ++template + inline std::shared_ptr syslog_logger_mt(const std::string &logger_name, const std::string &syslog_ident = "", int syslog_option = 0, + int syslog_facility = LOG_USER, bool enable_formatting = false) + { + return Factory::template create(logger_name, syslog_ident, syslog_option, syslog_facility, enable_formatting); + } + +-template ++template + inline std::shared_ptr syslog_logger_st(const std::string &logger_name, const std::string &syslog_ident = "", int syslog_option = 0, + int syslog_facility = LOG_USER, bool enable_formatting = false) + { + return Factory::template create(logger_name, syslog_ident, syslog_option, syslog_facility, enable_formatting); + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/systemd_sink.h include/spdlog/sinks/systemd_sink.h +--- include/spdlog/sinks/systemd_sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/systemd_sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -14,7 +14,7 @@ + #endif + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + /** +@@ -27,13 +27,13 @@ + systemd_sink(std::string ident = "", bool enable_formatting = false) + : ident_{std::move(ident)} + , enable_formatting_{enable_formatting} +- , syslog_levels_{{/* spdlog::level::trace */ LOG_DEBUG, +- /* spdlog::level::debug */ LOG_DEBUG, +- /* spdlog::level::info */ LOG_INFO, +- /* spdlog::level::warn */ LOG_WARNING, +- /* spdlog::level::err */ LOG_ERR, +- /* spdlog::level::critical */ LOG_CRIT, +- /* spdlog::level::off */ LOG_INFO}} ++ , syslog_levels_{{/* yr_spdlog::level::trace */ LOG_DEBUG, ++ /* yr_spdlog::level::debug */ LOG_DEBUG, ++ /* yr_spdlog::level::info */ LOG_INFO, ++ /* yr_spdlog::level::warn */ LOG_WARNING, ++ /* yr_spdlog::level::err */ LOG_ERR, ++ /* yr_spdlog::level::critical */ LOG_CRIT, ++ /* yr_spdlog::level::off */ LOG_INFO}} + {} + + ~systemd_sink() override {} +@@ -110,17 +110,17 @@ + } // namespace sinks + + // Create and register a syslog logger +-template ++template + inline std::shared_ptr systemd_logger_mt( + const std::string &logger_name, const std::string &ident = "", bool enable_formatting = false) + { + return Factory::template create(logger_name, ident, enable_formatting); + } + +-template ++template + inline std::shared_ptr systemd_logger_st( + const std::string &logger_name, const std::string &ident = "", bool enable_formatting = false) + { + return Factory::template create(logger_name, ident, enable_formatting); + } +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/tcp_sink.h include/spdlog/sinks/tcp_sink.h +--- include/spdlog/sinks/tcp_sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/tcp_sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -24,7 +24,7 @@ + // Will attempt to reconnect if connection drops. + // If more complicated behaviour is needed (i.e get responses), you can inherit it and override the sink_it_ method. + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + struct tcp_sink_config +@@ -40,7 +40,7 @@ + }; + + template +-class tcp_sink : public spdlog::sinks::base_sink ++class tcp_sink : public yr_spdlog::sinks::base_sink + { + public: + // connect to tcp host/port or throw if failed +@@ -58,10 +58,10 @@ + ~tcp_sink() override = default; + + protected: +- void sink_it_(const spdlog::details::log_msg &msg) override ++ void sink_it_(const yr_spdlog::details::log_msg &msg) override + { +- spdlog::memory_buf_t formatted; +- spdlog::sinks::base_sink::formatter_->format(msg, formatted); ++ yr_spdlog::memory_buf_t formatted; ++ yr_spdlog::sinks::base_sink::formatter_->format(msg, formatted); + if (!client_.is_connected()) + { + client_.connect(config_.server_host, config_.server_port); +@@ -75,7 +75,7 @@ + }; + + using tcp_sink_mt = tcp_sink; +-using tcp_sink_st = tcp_sink; ++using tcp_sink_st = tcp_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/udp_sink.h include/spdlog/sinks/udp_sink.h +--- include/spdlog/sinks/udp_sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/udp_sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -20,7 +20,7 @@ + // Simple udp client sink + // Sends formatted log via udp + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + struct udp_sink_config +@@ -35,7 +35,7 @@ + }; + + template +-class udp_sink : public spdlog::sinks::base_sink ++class udp_sink : public yr_spdlog::sinks::base_sink + { + public: + // host can be hostname or ip address +@@ -46,10 +46,10 @@ + ~udp_sink() override = default; + + protected: +- void sink_it_(const spdlog::details::log_msg &msg) override ++ void sink_it_(const yr_spdlog::details::log_msg &msg) override + { +- spdlog::memory_buf_t formatted; +- spdlog::sinks::base_sink::formatter_->format(msg, formatted); ++ yr_spdlog::memory_buf_t formatted; ++ yr_spdlog::sinks::base_sink::formatter_->format(msg, formatted); + client_.send(formatted.data(), formatted.size()); + } + +@@ -58,17 +58,17 @@ + }; + + using udp_sink_mt = udp_sink; +-using udp_sink_st = udp_sink; ++using udp_sink_st = udp_sink; + + } // namespace sinks + + // + // factory functions + // +-template ++template + inline std::shared_ptr udp_logger_mt(const std::string &logger_name, sinks::udp_sink_config skin_config) + { + return Factory::template create(logger_name, skin_config); + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/wincolor_sink.h include/spdlog/sinks/wincolor_sink.h +--- include/spdlog/sinks/wincolor_sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/wincolor_sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -14,7 +14,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + /* + * Windows color console sink. Uses WriteConsoleA to write to the console with +@@ -35,7 +35,7 @@ + void log(const details::log_msg &msg) final override; + void flush() final override; + void set_pattern(const std::string &pattern) override final; +- void set_formatter(std::unique_ptr sink_formatter) override final; ++ void set_formatter(std::unique_ptr sink_formatter) override final; + void set_color_mode(color_mode mode); + + protected: +@@ -43,7 +43,7 @@ + void *out_handle_; + mutex_t &mutex_; + bool should_do_colors_; +- std::unique_ptr formatter_; ++ std::unique_ptr formatter_; + std::array colors_; + + // set foreground color and return the orig console attributes (for resetting later) +@@ -78,7 +78,7 @@ + using wincolor_stderr_sink_mt = wincolor_stderr_sink; + using wincolor_stderr_sink_st = wincolor_stderr_sink; + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "wincolor_sink-inl.h" +diff -ruN include/spdlog/sinks/wincolor_sink-inl.h include/spdlog/sinks/wincolor_sink-inl.h +--- include/spdlog/sinks/wincolor_sink-inl.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/wincolor_sink-inl.h 2025-07-02 15:16:19.634618058 +0800 +@@ -13,13 +13,13 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + template + SPDLOG_INLINE wincolor_sink::wincolor_sink(void *out_handle, color_mode mode) + : out_handle_(out_handle) + , mutex_(ConsoleMutex::mutex()) +- , formatter_(details::make_unique()) ++ , formatter_(details::make_unique()) + { + + set_color_mode_impl(mode); +@@ -88,11 +88,11 @@ + void SPDLOG_INLINE wincolor_sink::set_pattern(const std::string &pattern) + { + std::lock_guard lock(mutex_); +- formatter_ = std::unique_ptr(new pattern_formatter(pattern)); ++ formatter_ = std::unique_ptr(new pattern_formatter(pattern)); + } + + template +-void SPDLOG_INLINE wincolor_sink::set_formatter(std::unique_ptr sink_formatter) ++void SPDLOG_INLINE wincolor_sink::set_formatter(std::unique_ptr sink_formatter) + { + std::lock_guard lock(mutex_); + formatter_ = std::move(sink_formatter); +@@ -172,4 +172,4 @@ + : wincolor_sink(::GetStdHandle(STD_ERROR_HANDLE), mode) + {} + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/sinks/win_eventlog_sink.h include/spdlog/sinks/win_eventlog_sink.h +--- include/spdlog/sinks/win_eventlog_sink.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/sinks/win_eventlog_sink.h 2025-07-02 15:16:19.634618058 +0800 +@@ -40,7 +40,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + namespace win_eventlog { +@@ -286,4 +286,4 @@ + using win_eventlog_sink_st = win_eventlog::win_eventlog_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/spdlog.h include/spdlog/spdlog.h +--- include/spdlog/spdlog.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/spdlog.h 2025-07-02 15:16:19.634618058 +0800 +@@ -20,7 +20,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + + using default_factory = synchronous_factory; + +@@ -29,9 +29,9 @@ + // global settings. + // + // Example: +-// spdlog::create("logger_name", "dailylog_filename", 11, 59); ++// yr_spdlog::create("logger_name", "dailylog_filename", 11, 59); + template +-inline std::shared_ptr create(std::string logger_name, SinkArgs &&...sink_args) ++inline std::shared_ptr create(std::string logger_name, SinkArgs &&...sink_args) + { + return default_factory::create(std::move(logger_name), std::forward(sink_args)...); + } +@@ -42,20 +42,20 @@ + // Useful for initializing manually created loggers with the global settings. + // + // Example: +-// auto mylogger = std::make_shared("mylogger", ...); +-// spdlog::initialize_logger(mylogger); ++// auto mylogger = std::make_shared("mylogger", ...); ++// yr_spdlog::initialize_logger(mylogger); + SPDLOG_API void initialize_logger(std::shared_ptr logger); + + // Return an existing logger or nullptr if a logger with such name doesn't + // exist. +-// example: spdlog::get("my_logger")->info("hello {}", "world"); ++// example: yr_spdlog::get("my_logger")->info("hello {}", "world"); + SPDLOG_API std::shared_ptr get(const std::string &name); + + // Set global formatter. Each sink in each logger will get a clone of this object +-SPDLOG_API void set_formatter(std::unique_ptr formatter); ++SPDLOG_API void set_formatter(std::unique_ptr formatter); + + // Set global format string. +-// example: spdlog::set_pattern("%Y-%m-%d %H:%M:%S.%e %l : %v"); ++// example: yr_spdlog::set_pattern("%Y-%m-%d %H:%M:%S.%e %l : %v"); + SPDLOG_API void set_pattern(std::string pattern, pattern_time_type time_type = pattern_time_type::local); + + // enable global backtrace support +@@ -95,7 +95,7 @@ + + // Apply a user defined function on all registered loggers + // Example: +-// spdlog::apply_all([&](std::shared_ptr l) {l->flush();}); ++// yr_spdlog::apply_all([&](std::shared_ptr l) {l->flush();}); + SPDLOG_API void apply_all(const std::function)> &fun); + + // Drop the reference to the given logger +@@ -107,37 +107,37 @@ + // stop any running threads started by spdlog and clean registry loggers + SPDLOG_API void shutdown(); + +-// Automatic registration of loggers when using spdlog::create() or spdlog::create_async ++// Automatic registration of loggers when using yr_spdlog::create() or yr_spdlog::create_async + SPDLOG_API void set_automatic_registration(bool automatic_registration); + + // API for using default logger (stdout_color_mt), +-// e.g: spdlog::info("Message {}", 1); ++// e.g: yr_spdlog::info("Message {}", 1); + // +-// The default logger object can be accessed using the spdlog::default_logger(): ++// The default logger object can be accessed using the yr_spdlog::default_logger(): + // For example, to add another sink to it: +-// spdlog::default_logger()->sinks().push_back(some_sink); ++// yr_spdlog::default_logger()->sinks().push_back(some_sink); + // +-// The default logger can replaced using spdlog::set_default_logger(new_logger). ++// The default logger can replaced using yr_spdlog::set_default_logger(new_logger). + // For example, to replace it with a file logger. + // + // IMPORTANT: + // The default API is thread safe (for _mt loggers), but: + // set_default_logger() *should not* be used concurrently with the default API. +-// e.g do not call set_default_logger() from one thread while calling spdlog::info() from another. ++// e.g do not call set_default_logger() from one thread while calling yr_spdlog::info() from another. + +-SPDLOG_API std::shared_ptr default_logger(); ++SPDLOG_API std::shared_ptr default_logger(); + +-SPDLOG_API spdlog::logger *default_logger_raw(); ++SPDLOG_API yr_spdlog::logger *default_logger_raw(); + +-SPDLOG_API void set_default_logger(std::shared_ptr default_logger); ++SPDLOG_API void set_default_logger(std::shared_ptr default_logger); + + // Initialize logger level based on environment configs. + // + // Useful for applying SPDLOG_LEVEL to manually created loggers. + // + // Example: +-// auto mylogger = std::make_shared("mylogger", ...); +-// spdlog::apply_logger_env_levels(mylogger); ++// auto mylogger = std::make_shared("mylogger", ...); ++// yr_spdlog::apply_logger_env_levels(mylogger); + SPDLOG_API void apply_logger_env_levels(std::shared_ptr logger); + + template +@@ -286,7 +286,7 @@ + default_logger_raw()->critical(msg); + } + +-} // namespace spdlog ++} // namespace yr_spdlog + + // + // enable/disable log calls at compile time according to global level. +@@ -303,54 +303,54 @@ + + #ifndef SPDLOG_NO_SOURCE_LOC + # define SPDLOG_LOGGER_CALL(logger, level, ...) \ +- (logger)->log(spdlog::source_loc{__FILE__, __LINE__, SPDLOG_FUNCTION}, level, __VA_ARGS__) ++ (logger)->log(yr_spdlog::source_loc{__FILE__, __LINE__, SPDLOG_FUNCTION}, level, __VA_ARGS__) + #else +-# define SPDLOG_LOGGER_CALL(logger, level, ...) (logger)->log(spdlog::source_loc{}, level, __VA_ARGS__) ++# define SPDLOG_LOGGER_CALL(logger, level, ...) (logger)->log(yr_spdlog::source_loc{}, level, __VA_ARGS__) + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_TRACE +-# define SPDLOG_LOGGER_TRACE(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::trace, __VA_ARGS__) +-# define SPDLOG_TRACE(...) SPDLOG_LOGGER_TRACE(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_TRACE(logger, ...) SPDLOG_LOGGER_CALL(logger, yr_spdlog::level::trace, __VA_ARGS__) ++# define SPDLOG_TRACE(...) SPDLOG_LOGGER_TRACE(yr_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_TRACE(logger, ...) (void)0 + # define SPDLOG_TRACE(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_DEBUG +-# define SPDLOG_LOGGER_DEBUG(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::debug, __VA_ARGS__) +-# define SPDLOG_DEBUG(...) SPDLOG_LOGGER_DEBUG(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_DEBUG(logger, ...) SPDLOG_LOGGER_CALL(logger, yr_spdlog::level::debug, __VA_ARGS__) ++# define SPDLOG_DEBUG(...) SPDLOG_LOGGER_DEBUG(yr_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_DEBUG(logger, ...) (void)0 + # define SPDLOG_DEBUG(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_INFO +-# define SPDLOG_LOGGER_INFO(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::info, __VA_ARGS__) +-# define SPDLOG_INFO(...) SPDLOG_LOGGER_INFO(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_INFO(logger, ...) SPDLOG_LOGGER_CALL(logger, yr_spdlog::level::info, __VA_ARGS__) ++# define SPDLOG_INFO(...) SPDLOG_LOGGER_INFO(yr_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_INFO(logger, ...) (void)0 + # define SPDLOG_INFO(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_WARN +-# define SPDLOG_LOGGER_WARN(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::warn, __VA_ARGS__) +-# define SPDLOG_WARN(...) SPDLOG_LOGGER_WARN(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_WARN(logger, ...) SPDLOG_LOGGER_CALL(logger, yr_spdlog::level::warn, __VA_ARGS__) ++# define SPDLOG_WARN(...) SPDLOG_LOGGER_WARN(yr_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_WARN(logger, ...) (void)0 + # define SPDLOG_WARN(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_ERROR +-# define SPDLOG_LOGGER_ERROR(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::err, __VA_ARGS__) +-# define SPDLOG_ERROR(...) SPDLOG_LOGGER_ERROR(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_ERROR(logger, ...) SPDLOG_LOGGER_CALL(logger, yr_spdlog::level::err, __VA_ARGS__) ++# define SPDLOG_ERROR(...) SPDLOG_LOGGER_ERROR(yr_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_ERROR(logger, ...) (void)0 + # define SPDLOG_ERROR(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_CRITICAL +-# define SPDLOG_LOGGER_CRITICAL(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::critical, __VA_ARGS__) +-# define SPDLOG_CRITICAL(...) SPDLOG_LOGGER_CRITICAL(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_CRITICAL(logger, ...) SPDLOG_LOGGER_CALL(logger, yr_spdlog::level::critical, __VA_ARGS__) ++# define SPDLOG_CRITICAL(...) SPDLOG_LOGGER_CRITICAL(yr_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_CRITICAL(logger, ...) (void)0 + # define SPDLOG_CRITICAL(...) (void)0 +diff -ruN include/spdlog/spdlog-inl.h include/spdlog/spdlog-inl.h +--- include/spdlog/spdlog-inl.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/spdlog-inl.h 2025-07-02 15:16:19.634618058 +0800 +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + + SPDLOG_INLINE void initialize_logger(std::shared_ptr logger) + { +@@ -22,14 +22,14 @@ + return details::registry::instance().get(name); + } + +-SPDLOG_INLINE void set_formatter(std::unique_ptr formatter) ++SPDLOG_INLINE void set_formatter(std::unique_ptr formatter) + { + details::registry::instance().set_formatter(std::move(formatter)); + } + + SPDLOG_INLINE void set_pattern(std::string pattern, pattern_time_type time_type) + { +- set_formatter(std::unique_ptr(new pattern_formatter(std::move(pattern), time_type))); ++ set_formatter(std::unique_ptr(new pattern_formatter(std::move(pattern), time_type))); + } + + SPDLOG_INLINE void enable_backtrace(size_t n_messages) +@@ -102,17 +102,17 @@ + details::registry::instance().set_automatic_registration(automatic_registration); + } + +-SPDLOG_INLINE std::shared_ptr default_logger() ++SPDLOG_INLINE std::shared_ptr default_logger() + { + return details::registry::instance().default_logger(); + } + +-SPDLOG_INLINE spdlog::logger *default_logger_raw() ++SPDLOG_INLINE yr_spdlog::logger *default_logger_raw() + { + return details::registry::instance().get_default_raw(); + } + +-SPDLOG_INLINE void set_default_logger(std::shared_ptr default_logger) ++SPDLOG_INLINE void set_default_logger(std::shared_ptr default_logger) + { + details::registry::instance().set_default_logger(std::move(default_logger)); + } +@@ -122,4 +122,4 @@ + details::registry::instance().apply_logger_env_levels(std::move(logger)); + } + +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN include/spdlog/stopwatch.h include/spdlog/stopwatch.h +--- include/spdlog/stopwatch.h 2025-07-02 15:16:11.168618390 +0800 ++++ include/spdlog/stopwatch.h 2025-07-02 15:16:19.634618058 +0800 +@@ -11,10 +11,10 @@ + // + // Usage: + // +-// spdlog::stopwatch sw; ++// yr_spdlog::stopwatch sw; + // ... +-// spdlog::debug("Elapsed: {} seconds", sw); => "Elapsed 0.005116733 seconds" +-// spdlog::info("Elapsed: {:.6} seconds", sw); => "Elapsed 0.005163 seconds" ++// yr_spdlog::debug("Elapsed: {} seconds", sw); => "Elapsed 0.005116733 seconds" ++// yr_spdlog::info("Elapsed: {:.6} seconds", sw); => "Elapsed 0.005163 seconds" + // + // + // If other units are needed (e.g. millis instead of double), include "fmt/chrono.h" and use "duration_cast<..>(sw.elapsed())": +@@ -23,9 +23,9 @@ + //.. + // using std::chrono::duration_cast; + // using std::chrono::milliseconds; +-// spdlog::info("Elapsed {}", duration_cast(sw.elapsed())); => "Elapsed 5ms" ++// yr_spdlog::info("Elapsed {}", duration_cast(sw.elapsed())); => "Elapsed 5ms" + +-namespace spdlog { ++namespace yr_spdlog { + class stopwatch + { + using clock = std::chrono::steady_clock; +@@ -46,7 +46,7 @@ + start_tp_ = clock::now(); + } + }; +-} // namespace spdlog ++} // namespace yr_spdlog + + // Support for fmt formatting (e.g. "{:012.9}" or just "{}") + namespace +@@ -58,10 +58,10 @@ + { + + template<> +-struct formatter : formatter ++struct formatter : formatter + { + template +- auto format(const spdlog::stopwatch &sw, FormatContext &ctx) const -> decltype(ctx.out()) ++ auto format(const yr_spdlog::stopwatch &sw, FormatContext &ctx) const -> decltype(ctx.out()) + { + return formatter::format(sw.elapsed().count(), ctx); + } +diff -ruN src/color_sinks.cpp src/color_sinks.cpp +--- src/color_sinks.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ src/color_sinks.cpp 2025-07-02 15:16:19.634618058 +0800 +@@ -14,38 +14,38 @@ + // + #ifdef _WIN32 + # include +-template class SPDLOG_API spdlog::sinks::wincolor_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_stdout_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_stdout_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_stderr_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_stderr_sink; ++template class SPDLOG_API yr_spdlog::sinks::wincolor_sink; ++template class SPDLOG_API yr_spdlog::sinks::wincolor_sink; ++template class SPDLOG_API yr_spdlog::sinks::wincolor_stdout_sink; ++template class SPDLOG_API yr_spdlog::sinks::wincolor_stdout_sink; ++template class SPDLOG_API yr_spdlog::sinks::wincolor_stderr_sink; ++template class SPDLOG_API yr_spdlog::sinks::wincolor_stderr_sink; + #else + # include "spdlog/sinks/ansicolor_sink-inl.h" +-template class SPDLOG_API spdlog::sinks::ansicolor_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_stdout_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_stdout_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_stderr_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_stderr_sink; ++template class SPDLOG_API yr_spdlog::sinks::ansicolor_sink; ++template class SPDLOG_API yr_spdlog::sinks::ansicolor_sink; ++template class SPDLOG_API yr_spdlog::sinks::ansicolor_stdout_sink; ++template class SPDLOG_API yr_spdlog::sinks::ansicolor_stdout_sink; ++template class SPDLOG_API yr_spdlog::sinks::ansicolor_stderr_sink; ++template class SPDLOG_API yr_spdlog::sinks::ansicolor_stderr_sink; + #endif + + // factory methods for color loggers + #include "spdlog/sinks/stdout_color_sinks-inl.h" +-template SPDLOG_API std::shared_ptr spdlog::stdout_color_mt( ++template SPDLOG_API std::shared_ptr yr_spdlog::stdout_color_mt( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stdout_color_st( ++template SPDLOG_API std::shared_ptr yr_spdlog::stdout_color_st( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stderr_color_mt( ++template SPDLOG_API std::shared_ptr yr_spdlog::stderr_color_mt( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stderr_color_st( ++template SPDLOG_API std::shared_ptr yr_spdlog::stderr_color_st( + const std::string &logger_name, color_mode mode); + +-template SPDLOG_API std::shared_ptr spdlog::stdout_color_mt( ++template SPDLOG_API std::shared_ptr yr_spdlog::stdout_color_mt( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stdout_color_st( ++template SPDLOG_API std::shared_ptr yr_spdlog::stdout_color_st( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stderr_color_mt( ++template SPDLOG_API std::shared_ptr yr_spdlog::stderr_color_mt( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stderr_color_st( ++template SPDLOG_API std::shared_ptr yr_spdlog::stderr_color_st( + const std::string &logger_name, color_mode mode); +diff -ruN src/file_sinks.cpp src/file_sinks.cpp +--- src/file_sinks.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ src/file_sinks.cpp 2025-07-02 15:16:19.635618058 +0800 +@@ -12,9 +12,9 @@ + + #include + +-template class SPDLOG_API spdlog::sinks::basic_file_sink; +-template class SPDLOG_API spdlog::sinks::basic_file_sink; ++template class SPDLOG_API yr_spdlog::sinks::basic_file_sink; ++template class SPDLOG_API yr_spdlog::sinks::basic_file_sink; + + #include +-template class SPDLOG_API spdlog::sinks::rotating_file_sink; +-template class SPDLOG_API spdlog::sinks::rotating_file_sink; ++template class SPDLOG_API yr_spdlog::sinks::rotating_file_sink; ++template class SPDLOG_API yr_spdlog::sinks::rotating_file_sink; +diff -ruN src/spdlog.cpp src/spdlog.cpp +--- src/spdlog.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ src/spdlog.cpp 2025-07-02 15:16:19.635618058 +0800 +@@ -21,6 +21,6 @@ + #include + + // template instantiate logger constructor with sinks init list +-template SPDLOG_API spdlog::logger::logger(std::string name, sinks_init_list::iterator begin, sinks_init_list::iterator end); +-template class SPDLOG_API spdlog::sinks::base_sink; +-template class SPDLOG_API spdlog::sinks::base_sink; ++template SPDLOG_API yr_spdlog::logger::logger(std::string name, sinks_init_list::iterator begin, sinks_init_list::iterator end); ++template class SPDLOG_API yr_spdlog::sinks::base_sink; ++template class SPDLOG_API yr_spdlog::sinks::base_sink; +diff -ruN src/stdout_sinks.cpp src/stdout_sinks.cpp +--- src/stdout_sinks.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ src/stdout_sinks.cpp 2025-07-02 15:16:19.635618058 +0800 +@@ -11,19 +11,19 @@ + #include + #include + +-template class SPDLOG_API spdlog::sinks::stdout_sink_base; +-template class SPDLOG_API spdlog::sinks::stdout_sink_base; +-template class SPDLOG_API spdlog::sinks::stdout_sink; +-template class SPDLOG_API spdlog::sinks::stdout_sink; +-template class SPDLOG_API spdlog::sinks::stderr_sink; +-template class SPDLOG_API spdlog::sinks::stderr_sink; ++template class SPDLOG_API yr_spdlog::sinks::stdout_sink_base; ++template class SPDLOG_API yr_spdlog::sinks::stdout_sink_base; ++template class SPDLOG_API yr_spdlog::sinks::stdout_sink; ++template class SPDLOG_API yr_spdlog::sinks::stdout_sink; ++template class SPDLOG_API yr_spdlog::sinks::stderr_sink; ++template class SPDLOG_API yr_spdlog::sinks::stderr_sink; + +-template SPDLOG_API std::shared_ptr spdlog::stdout_logger_mt(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stdout_logger_st(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stderr_logger_mt(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stderr_logger_st(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr yr_spdlog::stdout_logger_mt(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr yr_spdlog::stdout_logger_st(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr yr_spdlog::stderr_logger_mt(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr yr_spdlog::stderr_logger_st(const std::string &logger_name); + +-template SPDLOG_API std::shared_ptr spdlog::stdout_logger_mt(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stdout_logger_st(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stderr_logger_mt(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stderr_logger_st(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr yr_spdlog::stdout_logger_mt(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr yr_spdlog::stdout_logger_st(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr yr_spdlog::stderr_logger_mt(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr yr_spdlog::stderr_logger_st(const std::string &logger_name); +diff -ruN tests/CMakeLists.txt tests/CMakeLists.txt +--- tests/CMakeLists.txt 2025-07-02 15:16:11.169618390 +0800 ++++ tests/CMakeLists.txt 2025-07-02 15:16:19.635618058 +0800 +@@ -77,10 +77,10 @@ + + # The compiled library tests + if(SPDLOG_BUILD_TESTS OR SPDLOG_BUILD_ALL) +- spdlog_prepare_test(spdlog-utests spdlog::spdlog) ++ spdlog_prepare_test(spdlog-utests yr_spdlog::spdlog) + endif() + + # The header-only library version tests + if(SPDLOG_BUILD_TESTS_HO OR SPDLOG_BUILD_ALL) +- spdlog_prepare_test(spdlog-utests-ho spdlog::spdlog_header_only) ++ spdlog_prepare_test(spdlog-utests-ho yr_spdlog::spdlog_header_only) + endif() +diff -ruN tests/test_async.cpp tests/test_async.cpp +--- tests/test_async.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_async.cpp 2025-07-02 15:16:19.635618058 +0800 +@@ -7,13 +7,13 @@ + + TEST_CASE("basic async test ", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + size_t overrun_counter = 0; + size_t queue_size = 128; + size_t messages = 256; + { +- auto tp = std::make_shared(queue_size, 1); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::block); ++ auto tp = std::make_shared(queue_size, 1); ++ auto logger = std::make_shared("as", test_sink, tp, yr_spdlog::async_overflow_policy::block); + for (size_t i = 0; i < messages; i++) + { + logger->info("Hello message #{}", i); +@@ -28,13 +28,13 @@ + + TEST_CASE("discard policy ", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + test_sink->set_delay(std::chrono::milliseconds(1)); + size_t queue_size = 4; + size_t messages = 1024; + +- auto tp = std::make_shared(queue_size, 1); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::overrun_oldest); ++ auto tp = std::make_shared(queue_size, 1); ++ auto logger = std::make_shared("as", test_sink, tp, yr_spdlog::async_overflow_policy::overrun_oldest); + for (size_t i = 0; i < messages; i++) + { + logger->info("Hello message"); +@@ -47,10 +47,10 @@ + { + size_t queue_size = 4; + size_t messages = 1024; +- spdlog::init_thread_pool(queue_size, 1); ++ yr_spdlog::init_thread_pool(queue_size, 1); + +- auto logger = spdlog::create_async_nb("as2"); +- auto test_sink = std::static_pointer_cast(logger->sinks()[0]); ++ auto logger = yr_spdlog::create_async_nb("as2"); ++ auto test_sink = std::static_pointer_cast(logger->sinks()[0]); + test_sink->set_delay(std::chrono::milliseconds(3)); + + for (size_t i = 0; i < messages; i++) +@@ -59,17 +59,17 @@ + } + + REQUIRE(test_sink->msg_counter() < messages); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("flush", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + size_t queue_size = 256; + size_t messages = 256; + { +- auto tp = std::make_shared(queue_size, 1); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::block); ++ auto tp = std::make_shared(queue_size, 1); ++ auto logger = std::make_shared("as", test_sink, tp, yr_spdlog::async_overflow_policy::block); + for (size_t i = 0; i < messages; i++) + { + logger->info("Hello message #{}", i); +@@ -85,24 +85,24 @@ + TEST_CASE("async periodic flush", "[async]") + { + +- auto logger = spdlog::create_async("as"); +- auto test_sink = std::static_pointer_cast(logger->sinks()[0]); ++ auto logger = yr_spdlog::create_async("as"); ++ auto test_sink = std::static_pointer_cast(logger->sinks()[0]); + +- spdlog::flush_every(std::chrono::seconds(1)); ++ yr_spdlog::flush_every(std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::milliseconds(1700)); + REQUIRE(test_sink->flush_counter() == 1); +- spdlog::flush_every(std::chrono::seconds(0)); +- spdlog::drop_all(); ++ yr_spdlog::flush_every(std::chrono::seconds(0)); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("tp->wait_empty() ", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + test_sink->set_delay(std::chrono::milliseconds(5)); + size_t messages = 100; + +- auto tp = std::make_shared(messages, 2); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::block); ++ auto tp = std::make_shared(messages, 2); ++ auto logger = std::make_shared("as", test_sink, tp, yr_spdlog::async_overflow_policy::block); + for (size_t i = 0; i < messages; i++) + { + logger->info("Hello message #{}", i); +@@ -116,13 +116,13 @@ + + TEST_CASE("multi threads", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + size_t queue_size = 128; + size_t messages = 256; + size_t n_threads = 10; + { +- auto tp = std::make_shared(queue_size, 1); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::block); ++ auto tp = std::make_shared(queue_size, 1); ++ auto logger = std::make_shared("as", test_sink, tp, yr_spdlog::async_overflow_policy::block); + + std::vector threads; + for (size_t i = 0; i < n_threads; i++) +@@ -151,11 +151,11 @@ + prepare_logdir(); + size_t messages = 1024; + size_t tp_threads = 1; +- spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); + { +- auto file_sink = std::make_shared(filename, true); +- auto tp = std::make_shared(messages, tp_threads); +- auto logger = std::make_shared("as", std::move(file_sink), std::move(tp)); ++ auto file_sink = std::make_shared(filename, true); ++ auto tp = std::make_shared(messages, tp_threads); ++ auto logger = std::make_shared("as", std::move(file_sink), std::move(tp)); + + for (size_t j = 0; j < messages; j++) + { +@@ -165,8 +165,8 @@ + + require_message_count(TEST_FILENAME, messages); + auto contents = file_contents(TEST_FILENAME); +- using spdlog::details::os::default_eol; +- REQUIRE(ends_with(contents, spdlog::fmt_lib::format("Hello message #1023{}", default_eol))); ++ using yr_spdlog::details::os::default_eol; ++ REQUIRE(ends_with(contents, yr_spdlog::fmt_lib::format("Hello message #1023{}", default_eol))); + } + + TEST_CASE("to_file multi-workers", "[async]") +@@ -174,11 +174,11 @@ + prepare_logdir(); + size_t messages = 1024 * 10; + size_t tp_threads = 10; +- spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); + { +- auto file_sink = std::make_shared(filename, true); +- auto tp = std::make_shared(messages, tp_threads); +- auto logger = std::make_shared("as", std::move(file_sink), std::move(tp)); ++ auto file_sink = std::make_shared(filename, true); ++ auto tp = std::make_shared(messages, tp_threads); ++ auto logger = std::make_shared("as", std::move(file_sink), std::move(tp)); + + for (size_t j = 0; j < messages; j++) + { +@@ -190,9 +190,9 @@ + + TEST_CASE("bad_tp", "[async]") + { +- auto test_sink = std::make_shared(); +- std::shared_ptr const empty_tp; +- auto logger = std::make_shared("as", test_sink, empty_tp); ++ auto test_sink = std::make_shared(); ++ std::shared_ptr const empty_tp; ++ auto logger = std::make_shared("as", test_sink, empty_tp); + logger->info("Please throw an exception"); + REQUIRE(test_sink->msg_counter() == 0); + } +diff -ruN tests/test_backtrace.cpp tests/test_backtrace.cpp +--- tests/test_backtrace.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_backtrace.cpp 2025-07-02 15:16:19.635618058 +0800 +@@ -5,11 +5,11 @@ + TEST_CASE("bactrace1", "[bactrace]") + { + +- using spdlog::sinks::test_sink_st; ++ using yr_spdlog::sinks::test_sink_st; + auto test_sink = std::make_shared(); + size_t backtrace_size = 5; + +- spdlog::logger logger("test-backtrace", test_sink); ++ yr_spdlog::logger logger("test-backtrace", test_sink); + logger.set_pattern("%v"); + logger.enable_backtrace(backtrace_size); + +@@ -33,11 +33,11 @@ + + TEST_CASE("bactrace-empty", "[bactrace]") + { +- using spdlog::sinks::test_sink_st; ++ using yr_spdlog::sinks::test_sink_st; + auto test_sink = std::make_shared(); + size_t backtrace_size = 5; + +- spdlog::logger logger("test-backtrace", test_sink); ++ yr_spdlog::logger logger("test-backtrace", test_sink); + logger.set_pattern("%v"); + logger.enable_backtrace(backtrace_size); + logger.dump_backtrace(); +@@ -46,14 +46,14 @@ + + TEST_CASE("bactrace-async", "[bactrace]") + { +- using spdlog::sinks::test_sink_mt; ++ using yr_spdlog::sinks::test_sink_mt; + auto test_sink = std::make_shared(); +- using spdlog::details::os::sleep_for_millis; ++ using yr_spdlog::details::os::sleep_for_millis; + + size_t backtrace_size = 5; + +- spdlog::init_thread_pool(120, 1); +- auto logger = std::make_shared("test-bactrace-async", test_sink, spdlog::thread_pool()); ++ yr_spdlog::init_thread_pool(120, 1); ++ auto logger = std::make_shared("test-bactrace-async", test_sink, yr_spdlog::thread_pool()); + logger->set_pattern("%v"); + logger->enable_backtrace(backtrace_size); + +diff -ruN tests/test_bin_to_hex.cpp tests/test_bin_to_hex.cpp +--- tests/test_bin_to_hex.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_bin_to_hex.cpp 2025-07-02 15:16:19.635618058 +0800 +@@ -5,89 +5,89 @@ + TEST_CASE("to_hex", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ yr_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0xc, 0xff, 0xff}; +- oss_logger.info("{}", spdlog::to_hex(v)); ++ oss_logger.info("{}", yr_spdlog::to_hex(v)); + + auto output = oss.str(); +- REQUIRE(ends_with(output, "0000: 09 0a 0b 0c ff ff" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(output, "0000: 09 0a 0b 0c ff ff" + std::string(yr_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_upper", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ yr_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0xc, 0xff, 0xff}; +- oss_logger.info("{:X}", spdlog::to_hex(v)); ++ oss_logger.info("{:X}", yr_spdlog::to_hex(v)); + + auto output = oss.str(); +- REQUIRE(ends_with(output, "0000: 09 0A 0B 0C FF FF" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(output, "0000: 09 0A 0B 0C FF FF" + std::string(yr_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_no_delimiter", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ yr_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0xc, 0xff, 0xff}; +- oss_logger.info("{:sX}", spdlog::to_hex(v)); ++ oss_logger.info("{:sX}", yr_spdlog::to_hex(v)); + + auto output = oss.str(); +- REQUIRE(ends_with(output, "0000: 090A0B0CFFFF" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(output, "0000: 090A0B0CFFFF" + std::string(yr_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_show_ascii", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ yr_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0x41, 0xc, 0x4b, 0xff, 0xff}; +- oss_logger.info("{:Xsa}", spdlog::to_hex(v, 8)); ++ oss_logger.info("{:Xsa}", yr_spdlog::to_hex(v, 8)); + +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF ...A.K.." + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF ...A.K.." + std::string(yr_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_different_size_per_line", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ yr_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0x41, 0xc, 0x4b, 0xff, 0xff}; + +- oss_logger.info("{:Xsa}", spdlog::to_hex(v, 10)); +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF ...A.K.." + std::string(spdlog::details::os::default_eol))); ++ oss_logger.info("{:Xsa}", yr_spdlog::to_hex(v, 10)); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF ...A.K.." + std::string(yr_spdlog::details::os::default_eol))); + +- oss_logger.info("{:Xs}", spdlog::to_hex(v, 10)); +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF" + std::string(spdlog::details::os::default_eol))); ++ oss_logger.info("{:Xs}", yr_spdlog::to_hex(v, 10)); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF" + std::string(yr_spdlog::details::os::default_eol))); + +- oss_logger.info("{:Xsa}", spdlog::to_hex(v, 6)); +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4B ...A.K" + std::string(spdlog::details::os::default_eol) + "0006: FFFF .." + +- std::string(spdlog::details::os::default_eol))); +- +- oss_logger.info("{:Xs}", spdlog::to_hex(v, 6)); +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4B" + std::string(spdlog::details::os::default_eol) + "0006: FFFF" + +- std::string(spdlog::details::os::default_eol))); ++ oss_logger.info("{:Xsa}", yr_spdlog::to_hex(v, 6)); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4B ...A.K" + std::string(yr_spdlog::details::os::default_eol) + "0006: FFFF .." + ++ std::string(yr_spdlog::details::os::default_eol))); ++ ++ oss_logger.info("{:Xs}", yr_spdlog::to_hex(v, 6)); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4B" + std::string(yr_spdlog::details::os::default_eol) + "0006: FFFF" + ++ std::string(yr_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_no_ascii", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ yr_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0x41, 0xc, 0x4b, 0xff, 0xff}; +- oss_logger.info("{:Xs}", spdlog::to_hex(v, 8)); ++ oss_logger.info("{:Xs}", yr_spdlog::to_hex(v, 8)); + +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF" + std::string(yr_spdlog::details::os::default_eol))); + +- oss_logger.info("{:Xsna}", spdlog::to_hex(v, 8)); ++ oss_logger.info("{:Xsna}", yr_spdlog::to_hex(v, 8)); + +- REQUIRE(ends_with(oss.str(), "090A0B410C4BFFFF" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(oss.str(), "090A0B410C4BFFFF" + std::string(yr_spdlog::details::os::default_eol))); + } +diff -ruN tests/test_cfg.cpp tests/test_cfg.cpp +--- tests/test_cfg.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_cfg.cpp 2025-07-02 15:16:19.635618058 +0800 +@@ -5,179 +5,179 @@ + #include + #include + +-using spdlog::cfg::load_argv_levels; +-using spdlog::cfg::load_env_levels; +-using spdlog::sinks::test_sink_st; ++using yr_spdlog::cfg::load_argv_levels; ++using yr_spdlog::cfg::load_env_levels; ++using yr_spdlog::sinks::test_sink_st; + + TEST_CASE("env", "[cfg]") + { +- spdlog::drop("l1"); +- auto l1 = spdlog::create("l1"); ++ yr_spdlog::drop("l1"); ++ auto l1 = yr_spdlog::create("l1"); + #ifdef CATCH_PLATFORM_WINDOWS + _putenv_s("SPDLOG_LEVEL", "l1=warn"); + #else + setenv("SPDLOG_LEVEL", "l1=warn", 1); + #endif + load_env_levels(); +- REQUIRE(l1->level() == spdlog::level::warn); +- spdlog::set_default_logger(spdlog::create("cfg-default")); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ REQUIRE(l1->level() == yr_spdlog::level::warn); ++ yr_spdlog::set_default_logger(yr_spdlog::create("cfg-default")); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::info); + } + + TEST_CASE("argv1", "[cfg]") + { +- spdlog::drop("l1"); ++ yr_spdlog::drop("l1"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=warn"}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ auto l1 = yr_spdlog::create("l1"); ++ REQUIRE(l1->level() == yr_spdlog::level::warn); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::info); + } + + TEST_CASE("argv2", "[cfg]") + { +- spdlog::drop("l1"); ++ yr_spdlog::drop("l1"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=warn,trace"}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::trace); ++ auto l1 = yr_spdlog::create("l1"); ++ REQUIRE(l1->level() == yr_spdlog::level::warn); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::trace); + } + + TEST_CASE("argv3", "[cfg]") + { +- spdlog::set_level(spdlog::level::trace); ++ yr_spdlog::set_level(yr_spdlog::level::trace); + +- spdlog::drop("l1"); ++ yr_spdlog::drop("l1"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=junk_name=warn"}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::trace); ++ auto l1 = yr_spdlog::create("l1"); ++ REQUIRE(l1->level() == yr_spdlog::level::trace); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::trace); + } + + TEST_CASE("argv4", "[cfg]") + { +- spdlog::set_level(spdlog::level::info); +- spdlog::drop("l1"); ++ yr_spdlog::set_level(yr_spdlog::level::info); ++ yr_spdlog::drop("l1"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=junk"}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::info); ++ auto l1 = yr_spdlog::create("l1"); ++ REQUIRE(l1->level() == yr_spdlog::level::info); + } + + TEST_CASE("argv5", "[cfg]") + { +- spdlog::set_level(spdlog::level::info); +- spdlog::drop("l1"); ++ yr_spdlog::set_level(yr_spdlog::level::info); ++ yr_spdlog::drop("l1"); + const char *argv[] = {"ignore", "ignore", "SPDLOG_LEVEL=l1=warn,trace"}; + load_argv_levels(3, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::trace); +- spdlog::set_level(spdlog::level::info); ++ auto l1 = yr_spdlog::create("l1"); ++ REQUIRE(l1->level() == yr_spdlog::level::warn); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::trace); ++ yr_spdlog::set_level(yr_spdlog::level::info); + } + + TEST_CASE("argv6", "[cfg]") + { +- spdlog::set_level(spdlog::level::err); ++ yr_spdlog::set_level(yr_spdlog::level::err); + const char *argv[] = {""}; + load_argv_levels(1, argv); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::err); +- spdlog::set_level(spdlog::level::info); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::err); ++ yr_spdlog::set_level(yr_spdlog::level::info); + } + + TEST_CASE("argv7", "[cfg]") + { +- spdlog::set_level(spdlog::level::err); ++ yr_spdlog::set_level(yr_spdlog::level::err); + const char *argv[] = {""}; + load_argv_levels(0, argv); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::err); +- spdlog::set_level(spdlog::level::info); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::err); ++ yr_spdlog::set_level(yr_spdlog::level::info); + } + + TEST_CASE("level-not-set-test1", "[cfg]") + { +- spdlog::drop("l1"); ++ yr_spdlog::drop("l1"); + const char *argv[] = {"ignore", ""}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- l1->set_level(spdlog::level::trace); +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ auto l1 = yr_spdlog::create("l1"); ++ l1->set_level(yr_spdlog::level::trace); ++ REQUIRE(l1->level() == yr_spdlog::level::trace); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::info); + } + + TEST_CASE("level-not-set-test2", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ yr_spdlog::drop("l1"); ++ yr_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=trace"}; + +- auto l1 = spdlog::create("l1"); +- l1->set_level(spdlog::level::warn); +- auto l2 = spdlog::create("l2"); +- l2->set_level(spdlog::level::warn); ++ auto l1 = yr_spdlog::create("l1"); ++ l1->set_level(yr_spdlog::level::warn); ++ auto l2 = yr_spdlog::create("l2"); ++ l2->set_level(yr_spdlog::level::warn); + + load_argv_levels(2, argv); + +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(l2->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ REQUIRE(l1->level() == yr_spdlog::level::trace); ++ REQUIRE(l2->level() == yr_spdlog::level::warn); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::info); + } + + TEST_CASE("level-not-set-test3", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ yr_spdlog::drop("l1"); ++ yr_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=trace"}; + + load_argv_levels(2, argv); + +- auto l1 = spdlog::create("l1"); +- auto l2 = spdlog::create("l2"); ++ auto l1 = yr_spdlog::create("l1"); ++ auto l2 = yr_spdlog::create("l2"); + +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(l2->level() == spdlog::level::info); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ REQUIRE(l1->level() == yr_spdlog::level::trace); ++ REQUIRE(l2->level() == yr_spdlog::level::info); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::info); + } + + TEST_CASE("level-not-set-test4", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ yr_spdlog::drop("l1"); ++ yr_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=trace,warn"}; + + load_argv_levels(2, argv); + +- auto l1 = spdlog::create("l1"); +- auto l2 = spdlog::create("l2"); ++ auto l1 = yr_spdlog::create("l1"); ++ auto l2 = yr_spdlog::create("l2"); + +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(l2->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::warn); ++ REQUIRE(l1->level() == yr_spdlog::level::trace); ++ REQUIRE(l2->level() == yr_spdlog::level::warn); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::warn); + } + + TEST_CASE("level-not-set-test5", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ yr_spdlog::drop("l1"); ++ yr_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=junk,warn"}; + + load_argv_levels(2, argv); + +- auto l1 = spdlog::create("l1"); +- auto l2 = spdlog::create("l2"); ++ auto l1 = yr_spdlog::create("l1"); ++ auto l2 = yr_spdlog::create("l2"); + +- REQUIRE(l1->level() == spdlog::level::warn); +- REQUIRE(l2->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::warn); ++ REQUIRE(l1->level() == yr_spdlog::level::warn); ++ REQUIRE(l2->level() == yr_spdlog::level::warn); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::warn); + } + + TEST_CASE("restore-to-default", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ yr_spdlog::drop("l1"); ++ yr_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=info"}; + load_argv_levels(2, argv); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ REQUIRE(yr_spdlog::default_logger()->level() == yr_spdlog::level::info); + } +diff -ruN tests/test_create_dir.cpp tests/test_create_dir.cpp +--- tests/test_create_dir.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_create_dir.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -3,10 +3,10 @@ + */ + #include "includes.h" + +-using spdlog::details::os::create_dir; +-using spdlog::details::os::path_exists; ++using yr_spdlog::details::os::create_dir; ++using yr_spdlog::details::os::path_exists; + +-bool try_create_dir(const spdlog::filename_t &path, const spdlog::filename_t &normalized_path) ++bool try_create_dir(const yr_spdlog::filename_t &path, const yr_spdlog::filename_t &normalized_path) + { + auto rv = create_dir(path); + REQUIRE(rv == true); +@@ -36,7 +36,7 @@ + TEST_CASE("create_invalid_dir", "[create_dir]") + { + REQUIRE(create_dir(SPDLOG_FILENAME_T("")) == false); +- REQUIRE(create_dir(spdlog::filename_t{}) == false); ++ REQUIRE(create_dir(yr_spdlog::filename_t{}) == false); + #ifdef __linux__ + REQUIRE(create_dir("/proc/spdlog-utest") == false); + #endif +@@ -44,7 +44,7 @@ + + TEST_CASE("dir_name", "[create_dir]") + { +- using spdlog::details::os::dir_name; ++ using yr_spdlog::details::os::dir_name; + REQUIRE(dir_name(SPDLOG_FILENAME_T("")).empty()); + REQUIRE(dir_name(SPDLOG_FILENAME_T("dir")).empty()); + +diff -ruN tests/test_custom_callbacks.cpp tests/test_custom_callbacks.cpp +--- tests/test_custom_callbacks.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_custom_callbacks.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -10,16 +10,16 @@ + TEST_CASE("custom_callback_logger", "[custom_callback_logger]") + { + std::vector lines; +- spdlog::pattern_formatter formatter; +- auto callback_logger = std::make_shared([&](const spdlog::details::log_msg &msg) { +- spdlog::memory_buf_t formatted; ++ yr_spdlog::pattern_formatter formatter; ++ auto callback_logger = std::make_shared([&](const yr_spdlog::details::log_msg &msg) { ++ yr_spdlog::memory_buf_t formatted; + formatter.format(msg, formatted); +- auto eol_len = strlen(spdlog::details::os::default_eol); ++ auto eol_len = strlen(yr_spdlog::details::os::default_eol); + lines.emplace_back(formatted.begin(), formatted.end() - eol_len); + }); +- std::shared_ptr test_sink(new spdlog::sinks::test_sink_st); ++ std::shared_ptr test_sink(new yr_spdlog::sinks::test_sink_st); + +- spdlog::logger logger("test-callback", {callback_logger, test_sink}); ++ yr_spdlog::logger logger("test-callback", {callback_logger, test_sink}); + + logger.info("test message 1"); + logger.info("test message 2"); +@@ -30,5 +30,5 @@ + REQUIRE(lines[0] == ref_lines[0]); + REQUIRE(lines[1] == ref_lines[1]); + REQUIRE(lines[2] == ref_lines[2]); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } +diff -ruN tests/test_daily_logger.cpp tests/test_daily_logger.cpp +--- tests/test_daily_logger.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_daily_logger.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -4,16 +4,16 @@ + #include "includes.h" + + #ifdef SPDLOG_USE_STD_FORMAT +-using filename_memory_buf_t = std::basic_string; ++using filename_memory_buf_t = std::basic_string; + #else +-using filename_memory_buf_t = fmt::basic_memory_buffer; ++using filename_memory_buf_t = fmt::basic_memory_buffer; + #endif + + #ifdef SPDLOG_WCHAR_FILENAMES + std::string filename_buf_to_utf8string(const filename_memory_buf_t &w) + { +- spdlog::memory_buf_t buf; +- spdlog::details::os::wstr_to_utf8buf(spdlog::wstring_view_t(w.data(), w.size()), buf); ++ yr_spdlog::memory_buf_t buf; ++ yr_spdlog::details::os::wstr_to_utf8buf(yr_spdlog::wstring_view_t(w.data(), w.size()), buf); + return SPDLOG_BUF_TO_STRING(buf); + } + #else +@@ -25,18 +25,18 @@ + + TEST_CASE("daily_logger with dateonly calculator", "[daily_logger]") + { +- using sink_type = spdlog::sinks::daily_file_sink; ++ using sink_type = yr_spdlog::sinks::daily_file_sink; + + prepare_logdir(); + + // calculate filename (time based) +- spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_dateonly"); +- std::tm tm = spdlog::details::os::localtime(); ++ yr_spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_dateonly"); ++ std::tm tm = yr_spdlog::details::os::localtime(); + filename_memory_buf_t w; +- spdlog::fmt_lib::format_to( ++ yr_spdlog::fmt_lib::format_to( + std::back_inserter(w), SPDLOG_FILENAME_T("{}_{:04d}-{:02d}-{:02d}"), basename, tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday); + +- auto logger = spdlog::create("logger", basename, 0, 0); ++ auto logger = yr_spdlog::create("logger", basename, 0, 0); + for (int i = 0; i < 10; ++i) + { + +@@ -49,10 +49,10 @@ + + struct custom_daily_file_name_calculator + { +- static spdlog::filename_t calc_filename(const spdlog::filename_t &basename, const tm &now_tm) ++ static yr_spdlog::filename_t calc_filename(const yr_spdlog::filename_t &basename, const tm &now_tm) + { + filename_memory_buf_t w; +- spdlog::fmt_lib::format_to(std::back_inserter(w), SPDLOG_FILENAME_T("{}{:04d}{:02d}{:02d}"), basename, now_tm.tm_year + 1900, ++ yr_spdlog::fmt_lib::format_to(std::back_inserter(w), SPDLOG_FILENAME_T("{}{:04d}{:02d}{:02d}"), basename, now_tm.tm_year + 1900, + now_tm.tm_mon + 1, now_tm.tm_mday); + + return SPDLOG_BUF_TO_STRING(w); +@@ -61,18 +61,18 @@ + + TEST_CASE("daily_logger with custom calculator", "[daily_logger]") + { +- using sink_type = spdlog::sinks::daily_file_sink; ++ using sink_type = yr_spdlog::sinks::daily_file_sink; + + prepare_logdir(); + + // calculate filename (time based) +- spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_dateonly"); +- std::tm tm = spdlog::details::os::localtime(); ++ yr_spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_dateonly"); ++ std::tm tm = yr_spdlog::details::os::localtime(); + filename_memory_buf_t w; +- spdlog::fmt_lib::format_to( ++ yr_spdlog::fmt_lib::format_to( + std::back_inserter(w), SPDLOG_FILENAME_T("{}{:04d}{:02d}{:02d}"), basename, tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday); + +- auto logger = spdlog::create("logger", basename, 0, 0); ++ auto logger = yr_spdlog::create("logger", basename, 0, 0); + for (int i = 0; i < 10; ++i) + { + logger->info("Test message {}", i); +@@ -89,19 +89,19 @@ + + TEST_CASE("rotating_file_sink::calc_filename1", "[rotating_file_sink]") + { +- auto filename = spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated.txt"), 3); ++ auto filename = yr_spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated.txt"), 3); + REQUIRE(filename == SPDLOG_FILENAME_T("rotated.3.txt")); + } + + TEST_CASE("rotating_file_sink::calc_filename2", "[rotating_file_sink]") + { +- auto filename = spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated"), 3); ++ auto filename = yr_spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated"), 3); + REQUIRE(filename == SPDLOG_FILENAME_T("rotated.3")); + } + + TEST_CASE("rotating_file_sink::calc_filename3", "[rotating_file_sink]") + { +- auto filename = spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated.txt"), 0); ++ auto filename = yr_spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated.txt"), 0); + REQUIRE(filename == SPDLOG_FILENAME_T("rotated.txt")); + } + +@@ -114,43 +114,43 @@ + { + // daily_YYYY-MM-DD_hh-mm.txt + auto filename = +- spdlog::sinks::daily_filename_calculator::calc_filename(SPDLOG_FILENAME_T("daily.txt"), spdlog::details::os::localtime()); ++ yr_spdlog::sinks::daily_filename_calculator::calc_filename(SPDLOG_FILENAME_T("daily.txt"), yr_spdlog::details::os::localtime()); + // date regex based on https://www.regular-expressions.info/dates.html +- std::basic_regex re( ++ std::basic_regex re( + SPDLOG_FILENAME_T(R"(^daily_(19|20)\d\d-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\.txt$)")); +- std::match_results match; ++ std::match_results match; + REQUIRE(std::regex_match(filename, match, re)); + } + #endif + + TEST_CASE("daily_file_sink::daily_filename_format_calculator", "[daily_file_sink]") + { +- std::tm tm = spdlog::details::os::localtime(); ++ std::tm tm = yr_spdlog::details::os::localtime(); + // example-YYYY-MM-DD.log +- auto filename = spdlog::sinks::daily_filename_format_calculator::calc_filename(SPDLOG_FILENAME_T("example-%Y-%m-%d.log"), tm); ++ auto filename = yr_spdlog::sinks::daily_filename_format_calculator::calc_filename(SPDLOG_FILENAME_T("example-%Y-%m-%d.log"), tm); + + REQUIRE(filename == +- spdlog::fmt_lib::format(SPDLOG_FILENAME_T("example-{:04d}-{:02d}-{:02d}.log"), tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday)); ++ yr_spdlog::fmt_lib::format(SPDLOG_FILENAME_T("example-{:04d}-{:02d}-{:02d}.log"), tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday)); + } + + /* Test removal of old files */ +-static spdlog::details::log_msg create_msg(std::chrono::seconds offset) ++static yr_spdlog::details::log_msg create_msg(std::chrono::seconds offset) + { +- using spdlog::log_clock; +- spdlog::details::log_msg msg{"test", spdlog::level::info, "Hello Message"}; ++ using yr_spdlog::log_clock; ++ yr_spdlog::details::log_msg msg{"test", yr_spdlog::level::info, "Hello Message"}; + msg.time = log_clock::now() + offset; + return msg; + } + + static void test_rotate(int days_to_run, uint16_t max_days, uint16_t expected_n_files) + { +- using spdlog::log_clock; +- using spdlog::details::log_msg; +- using spdlog::sinks::daily_file_sink_st; ++ using yr_spdlog::log_clock; ++ using yr_spdlog::details::log_msg; ++ using yr_spdlog::sinks::daily_file_sink_st; + + prepare_logdir(); + +- spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_rotate.txt"); ++ yr_spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_rotate.txt"); + daily_file_sink_st sink{basename, 2, 30, true, max_days}; + + // simulate messages with 24 intervals +diff -ruN tests/test_dup_filter.cpp tests/test_dup_filter.cpp +--- tests/test_dup_filter.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_dup_filter.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -4,8 +4,8 @@ + + TEST_CASE("dup_filter_test1", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_st; +- using spdlog::sinks::test_sink_mt; ++ using yr_spdlog::sinks::dup_filter_sink_st; ++ using yr_spdlog::sinks::test_sink_mt; + + dup_filter_sink_st dup_sink{std::chrono::seconds{5}}; + auto test_sink = std::make_shared(); +@@ -13,7 +13,7 @@ + + for (int i = 0; i < 10; i++) + { +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message1"}); + } + + REQUIRE(test_sink->msg_counter() == 1); +@@ -21,8 +21,8 @@ + + TEST_CASE("dup_filter_test2", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_st; +- using spdlog::sinks::test_sink_mt; ++ using yr_spdlog::sinks::dup_filter_sink_st; ++ using yr_spdlog::sinks::test_sink_mt; + + dup_filter_sink_st dup_sink{std::chrono::seconds{0}}; + auto test_sink = std::make_shared(); +@@ -30,7 +30,7 @@ + + for (int i = 0; i < 10; i++) + { +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message1"}); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + +@@ -39,8 +39,8 @@ + + TEST_CASE("dup_filter_test3", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_st; +- using spdlog::sinks::test_sink_mt; ++ using yr_spdlog::sinks::dup_filter_sink_st; ++ using yr_spdlog::sinks::test_sink_mt; + + dup_filter_sink_st dup_sink{std::chrono::seconds{1}}; + auto test_sink = std::make_shared(); +@@ -48,8 +48,8 @@ + + for (int i = 0; i < 10; i++) + { +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message2"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message1"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message2"}); + } + + REQUIRE(test_sink->msg_counter() == 20); +@@ -57,33 +57,33 @@ + + TEST_CASE("dup_filter_test4", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_mt; +- using spdlog::sinks::test_sink_mt; ++ using yr_spdlog::sinks::dup_filter_sink_mt; ++ using yr_spdlog::sinks::test_sink_mt; + + dup_filter_sink_mt dup_sink{std::chrono::milliseconds{10}}; + auto test_sink = std::make_shared(); + dup_sink.add_sink(test_sink); + +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message"}); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message"}); + REQUIRE(test_sink->msg_counter() == 2); + } + + TEST_CASE("dup_filter_test5", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_mt; +- using spdlog::sinks::test_sink_mt; ++ using yr_spdlog::sinks::dup_filter_sink_mt; ++ using yr_spdlog::sinks::test_sink_mt; + + dup_filter_sink_mt dup_sink{std::chrono::seconds{5}}; + auto test_sink = std::make_shared(); + test_sink->set_pattern("%v"); + dup_sink.add_sink(test_sink); + +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message2"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message1"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message1"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message1"}); ++ dup_sink.log(yr_spdlog::details::log_msg{"test", yr_spdlog::level::info, "message2"}); + + REQUIRE(test_sink->msg_counter() == 3); // skip 2 messages but log the "skipped.." message before message2 + REQUIRE(test_sink->lines()[1] == "Skipped 2 duplicate messages.."); +diff -ruN tests/test_errors.cpp tests/test_errors.cpp +--- tests/test_errors.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_errors.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -8,10 +8,10 @@ + #define SIMPLE_LOG "test_logs/simple_log.txt" + #define SIMPLE_ASYNC_LOG "test_logs/simple_async_log.txt" + +-class failing_sink : public spdlog::sinks::base_sink ++class failing_sink : public yr_spdlog::sinks::base_sink + { + protected: +- void sink_it_(const spdlog::details::log_msg &) final ++ void sink_it_(const yr_spdlog::details::log_msg &) final + { + throw std::runtime_error("some error happened during log"); + } +@@ -28,24 +28,24 @@ + TEST_CASE("default_error_handler", "[errors]") + { + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); ++ yr_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); + +- auto logger = spdlog::create("test-error", filename, true); ++ auto logger = yr_spdlog::create("test-error", filename, true); + logger->set_pattern("%v"); + logger->info(SPDLOG_FMT_RUNTIME("Test message {} {}"), 1); + logger->info("Test message {}", 2); + logger->flush(); +- using spdlog::details::os::default_eol; +- REQUIRE(file_contents(SIMPLE_LOG) == spdlog::fmt_lib::format("Test message 2{}", default_eol)); ++ using yr_spdlog::details::os::default_eol; ++ REQUIRE(file_contents(SIMPLE_LOG) == yr_spdlog::fmt_lib::format("Test message 2{}", default_eol)); + REQUIRE(count_lines(SIMPLE_LOG) == 1); + } + + TEST_CASE("custom_error_handler", "[errors]") + { + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); +- auto logger = spdlog::create("logger", filename, true); +- logger->flush_on(spdlog::level::info); ++ yr_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); ++ auto logger = yr_spdlog::create("logger", filename, true); ++ logger->flush_on(yr_spdlog::level::info); + logger->set_error_handler([=](const std::string &) { throw custom_ex(); }); + logger->info("Good message #1"); + +@@ -57,16 +57,16 @@ + + TEST_CASE("default_error_handler2", "[errors]") + { +- spdlog::drop_all(); +- auto logger = spdlog::create("failed_logger"); ++ yr_spdlog::drop_all(); ++ auto logger = yr_spdlog::create("failed_logger"); + logger->set_error_handler([=](const std::string &) { throw custom_ex(); }); + REQUIRE_THROWS_AS(logger->info("Some message"), custom_ex); + } + + TEST_CASE("flush_error_handler", "[errors]") + { +- spdlog::drop_all(); +- auto logger = spdlog::create("failed_logger"); ++ yr_spdlog::drop_all(); ++ auto logger = yr_spdlog::create("failed_logger"); + logger->set_error_handler([=](const std::string &) { throw custom_ex(); }); + REQUIRE_THROWS_AS(logger->flush(), custom_ex); + } +@@ -77,10 +77,10 @@ + prepare_logdir(); + std::string err_msg("log failed with some msg"); + +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_ASYNC_LOG); ++ yr_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_ASYNC_LOG); + { +- spdlog::init_thread_pool(128, 1); +- auto logger = spdlog::create_async("logger", filename, true); ++ yr_spdlog::init_thread_pool(128, 1); ++ auto logger = yr_spdlog::create_async("logger", filename, true); + logger->set_error_handler([=](const std::string &) { + std::ofstream ofs("test_logs/custom_err.txt"); + if (!ofs) +@@ -92,9 +92,9 @@ + logger->info("Good message #1"); + logger->info(SPDLOG_FMT_RUNTIME("Bad format msg {} {}"), "xxx"); + logger->info("Good message #2"); +- spdlog::drop("logger"); // force logger to drain the queue and shutdown ++ yr_spdlog::drop("logger"); // force logger to drain the queue and shutdown + } +- spdlog::init_thread_pool(128, 1); ++ yr_spdlog::init_thread_pool(128, 1); + require_message_count(SIMPLE_ASYNC_LOG, 2); + REQUIRE(file_contents("test_logs/custom_err.txt") == err_msg); + } +@@ -106,9 +106,9 @@ + prepare_logdir(); + std::string err_msg("This is async handler error message"); + { +- spdlog::details::os::create_dir(SPDLOG_FILENAME_T("test_logs")); +- spdlog::init_thread_pool(128, 1); +- auto logger = spdlog::create_async("failed_logger"); ++ yr_spdlog::details::os::create_dir(SPDLOG_FILENAME_T("test_logs")); ++ yr_spdlog::init_thread_pool(128, 1); ++ auto logger = yr_spdlog::create_async("failed_logger"); + logger->set_error_handler([=](const std::string &) { + std::ofstream ofs("test_logs/custom_err2.txt"); + if (!ofs) +@@ -116,9 +116,9 @@ + ofs << err_msg; + }); + logger->info("Hello failure"); +- spdlog::drop("failed_logger"); // force logger to drain the queue and shutdown ++ yr_spdlog::drop("failed_logger"); // force logger to drain the queue and shutdown + } + +- spdlog::init_thread_pool(128, 1); ++ yr_spdlog::init_thread_pool(128, 1); + REQUIRE(file_contents("test_logs/custom_err2.txt") == err_msg); + } +diff -ruN tests/test_eventlog.cpp tests/test_eventlog.cpp +--- tests/test_eventlog.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_eventlog.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -46,16 +46,16 @@ + REQUIRE((expected_time_generated - record->TimeGenerated) <= 3u); + + std::string message_in_log(((char *)record + record->StringOffset)); +- REQUIRE(message_in_log == expected_contents + spdlog::details::os::default_eol); ++ REQUIRE(message_in_log == expected_contents + yr_spdlog::details::os::default_eol); + } + + TEST_CASE("eventlog", "[eventlog]") + { +- using namespace spdlog; ++ using namespace yr_spdlog; + + auto test_sink = std::make_shared(TEST_SOURCE); + +- spdlog::logger test_logger("eventlog", test_sink); ++ yr_spdlog::logger test_logger("eventlog", test_sink); + test_logger.set_level(level::trace); + + test_sink->set_pattern("%v"); +diff -ruN tests/test_file_helper.cpp tests/test_file_helper.cpp +--- tests/test_file_helper.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_file_helper.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -5,12 +5,12 @@ + + #define TEST_FILENAME "test_logs/file_helper_test.txt" + +-using spdlog::details::file_helper; ++using yr_spdlog::details::file_helper; + + static void write_with_helper(file_helper &helper, size_t howmany) + { +- spdlog::memory_buf_t formatted; +- spdlog::fmt_lib::format_to(std::back_inserter(formatted), "{}", std::string(howmany, '1')); ++ yr_spdlog::memory_buf_t formatted; ++ yr_spdlog::fmt_lib::format_to(std::back_inserter(formatted), "{}", std::string(howmany, '1')); + helper.write(formatted); + helper.flush(); + } +@@ -20,7 +20,7 @@ + prepare_logdir(); + + file_helper helper; +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + helper.open(target_filename); + REQUIRE(helper.filename() == target_filename); + } +@@ -28,7 +28,7 @@ + TEST_CASE("file_helper_size", "[file_helper::size()]") + { + prepare_logdir(); +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + size_t expected_size = 123; + { + file_helper helper; +@@ -42,7 +42,7 @@ + TEST_CASE("file_helper_reopen", "[file_helper::reopen()]") + { + prepare_logdir(); +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + file_helper helper; + helper.open(target_filename); + write_with_helper(helper, 12); +@@ -54,7 +54,7 @@ + TEST_CASE("file_helper_reopen2", "[file_helper::reopen(false)]") + { + prepare_logdir(); +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + size_t expected_size = 14; + file_helper helper; + helper.open(target_filename); +@@ -64,15 +64,15 @@ + REQUIRE(helper.size() == expected_size); + } + +-static void test_split_ext(const spdlog::filename_t::value_type *fname, const spdlog::filename_t::value_type *expect_base, +- const spdlog::filename_t::value_type *expect_ext) ++static void test_split_ext(const yr_spdlog::filename_t::value_type *fname, const yr_spdlog::filename_t::value_type *expect_base, ++ const yr_spdlog::filename_t::value_type *expect_ext) + { +- spdlog::filename_t filename(fname); +- spdlog::filename_t expected_base(expect_base); +- spdlog::filename_t expected_ext(expect_ext); ++ yr_spdlog::filename_t filename(fname); ++ yr_spdlog::filename_t expected_base(expect_base); ++ yr_spdlog::filename_t expected_ext(expect_ext); + +- spdlog::filename_t basename; +- spdlog::filename_t ext; ++ yr_spdlog::filename_t basename; ++ yr_spdlog::filename_t ext; + std::tie(basename, ext) = file_helper::split_by_extension(filename); + REQUIRE(basename == expected_base); + REQUIRE(ext == expected_ext); +@@ -111,32 +111,32 @@ + }; + prepare_logdir(); + +- spdlog::filename_t test_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t test_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + // define event handles that update vector of flags when called + std::vector events; +- spdlog::file_event_handlers handlers; +- handlers.before_open = [&](spdlog::filename_t filename) { ++ yr_spdlog::file_event_handlers handlers; ++ handlers.before_open = [&](yr_spdlog::filename_t filename) { + REQUIRE(filename == test_filename); + events.push_back(flags::before_open); + }; +- handlers.after_open = [&](spdlog::filename_t filename, std::FILE *fstream) { ++ handlers.after_open = [&](yr_spdlog::filename_t filename, std::FILE *fstream) { + REQUIRE(filename == test_filename); + REQUIRE(fstream); + fputs("after_open\n", fstream); + events.push_back(flags::after_open); + }; +- handlers.before_close = [&](spdlog::filename_t filename, std::FILE *fstream) { ++ handlers.before_close = [&](yr_spdlog::filename_t filename, std::FILE *fstream) { + REQUIRE(filename == test_filename); + REQUIRE(fstream); + fputs("before_close\n", fstream); + events.push_back(flags::before_close); + }; +- handlers.after_close = [&](spdlog::filename_t filename) { ++ handlers.after_close = [&](yr_spdlog::filename_t filename) { + REQUIRE(filename == test_filename); + events.push_back(flags::after_close); + }; + { +- spdlog::details::file_helper helper{handlers}; ++ yr_spdlog::details::file_helper helper{handlers}; + REQUIRE(events.empty()); + + helper.open(test_filename); +@@ -158,11 +158,11 @@ + TEST_CASE("file_helper_open", "[file_helper]") + { + prepare_logdir(); +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + file_helper helper; + helper.open(target_filename); + helper.close(); + + target_filename += SPDLOG_FILENAME_T("/invalid"); +- REQUIRE_THROWS_AS(helper.open(target_filename), spdlog::spdlog_ex); ++ REQUIRE_THROWS_AS(helper.open(target_filename), yr_spdlog::spdlog_ex); + } +diff -ruN tests/test_file_logging.cpp tests/test_file_logging.cpp +--- tests/test_file_logging.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_file_logging.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -9,9 +9,9 @@ + TEST_CASE("simple_file_logger", "[simple_logger]") + { + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); ++ yr_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); + +- auto logger = spdlog::create("logger", filename); ++ auto logger = yr_spdlog::create("logger", filename); + logger->set_pattern("%v"); + + logger->info("Test message {}", 1); +@@ -19,19 +19,19 @@ + + logger->flush(); + require_message_count(SIMPLE_LOG, 2); +- using spdlog::details::os::default_eol; +- REQUIRE(file_contents(SIMPLE_LOG) == spdlog::fmt_lib::format("Test message 1{}Test message 2{}", default_eol, default_eol)); ++ using yr_spdlog::details::os::default_eol; ++ REQUIRE(file_contents(SIMPLE_LOG) == yr_spdlog::fmt_lib::format("Test message 1{}Test message 2{}", default_eol, default_eol)); + } + + TEST_CASE("flush_on", "[flush_on]") + { + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); ++ yr_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); + +- auto logger = spdlog::create("logger", filename); ++ auto logger = yr_spdlog::create("logger", filename); + logger->set_pattern("%v"); +- logger->set_level(spdlog::level::trace); +- logger->flush_on(spdlog::level::info); ++ logger->set_level(yr_spdlog::level::trace); ++ logger->flush_on(yr_spdlog::level::info); + logger->trace("Should not be flushed"); + REQUIRE(count_lines(SIMPLE_LOG) == 0); + +@@ -39,17 +39,17 @@ + logger->info("Test message {}", 2); + + require_message_count(SIMPLE_LOG, 3); +- using spdlog::details::os::default_eol; ++ using yr_spdlog::details::os::default_eol; + REQUIRE(file_contents(SIMPLE_LOG) == +- spdlog::fmt_lib::format("Should not be flushed{}Test message 1{}Test message 2{}", default_eol, default_eol, default_eol)); ++ yr_spdlog::fmt_lib::format("Should not be flushed{}Test message 1{}Test message 2{}", default_eol, default_eol, default_eol)); + } + + TEST_CASE("rotating_file_logger1", "[rotating_logger]") + { + prepare_logdir(); + size_t max_size = 1024 * 10; +- spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); +- auto logger = spdlog::rotating_logger_mt("logger", basename, max_size, 0); ++ yr_spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); ++ auto logger = yr_spdlog::rotating_logger_mt("logger", basename, max_size, 0); + + for (int i = 0; i < 10; ++i) + { +@@ -64,21 +64,21 @@ + { + prepare_logdir(); + size_t max_size = 1024 * 10; +- spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); ++ yr_spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); + + { + // make an initial logger to create the first output file +- auto logger = spdlog::rotating_logger_mt("logger", basename, max_size, 2, true); ++ auto logger = yr_spdlog::rotating_logger_mt("logger", basename, max_size, 2, true); + for (int i = 0; i < 10; ++i) + { + logger->info("Test message {}", i); + } + // drop causes the logger destructor to be called, which is required so the + // next logger can rename the first output file. +- spdlog::drop(logger->name()); ++ yr_spdlog::drop(logger->name()); + } + +- auto logger = spdlog::rotating_logger_mt("logger", basename, max_size, 2, true); ++ auto logger = yr_spdlog::rotating_logger_mt("logger", basename, max_size, 2, true); + for (int i = 0; i < 10; ++i) + { + logger->info("Test message {}", i); +@@ -104,6 +104,6 @@ + { + prepare_logdir(); + size_t max_size = 0; +- spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); +- REQUIRE_THROWS_AS(spdlog::rotating_logger_mt("logger", basename, max_size, 0), spdlog::spdlog_ex); ++ yr_spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); ++ REQUIRE_THROWS_AS(yr_spdlog::rotating_logger_mt("logger", basename, max_size, 0), yr_spdlog::spdlog_ex); + } +diff -ruN tests/test_fmt_helper.cpp tests/test_fmt_helper.cpp +--- tests/test_fmt_helper.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_fmt_helper.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -1,13 +1,13 @@ + + #include "includes.h" + +-using spdlog::memory_buf_t; +-using spdlog::details::to_string_view; ++using yr_spdlog::memory_buf_t; ++using yr_spdlog::details::to_string_view; + + void test_pad2(int n, const char *expected) + { + memory_buf_t buf; +- spdlog::details::fmt_helper::pad2(n, buf); ++ yr_spdlog::details::fmt_helper::pad2(n, buf); + + REQUIRE(to_string_view(buf) == expected); + } +@@ -15,7 +15,7 @@ + void test_pad3(uint32_t n, const char *expected) + { + memory_buf_t buf; +- spdlog::details::fmt_helper::pad3(n, buf); ++ yr_spdlog::details::fmt_helper::pad3(n, buf); + + REQUIRE(to_string_view(buf) == expected); + } +@@ -23,7 +23,7 @@ + void test_pad6(std::size_t n, const char *expected) + { + memory_buf_t buf; +- spdlog::details::fmt_helper::pad6(n, buf); ++ yr_spdlog::details::fmt_helper::pad6(n, buf); + + REQUIRE(to_string_view(buf) == expected); + } +@@ -31,7 +31,7 @@ + void test_pad9(std::size_t n, const char *expected) + { + memory_buf_t buf; +- spdlog::details::fmt_helper::pad9(n, buf); ++ yr_spdlog::details::fmt_helper::pad9(n, buf); + + REQUIRE(to_string_view(buf) == expected); + } +diff -ruN tests/test_macros.cpp tests/test_macros.cpp +--- tests/test_macros.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_macros.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -14,30 +14,30 @@ + { + + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ yr_spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); + +- auto logger = spdlog::create("logger", filename); ++ auto logger = yr_spdlog::create("logger", filename); + logger->set_pattern("%v"); +- logger->set_level(spdlog::level::trace); ++ logger->set_level(yr_spdlog::level::trace); + + SPDLOG_LOGGER_TRACE(logger, "Test message 1"); + SPDLOG_LOGGER_DEBUG(logger, "Test message 2"); + logger->flush(); + +- using spdlog::details::os::default_eol; +- REQUIRE(ends_with(file_contents(TEST_FILENAME), spdlog::fmt_lib::format("Test message 2{}", default_eol))); ++ using yr_spdlog::details::os::default_eol; ++ REQUIRE(ends_with(file_contents(TEST_FILENAME), yr_spdlog::fmt_lib::format("Test message 2{}", default_eol))); + REQUIRE(count_lines(TEST_FILENAME) == 1); + +- auto orig_default_logger = spdlog::default_logger(); +- spdlog::set_default_logger(logger); ++ auto orig_default_logger = yr_spdlog::default_logger(); ++ yr_spdlog::set_default_logger(logger); + + SPDLOG_TRACE("Test message 3"); + SPDLOG_DEBUG("Test message {}", 4); + logger->flush(); + + require_message_count(TEST_FILENAME, 2); +- REQUIRE(ends_with(file_contents(TEST_FILENAME), spdlog::fmt_lib::format("Test message 4{}", default_eol))); +- spdlog::set_default_logger(std::move(orig_default_logger)); ++ REQUIRE(ends_with(file_contents(TEST_FILENAME), yr_spdlog::fmt_lib::format("Test message 4{}", default_eol))); ++ yr_spdlog::set_default_logger(std::move(orig_default_logger)); + } + + TEST_CASE("disable param evaluation", "[macros]") +@@ -47,7 +47,7 @@ + + TEST_CASE("pass logger pointer", "[macros]") + { +- auto logger = spdlog::create("refmacro"); ++ auto logger = yr_spdlog::create("refmacro"); + auto &ref = *logger; + SPDLOG_LOGGER_TRACE(&ref, "Test message 1"); + SPDLOG_LOGGER_DEBUG(&ref, "Test message 2"); +diff -ruN tests/test_misc.cpp tests/test_misc.cpp +--- tests/test_misc.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_misc.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -2,18 +2,18 @@ + #include "test_sink.h" + + template +-std::string log_info(const T &what, spdlog::level::level_enum logger_level = spdlog::level::info) ++std::string log_info(const T &what, yr_spdlog::level::level_enum logger_level = yr_spdlog::level::info) + { + + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); ++ auto oss_sink = std::make_shared(oss); + +- spdlog::logger oss_logger("oss", oss_sink); ++ yr_spdlog::logger oss_logger("oss", oss_sink); + oss_logger.set_level(logger_level); + oss_logger.set_pattern("%v"); + oss_logger.info(what); + +- return oss.str().substr(0, oss.str().length() - strlen(spdlog::details::os::default_eol)); ++ return oss.str().substr(0, oss.str().length() - strlen(yr_spdlog::details::os::default_eol)); + } + + TEST_CASE("basic_logging ", "[basic_logging]") +@@ -36,66 +36,66 @@ + + TEST_CASE("log_levels", "[log_levels]") + { +- REQUIRE(log_info("Hello", spdlog::level::err).empty()); +- REQUIRE(log_info("Hello", spdlog::level::critical).empty()); +- REQUIRE(log_info("Hello", spdlog::level::info) == "Hello"); +- REQUIRE(log_info("Hello", spdlog::level::debug) == "Hello"); +- REQUIRE(log_info("Hello", spdlog::level::trace) == "Hello"); ++ REQUIRE(log_info("Hello", yr_spdlog::level::err).empty()); ++ REQUIRE(log_info("Hello", yr_spdlog::level::critical).empty()); ++ REQUIRE(log_info("Hello", yr_spdlog::level::info) == "Hello"); ++ REQUIRE(log_info("Hello", yr_spdlog::level::debug) == "Hello"); ++ REQUIRE(log_info("Hello", yr_spdlog::level::trace) == "Hello"); + } + + TEST_CASE("level_to_string_view", "[convert_to_string_view") + { +- REQUIRE(spdlog::level::to_string_view(spdlog::level::trace) == "trace"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::debug) == "debug"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::info) == "info"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::warn) == "warning"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::err) == "error"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::critical) == "critical"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::off) == "off"); ++ REQUIRE(yr_spdlog::level::to_string_view(yr_spdlog::level::trace) == "trace"); ++ REQUIRE(yr_spdlog::level::to_string_view(yr_spdlog::level::debug) == "debug"); ++ REQUIRE(yr_spdlog::level::to_string_view(yr_spdlog::level::info) == "info"); ++ REQUIRE(yr_spdlog::level::to_string_view(yr_spdlog::level::warn) == "warning"); ++ REQUIRE(yr_spdlog::level::to_string_view(yr_spdlog::level::err) == "error"); ++ REQUIRE(yr_spdlog::level::to_string_view(yr_spdlog::level::critical) == "critical"); ++ REQUIRE(yr_spdlog::level::to_string_view(yr_spdlog::level::off) == "off"); + } + + TEST_CASE("to_short_c_str", "[convert_to_short_c_str]") + { +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::trace)) == "T"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::debug)) == "D"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::info)) == "I"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::warn)) == "W"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::err)) == "E"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::critical)) == "C"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::off)) == "O"); ++ REQUIRE(std::string(yr_spdlog::level::to_short_c_str(yr_spdlog::level::trace)) == "T"); ++ REQUIRE(std::string(yr_spdlog::level::to_short_c_str(yr_spdlog::level::debug)) == "D"); ++ REQUIRE(std::string(yr_spdlog::level::to_short_c_str(yr_spdlog::level::info)) == "I"); ++ REQUIRE(std::string(yr_spdlog::level::to_short_c_str(yr_spdlog::level::warn)) == "W"); ++ REQUIRE(std::string(yr_spdlog::level::to_short_c_str(yr_spdlog::level::err)) == "E"); ++ REQUIRE(std::string(yr_spdlog::level::to_short_c_str(yr_spdlog::level::critical)) == "C"); ++ REQUIRE(std::string(yr_spdlog::level::to_short_c_str(yr_spdlog::level::off)) == "O"); + } + + TEST_CASE("to_level_enum", "[convert_to_level_enum]") + { +- REQUIRE(spdlog::level::from_str("trace") == spdlog::level::trace); +- REQUIRE(spdlog::level::from_str("debug") == spdlog::level::debug); +- REQUIRE(spdlog::level::from_str("info") == spdlog::level::info); +- REQUIRE(spdlog::level::from_str("warning") == spdlog::level::warn); +- REQUIRE(spdlog::level::from_str("warn") == spdlog::level::warn); +- REQUIRE(spdlog::level::from_str("error") == spdlog::level::err); +- REQUIRE(spdlog::level::from_str("critical") == spdlog::level::critical); +- REQUIRE(spdlog::level::from_str("off") == spdlog::level::off); +- REQUIRE(spdlog::level::from_str("null") == spdlog::level::off); ++ REQUIRE(yr_spdlog::level::from_str("trace") == yr_spdlog::level::trace); ++ REQUIRE(yr_spdlog::level::from_str("debug") == yr_spdlog::level::debug); ++ REQUIRE(yr_spdlog::level::from_str("info") == yr_spdlog::level::info); ++ REQUIRE(yr_spdlog::level::from_str("warning") == yr_spdlog::level::warn); ++ REQUIRE(yr_spdlog::level::from_str("warn") == yr_spdlog::level::warn); ++ REQUIRE(yr_spdlog::level::from_str("error") == yr_spdlog::level::err); ++ REQUIRE(yr_spdlog::level::from_str("critical") == yr_spdlog::level::critical); ++ REQUIRE(yr_spdlog::level::from_str("off") == yr_spdlog::level::off); ++ REQUIRE(yr_spdlog::level::from_str("null") == yr_spdlog::level::off); + } + + TEST_CASE("periodic flush", "[periodic_flush]") + { +- using spdlog::sinks::test_sink_mt; +- auto logger = spdlog::create("periodic_flush"); ++ using yr_spdlog::sinks::test_sink_mt; ++ auto logger = yr_spdlog::create("periodic_flush"); + auto test_sink = std::static_pointer_cast(logger->sinks()[0]); + +- spdlog::flush_every(std::chrono::seconds(1)); ++ yr_spdlog::flush_every(std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::milliseconds(1250)); + REQUIRE(test_sink->flush_counter() == 1); +- spdlog::flush_every(std::chrono::seconds(0)); +- spdlog::drop_all(); ++ yr_spdlog::flush_every(std::chrono::seconds(0)); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("clone-logger", "[clone]") + { +- using spdlog::sinks::test_sink_mt; ++ using yr_spdlog::sinks::test_sink_mt; + auto test_sink = std::make_shared(); +- auto logger = std::make_shared("orig", test_sink); ++ auto logger = std::make_shared("orig", test_sink); + logger->set_pattern("%v"); + auto cloned = logger->clone("clone"); + +@@ -110,15 +110,15 @@ + REQUIRE(test_sink->lines()[0] == "Some message 1"); + REQUIRE(test_sink->lines()[1] == "Some message 2"); + +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("clone async", "[clone]") + { +- using spdlog::sinks::test_sink_st; +- spdlog::init_thread_pool(4, 1); ++ using yr_spdlog::sinks::test_sink_st; ++ yr_spdlog::init_thread_pool(4, 1); + auto test_sink = std::make_shared(); +- auto logger = std::make_shared("orig", test_sink, spdlog::thread_pool()); ++ auto logger = std::make_shared("orig", test_sink, yr_spdlog::thread_pool()); + logger->set_pattern("%v"); + auto cloned = logger->clone("clone"); + +@@ -130,51 +130,51 @@ + logger->info("Some message 1"); + cloned->info("Some message 2"); + +- spdlog::details::os::sleep_for_millis(100); ++ yr_spdlog::details::os::sleep_for_millis(100); + + REQUIRE(test_sink->lines().size() == 2); + REQUIRE(test_sink->lines()[0] == "Some message 1"); + REQUIRE(test_sink->lines()[1] == "Some message 2"); + +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("default logger API", "[default logger]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); ++ auto oss_sink = std::make_shared(oss); + +- spdlog::set_default_logger(std::make_shared("oss", oss_sink)); +- spdlog::set_pattern("*** %v"); ++ yr_spdlog::set_default_logger(std::make_shared("oss", oss_sink)); ++ yr_spdlog::set_pattern("*** %v"); + +- spdlog::default_logger()->set_level(spdlog::level::trace); +- spdlog::trace("hello trace"); +- REQUIRE(oss.str() == "*** hello trace" + std::string(spdlog::details::os::default_eol)); ++ yr_spdlog::default_logger()->set_level(yr_spdlog::level::trace); ++ yr_spdlog::trace("hello trace"); ++ REQUIRE(oss.str() == "*** hello trace" + std::string(yr_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::debug("hello debug"); +- REQUIRE(oss.str() == "*** hello debug" + std::string(spdlog::details::os::default_eol)); ++ yr_spdlog::debug("hello debug"); ++ REQUIRE(oss.str() == "*** hello debug" + std::string(yr_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::info("Hello"); +- REQUIRE(oss.str() == "*** Hello" + std::string(spdlog::details::os::default_eol)); ++ yr_spdlog::info("Hello"); ++ REQUIRE(oss.str() == "*** Hello" + std::string(yr_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::warn("Hello again {}", 2); +- REQUIRE(oss.str() == "*** Hello again 2" + std::string(spdlog::details::os::default_eol)); ++ yr_spdlog::warn("Hello again {}", 2); ++ REQUIRE(oss.str() == "*** Hello again 2" + std::string(yr_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::error(123); +- REQUIRE(oss.str() == "*** 123" + std::string(spdlog::details::os::default_eol)); ++ yr_spdlog::error(123); ++ REQUIRE(oss.str() == "*** 123" + std::string(yr_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::critical(std::string("some string")); +- REQUIRE(oss.str() == "*** some string" + std::string(spdlog::details::os::default_eol)); ++ yr_spdlog::critical(std::string("some string")); ++ REQUIRE(oss.str() == "*** some string" + std::string(yr_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::set_level(spdlog::level::info); +- spdlog::debug("should not be logged"); ++ yr_spdlog::set_level(yr_spdlog::level::info); ++ yr_spdlog::debug("should not be logged"); + REQUIRE(oss.str().empty()); +- spdlog::drop_all(); +- spdlog::set_pattern("%v"); ++ yr_spdlog::drop_all(); ++ yr_spdlog::set_pattern("%v"); + } +diff -ruN tests/test_mpmc_q.cpp tests/test_mpmc_q.cpp +--- tests/test_mpmc_q.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_mpmc_q.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -11,7 +11,7 @@ + { + size_t q_size = 100; + milliseconds tolerance_wait(20); +- spdlog::details::mpmc_blocking_queue q(q_size); ++ yr_spdlog::details::mpmc_blocking_queue q(q_size); + int popped_item = 0; + + auto start = test_clock::now(); +@@ -30,7 +30,7 @@ + milliseconds wait_ms(250); + milliseconds tolerance_wait(250); + +- spdlog::details::mpmc_blocking_queue q(q_size); ++ yr_spdlog::details::mpmc_blocking_queue q(q_size); + int popped_item = 0; + auto start = test_clock::now(); + auto rv = q.dequeue_for(popped_item, wait_ms); +@@ -45,7 +45,7 @@ + + TEST_CASE("dequeue-full-nowait", "[mpmc_blocking_q]") + { +- spdlog::details::mpmc_blocking_queue q(1); ++ yr_spdlog::details::mpmc_blocking_queue q(1); + q.enqueue(42); + + int item = 0; +@@ -55,7 +55,7 @@ + + TEST_CASE("dequeue-full-wait", "[mpmc_blocking_q]") + { +- spdlog::details::mpmc_blocking_queue q(1); ++ yr_spdlog::details::mpmc_blocking_queue q(1); + q.enqueue(42); + + int item = 0; +@@ -67,7 +67,7 @@ + { + + size_t q_size = 1; +- spdlog::details::mpmc_blocking_queue q(q_size); ++ yr_spdlog::details::mpmc_blocking_queue q(q_size); + milliseconds tolerance_wait(10); + + q.enqueue(1); +@@ -85,7 +85,7 @@ + TEST_CASE("bad_queue", "[mpmc_blocking_q]") + { + size_t q_size = 0; +- spdlog::details::mpmc_blocking_queue q(q_size); ++ yr_spdlog::details::mpmc_blocking_queue q(q_size); + q.enqueue_nowait(1); + REQUIRE(q.overrun_counter() == 1); + int i = 0; +@@ -95,7 +95,7 @@ + TEST_CASE("empty_queue", "[mpmc_blocking_q]") + { + size_t q_size = 10; +- spdlog::details::mpmc_blocking_queue q(q_size); ++ yr_spdlog::details::mpmc_blocking_queue q(q_size); + int i = 0; + REQUIRE(q.dequeue_for(i, milliseconds(10)) == false); + } +@@ -103,7 +103,7 @@ + TEST_CASE("full_queue", "[mpmc_blocking_q]") + { + size_t q_size = 100; +- spdlog::details::mpmc_blocking_queue q(q_size); ++ yr_spdlog::details::mpmc_blocking_queue q(q_size); + for (int i = 0; i < static_cast(q_size); i++) + { + q.enqueue(i + 0); // i+0 to force rvalue and avoid tidy warnings on the same time if we std::move(i) instead +diff -ruN tests/test_pattern_formatter.cpp tests/test_pattern_formatter.cpp +--- tests/test_pattern_formatter.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_pattern_formatter.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -1,19 +1,19 @@ + #include "includes.h" + #include "test_sink.h" + +-using spdlog::memory_buf_t; +-using spdlog::details::to_string_view; ++using yr_spdlog::memory_buf_t; ++using yr_spdlog::details::to_string_view; + + // log to str and return it + template + static std::string log_to_str(const std::string &msg, const Args &...args) + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("pattern_tester", oss_sink); +- oss_logger.set_level(spdlog::level::info); ++ auto oss_sink = std::make_shared(oss); ++ yr_spdlog::logger oss_logger("pattern_tester", oss_sink); ++ oss_logger.set_level(yr_spdlog::level::info); + +- oss_logger.set_formatter(std::unique_ptr(new spdlog::pattern_formatter(args...))); ++ oss_logger.set_formatter(std::unique_ptr(new yr_spdlog::pattern_formatter(args...))); + + oss_logger.info(msg); + return oss.str(); +@@ -23,75 +23,75 @@ + { + std::string msg = "Hello custom eol test"; + std::string eol = ";)"; +- REQUIRE(log_to_str(msg, "%v", spdlog::pattern_time_type::local, ";)") == msg + eol); ++ REQUIRE(log_to_str(msg, "%v", yr_spdlog::pattern_time_type::local, ";)") == msg + eol); + } + + TEST_CASE("empty format", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "", spdlog::pattern_time_type::local, "").empty()); ++ REQUIRE(log_to_str("Some message", "", yr_spdlog::pattern_time_type::local, "").empty()); + } + + TEST_CASE("empty format2", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "", spdlog::pattern_time_type::local, "\n") == "\n"); ++ REQUIRE(log_to_str("Some message", "", yr_spdlog::pattern_time_type::local, "\n") == "\n"); + } + + TEST_CASE("level", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%l] %v", spdlog::pattern_time_type::local, "\n") == "[info] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%l] %v", yr_spdlog::pattern_time_type::local, "\n") == "[info] Some message\n"); + } + + TEST_CASE("short level", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%L] %v", spdlog::pattern_time_type::local, "\n") == "[I] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%L] %v", yr_spdlog::pattern_time_type::local, "\n") == "[I] Some message\n"); + } + + TEST_CASE("name", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%n] %v", spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%n] %v", yr_spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); + } + + TEST_CASE("date MM/DD/YY ", "[pattern_formatter]") + { +- auto now_tm = spdlog::details::os::localtime(); ++ auto now_tm = yr_spdlog::details::os::localtime(); + std::stringstream oss; + oss << std::setfill('0') << std::setw(2) << now_tm.tm_mon + 1 << "/" << std::setw(2) << now_tm.tm_mday << "/" << std::setw(2) + << (now_tm.tm_year + 1900) % 1000 << " Some message\n"; +- REQUIRE(log_to_str("Some message", "%D %v", spdlog::pattern_time_type::local, "\n") == oss.str()); ++ REQUIRE(log_to_str("Some message", "%D %v", yr_spdlog::pattern_time_type::local, "\n") == oss.str()); + } + + TEST_CASE("color range test1", "[pattern_formatter]") + { +- auto formatter = std::make_shared("%^%v%$", spdlog::pattern_time_type::local, "\n"); ++ auto formatter = std::make_shared("%^%v%$", yr_spdlog::pattern_time_type::local, "\n"); + + memory_buf_t buf; +- spdlog::fmt_lib::format_to(std::back_inserter(buf), "Hello"); ++ yr_spdlog::fmt_lib::format_to(std::back_inserter(buf), "Hello"); + memory_buf_t formatted; + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, spdlog::string_view_t(buf.data(), buf.size())); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, yr_spdlog::string_view_t(buf.data(), buf.size())); + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 0); + REQUIRE(msg.color_range_end == 5); +- REQUIRE(log_to_str("hello", "%^%v%$", spdlog::pattern_time_type::local, "\n") == "hello\n"); ++ REQUIRE(log_to_str("hello", "%^%v%$", yr_spdlog::pattern_time_type::local, "\n") == "hello\n"); + } + + TEST_CASE("color range test2", "[pattern_formatter]") + { +- auto formatter = std::make_shared("%^%$", spdlog::pattern_time_type::local, "\n"); ++ auto formatter = std::make_shared("%^%$", yr_spdlog::pattern_time_type::local, "\n"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, ""); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, ""); + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 0); + REQUIRE(msg.color_range_end == 0); +- REQUIRE(log_to_str("", "%^%$", spdlog::pattern_time_type::local, "\n") == "\n"); ++ REQUIRE(log_to_str("", "%^%$", yr_spdlog::pattern_time_type::local, "\n") == "\n"); + } + + TEST_CASE("color range test3", "[pattern_formatter]") + { +- auto formatter = std::make_shared("%^***%$"); ++ auto formatter = std::make_shared("%^***%$"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "ignored"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "ignored"); + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 0); +@@ -100,22 +100,22 @@ + + TEST_CASE("color range test4", "[pattern_formatter]") + { +- auto formatter = std::make_shared("XX%^YYY%$", spdlog::pattern_time_type::local, "\n"); ++ auto formatter = std::make_shared("XX%^YYY%$", yr_spdlog::pattern_time_type::local, "\n"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "ignored"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "ignored"); + + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 2); + REQUIRE(msg.color_range_end == 5); +- REQUIRE(log_to_str("ignored", "XX%^YYY%$", spdlog::pattern_time_type::local, "\n") == "XXYYY\n"); ++ REQUIRE(log_to_str("ignored", "XX%^YYY%$", yr_spdlog::pattern_time_type::local, "\n") == "XXYYY\n"); + } + + TEST_CASE("color range test5", "[pattern_formatter]") + { +- auto formatter = std::make_shared("**%^"); ++ auto formatter = std::make_shared("**%^"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "ignored"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "ignored"); + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 2); +@@ -124,9 +124,9 @@ + + TEST_CASE("color range test6", "[pattern_formatter]") + { +- auto formatter = std::make_shared("**%$"); ++ auto formatter = std::make_shared("**%$"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "ignored"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "ignored"); + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 0); +@@ -139,73 +139,73 @@ + + TEST_CASE("level_left_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%8l] %v", spdlog::pattern_time_type::local, "\n") == "[ info] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%8!l] %v", spdlog::pattern_time_type::local, "\n") == "[ info] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%8l] %v", yr_spdlog::pattern_time_type::local, "\n") == "[ info] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%8!l] %v", yr_spdlog::pattern_time_type::local, "\n") == "[ info] Some message\n"); + } + + TEST_CASE("level_right_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-8l] %v", spdlog::pattern_time_type::local, "\n") == "[info ] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%-8!l] %v", spdlog::pattern_time_type::local, "\n") == "[info ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-8l] %v", yr_spdlog::pattern_time_type::local, "\n") == "[info ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-8!l] %v", yr_spdlog::pattern_time_type::local, "\n") == "[info ] Some message\n"); + } + + TEST_CASE("level_center_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%=8l] %v", spdlog::pattern_time_type::local, "\n") == "[ info ] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%=8!l] %v", spdlog::pattern_time_type::local, "\n") == "[ info ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=8l] %v", yr_spdlog::pattern_time_type::local, "\n") == "[ info ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=8!l] %v", yr_spdlog::pattern_time_type::local, "\n") == "[ info ] Some message\n"); + } + + TEST_CASE("short level_left_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%3L] %v", spdlog::pattern_time_type::local, "\n") == "[ I] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%3!L] %v", spdlog::pattern_time_type::local, "\n") == "[ I] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%3L] %v", yr_spdlog::pattern_time_type::local, "\n") == "[ I] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%3!L] %v", yr_spdlog::pattern_time_type::local, "\n") == "[ I] Some message\n"); + } + + TEST_CASE("short level_right_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-3L] %v", spdlog::pattern_time_type::local, "\n") == "[I ] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%-3!L] %v", spdlog::pattern_time_type::local, "\n") == "[I ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-3L] %v", yr_spdlog::pattern_time_type::local, "\n") == "[I ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-3!L] %v", yr_spdlog::pattern_time_type::local, "\n") == "[I ] Some message\n"); + } + + TEST_CASE("short level_center_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%=3L] %v", spdlog::pattern_time_type::local, "\n") == "[ I ] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%=3!L] %v", spdlog::pattern_time_type::local, "\n") == "[ I ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=3L] %v", yr_spdlog::pattern_time_type::local, "\n") == "[ I ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=3!L] %v", yr_spdlog::pattern_time_type::local, "\n") == "[ I ] Some message\n"); + } + + TEST_CASE("left_padded_short", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%3n] %v", spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%3!n] %v", spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%3n] %v", yr_spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%3!n] %v", yr_spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); + } + + TEST_CASE("right_padded_short", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-3n] %v", spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%-3!n] %v", spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-3n] %v", yr_spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-3!n] %v", yr_spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); + } + + TEST_CASE("center_padded_short", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%=3n] %v", spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%=3!n] %v", spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=3n] %v", yr_spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=3!n] %v", yr_spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); + } + + TEST_CASE("left_padded_huge", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-300n] %v", spdlog::pattern_time_type::local, "\n") == ++ REQUIRE(log_to_str("Some message", "[%-300n] %v", yr_spdlog::pattern_time_type::local, "\n") == + "[pattern_tester ] Some message\n"); + +- REQUIRE(log_to_str("Some message", "[%-300!n] %v", spdlog::pattern_time_type::local, "\n") == ++ REQUIRE(log_to_str("Some message", "[%-300!n] %v", yr_spdlog::pattern_time_type::local, "\n") == + "[pattern_tester ] Some message\n"); + } + + TEST_CASE("left_padded_max", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-64n] %v", spdlog::pattern_time_type::local, "\n") == ++ REQUIRE(log_to_str("Some message", "[%-64n] %v", yr_spdlog::pattern_time_type::local, "\n") == + "[pattern_tester ] Some message\n"); + +- REQUIRE(log_to_str("Some message", "[%-64!n] %v", spdlog::pattern_time_type::local, "\n") == ++ REQUIRE(log_to_str("Some message", "[%-64!n] %v", yr_spdlog::pattern_time_type::local, "\n") == + "[pattern_tester ] Some message\n"); + } + +@@ -213,61 +213,61 @@ + + TEST_CASE("paddinng_truncate", "[pattern_formatter]") + { +- REQUIRE(log_to_str("123456", "%6!v", spdlog::pattern_time_type::local, "\n") == "123456\n"); +- REQUIRE(log_to_str("123456", "%5!v", spdlog::pattern_time_type::local, "\n") == "12345\n"); +- REQUIRE(log_to_str("123456", "%7!v", spdlog::pattern_time_type::local, "\n") == " 123456\n"); +- +- REQUIRE(log_to_str("123456", "%-6!v", spdlog::pattern_time_type::local, "\n") == "123456\n"); +- REQUIRE(log_to_str("123456", "%-5!v", spdlog::pattern_time_type::local, "\n") == "12345\n"); +- REQUIRE(log_to_str("123456", "%-7!v", spdlog::pattern_time_type::local, "\n") == "123456 \n"); +- +- REQUIRE(log_to_str("123456", "%=6!v", spdlog::pattern_time_type::local, "\n") == "123456\n"); +- REQUIRE(log_to_str("123456", "%=5!v", spdlog::pattern_time_type::local, "\n") == "12345\n"); +- REQUIRE(log_to_str("123456", "%=7!v", spdlog::pattern_time_type::local, "\n") == "123456 \n"); ++ REQUIRE(log_to_str("123456", "%6!v", yr_spdlog::pattern_time_type::local, "\n") == "123456\n"); ++ REQUIRE(log_to_str("123456", "%5!v", yr_spdlog::pattern_time_type::local, "\n") == "12345\n"); ++ REQUIRE(log_to_str("123456", "%7!v", yr_spdlog::pattern_time_type::local, "\n") == " 123456\n"); ++ ++ REQUIRE(log_to_str("123456", "%-6!v", yr_spdlog::pattern_time_type::local, "\n") == "123456\n"); ++ REQUIRE(log_to_str("123456", "%-5!v", yr_spdlog::pattern_time_type::local, "\n") == "12345\n"); ++ REQUIRE(log_to_str("123456", "%-7!v", yr_spdlog::pattern_time_type::local, "\n") == "123456 \n"); ++ ++ REQUIRE(log_to_str("123456", "%=6!v", yr_spdlog::pattern_time_type::local, "\n") == "123456\n"); ++ REQUIRE(log_to_str("123456", "%=5!v", yr_spdlog::pattern_time_type::local, "\n") == "12345\n"); ++ REQUIRE(log_to_str("123456", "%=7!v", yr_spdlog::pattern_time_type::local, "\n") == "123456 \n"); + +- REQUIRE(log_to_str("123456", "%0!v", spdlog::pattern_time_type::local, "\n") == "\n"); ++ REQUIRE(log_to_str("123456", "%0!v", yr_spdlog::pattern_time_type::local, "\n") == "\n"); + } + + TEST_CASE("padding_truncate_funcname", "[pattern_formatter]") + { +- spdlog::sinks::test_sink_st test_sink; ++ yr_spdlog::sinks::test_sink_st test_sink; + + const char *pattern = "%v [%5!!]"; +- auto formatter = std::unique_ptr(new spdlog::pattern_formatter(pattern)); ++ auto formatter = std::unique_ptr(new yr_spdlog::pattern_formatter(pattern)); + test_sink.set_formatter(std::move(formatter)); + +- spdlog::details::log_msg msg1{spdlog::source_loc{"ignored", 1, "func"}, "test_logger", spdlog::level::info, "message"}; ++ yr_spdlog::details::log_msg msg1{yr_spdlog::source_loc{"ignored", 1, "func"}, "test_logger", yr_spdlog::level::info, "message"}; + test_sink.log(msg1); + REQUIRE(test_sink.lines()[0] == "message [ func]"); + +- spdlog::details::log_msg msg2{spdlog::source_loc{"ignored", 1, "function"}, "test_logger", spdlog::level::info, "message"}; ++ yr_spdlog::details::log_msg msg2{yr_spdlog::source_loc{"ignored", 1, "function"}, "test_logger", yr_spdlog::level::info, "message"}; + test_sink.log(msg2); + REQUIRE(test_sink.lines()[1] == "message [funct]"); + } + + TEST_CASE("padding_funcname", "[pattern_formatter]") + { +- spdlog::sinks::test_sink_st test_sink; ++ yr_spdlog::sinks::test_sink_st test_sink; + + const char *pattern = "%v [%10!]"; +- auto formatter = std::unique_ptr(new spdlog::pattern_formatter(pattern)); ++ auto formatter = std::unique_ptr(new yr_spdlog::pattern_formatter(pattern)); + test_sink.set_formatter(std::move(formatter)); + +- spdlog::details::log_msg msg1{spdlog::source_loc{"ignored", 1, "func"}, "test_logger", spdlog::level::info, "message"}; ++ yr_spdlog::details::log_msg msg1{yr_spdlog::source_loc{"ignored", 1, "func"}, "test_logger", yr_spdlog::level::info, "message"}; + test_sink.log(msg1); + REQUIRE(test_sink.lines()[0] == "message [ func]"); + +- spdlog::details::log_msg msg2{spdlog::source_loc{"ignored", 1, "func567890123"}, "test_logger", spdlog::level::info, "message"}; ++ yr_spdlog::details::log_msg msg2{yr_spdlog::source_loc{"ignored", 1, "func567890123"}, "test_logger", yr_spdlog::level::info, "message"}; + test_sink.log(msg2); + REQUIRE(test_sink.lines()[1] == "message [func567890123]"); + } + + TEST_CASE("clone-default-formatter", "[pattern_formatter]") + { +- auto formatter_1 = std::make_shared(); ++ auto formatter_1 = std::make_shared(); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; +@@ -279,10 +279,10 @@ + + TEST_CASE("clone-default-formatter2", "[pattern_formatter]") + { +- auto formatter_1 = std::make_shared("%+"); ++ auto formatter_1 = std::make_shared("%+"); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; +@@ -294,10 +294,10 @@ + + TEST_CASE("clone-formatter", "[pattern_formatter]") + { +- auto formatter_1 = std::make_shared("%D %X [%] [%n] %v"); ++ auto formatter_1 = std::make_shared("%D %X [%] [%n] %v"); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; +@@ -309,11 +309,11 @@ + + TEST_CASE("clone-formatter-2", "[pattern_formatter]") + { +- using spdlog::pattern_time_type; +- auto formatter_1 = std::make_shared("%D %X [%] [%n] %v", pattern_time_type::utc, "xxxxxx\n"); ++ using yr_spdlog::pattern_time_type; ++ auto formatter_1 = std::make_shared("%D %X [%] [%n] %v", pattern_time_type::utc, "xxxxxx\n"); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "test2"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; +@@ -323,29 +323,29 @@ + REQUIRE(to_string_view(formatted_1) == to_string_view(formatted_2)); + } + +-class custom_test_flag : public spdlog::custom_flag_formatter ++class custom_test_flag : public yr_spdlog::custom_flag_formatter + { + public: + explicit custom_test_flag(std::string txt) + : some_txt{std::move(txt)} + {} + +- void format(const spdlog::details::log_msg &, const std::tm &tm, spdlog::memory_buf_t &dest) override ++ void format(const yr_spdlog::details::log_msg &, const std::tm &tm, yr_spdlog::memory_buf_t &dest) override + { + if (some_txt == "throw_me") + { +- throw spdlog::spdlog_ex("custom_flag_exception_test"); ++ throw yr_spdlog::spdlog_ex("custom_flag_exception_test"); + } + else if (some_txt == "time") + { +- auto formatted = spdlog::fmt_lib::format("{:d}:{:02d}{:s}", tm.tm_hour % 12, tm.tm_min, tm.tm_hour / 12 ? "PM" : "AM"); ++ auto formatted = yr_spdlog::fmt_lib::format("{:d}:{:02d}{:s}", tm.tm_hour % 12, tm.tm_min, tm.tm_hour / 12 ? "PM" : "AM"); + dest.append(formatted.data(), formatted.data() + formatted.size()); + return; + } + some_txt = std::string(padinfo_.width_, ' ') + some_txt; + dest.append(some_txt.data(), some_txt.data() + some_txt.size()); + } +- spdlog::details::padding_info get_padding_info() ++ yr_spdlog::details::padding_info get_padding_info() + { + return padinfo_; + } +@@ -354,24 +354,24 @@ + + std::unique_ptr clone() const override + { +- return spdlog::details::make_unique(some_txt); ++ return yr_spdlog::details::make_unique(some_txt); + } + }; + // test clone with custom flag formatters + TEST_CASE("clone-custom_formatter", "[pattern_formatter]") + { +- auto formatter_1 = std::make_shared(); ++ auto formatter_1 = std::make_shared(); + formatter_1->add_flag('t', "custom_output").set_pattern("[%n] [%t] %v"); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "logger-name"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(logger_name, yr_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; + formatter_1->format(msg, formatted_1); + formatter_2->format(msg, formatted_2); + +- auto expected = spdlog::fmt_lib::format("[logger-name] [custom_output] some message{}", spdlog::details::os::default_eol); ++ auto expected = yr_spdlog::fmt_lib::format("[logger-name] [custom_output] some message{}", yr_spdlog::details::os::default_eol); + + REQUIRE(to_string_view(formatted_1) == expected); + REQUIRE(to_string_view(formatted_2) == expected); +@@ -389,11 +389,11 @@ + + TEST_CASE("short filename formatter-1", "[pattern_formatter]") + { +- spdlog::pattern_formatter formatter("%s", spdlog::pattern_time_type::local, ""); ++ yr_spdlog::pattern_formatter formatter("%s", yr_spdlog::pattern_time_type::local, ""); + memory_buf_t formatted; + std::string logger_name = "logger-name"; +- spdlog::source_loc source_loc{test_path, 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, "logger-name", spdlog::level::info, "Hello"); ++ yr_spdlog::source_loc source_loc{test_path, 123, "some_func()"}; ++ yr_spdlog::details::log_msg msg(source_loc, "logger-name", yr_spdlog::level::info, "Hello"); + formatter.format(msg, formatted); + + REQUIRE(to_string_view(formatted) == "myfile.cpp"); +@@ -401,11 +401,11 @@ + + TEST_CASE("short filename formatter-2", "[pattern_formatter]") + { +- spdlog::pattern_formatter formatter("%s:%#", spdlog::pattern_time_type::local, ""); ++ yr_spdlog::pattern_formatter formatter("%s:%#", yr_spdlog::pattern_time_type::local, ""); + memory_buf_t formatted; + std::string logger_name = "logger-name"; +- spdlog::source_loc source_loc{"myfile.cpp", 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, "logger-name", spdlog::level::info, "Hello"); ++ yr_spdlog::source_loc source_loc{"myfile.cpp", 123, "some_func()"}; ++ yr_spdlog::details::log_msg msg(source_loc, "logger-name", yr_spdlog::level::info, "Hello"); + formatter.format(msg, formatted); + + REQUIRE(to_string_view(formatted) == "myfile.cpp:123"); +@@ -413,11 +413,11 @@ + + TEST_CASE("short filename formatter-3", "[pattern_formatter]") + { +- spdlog::pattern_formatter formatter("%s %v", spdlog::pattern_time_type::local, ""); ++ yr_spdlog::pattern_formatter formatter("%s %v", yr_spdlog::pattern_time_type::local, ""); + memory_buf_t formatted; + std::string logger_name = "logger-name"; +- spdlog::source_loc source_loc{"", 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, "logger-name", spdlog::level::info, "Hello"); ++ yr_spdlog::source_loc source_loc{"", 123, "some_func()"}; ++ yr_spdlog::details::log_msg msg(source_loc, "logger-name", yr_spdlog::level::info, "Hello"); + formatter.format(msg, formatted); + + REQUIRE(to_string_view(formatted) == " Hello"); +@@ -425,11 +425,11 @@ + + TEST_CASE("full filename formatter", "[pattern_formatter]") + { +- spdlog::pattern_formatter formatter("%g", spdlog::pattern_time_type::local, ""); ++ yr_spdlog::pattern_formatter formatter("%g", yr_spdlog::pattern_time_type::local, ""); + memory_buf_t formatted; + std::string logger_name = "logger-name"; +- spdlog::source_loc source_loc{test_path, 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, "logger-name", spdlog::level::info, "Hello"); ++ yr_spdlog::source_loc source_loc{test_path, 123, "some_func()"}; ++ yr_spdlog::details::log_msg msg(source_loc, "logger-name", yr_spdlog::level::info, "Hello"); + formatter.format(msg, formatted); + + REQUIRE(to_string_view(formatted) == test_path); +@@ -437,50 +437,50 @@ + + TEST_CASE("custom flags", "[pattern_formatter]") + { +- auto formatter = std::make_shared(); ++ auto formatter = std::make_shared(); + formatter->add_flag('t', "custom1").add_flag('u', "custom2").set_pattern("[%n] [%t] [%u] %v"); + + memory_buf_t formatted; + +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(yr_spdlog::source_loc{}, "logger-name", yr_spdlog::level::info, "some message"); + formatter->format(msg, formatted); +- auto expected = spdlog::fmt_lib::format("[logger-name] [custom1] [custom2] some message{}", spdlog::details::os::default_eol); ++ auto expected = yr_spdlog::fmt_lib::format("[logger-name] [custom1] [custom2] some message{}", yr_spdlog::details::os::default_eol); + + REQUIRE(to_string_view(formatted) == expected); + } + + TEST_CASE("custom flags-padding", "[pattern_formatter]") + { +- auto formatter = std::make_shared(); ++ auto formatter = std::make_shared(); + formatter->add_flag('t', "custom1").add_flag('u', "custom2").set_pattern("[%n] [%t] [%5u] %v"); + + memory_buf_t formatted; + +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(yr_spdlog::source_loc{}, "logger-name", yr_spdlog::level::info, "some message"); + formatter->format(msg, formatted); +- auto expected = spdlog::fmt_lib::format("[logger-name] [custom1] [ custom2] some message{}", spdlog::details::os::default_eol); ++ auto expected = yr_spdlog::fmt_lib::format("[logger-name] [custom1] [ custom2] some message{}", yr_spdlog::details::os::default_eol); + + REQUIRE(to_string_view(formatted) == expected); + } + + TEST_CASE("custom flags-exception", "[pattern_formatter]") + { +- auto formatter = std::make_shared(); ++ auto formatter = std::make_shared(); + formatter->add_flag('t', "throw_me").add_flag('u', "custom2").set_pattern("[%n] [%t] [%u] %v"); + + memory_buf_t formatted; +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); +- CHECK_THROWS_AS(formatter->format(msg, formatted), spdlog::spdlog_ex); ++ yr_spdlog::details::log_msg msg(yr_spdlog::source_loc{}, "logger-name", yr_spdlog::level::info, "some message"); ++ CHECK_THROWS_AS(formatter->format(msg, formatted), yr_spdlog::spdlog_ex); + } + + TEST_CASE("override need_localtime", "[pattern_formatter]") + { +- auto formatter = std::make_shared(spdlog::pattern_time_type::local, "\n"); ++ auto formatter = std::make_shared(yr_spdlog::pattern_time_type::local, "\n"); + formatter->add_flag('t', "time").set_pattern("%t> %v"); + + { + memory_buf_t formatted; +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(yr_spdlog::source_loc{}, "logger-name", yr_spdlog::level::info, "some message"); + formatter->format(msg, formatted); + REQUIRE(to_string_view(formatted) == "0:00AM> some message\n"); + } +@@ -488,13 +488,13 @@ + { + formatter->need_localtime(); + +- auto now_tm = spdlog::details::os::localtime(); ++ auto now_tm = yr_spdlog::details::os::localtime(); + std::stringstream oss; + oss << (now_tm.tm_hour % 12) << ":" << std::setfill('0') << std::setw(2) << now_tm.tm_min << (now_tm.tm_hour / 12 ? "PM" : "AM") + << "> some message\n"; + + memory_buf_t formatted; +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); ++ yr_spdlog::details::log_msg msg(yr_spdlog::source_loc{}, "logger-name", yr_spdlog::level::info, "some message"); + formatter->format(msg, formatted); + REQUIRE(to_string_view(formatted) == oss.str()); + } +diff -ruN tests/test_registry.cpp tests/test_registry.cpp +--- tests/test_registry.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_registry.cpp 2025-07-02 15:16:19.636618058 +0800 +@@ -6,39 +6,39 @@ + #ifndef SPDLOG_NO_EXCEPTIONS + TEST_CASE("register_drop", "[registry]") + { +- spdlog::drop_all(); +- spdlog::create(tested_logger_name); +- REQUIRE(spdlog::get(tested_logger_name) != nullptr); ++ yr_spdlog::drop_all(); ++ yr_spdlog::create(tested_logger_name); ++ REQUIRE(yr_spdlog::get(tested_logger_name) != nullptr); + // Throw if registering existing name +- REQUIRE_THROWS_AS(spdlog::create(tested_logger_name), spdlog::spdlog_ex); ++ REQUIRE_THROWS_AS(yr_spdlog::create(tested_logger_name), yr_spdlog::spdlog_ex); + } + + TEST_CASE("explicit register", "[registry]") + { +- spdlog::drop_all(); +- auto logger = std::make_shared(tested_logger_name, std::make_shared()); +- spdlog::register_logger(logger); +- REQUIRE(spdlog::get(tested_logger_name) != nullptr); ++ yr_spdlog::drop_all(); ++ auto logger = std::make_shared(tested_logger_name, std::make_shared()); ++ yr_spdlog::register_logger(logger); ++ REQUIRE(yr_spdlog::get(tested_logger_name) != nullptr); + // Throw if registering existing name +- REQUIRE_THROWS_AS(spdlog::create(tested_logger_name), spdlog::spdlog_ex); ++ REQUIRE_THROWS_AS(yr_spdlog::create(tested_logger_name), yr_spdlog::spdlog_ex); + } + #endif + + TEST_CASE("apply_all", "[registry]") + { +- spdlog::drop_all(); +- auto logger = std::make_shared(tested_logger_name, std::make_shared()); +- spdlog::register_logger(logger); +- auto logger2 = std::make_shared(tested_logger_name2, std::make_shared()); +- spdlog::register_logger(logger2); ++ yr_spdlog::drop_all(); ++ auto logger = std::make_shared(tested_logger_name, std::make_shared()); ++ yr_spdlog::register_logger(logger); ++ auto logger2 = std::make_shared(tested_logger_name2, std::make_shared()); ++ yr_spdlog::register_logger(logger2); + + int counter = 0; +- spdlog::apply_all([&counter](std::shared_ptr) { counter++; }); ++ yr_spdlog::apply_all([&counter](std::shared_ptr) { counter++; }); + REQUIRE(counter == 2); + + counter = 0; +- spdlog::drop(tested_logger_name2); +- spdlog::apply_all([&counter](std::shared_ptr l) { ++ yr_spdlog::drop(tested_logger_name2); ++ yr_spdlog::apply_all([&counter](std::shared_ptr l) { + REQUIRE(l->name() == tested_logger_name); + counter++; + }); +@@ -47,70 +47,70 @@ + + TEST_CASE("drop", "[registry]") + { +- spdlog::drop_all(); +- spdlog::create(tested_logger_name); +- spdlog::drop(tested_logger_name); +- REQUIRE_FALSE(spdlog::get(tested_logger_name)); ++ yr_spdlog::drop_all(); ++ yr_spdlog::create(tested_logger_name); ++ yr_spdlog::drop(tested_logger_name); ++ REQUIRE_FALSE(yr_spdlog::get(tested_logger_name)); + } + + TEST_CASE("drop-default", "[registry]") + { +- spdlog::set_default_logger(spdlog::null_logger_st(tested_logger_name)); +- spdlog::drop(tested_logger_name); +- REQUIRE_FALSE(spdlog::default_logger()); +- REQUIRE_FALSE(spdlog::get(tested_logger_name)); ++ yr_spdlog::set_default_logger(yr_spdlog::null_logger_st(tested_logger_name)); ++ yr_spdlog::drop(tested_logger_name); ++ REQUIRE_FALSE(yr_spdlog::default_logger()); ++ REQUIRE_FALSE(yr_spdlog::get(tested_logger_name)); + } + + TEST_CASE("drop_all", "[registry]") + { +- spdlog::drop_all(); +- spdlog::create(tested_logger_name); +- spdlog::create(tested_logger_name2); +- spdlog::drop_all(); +- REQUIRE_FALSE(spdlog::get(tested_logger_name)); +- REQUIRE_FALSE(spdlog::get(tested_logger_name2)); +- REQUIRE_FALSE(spdlog::default_logger()); ++ yr_spdlog::drop_all(); ++ yr_spdlog::create(tested_logger_name); ++ yr_spdlog::create(tested_logger_name2); ++ yr_spdlog::drop_all(); ++ REQUIRE_FALSE(yr_spdlog::get(tested_logger_name)); ++ REQUIRE_FALSE(yr_spdlog::get(tested_logger_name2)); ++ REQUIRE_FALSE(yr_spdlog::default_logger()); + } + + TEST_CASE("drop non existing", "[registry]") + { +- spdlog::drop_all(); +- spdlog::create(tested_logger_name); +- spdlog::drop("some_name"); +- REQUIRE_FALSE(spdlog::get("some_name")); +- REQUIRE(spdlog::get(tested_logger_name)); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); ++ yr_spdlog::create(tested_logger_name); ++ yr_spdlog::drop("some_name"); ++ REQUIRE_FALSE(yr_spdlog::get("some_name")); ++ REQUIRE(yr_spdlog::get(tested_logger_name)); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("default logger", "[registry]") + { +- spdlog::drop_all(); +- spdlog::set_default_logger(spdlog::null_logger_st(tested_logger_name)); +- REQUIRE(spdlog::get(tested_logger_name) == spdlog::default_logger()); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); ++ yr_spdlog::set_default_logger(yr_spdlog::null_logger_st(tested_logger_name)); ++ REQUIRE(yr_spdlog::get(tested_logger_name) == yr_spdlog::default_logger()); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("set_default_logger(nullptr)", "[registry]") + { +- spdlog::set_default_logger(nullptr); +- REQUIRE_FALSE(spdlog::default_logger()); ++ yr_spdlog::set_default_logger(nullptr); ++ REQUIRE_FALSE(yr_spdlog::default_logger()); + } + + TEST_CASE("disable automatic registration", "[registry]") + { + // set some global parameters +- spdlog::level::level_enum log_level = spdlog::level::level_enum::warn; +- spdlog::set_level(log_level); ++ yr_spdlog::level::level_enum log_level = yr_spdlog::level::level_enum::warn; ++ yr_spdlog::set_level(log_level); + // but disable automatic registration +- spdlog::set_automatic_registration(false); +- auto logger1 = spdlog::create(tested_logger_name, SPDLOG_FILENAME_T("filename"), 11, 59); +- auto logger2 = spdlog::create_async(tested_logger_name2); ++ yr_spdlog::set_automatic_registration(false); ++ auto logger1 = yr_spdlog::create(tested_logger_name, SPDLOG_FILENAME_T("filename"), 11, 59); ++ auto logger2 = yr_spdlog::create_async(tested_logger_name2); + // loggers should not be part of the registry +- REQUIRE_FALSE(spdlog::get(tested_logger_name)); +- REQUIRE_FALSE(spdlog::get(tested_logger_name2)); ++ REQUIRE_FALSE(yr_spdlog::get(tested_logger_name)); ++ REQUIRE_FALSE(yr_spdlog::get(tested_logger_name2)); + // but make sure they are still initialized according to global defaults + REQUIRE(logger1->level() == log_level); + REQUIRE(logger2->level() == log_level); +- spdlog::set_level(spdlog::level::info); +- spdlog::set_automatic_registration(true); ++ yr_spdlog::set_level(yr_spdlog::level::info); ++ yr_spdlog::set_automatic_registration(true); + } +diff -ruN tests/test_sink.h tests/test_sink.h +--- tests/test_sink.h 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_sink.h 2025-07-02 15:16:19.641618058 +0800 +@@ -12,7 +12,7 @@ + #include + #include + +-namespace spdlog { ++namespace yr_spdlog { + namespace sinks { + + template +@@ -76,4 +76,4 @@ + using test_sink_st = test_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace yr_spdlog +diff -ruN tests/test_stdout_api.cpp tests/test_stdout_api.cpp +--- tests/test_stdout_api.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_stdout_api.cpp 2025-07-02 15:16:19.641618058 +0800 +@@ -6,93 +6,93 @@ + #include "spdlog/sinks/stdout_color_sinks.h" + TEST_CASE("stdout_st", "[stdout]") + { +- auto l = spdlog::stdout_logger_st("test"); ++ auto l = yr_spdlog::stdout_logger_st("test"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::trace); ++ l->set_level(yr_spdlog::level::trace); + l->trace("Test stdout_st"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("stdout_mt", "[stdout]") + { +- auto l = spdlog::stdout_logger_mt("test"); ++ auto l = yr_spdlog::stdout_logger_mt("test"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::debug); ++ l->set_level(yr_spdlog::level::debug); + l->debug("Test stdout_mt"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("stderr_st", "[stderr]") + { +- auto l = spdlog::stderr_logger_st("test"); ++ auto l = yr_spdlog::stderr_logger_st("test"); + l->set_pattern("%+"); + l->info("Test stderr_st"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("stderr_mt", "[stderr]") + { +- auto l = spdlog::stderr_logger_mt("test"); ++ auto l = yr_spdlog::stderr_logger_mt("test"); + l->set_pattern("%+"); + l->info("Test stderr_mt"); + l->warn("Test stderr_mt"); + l->error("Test stderr_mt"); + l->critical("Test stderr_mt"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + // color loggers + TEST_CASE("stdout_color_st", "[stdout]") + { +- auto l = spdlog::stdout_color_st("test"); ++ auto l = yr_spdlog::stdout_color_st("test"); + l->set_pattern("%+"); + l->info("Test stdout_color_st"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("stdout_color_mt", "[stdout]") + { +- auto l = spdlog::stdout_color_mt("test"); ++ auto l = yr_spdlog::stdout_color_mt("test"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::trace); ++ l->set_level(yr_spdlog::level::trace); + l->trace("Test stdout_color_mt"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("stderr_color_st", "[stderr]") + { +- auto l = spdlog::stderr_color_st("test"); ++ auto l = yr_spdlog::stderr_color_st("test"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::debug); ++ l->set_level(yr_spdlog::level::debug); + l->debug("Test stderr_color_st"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + TEST_CASE("stderr_color_mt", "[stderr]") + { +- auto l = spdlog::stderr_color_mt("test"); ++ auto l = yr_spdlog::stderr_color_mt("test"); + l->set_pattern("%+"); + l->info("Test stderr_color_mt"); + l->warn("Test stderr_color_mt"); + l->error("Test stderr_color_mt"); + l->critical("Test stderr_color_mt"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + #ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT + + TEST_CASE("wchar_api", "[stdout]") + { +- auto l = spdlog::stdout_logger_st("wchar_logger"); ++ auto l = yr_spdlog::stdout_logger_st("wchar_logger"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::trace); ++ l->set_level(yr_spdlog::level::trace); + l->trace(L"Test wchar_api"); + l->trace(L"Test wchar_api {}", L"param"); + l->trace(L"Test wchar_api {}", 1); + l->trace(L"Test wchar_api {}", std::wstring{L"wstring param"}); + l->trace(std::wstring{L"Test wchar_api wstring"}); + SPDLOG_LOGGER_DEBUG(l, L"Test SPDLOG_LOGGER_DEBUG {}", L"param"); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } + + #endif +diff -ruN tests/test_stopwatch.cpp tests/test_stopwatch.cpp +--- tests/test_stopwatch.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_stopwatch.cpp 2025-07-02 15:16:19.641618058 +0800 +@@ -9,7 +9,7 @@ + milliseconds wait_ms(200); + milliseconds tolerance_ms(250); + auto start = clock::now(); +- spdlog::stopwatch sw; ++ yr_spdlog::stopwatch sw; + std::this_thread::sleep_for(wait_ms); + auto stop = clock::now(); + auto diff_ms = std::chrono::duration_cast(stop - start); +@@ -19,7 +19,7 @@ + + TEST_CASE("stopwatch2", "[stopwatch]") + { +- using spdlog::sinks::test_sink_st; ++ using yr_spdlog::sinks::test_sink_st; + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using clock = std::chrono::steady_clock; +@@ -30,8 +30,8 @@ + auto test_sink = std::make_shared(); + + auto start = clock::now(); +- spdlog::stopwatch sw; +- spdlog::logger logger("test-stopwatch", test_sink); ++ yr_spdlog::stopwatch sw; ++ yr_spdlog::logger logger("test-stopwatch", test_sink); + logger.set_pattern("%v"); + std::this_thread::sleep_for(wait_duration); + auto stop = clock::now(); +diff -ruN tests/test_systemd.cpp tests/test_systemd.cpp +--- tests/test_systemd.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_systemd.cpp 2025-07-02 15:16:19.641618058 +0800 +@@ -3,9 +3,9 @@ + + TEST_CASE("systemd", "[all]") + { +- auto systemd_sink = std::make_shared(); +- spdlog::logger logger("spdlog_systemd_test", systemd_sink); +- logger.set_level(spdlog::level::trace); ++ auto systemd_sink = std::make_shared(); ++ yr_spdlog::logger logger("spdlog_systemd_test", systemd_sink); ++ logger.set_level(yr_spdlog::level::trace); + logger.trace("test spdlog trace"); + logger.debug("test spdlog debug"); + SPDLOG_LOGGER_INFO((&logger), "test spdlog info"); +diff -ruN tests/test_time_point.cpp tests/test_time_point.cpp +--- tests/test_time_point.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/test_time_point.cpp 2025-07-02 15:16:19.641618058 +0800 +@@ -4,10 +4,10 @@ + + TEST_CASE("time_point1", "[time_point log_msg]") + { +- std::shared_ptr test_sink(new spdlog::sinks::test_sink_st); +- spdlog::logger logger("test-time_point", test_sink); ++ std::shared_ptr test_sink(new yr_spdlog::sinks::test_sink_st); ++ yr_spdlog::logger logger("test-time_point", test_sink); + +- spdlog::source_loc source{}; ++ yr_spdlog::source_loc source{}; + std::chrono::system_clock::time_point tp{std::chrono::system_clock::now()}; + test_sink->set_pattern("%T.%F"); // interested in the time_point + +@@ -15,15 +15,15 @@ + test_sink->set_delay(std::chrono::milliseconds(10)); + for (int i = 0; i < 5; i++) + { +- spdlog::details::log_msg msg{tp, source, "test_logger", spdlog::level::info, "message"}; ++ yr_spdlog::details::log_msg msg{tp, source, "test_logger", yr_spdlog::level::info, "message"}; + test_sink->log(msg); + } + +- logger.log(tp, source, spdlog::level::info, "formatted message"); +- logger.log(tp, source, spdlog::level::info, "formatted message"); +- logger.log(tp, source, spdlog::level::info, "formatted message"); +- logger.log(tp, source, spdlog::level::info, "formatted message"); +- logger.log(source, spdlog::level::info, "formatted message"); // last line has different time_point ++ logger.log(tp, source, yr_spdlog::level::info, "formatted message"); ++ logger.log(tp, source, yr_spdlog::level::info, "formatted message"); ++ logger.log(tp, source, yr_spdlog::level::info, "formatted message"); ++ logger.log(tp, source, yr_spdlog::level::info, "formatted message"); ++ logger.log(source, yr_spdlog::level::info, "formatted message"); // last line has different time_point + + // now the real test... that the times are the same. + std::vector lines = test_sink->lines(); +@@ -32,5 +32,5 @@ + REQUIRE(lines[4] == lines[5]); + REQUIRE(lines[6] == lines[7]); + REQUIRE(lines[8] != lines[9]); +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + } +diff -ruN tests/utils.cpp tests/utils.cpp +--- tests/utils.cpp 2025-07-02 15:16:11.169618390 +0800 ++++ tests/utils.cpp 2025-07-02 15:16:19.641618058 +0800 +@@ -9,7 +9,7 @@ + + void prepare_logdir() + { +- spdlog::drop_all(); ++ yr_spdlog::drop_all(); + #ifdef _WIN32 + system("rmdir /S /Q test_logs"); + #else +@@ -48,7 +48,7 @@ + + void require_message_count(const std::string &filename, const std::size_t messages) + { +- if (strlen(spdlog::details::os::default_eol) == 0) ++ if (strlen(yr_spdlog::details::os::default_eol) == 0) + { + REQUIRE(count_lines(filename) == 1); + } diff --git a/scripts/package_yuanrong.sh b/scripts/package_yuanrong.sh index a743676..f416ce6 100644 --- a/scripts/package_yuanrong.sh +++ b/scripts/package_yuanrong.sh @@ -22,12 +22,13 @@ BASE_DIR=$( . ${BASE_DIR}/package/utils.sh OUTPUT_DIR="${BASE_DIR}/../output" function parse_args () { - getopt_cmd=$(getopt -o t:h -l tag:,help -- "$@") + getopt_cmd=$(getopt -o t:h -l tag:,python_bin_path:,help -- "$@") [ $? -ne 0 ] && exit 1 eval set -- "$getopt_cmd" while true; do case "$1" in -h|--help) SHOW_HELP="true" && shift ;; + --python_bin_path) PYTHON_BIN_PATH=$2 && shift 2 ;; -t|--tag) TAG=$2 && shift 2 ;; --) shift && break ;; *) die "Invalid option: $1" && exit 1 ;; diff --git a/src/dto/acquire_options.h b/src/dto/acquire_options.h index 92fc5c2..30f0bd0 100644 --- a/src/dto/acquire_options.h +++ b/src/dto/acquire_options.h @@ -46,5 +46,17 @@ struct InstanceResponse { std::string errorMessage; //`json:"errorMessage"` float schedulerTime; // `json:"schedulerTime"` }; + +struct InstanceAllocationFailedRsp { + int errorCode; //`json:"errorCode"` + std::string errorMessage; //`json:"errorMessage"` +}; + +struct BatchInstanceResponse { + std::unordered_map instanceAllocSucceed; + std::unordered_map instanceAllocFailed; + int tLeaseInterval; + float schedulerTime; // `json:"schedulerTime"` +}; } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/dto/affinity.h b/src/dto/affinity.h index 6920912..54d2a1a 100644 --- a/src/dto/affinity.h +++ b/src/dto/affinity.h @@ -31,6 +31,8 @@ using LabelExpression = ::common::LabelExpression; using SubCondition = ::common::SubCondition; using Condition = ::common::Condition; using Selector = ::common::Selector; +using PBInstanceAffinity = ::common::InstanceAffinity; +using PBAffinityScope = ::common::AffinityScope; class LabelOperator { public: @@ -62,7 +64,7 @@ public: auto values = GetValues(); for (auto &value : values) { std::size_t h3 = std::hash()(value); - res = res ^ h3; + std::size_t res = res ^ h3; } return res; } @@ -203,6 +205,25 @@ public: return this->preferredAntiOtherLabels; } + std::string GetAffinityScope() const + { + return this->affinityScope; + } + + void SetAffinityScope(const std::string &affinityScope) + { + this->affinityScope = affinityScope; + } + + void UpdateAffinityScope(PBInstanceAffinity *pbInstanceAffinity) + { + if (affinityScope == AFFINITYSCOPE_NODE) { + pbInstanceAffinity->set_scope(common::AffinityScope::NODE); + } else if (affinityScope == AFFINITYSCOPE_POD) { + pbInstanceAffinity->set_scope(common::AffinityScope::POD); + } + } + virtual void UpdatePbAffinity(PBAffinity *pbAffinity) { auto *condition = pbAffinity->mutable_resource()->mutable_preferredaffinity()->mutable_condition(); @@ -258,6 +279,7 @@ protected: bool requiredPriority = false; bool preferredAntiOtherLabels = false; std::list> labelOperators; + std::string affinityScope; }; class ResourcePreferredAffinity : public Affinity { @@ -284,6 +306,8 @@ public: { auto *condition = pbAffinity->mutable_instance()->mutable_preferredaffinity()->mutable_condition(); UpdateCondition(condition); + auto *pbInstanceAffinity = pbAffinity->mutable_instance(); + UpdateAffinityScope(pbInstanceAffinity); } }; @@ -311,6 +335,8 @@ public: { auto *condition = pbAffinity->mutable_instance()->mutable_preferredantiaffinity()->mutable_condition(); UpdateCondition(condition); + auto *pbInstanceAffinity = pbAffinity->mutable_instance(); + UpdateAffinityScope(pbInstanceAffinity); } }; @@ -333,6 +359,8 @@ public: { auto *condition = pbAffinity->mutable_instance()->mutable_requiredaffinity()->mutable_condition(); UpdateCondition(condition); + auto *pbInstanceAffinity = pbAffinity->mutable_instance(); + UpdateAffinityScope(pbInstanceAffinity); } }; @@ -355,6 +383,8 @@ public: { auto *condition = pbAffinity->mutable_instance()->mutable_requiredantiaffinity()->mutable_condition(); UpdateCondition(condition); + auto *pbInstanceAffinity = pbAffinity->mutable_instance(); + UpdateAffinityScope(pbInstanceAffinity); } }; } // namespace Libruntime diff --git a/src/dto/config.h b/src/dto/config.h index ba85afe..c5648b4 100644 --- a/src/dto/config.h +++ b/src/dto/config.h @@ -23,7 +23,7 @@ namespace YR { namespace Libruntime { -const size_t REQUEST_ACK_TIMEOUT_SEC = 10; +const size_t REQUEST_ACK_TIMEOUT_SEC = 2; const char *const TRUE_STR = "true"; const char *const FALSE_STR = "false"; const char *const TRUE_NUM = "1"; @@ -99,7 +99,7 @@ public: \ CONFIG_DECLARE_VALID(size_t, REQUEST_ACK_ACC_MAX_SEC, 1800, [](const size_t &val) -> bool { return val >= REQUEST_ACK_TIMEOUT_SEC; }); - CONFIG_DECLARE_VALID(size_t, DS_CONNECT_TIMEOUT_SEC, 1800, + CONFIG_DECLARE_VALID(size_t, DS_CONNECT_TIMEOUT_SEC, 60, [](const size_t &val) -> bool { return val >= REQUEST_ACK_TIMEOUT_SEC; }); CONFIG_DECLARE(bool, AUTH_ENABLE, false); CONFIG_DECLARE(std::string, GRPC_SERVER_ADDRESS, "0.0.0.0:0"); @@ -109,7 +109,8 @@ public: \ CONFIG_DECLARE(std::string, FUNCTION_NAME, ""); CONFIG_DECLARE(std::string, FUNCTION_LIB_PATH, "/dcache/layer/func"); CONFIG_DECLARE(std::string, GLOG_log_dir, "/home/snuser/log"); - CONFIG_DECLARE(std::string, SNLIB_PATH, "/home/snuser/snlib"); + CONFIG_DECLARE(std::string, YR_LOG_PREFIX, ""); + CONFIG_DECLARE(std::string, SNUSER_LIB_PATH, "/home/snuser/snlib"); CONFIG_DECLARE(std::string, YR_LOG_LEVEL, "INFO"); CONFIG_DECLARE(std::string, YRFUNCID, ""); CONFIG_DECLARE(std::string, YR_PYTHON_FUNCID, ""); @@ -118,8 +119,8 @@ public: \ CONFIG_DECLARE(std::string, YR_SERVER_ADDRESS, ""); CONFIG_DECLARE(std::string, POSIX_LISTEN_ADDR, ""); CONFIG_DECLARE(std::string, YR_LOG_PATH, "./"); - CONFIG_DECLARE(uint32_t, YR_MAX_LOG_SIZE_MB, 40); - CONFIG_DECLARE(uint32_t, YR_MAX_LOG_FILE_NUM, 20); + CONFIG_DECLARE(uint32_t, YR_MAX_LOG_SIZE_MB, 500); + CONFIG_DECLARE(uint32_t, YR_MAX_LOG_FILE_NUM, 10); CONFIG_DECLARE(uint32_t, YR_HTTP_CONNECTION_NUM, 10); CONFIG_DECLARE(bool, YR_LOG_COMPRESS, true); CONFIG_DECLARE(std::string, HOST_IP, ""); @@ -129,6 +130,8 @@ public: \ CONFIG_DECLARE(bool, ENABLE_METRICS, false); CONFIG_DECLARE(std::string, METRICS_CONFIG, ""); CONFIG_DECLARE(std::string, METRICS_CONFIG_FILE, ""); + CONFIG_DECLARE(bool, ENABLE_TRACE, false); + CONFIG_DECLARE(std::string, RUNTIME_TRACE_CONFIG, ""); CONFIG_DECLARE(bool, ENABLE_DS_AUTH, false); CONFIG_DECLARE(bool, ENABLE_SERVER_AUTH, false); CONFIG_DECLARE(bool, ENABLE_SERVER_MODE, true); @@ -151,8 +154,16 @@ public: \ CONFIG_DECLARE(size_t, MEM_STORE_SIZE_THRESHOLD, 100 * 1024); CONFIG_DECLARE(size_t, FASS_SCHEDULE_TIMEOUT, 120); // 120 seconds CONFIG_DECLARE(int, YR_ASYNCIO_MAX_CONCURRENCY, 1000); // 1k + CONFIG_DECLARE(bool, ENABLE_CLEAN_STREAM_PRODUCER, true); + CONFIG_DECLARE(bool, ENABLE_PRIORITY, false); + CONFIG_DECLARE(bool, ENABLE_DS_HEALTH_CHECK, false); + CONFIG_DECLARE(int, MAX_HTTP_RETRY_TIME, 1); + CONFIG_DECLARE(int, MAX_HTTP_TIMEOUT_SEC, -1); + CONFIG_DECLARE(int, INITIAL_HTTP_CONNECT_SEC, -1); + CONFIG_DECLARE(int, YR_HTTP_IDLE_TIME, 30); + CONFIG_DECLARE(int, YR_NOTIFY_THREAD_POOL_SIZE, 5); CONFIG_DECLARE_VALID(std::string, RUN_MODE, "integrated", // integrated or standalone - [](const std::string &val) -> bool { return (val == INTEGRATED || val == STANDALONE); }); + [](const std::string &val) -> bool { return (val == INTEGRATED || val == STANDALONE); }); CONFIG_DECLARE(bool, ENABLE_FUNCTION_SCHEDULER, false); // whether start an in-memory scheduler CONFIG_DECLARE(int, FUNCTION_SCHEDULER_GRPC_PORT, 23770); // allow scheduler to interact with runtime CONFIG_DECLARE(int, FUNCTION_SCHEDULER_HTTP_PORT, 23771); // allow http access to function scheduler diff --git a/src/dto/constant.h b/src/dto/constant.h index 7bf0c76..b7a881c 100644 --- a/src/dto/constant.h +++ b/src/dto/constant.h @@ -42,6 +42,8 @@ const std::string LABEL_IN = "LabelIn"; const std::string LABEL_NOT_IN = "LabelNotIn"; const std::string LABEL_EXISTS = "LabelExists"; const std::string LABEL_DOES_NOT_EXIST = "LabelDoesNotExist"; +const std::string AFFINITYSCOPE_POD = "POD"; +const std::string AFFINITYSCOPE_NODE = "NODE"; const std::string DEFAULT_YR_NAMESPACE = "yr_defalut_namespace"; } // namespace Libruntime } // namespace YR diff --git a/src/dto/data_object.h b/src/dto/data_object.h index e4f7e66..9a11ef2 100644 --- a/src/dto/data_object.h +++ b/src/dto/data_object.h @@ -16,15 +16,18 @@ #pragma once +#include #include #include #include "buffer.h" +#include "src/libruntime/utils/serializer.h" #include "src/utility/logger/logger.h" namespace YR { namespace Libruntime { const uint64_t MetaDataLen = 16; +const uint64_t MetaDataTypeLen = 8; struct DataObject { DataObject() : putDone(false) {} DataObject(const std::string &objId) : id(objId), putDone(false) {} @@ -82,6 +85,16 @@ struct DataObject { nestedObjIds = ids; } + void SetMetaDataType(const uint8_t &type) + { + msgpack::sbuffer metaDataType = Serializer::Serialize(type); + if (metaDataType.size() <= MetaDataTypeLen) { + meta->MemoryCopy(metaDataType.data(), metaDataType.size()); + } else { + YRLOG_ERROR("unexpect metaDataType size {}", metaDataType.size()); + } + } + uint64_t totalSize = 0; std::string id; std::shared_ptr buffer; diff --git a/src/dto/debug_config.h b/src/dto/debug_config.h new file mode 100644 index 0000000..a3d5389 --- /dev/null +++ b/src/dto/debug_config.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace YR { +namespace Libruntime { +struct DebugConfig { + DebugConfig() = default; + ~DebugConfig() = default; + bool enable = false; +}; +} // namespace Libruntime +} // namespace YR + +namespace std { +template <> +class hash { +public: + size_t operator()(const YR::Libruntime::DebugConfig &d) const + { + return std::hash()(d.enable); + } +}; +} diff --git a/src/dto/invoke_options.h b/src/dto/invoke_options.h index 92bb4ee..9f8314a 100644 --- a/src/dto/invoke_options.h +++ b/src/dto/invoke_options.h @@ -18,8 +18,8 @@ #include #include #include - #include "src/dto/affinity.h" +#include "src/dto/debug_config.h" #include "src/dto/device.h" #include "src/libruntime/err_type.h" #include "src/libruntime/stacktrace/stack_trace_info.h" @@ -28,6 +28,8 @@ namespace YR { namespace Libruntime { +const size_t FAAS_DEFALUT_INVOKE_TIMEOUT = 900; // second +const size_t FAAS_DEFALUT_ACQUIRE_TIMEOUT = 120; // second enum class BundleAffinity : int { COMPACT, @@ -65,6 +67,30 @@ struct InstanceSession { int concurrency; }; +struct FaasInvokeData { + FaasInvokeData() = default; + FaasInvokeData(const std::string &teId, const std::string &funcName, const std::string &inputAliAs, + const std::string &inputTraceId, const long long inputSubmitTime) + : tenantId(teId), + functionName(funcName), + aliAs(inputAliAs), + traceId(inputTraceId), + submitTime(inputSubmitTime){}; + std::string businessId; + std::string tenantId; + std::string srcAppId; + std::string functionName; + std::string aliAs; + std::string version; + std::string traceId; + std::string code; + std::string innerCode; + std::string describeMsg; + long long submitTime = 0; + long long sendTime = 0; + long long endTime = 0; +}; + struct InvokeOptions { int cpu = 500; @@ -86,6 +112,8 @@ struct InvokeOptions { size_t retryTimes = 0; + int maxRetryTime = -1; + std::function retryChecker = nullptr; size_t priority = 0; @@ -112,6 +140,8 @@ struct InvokeOptions { std::string groupName; + bool isDataAffinity = false; + bool needOrder = false; int64_t scheduleTimeoutMs = 30000; @@ -142,6 +172,8 @@ struct InvokeOptions { std::shared_ptr instanceSession; + DebugConfig debug; + std::string workingDir; }; @@ -156,12 +188,16 @@ struct FunctionMeta { std::string poolLabel; libruntime::ApiType apiType; std::string functionId; - std::optional name; - std::optional ns; + std::string name; + std::string ns; std::string initializerCodeId; bool isAsync = false; bool isGenerator = false; bool needOrder = false; + bool IsServiceApiType() + { + return (apiType == libruntime::ApiType::Faas or apiType == libruntime::ApiType::Serve); + } }; struct GroupOpts { @@ -201,6 +237,7 @@ struct GaugeData { enum class AlarmSeverity { OFF, INFO, MINOR, MAJOR, CRITICAL }; struct AlarmInfo { + std::string id; std::string alarmName; AlarmSeverity alarmSeverity = AlarmSeverity::OFF; std::string locationInfo; @@ -210,5 +247,11 @@ struct AlarmInfo { long timeout = DEFAULT_ALARM_TIMEOUT; std::unordered_map customOptions; }; + +struct Credential { + std::string ak; + std::string sk; + std::string dk; +}; } // namespace Libruntime } // namespace YR diff --git a/src/dto/resource_unit.h b/src/dto/resource_unit.h index 9d9f079..a9d9fb0 100644 --- a/src/dto/resource_unit.h +++ b/src/dto/resource_unit.h @@ -24,6 +24,7 @@ struct ResourceUnit { std::string id; std::unordered_map capacity; std::unordered_map allocatable; + std::unordered_map> nodeLabels; uint32_t status; }; diff --git a/src/dto/stream_conf.h b/src/dto/stream_conf.h new file mode 100644 index 0000000..4f93606 --- /dev/null +++ b/src/dto/stream_conf.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "src/proto/libruntime.pb.h" + +namespace YR { +namespace Libruntime { +struct Element { + Element(uint8_t *ptr = nullptr, uint64_t size = 0, uint64_t id = ULONG_MAX) : ptr(ptr), size(size), id(id) {} + + ~Element() = default; + + uint8_t *ptr; + + uint64_t size; + + uint64_t id; +}; + +struct ProducerConf { + int64_t delayFlushTime = 5; + + int64_t pageSize = 1024 * 1024ul; + + uint64_t maxStreamSize = 100 * 1024 * 1024ul; + + bool autoCleanup = false; + + bool encryptStream = false; + + uint64_t retainForNumConsumers = 0; + + uint64_t reserveSize = 0; + + std::unordered_map extendConfig; + + std::string traceId; +}; + +struct SubscriptionConfig { + std::string subscriptionName; + + libruntime::SubscriptionType subscriptionType = libruntime::SubscriptionType::STREAM; + + std::string traceId; + + std::unordered_map extendConfig; + + SubscriptionConfig(std::string subName, const libruntime::SubscriptionType subType) + : subscriptionName(std::move(subName)), subscriptionType(subType) + { + } + + SubscriptionConfig() = default; +}; + +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/clientsmanager/clients_manager.cpp b/src/libruntime/clientsmanager/clients_manager.cpp index abd39a2..497ed45 100644 --- a/src/libruntime/clientsmanager/clients_manager.cpp +++ b/src/libruntime/clientsmanager/clients_manager.cpp @@ -15,29 +15,33 @@ */ #include "clients_manager.h" +#include "src/libruntime/fsclient/grpc/posix_auth_interceptor.h" namespace YR { namespace Libruntime { -std::pair, ErrorInfo> ClientsManager::GetFsConn(const std::string &ip, int port) +std::pair, ErrorInfo> ClientsManager::GetFsConn(const std::string &ip, int port, + const std::string &dstInstance) { auto addr = GetIpAddr(ip, port); - YRLOG_DEBUG("grpc client target is {}", addr); + auto connKey = dstInstance + ":" + addr; + YRLOG_DEBUG("grpc client target is {}", connKey); if (!RE2::FullMatch(addr, re2::RE2(IP_PORT_REGEX))) { YRLOG_ERROR("failed to get valid runtime-rpc server address({})", addr); return std::make_pair(nullptr, ErrorInfo(ErrorCode::ERR_CONNECTION_FAILED, "The server address is invalid.")); } std::lock_guard fsConnsLock(fsConnsMtx); - auto iter = fsConns.find(addr); + auto iter = fsConns.find(connKey); if (iter != fsConns.end()) { - fsConnsReferCounter[addr]++; + fsConnsReferCounter[connKey]++; return std::make_pair(iter->second, ErrorInfo()); } return std::make_pair(nullptr, ErrorInfo()); } std::pair, ErrorInfo> ClientsManager::NewFsConn(const std::string &ip, int port, - std::shared_ptr security) + std::shared_ptr security, + const std::string &dstInstance) { auto addr = GetIpAddr(ip, port); auto [res, error] = InitFunctionSystemConn(addr, security); @@ -45,29 +49,32 @@ std::pair, ErrorInfo> ClientsManager::NewFsConn(c return std::make_pair(nullptr, error); } std::lock_guard fsConnsLock(fsConnsMtx); - fsConns[addr] = res; - fsConnsReferCounter[addr]++; + auto connKey = dstInstance + ":" + addr; + fsConns[connKey] = res; + fsConnsReferCounter[connKey]++; return std::make_pair(res, ErrorInfo()); } -ErrorInfo ClientsManager::ReleaseFsConn(const std::string &ip, int port) +ErrorInfo ClientsManager::ReleaseFsConn(const std::string &ip, int port, const std::string &dstInstance) { auto addr = GetIpAddr(ip, port); std::lock_guard fsConnsLock(fsConnsMtx); - auto iter = fsConnsReferCounter.find(addr); + auto connKey = dstInstance + ":" + addr; + auto iter = fsConnsReferCounter.find(connKey); if (iter == fsConnsReferCounter.end()) { return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "Cannot find function system conn's ref count info."); } - fsConnsReferCounter[addr]--; - if (fsConnsReferCounter[addr] == 0) { - fsConnsReferCounter.erase(addr); - fsConns.erase(addr); + fsConnsReferCounter[connKey]--; + if (fsConnsReferCounter[connKey] == 0) { + fsConnsReferCounter.erase(connKey); + fsConns.erase(connKey); } return ErrorInfo(); } std::pair ClientsManager::GetOrNewDsClient( - const std::shared_ptr librtCfg, std::int32_t connectTimeout) + const std::shared_ptr librtCfg, const std::string &ak, const datasystem::SensitiveValue &sk, + std::int32_t connectTimeout) { auto key = GetIpAddr(librtCfg->dataSystemIpAddr, librtCfg->dataSystemPort); std::lock_guard dsClientsLock(dsClientsMtx); @@ -78,7 +85,7 @@ std::pair ClientsManager::GetOrNewDsClient( } auto res = InitDatasystemClient(librtCfg->dataSystemIpAddr, librtCfg->dataSystemPort, librtCfg->enableAuth, librtCfg->encryptEnable, librtCfg->runtimePublicKey, librtCfg->runtimePrivateKey, - librtCfg->dsPublicKey, connectTimeout); + librtCfg->dsPublicKey, librtCfg->token, ak, sk, connectTimeout); if (res.second.OK()) { dsClients[key] = res.first; dsClientsReferCounter[key]++; @@ -106,6 +113,10 @@ ErrorInfo ClientsManager::ReleaseDsClient(const std::string &ip, int port) dsClients[key].dsStateStore->Shutdown(); YRLOG_DEBUG("Shutdown state store clients"); } + if (dsClients[key].dsStreamStore != nullptr) { + dsClients[key].dsStreamStore->Shutdown(); + YRLOG_DEBUG("Shutdown stream store clients"); + } if (dsClients[key].dsHeteroStore != nullptr) { dsClients[key].dsHeteroStore->Shutdown(); YRLOG_DEBUG("Shutdown hetero store clients"); @@ -171,12 +182,21 @@ std::pair, ErrorInfo> ClientsManager::InitFunctio } try { std::string prefix = "ipv4:///"; - channel = grpc::CreateCustomChannel(prefix + target, YR::GetChannelCreds(security), args); + if (security != nullptr && security->IsFsAuthEnable()) { + std::vector> interCeptorCreators; + auto interCeptorCreator = new PosixClientAuthInterceptorFactory(); + interCeptorCreator->RegisterSecurity(security); + interCeptorCreators.push_back(std::unique_ptr(interCeptorCreator)); + channel = grpc::experimental::CreateCustomChannelWithInterceptors( + prefix + target, YR::GetChannelCreds(security), args, std::move(interCeptorCreators)); + } else { + channel = grpc::CreateCustomChannel(prefix + target, YR::GetChannelCreds(security), args); + } auto tmout = gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), {WAIT_FOR_STAGE_CHANGE_TIMEOUT_SEC, 0, GPR_TIMESPAN}); auto isConnect = channel->WaitForConnected(tmout); auto state = channel->GetState(true); if (!isConnect) { - YRLOG_ERROR("failed to connect to grpc server {}, channel state: {}", target, state); + YRLOG_ERROR("failed to connect to grpc server {}, channel state: {}", target, fmt::underlying(state)); return std::make_pair(nullptr, ErrorInfo(ErrorCode::ERR_CONNECTION_FAILED, "failed to connect to grpc server")); } @@ -190,25 +210,34 @@ std::pair, ErrorInfo> ClientsManager::InitFunctio std::pair ClientsManager::InitDatasystemClient( const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, - const datasystem::SensitiveValue &runtimePrivateKey, const std::string &dsPublicKey, std::int32_t connectTimeout) + const datasystem::SensitiveValue &runtimePrivateKey, const std::string &dsPublicKey, + const datasystem::SensitiveValue &token, const std::string &ak, const datasystem::SensitiveValue &sk, + std::int32_t connectTimeout) { datasystem::ConnectOptions connectOptions; connectOptions.host = ip; connectOptions.port = port; - connectOptions.connectTimeoutMs = ToMs(connectTimeout); + connectOptions.connectTimeoutMs = connectTimeout * S_TO_MS; if (encryptEnable) { connectOptions.clientPublicKey = runtimePublicKey; connectOptions.clientPrivateKey = runtimePrivateKey; connectOptions.serverPublicKey = dsPublicKey; } + if (enableDsAuth) { + if (!ak.empty() && !sk.Empty()) { + connectOptions.accessKey = ak; + connectOptions.secretKey = sk; + } + } std::string tenantId = Config::Instance().YR_TENANT_ID(); if (!tenantId.empty()) { connectOptions.tenantId = tenantId; } YRLOG_DEBUG( - "start init datasystem client connect param, ip is {}, port is {}, enableDsAuth is {}, " - "encryptEnableis {}, runtimePublicKey is empty {}, timeout is {}", - ip, port, enableDsAuth, encryptEnable, runtimePublicKey.empty(), connectTimeout); + "start init datasystem client connect param, tenant id is {}, ip is {}, port is {}, enableDsAuth is {}, " + "encryptEnableis {}, runtimePublicKey is empty {}, ak is empty {}, token is empty {}, timeout is {}", + tenantId, ip, port, enableDsAuth, encryptEnable, runtimePublicKey.empty(), ak.empty(), token.Empty(), + connectTimeout); DatasystemClients clients; clients.dsObjectStore = std::make_shared(); ErrorInfo infoObjectStore = clients.dsObjectStore->Init(connectOptions); @@ -222,6 +251,12 @@ std::pair ClientsManager::InitDatasystemClient( return std::make_pair(clients, infoObjectStore); } + clients.dsStreamStore = std::make_shared(); + ErrorInfo infoStreamStore = clients.dsStreamStore->Init(connectOptions, clients.dsStateStore); + if (!infoStreamStore.OK()) { + return std::make_pair(clients, infoStreamStore); + } + clients.dsHeteroStore = std::make_shared(); auto infoHeteroStore = clients.dsHeteroStore->Init(connectOptions); if (!infoHeteroStore.OK()) { @@ -235,7 +270,7 @@ std::pair, ErrorInfo> ClientsManager::InitHttpCli const std::string &ip, int port, const std::shared_ptr &config) { auto httpClient = std::make_shared(config); - ErrorInfo error = httpClient->Init(ConnectionParam{ip, std::to_string(port)}); + ErrorInfo error = httpClient->Init(ConnectionParam{ip, std::to_string(port), config->httpIdleTime}); if (!error.OK()) { return std::make_pair(nullptr, error); } diff --git a/src/libruntime/clientsmanager/clients_manager.h b/src/libruntime/clientsmanager/clients_manager.h index 53707fd..f40d016 100644 --- a/src/libruntime/clientsmanager/clients_manager.h +++ b/src/libruntime/clientsmanager/clients_manager.h @@ -26,6 +26,8 @@ #include "src/libruntime/objectstore/object_store.h" #include "src/libruntime/statestore/datasystem_state_store.h" #include "src/libruntime/statestore/state_store.h" +#include "src/libruntime/streamstore/datasystem_stream_store.h" +#include "src/libruntime/streamstore/stream_store.h" #include "src/libruntime/utils/security.h" #include "src/libruntime/utils/utils.h" #include "src/utility/logger/logger.h" @@ -42,6 +44,7 @@ const std::string IP_PORT_REGEX = R"(((\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{1,5 struct DatasystemClients { std::shared_ptr dsObjectStore; std::shared_ptr dsStateStore; + std::shared_ptr dsStreamStore; std::shared_ptr dsHeteroStore; }; @@ -50,13 +53,17 @@ public: ClientsManager() = default; std::pair, ErrorInfo> NewFsConn(const std::string &ip, int port, - std::shared_ptr security); + std::shared_ptr security, + const std::string &dstInstance); - std::pair, ErrorInfo> GetFsConn(const std::string &ip, int port); + std::pair, ErrorInfo> GetFsConn(const std::string &ip, int port, + const std::string &dstInstance); - ErrorInfo ReleaseFsConn(const std::string &ip, int port); + ErrorInfo ReleaseFsConn(const std::string &ip, int port, const std::string &dstInstance); std::pair GetOrNewDsClient(const std::shared_ptr librtCfg, + const std::string &ak, + const datasystem::SensitiveValue &sk, std::int32_t connectTimeout); ErrorInfo ReleaseDsClient(const std::string &ip, int port); @@ -72,6 +79,7 @@ private: std::pair InitDatasystemClient( const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, const std::string &dsPublicKey, + const datasystem::SensitiveValue &token, const std::string &ak, const datasystem::SensitiveValue &sk, std::int32_t connectTimeout); std::pair, ErrorInfo> InitHttpClient( diff --git a/src/libruntime/driverlog/driverlog_receiver.cpp b/src/libruntime/driverlog/driverlog_receiver.cpp new file mode 100644 index 0000000..9578270 --- /dev/null +++ b/src/libruntime/driverlog/driverlog_receiver.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "driverlog_receiver.h" + +#include "src/libruntime/streamstore/stream_producer_consumer.h" + +namespace YR { +namespace Libruntime { +const uint32_t CONSUMER_TIMEOUT = 1000; +const int AGG_WINDOWS_S = 5; +const int LOGINFO_MATCH_LEN = 2; +const size_t MAX_LOG_SCAN_LENGTH = 256; // only scan 256 charactors in log message to parse runtime id + +DriverLogReceiver::DriverLogReceiver() {} + +void DriverLogReceiver::Init(std::shared_ptr store, std::string &jobId, bool dedup) +{ + if (jobId.empty() || !store) { + YRLOG_ERROR("failed to init driverlog receiver {}", jobId); + } + jobId_ = "/log/runtime/std/" + jobId; + dedup_ = dedup; + dsStreamStore_ = store; + receiverThread_ = std::thread(&DriverLogReceiver::Receive, this); +} + +DriverLogReceiver::~DriverLogReceiver() +{ + Stop(); + if (receiverThread_.joinable()) { + receiverThread_.join(); + } +} + +void DriverLogReceiver::Stop() +{ + stopped_ = true; + if (consumer_ != nullptr) { + auto err = consumer_->Close(); + if (!err.OK()) { + YRLOG_WARN("failed to close log consumer {}, err code: {}, err message: {}", jobId_, + fmt::underlying(err.Code()), err.Msg()); + } + YRLOG_DEBUG("log consummer {} closed", jobId_); + if (!jobId_.empty()) { + err = dsStreamStore_->DeleteStream(jobId_); + if (!err.OK()) { + YRLOG_WARN("failed to delete stream {}, err code: {}, err message: {}", jobId_, + fmt::underlying(err.Code()), err.Msg()); + } + YRLOG_DEBUG("log stream {} delete", jobId_); + } + } + Flush(-1); +} + +std::string DriverLogReceiver::Format(const DedupState &state) +{ + std::ostringstream oss; + oss << "[repeated " << state.count << "x across cluster] " << state.line; + return oss.str(); +} + +std::pair DriverLogReceiver::ParseLine(const std::string &input) +{ + size_t scanLen = input.length() > MAX_LOG_SCAN_LENGTH ? MAX_LOG_SCAN_LENGTH : input.length(); + int leftBraceCounter = 0; + int rightBraceCounter = 0; + int runtimeIdStartPos = -1; + int runtimeIdEndPos = -1; + for (size_t i = 0; i < scanLen; i++) { + auto c = input[i]; + if (c == '(') { + leftBraceCounter++; + if (leftBraceCounter == 1) { + runtimeIdStartPos = i + 1; + } + } else if (c == ')') { + rightBraceCounter++; + if (rightBraceCounter == leftBraceCounter) { + runtimeIdEndPos = i; + break; + } + } + } + if (runtimeIdStartPos == -1 || runtimeIdEndPos == -1) { + return {"", ""}; + } + return {input.substr(runtimeIdStartPos, runtimeIdEndPos - runtimeIdStartPos), input.substr(runtimeIdEndPos + 1)}; +} + +void DriverLogReceiver::Deduplicate(std::shared_ptr> lines) +{ + auto now = std::chrono::time_point_cast(std::chrono::steady_clock::now()) + .time_since_epoch() + .count(); + Flush(now); + std::lock_guard lk(mtx); + for (const auto &line : *lines) { + auto [runtimeId, dedupKey] = ParseLine(line); + if (runtimeId.empty() || dedupKey.empty()) { + std::cout << line << std::flush; + continue; + } + if (recent_.find(dedupKey) != recent_.end()) { + recent_[dedupKey].sources.insert(runtimeId); + if (recent_[dedupKey].sources.size() > 1) { + recent_[dedupKey].timestamp = now; + recent_[dedupKey].count++; + recent_[dedupKey].line = line; + } else { + std::cout << line << std::flush; + } + } else { + recent_[dedupKey] = DedupState{now, 0, line, {runtimeId}}; + std::cout << line << std::flush; + } + } +} + +void DriverLogReceiver::Flush(int64_t now) +{ + std::lock_guard lk(mtx); + auto it = recent_.begin(); + while (it != recent_.end()) { + auto duration = now - it->second.timestamp; + if (duration >= 0 && duration < AGG_WINDOWS_S) { + ++it; + continue; + } + if (it->second.count > 1) { + std::cout << Format(it->second) << std::flush; + auto parseRes = ParseLine(it->second.line); + it->second.timestamp = now; + it->second.count = 0; + it->second.sources.clear(); + it->second.sources.insert(std::move(parseRes.first)); + } else if (it->second.count > 0) { + std::cout << it->second.line << std::flush; + // Aggregation wasn't fruitful, print the line and stop aggregating. + it = recent_.erase(it); + continue; + } else if (it->second.count == 0) { + it = recent_.erase(it); + continue; + } + ++it; + } +} + +void DriverLogReceiver::Receive() +{ + SubscriptionConfig config; + auto consumer = std::make_shared(); + auto initErr = dsStreamStore_->CreateStreamConsumer(jobId_, config, consumer, true); + if (!initErr.OK()) { + YRLOG_ERROR("failed to create log stream consumer {}, err code: {}, err message: {}", jobId_, + fmt::underlying(initErr.Code()), initErr.Msg()); + return; + } + consumer_ = consumer; + YRLOG_INFO("begin to receive log from topic: {}, dedup {}", jobId_, dedup_); + + while (!stopped_) { + std::vector elements; + auto err = consumer_->Receive(CONSUMER_TIMEOUT, elements); + if (!err.OK()) { + YRLOG_ERROR("failed to receive log from consumer, err code: {}, err message: {}", + fmt::underlying(err.Code()), err.Msg()); + return; + } + auto lines = std::make_shared>(); + for (auto &ele : elements) { + auto logInfo = std::string(reinterpret_cast(ele.ptr), ele.size); + if (!dedup_) { + std::cout << logInfo << std::flush; + } else { + lines->push_back(logInfo); + } + } + if (dedup_) { + Deduplicate(lines); + } + } + YRLOG_INFO("finishe to receive message from topic: {}", jobId_); +} +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/driverlog/driverlog_receiver.h b/src/libruntime/driverlog/driverlog_receiver.h new file mode 100644 index 0000000..0cf57a3 --- /dev/null +++ b/src/libruntime/driverlog/driverlog_receiver.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/libruntime/err_type.h" +#include "src/libruntime/utils/utils.h" +#include "src/libruntime/objectstore/memory_store.h" +#include "src/libruntime/streamstore/stream_store.h" + +namespace YR { +namespace Libruntime { +struct DedupState { + int64_t timestamp; + int count; + std::string line; + std::unordered_set sources; +}; + +class DriverLogReceiver { +public: + DriverLogReceiver(); + ~DriverLogReceiver(); + void Init(std::shared_ptr store, std::string &jobID, bool dedup); + void Stop(); + +private: + void Receive(); + void Deduplicate(std::shared_ptr> lines); + std::pair ParseLine(const std::string& input); + std::string Format(const DedupState &state); + void Flush(int64_t now); + std::shared_ptr dsStreamStore_; + std::shared_ptr consumer_; + std::thread receiverThread_; + std::atomic stopped_{false}; + std::string jobId_; + bool dedup_ = false; + mutable std::mutex mtx; + std::unordered_map recent_; +}; +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/err_type.h b/src/libruntime/err_type.h index 2524368..c552f89 100644 --- a/src/libruntime/err_type.h +++ b/src/libruntime/err_type.h @@ -82,8 +82,62 @@ enum ErrorCode : int { ERR_FUNCTION_MASTER_NOT_CONFIGURED = 9006, ERR_FUNCTION_MASTER_TIMEOUT = 9007, ERR_CLIENT_TERMINAL_KILLED = 9008, + ERR_ALL_SCHEDULER_UNAVALIABLE = 9009, + ERR_REQ_TIMEOUT_WITHOUT_ACK = 9010, }; +const static std::unordered_map errCodeToString = { + {ErrorCode::ERR_OK, "ERR_OK"}, + {ErrorCode::ERR_PARAM_INVALID, "ERR_PARAM_INVALID"}, + {ErrorCode::ERR_RESOURCE_NOT_ENOUGH, "ERR_RESOURCE_NOT_ENOUGH"}, + {ErrorCode::ERR_INSTANCE_NOT_FOUND, "ERR_INSTANCE_NOT_FOUND"}, + {ErrorCode::ERR_INSTANCE_DUPLICATED, "ERR_INSTANCE_DUPLICATED"}, + {ErrorCode::ERR_INVOKE_RATE_LIMITED, "ERR_INVOKE_RATE_LIMITED"}, + {ErrorCode::ERR_RESOURCE_CONFIG_ERROR, "ERR_RESOURCE_CONFIG_ERROR"}, + {ErrorCode::ERR_INSTANCE_EXITED, "ERR_INSTANCE_EXITED"}, + {ErrorCode::ERR_EXTENSION_META_ERROR, "ERR_EXTENSION_META_ERROR"}, + {ErrorCode::ERR_INSTANCE_SUB_HEALTH, "ERR_INSTANCE_SUB_HEALTH"}, + {ErrorCode::ERR_GROUP_SCHEDULE_FAILED, "ERR_GROUP_SCHEDULE_FAILED"}, + {ErrorCode::ERR_INSTANCE_EVICTED, "ERR_INSTANCE_EVICTED"}, + {ErrorCode::ERR_USER_CODE_LOAD, "ERR_USER_CODE_LOAD"}, + {ErrorCode::ERR_USER_FUNCTION_EXCEPTION, "ERR_USER_FUNCTION_EXCEPTION"}, + {ErrorCode::ERR_REQUEST_BETWEEN_RUNTIME_BUS, "ERR_REQUEST_BETWEEN_RUNTIME_BUS"}, + {ErrorCode::ERR_INNER_COMMUNICATION, "ERR_INNER_COMMUNICATION"}, + {ErrorCode::ERR_INNER_SYSTEM_ERROR, "ERR_INNER_SYSTEM_ERROR"}, + {ErrorCode::ERR_DISCONNECT_FRONTEND_BUS, "ERR_DISCONNECT_FRONTEND_BUS"}, + {ErrorCode::ERR_ETCD_OPERATION_ERROR, "ERR_ETCD_OPERATION_ERROR"}, + {ErrorCode::ERR_BUS_DISCONNECTION, "ERR_BUS_DISCONNECTION"}, + {ErrorCode::ERR_REQUEST_BETWEEN_RUNTIME_FRONTEND, "ERR_REQUEST_BETWEEN_RUNTIME_FRONTEND"}, + {ErrorCode::ERR_REDIS_OPERATION_ERROR, "ERR_REDIS_OPERATION_ERROR"}, + {ErrorCode::ERR_INCORRECT_INIT_USAGE, "ERR_INCORRECT_INIT_USAGE"}, + {ErrorCode::ERR_INIT_CONNECTION_FAILED, "ERR_INIT_CONNECTION_FAILED"}, + {ErrorCode::ERR_DESERIALIZATION_FAILED, "ERR_DESERIALIZATION_FAILED"}, + {ErrorCode::ERR_INSTANCE_ID_EMPTY, "ERR_INSTANCE_ID_EMPTY"}, + {ErrorCode::ERR_GET_OPERATION_FAILED, "ERR_GET_OPERATION_FAILED"}, + {ErrorCode::ERR_INCORRECT_FUNCTION_USAGE, "ERR_INCORRECT_FUNCTION_USAGE"}, + {ErrorCode::ERR_INCORRECT_CREATE_USAGE, "ERR_INCORRECT_CREATE_USAGE"}, + {ErrorCode::ERR_INCORRECT_INVOKE_USAGE, "ERR_INCORRECT_INVOKE_USAGE"}, + {ErrorCode::ERR_INCORRECT_KILL_USAGE, "ERR_INCORRECT_KILL_USAGE"}, + {ErrorCode::ERR_ROCKSDB_FAILED, "ERR_ROCKSDB_FAILED"}, + {ErrorCode::ERR_SHARED_MEMORY_LIMITED, "ERR_SHARED_MEMORY_LIMITED"}, + {ErrorCode::ERR_OPERATE_DISK_FAILED, "ERR_OPERATE_DISK_FAILED"}, + {ErrorCode::ERR_INSUFFICIENT_DISK_SPACE, "ERR_INSUFFICIENT_DISK_SPACE"}, + {ErrorCode::ERR_CONNECTION_FAILED, "ERR_CONNECTION_FAILED"}, + {ErrorCode::ERR_KEY_ALREADY_EXIST, "ERR_KEY_ALREADY_EXIST"}, + {ErrorCode::ERR_CLIENT_ALREADY_CLOSED, "ERR_CLIENT_ALREADY_CLOSED"}, + {ErrorCode::ERR_DATASYSTEM_FAILED, "ERR_DATASYSTEM_FAILED"}, + {ErrorCode::ERR_DEPENDENCY_FAILED, "ERR_DEPENDENCY_FAILED"}, + {ErrorCode::ERR_ACQUIRE_TIMEOUT, "ERR_ACQUIRE_TIMEOUT"}, + {ErrorCode::ERR_FINALIZED, "ERR_FINALIZED"}, + {ErrorCode::ERR_CREATE_RETURN_BUFFER, "ERR_CREATE_RETURN_BUFFER"}, + {ErrorCode::ERR_HEALTH_CHECK_HEALTHY, "ERR_HEALTH_CHECK_HEALTHY"}, + {ErrorCode::ERR_HEALTH_CHECK_FAILED, "ERR_HEALTH_CHECK_FAILED"}, + {ErrorCode::ERR_HEALTH_CHECK_SUBHEALTH, "ERR_HEALTH_CHECK_SUBHEALTH"}, + {ErrorCode::ERR_GENERATOR_FINISHED, "ERR_GENERATOR_FINISHED"}, + {ErrorCode::ERR_FUNCTION_MASTER_NOT_CONFIGURED, "ERR_FUNCTION_MASTER_NOT_CONFIGURED"}, + {ErrorCode::ERR_FUNCTION_MASTER_TIMEOUT, "ERR_FUNCTION_MASTER_TIMEOUT"}, + {ErrorCode::ERR_CLIENT_TERMINAL_KILLED, "ERR_CLIENT_TERMINAL_KILLED"}}; + const static std::unordered_map datasystemErrCodeMap = { {1, ErrorCode::ERR_PARAM_INVALID}, {2, ErrorCode::ERR_PARAM_INVALID}, {3, ErrorCode::ERR_GET_OPERATION_FAILED}, {4, ErrorCode::ERR_ROCKSDB_FAILED}, @@ -115,6 +169,7 @@ public: msg(err.Msg()), isCreate(err.IsCreate()), isTimeout(err.IsTimeout()), + isAckTimeout(err.IsAckTimeout()), stackTraceInfos_(std::move(err.GetStackTraceInfos())), dsStatusCode_(err.GetDsStatusCode()) { @@ -126,12 +181,17 @@ public: msg = err.Msg(); isCreate = err.IsCreate(); isTimeout = err.IsTimeout(); + isAckTimeout = err.IsAckTimeout(); stackTraceInfos_ = std::move(err.GetStackTraceInfos()); dsStatusCode_ = err.GetDsStatusCode(); return *this; } ErrorInfo(ErrorCode errCode, const std::string &errMsg) : code(errCode), msg(errMsg) {} + ErrorInfo(ErrorCode errCode, const std::string &errMsg, bool isAckTimeoutInput) + : code(errCode), msg(errMsg), isAckTimeout(isAckTimeoutInput) + { + } ErrorInfo(ErrorCode errCode, ModuleCode moduleCode, const std::string &errMsg) : code(errCode), mCode(moduleCode), msg(errMsg) { @@ -245,11 +305,21 @@ public: return isTimeout; } + bool IsAckTimeout() const + { + return isAckTimeout; + } + void SetIsTimeout(bool isTimeoutInput) { isTimeout = isTimeoutInput; } + void SetIsAckTimeout(bool isTimeoutInput) + { + isAckTimeout = isTimeoutInput; + } + std::vector GetStackTraceInfos() const { return stackTraceInfos_; @@ -267,6 +337,7 @@ private: bool isCreate = false; bool isTimeout = false; // This information is used to exclude the timeout error when the get operation fails due // to exception IDs. + bool isAckTimeout = false; std::vector stackTraceInfos_; int dsStatusCode_{0}; }; diff --git a/src/libruntime/fmclient/fm_client.cpp b/src/libruntime/fmclient/fm_client.cpp index bf99b97..d112e38 100644 --- a/src/libruntime/fmclient/fm_client.cpp +++ b/src/libruntime/fmclient/fm_client.cpp @@ -79,14 +79,14 @@ std::unordered_map ProcessResources( } result[resource.first] = value; } else { - YRLOG_DEBUG("unknow type {}: of {}", resource.second.type(), resource.first); + YRLOG_DEBUG("unknow type {}: of {}", fmt::underlying(resource.second.type()), resource.first); continue; } } return result; } -ErrorInfo ParseQueryResponseToRgUnit(const std::string &result, ResourceGroupUnit &rgUnit) +ErrorInfo ParseQueryResponseToRgUnit(const std::string &result, std::shared_ptr rgUnit) { QueryResourceGroupResponse resp; auto success = resp.ParseFromString(result); @@ -144,11 +144,23 @@ ErrorInfo ParseQueryResponseToRgUnit(const std::string &result, ResourceGroupUni } rgInfo.bundles.push_back(std::move(bdInfo)); } - rgUnit.resourceGroups[rgInfo.name] = std::move(rgInfo); + rgUnit->resourceGroups[rgInfo.name] = std::move(rgInfo); } return ErrorInfo(); } +std::unordered_map> ProcessNodeLabels( + const ::google::protobuf::Map &nodeLabels) +{ + std::unordered_map> result; + for (auto &counter : nodeLabels) { + for (auto &labels : counter.second.items()) { + result[counter.first].push_back(labels.first); + } + } + return result; +} + ErrorInfo ParseQueryResponse(const std::string &result, std::vector &res) { QueryResourcesInfoResponse resp; @@ -162,6 +174,7 @@ ErrorInfo ParseQueryResponse(const std::string &result, std::vector GetResourceGroupTableByHttpClient(std::s headers[std::string("Type")] = std::string("protobuf"); std::string body; req.SerializeToString(&body); + auto isExit = std::make_shared(false); auto asyncNotify = std::make_shared(); - ResourceGroupUnit res{}; + auto res = std::make_shared(); c->SubmitInvokeRequest( POST, GLOBAL_QUERY_RESOURCE_GROUP_TABLE, headers, body, reqId, - [&res, reqId, asyncNotify](const std::string &result, const boost::beast::error_code &errorCode, + [res, reqId, asyncNotify, isExit](const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + if (*isExit) { + return; + } auto err = CheckResponseCode(errorCode, statusCode, result, *reqId); if (err.OK()) { err = ParseQueryResponseToRgUnit(result, res); @@ -194,15 +211,13 @@ std::pair GetResourceGroupTableByHttpClient(std::s ss << "get request timeout: " << HTTP_REQUEST_TIMEOUT << ", requestId: " << *reqId; auto notifyErr = asyncNotify->WaitForNotificationWithTimeout( absl::Seconds(HTTP_REQUEST_TIMEOUT), ErrorInfo(ErrorCode::ERR_FUNCTION_MASTER_TIMEOUT, ss.str())); - if (notifyErr.Code() == ErrorCode::ERR_FUNCTION_MASTER_TIMEOUT) { - c->Cancel(); - } - return std::make_pair(notifyErr, res); + *isExit = true; + return std::make_pair(notifyErr, *res); } -ErrorInfo ParseQueryNamedInstancesResponse(const std::string &result, QueryNamedInsResponse &resp) +ErrorInfo ParseQueryNamedInstancesResponse(const std::string &result, std::shared_ptr resp) { - if (!resp.ParseFromString(result)) { + if (!resp->ParseFromString(result)) { YRLOG_WARN("Failed to parse QueryNamedInstances response: {}", result); return ErrorInfo(ErrorCode::ERR_PARAM_INVALID, "failed to parse QueryNamedInstances response"); } @@ -216,16 +231,19 @@ std::pair GetNamedInstancesByHttpClient(std::s req.set_requestid(*reqId); std::string body; req.SerializeToString(&body); - + auto isExit = std::make_shared(false); std::unordered_map headers = {{"Content-Type", "application/protobuf"}}; - QueryNamedInsResponse resp; + auto resp = std::make_shared(); auto asyncNotify = std::make_shared(); c->SubmitInvokeRequest( GET, INSTANCE_MANAGER_QUERY_NAMED_INSTANCES, headers, body, reqId, - [&resp, asyncNotify, reqId](const std::string &result, const boost::beast::error_code &errorCode, + [resp, asyncNotify, reqId, isExit](const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + if (*isExit) { + return; + } auto err = CheckResponseCode(errorCode, statusCode, result, *reqId); if (err.OK()) { err = ParseQueryNamedInstancesResponse(result, resp); @@ -235,13 +253,10 @@ std::pair GetNamedInstancesByHttpClient(std::s std::stringstream ss; ss << "get named instances request timeout: " << HTTP_REQUEST_TIMEOUT << ", requestId: " << *reqId; - auto notifyErr = asyncNotify->WaitForNotificationWithTimeout( absl::Seconds(HTTP_REQUEST_TIMEOUT), ErrorInfo(ErrorCode::ERR_FUNCTION_MASTER_TIMEOUT, ss.str())); - if (notifyErr.Code() == ErrorCode::ERR_FUNCTION_MASTER_TIMEOUT) { - c->Cancel(); - } - return {notifyErr, resp}; + *isExit = true; + return {notifyErr, *resp}; } std::pair> GetResourcesByHttpClient(std::shared_ptr c) @@ -249,6 +264,7 @@ std::pair> GetResourcesByHttpClient(std::sh auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); auto req = BuildGetResourcesReq(*requestId); auto headers = BuildGetResourcesHeaders(); + auto isExit = std::make_shared(false); std::string body; req.SerializeToString(&body); std::vector res; @@ -256,8 +272,11 @@ std::pair> GetResourcesByHttpClient(std::sh YRLOG_DEBUG("start to get resources by http client, request id: {}.", *requestId); c->SubmitInvokeRequest( GET, GLOBAL_SCHEDULER_QUERY_RESOURCES, headers, body, requestId, - [&res, requestId, asyncNotify](const std::string &result, const boost::beast::error_code &errorCode, - const uint statusCode) { + [&res, requestId, asyncNotify, isExit](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + if (*isExit) { + return; + } auto err = CheckResponseCode(errorCode, statusCode, result, *requestId); if (err.OK()) { err = ParseQueryResponse(result, res); @@ -268,9 +287,7 @@ std::pair> GetResourcesByHttpClient(std::sh ss << "get request timeout: " << HTTP_REQUEST_TIMEOUT << ", requestId: " << *requestId; auto notifyErr = asyncNotify->WaitForNotificationWithTimeout( absl::Seconds(HTTP_REQUEST_TIMEOUT), ErrorInfo(ErrorCode::ERR_FUNCTION_MASTER_TIMEOUT, ss.str())); - if (notifyErr.Code() == ErrorCode::ERR_FUNCTION_MASTER_TIMEOUT) { - c->Cancel(); - } + *isExit = true; return std::make_pair(notifyErr, res); } @@ -302,7 +319,7 @@ void FMClient::Stop() if (stopped) { return; } - work_.reset(); + work_->reset(); ioc_->stop(); if (iocThread_) { if (iocThread_->joinable()) { @@ -325,7 +342,7 @@ ErrorInfo FMClient::ActivateMasterClientIfNeed() } if (activeMasterHttpClient_ == nullptr || !activeMasterHttpClient_->Available() || !activeMasterHttpClient_->IsConnActive()) { - activeMasterHttpClient_ = std::make_shared(ioc_); + activeMasterHttpClient_ = InitCtxAndHttpClient(); std::vector result; YR::utility::Split(activeMasterAddr_, result, ':'); if (result.size() != IP_ADDR_SIZE) { @@ -406,83 +423,6 @@ std::pair> FMClient::GetResources() std::vector()); } -std::shared_ptr FMClient::GetCurrentHttpClient() -{ - InitHttpClientIfNeeded(); - - if (httpClients_.empty()) { - YRLOG_DEBUG("no http client available"); - return nullptr; - } - - if (!currentMaster_.empty() && httpClients_.find(currentMaster_) != httpClients_.end()) { - return httpClients_[currentMaster_]; - } - - auto it = httpClients_.begin(); - currentMaster_ = it->first; - return it->second; -} - -std::shared_ptr FMClient::GetNextHttpClient() -{ - if (currentMaster_.empty()) { - return GetCurrentHttpClient(); - } - - auto it = httpClients_.find(currentMaster_); - if (it == httpClients_.end()) { - return GetCurrentHttpClient(); - } - - ++it; - if (it == httpClients_.end()) { - it = httpClients_.begin(); - } - - currentMaster_ = it->first; - it->second->ReInit(); // for http client fault tolerance - return it->second; -} - -void FMClient::InitHttpClientIfNeeded(void) -{ - if (libConfig_->functionMasters.empty()) { - YRLOG_DEBUG("functiom masters addresses are not configured"); - return; - } - - if (libConfig_->functionMasters.size() == httpClients_.size()) { - YRLOG_DEBUG("all functiom masters cliets are already initialized, size: {}", - libConfig_->functionMasters.size()); - return; - } - - if (!iocThread_) { - iocThread_ = std::make_unique([&] { ioc_->run(); }); - } - - for (const auto &m : libConfig_->functionMasters) { - if (httpClients_.find(m) != httpClients_.end()) { - YRLOG_DEBUG("function master {} is already initialized", m); - continue; - } - - auto c = InitCtxAndHttpClient(); - std::vector result; - YR::utility::Split(m, result, ':'); - if (result.size() != IP_ADDR_SIZE) { - YRLOG_ERROR("Invalid ip addr {}", m); - continue; - } - ConnectionParam param; - param.ip = result[0]; - param.port = result[1]; - c->Init(param); - httpClients_[m] = c; - } -} - std::shared_ptr FMClient::InitCtxAndHttpClient(void) { ErrorInfo err; diff --git a/src/libruntime/fmclient/fm_client.h b/src/libruntime/fmclient/fm_client.h index 07ccc90..b0724db 100644 --- a/src/libruntime/fmclient/fm_client.h +++ b/src/libruntime/fmclient/fm_client.h @@ -37,7 +37,8 @@ using QueryResourceGroupRequest = ::messages::QueryResourceGroupRequest; using QueryResourceGroupResponse = ::messages::QueryResourceGroupResponse; using ResourceGroupInfo = ::messages::ResourceGroupInfo; using SubscribeActiveMasterCb = std::function; - +ErrorInfo CheckResponseCode(const boost::beast::error_code &errorCode, const uint statusCode, const std::string &result, + const std::string &requestId); const std::string GLOBAL_SCHEDULER_QUERY_RESOURCES = "/global-scheduler/resources"; const std::string INSTANCE_MANAGER_QUERY_NAMED_INSTANCES = "/instance-manager/named-ins"; const std::string GLOBAL_QUERY_RESOURCE_GROUP_TABLE = "/resource-group/rgroup"; @@ -46,12 +47,14 @@ class FMClient { public: FMClient() : ioc_(std::make_shared()) { - work_ = std::make_unique(*ioc_); + work_ = std::make_unique>( + boost::asio::make_work_guard(*ioc_)); } FMClient(const std::shared_ptr config) : libConfig_(config), ioc_(std::make_shared()) { - work_ = std::make_unique(*ioc_); + work_ = std::make_unique>( + boost::asio::make_work_guard(*ioc_)); enableMTLS_ = config->enableMTLS; } ~FMClient() @@ -72,9 +75,6 @@ public: void RemoveActiveMaster(); private: - std::shared_ptr GetCurrentHttpClient(); - std::shared_ptr GetNextHttpClient(); - void InitHttpClientIfNeeded(void); std::shared_ptr InitCtxAndHttpClient(void); void MockResp(); @@ -83,8 +83,8 @@ private: std::string currentMaster_; std::shared_ptr ioc_; std::unique_ptr iocThread_; - std::unique_ptr work_; - bool enableMTLS_; + std::unique_ptr> work_; + bool enableMTLS_{false}; std::mutex activeMasterMu_; std::condition_variable condVar_; std::string activeMasterAddr_; diff --git a/src/libruntime/fsclient/fs_client.cpp b/src/libruntime/fsclient/fs_client.cpp index 20abcae..2ef8e98 100644 --- a/src/libruntime/fsclient/fs_client.cpp +++ b/src/libruntime/fsclient/fs_client.cpp @@ -16,6 +16,7 @@ #include "src/libruntime/fsclient/fs_client.h" #include "src/libruntime/fsclient/fs_intf_impl.h" +#include "src/libruntime/gwclient/gw_client.h" namespace YR { namespace Libruntime { @@ -143,5 +144,18 @@ void FSClient::CreateRGroupAsync(const CreateResourceGroupRequest &req, CreateRe { return this->fsIntf->CreateRGroupAsync(req, callback, timeoutSec); } + +void FSClient::EraseIntf(const std::string &id) +{ + fsIntf->RemoveInsRtIntf(id); +} + +bool FSClient::IsHealth() +{ + if (!fsIntf) { + return false; + } + return fsIntf->IsHealth(); +} } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/fsclient/fs_client.h b/src/libruntime/fsclient/fs_client.h index b876150..fff31bf 100644 --- a/src/libruntime/fsclient/fs_client.h +++ b/src/libruntime/fsclient/fs_client.h @@ -54,6 +54,8 @@ public: void RemoveInsRtIntf(const std::string &instanceId); void CreateRGroupAsync(const CreateResourceGroupRequest &req, CreateResourceGroupCallBack callback, int timeoutSec = -1); + void EraseIntf(const std::string &id); + bool IsHealth(); private: std::shared_ptr fsIntf; diff --git a/src/libruntime/fsclient/fs_intf.cpp b/src/libruntime/fsclient/fs_intf.cpp index 605082e..6cdd5e7 100644 --- a/src/libruntime/fsclient/fs_intf.cpp +++ b/src/libruntime/fsclient/fs_intf.cpp @@ -18,8 +18,8 @@ #include +#include "src/dto/config.h" #include "src/utility/logger/logger.h" - namespace YR { namespace Libruntime { FSIntf::FSIntf(const FSIntfHandlers &handlers) : handlers(handlers) @@ -35,8 +35,7 @@ FSIntf::FSIntf(const FSIntfHandlers &handlers) : handlers(handlers) this->handlers.heartbeat = std::bind(&FSIntf::HandleHeartbeat, this, std::placeholders::_1); } - this->noitfyExecutor.Init(NOTIFY_THREAD_POOL_SIZE, "fs.notify"); - this->checkpointRecoverExecutor.Init(CKPT_RCVR_THREAD_POOL_SIZE, "fs.ckpt_rcvr"); + this->noitfyExecutor.Init(YR::Libruntime::Config::Instance().YR_NOTIFY_THREAD_POOL_SIZE(), "fs.notify"); this->shutdownExecutor.Init(SHUTDOWN_THREAD_POOL_SIZE, "fs.shutdown"); this->signalExecutor.Init(SIGNAL_THREAD_POOL_SIZE, "fs.signal"); if (!this->syncHeartbeat) { @@ -69,6 +68,7 @@ void FSIntf::Clear() void FSIntf::ReceiveRequestLoop(void) { + this->checkpointRecoverExecutor.Init(CKPT_RCVR_THREAD_POOL_SIZE, "fs.ckpt_rcvr"); this->callReceiver.InitAndRun(); } @@ -138,13 +138,13 @@ void FSIntf::HandleCallRequest(const std::shared_ptr &req, Call resp.set_code(code); resp.set_message(msg); YRLOG_DEBUG("send init call response, request ID: {}, code {}, message {}", - req->Immutable().requestid(), resp.code(), resp.message()); + req->Immutable().requestid(), fmt::underlying(resp.code()), resp.message()); callback(resp); } else { callback(resp); this->handlers.init(req); YRLOG_DEBUG("send init call response , request ID: {}, code {}, message {}", - req->Immutable().requestid(), resp.code(), resp.message()); + req->Immutable().requestid(), fmt::underlying(resp.code()), resp.message()); } } else { if (!status.WaitInitialized()) { @@ -152,13 +152,13 @@ void FSIntf::HandleCallRequest(const std::shared_ptr &req, Call resp.set_code(code); resp.set_message(msg); YRLOG_DEBUG("after wait initialized, send call response, request ID: {}, code {}, message {}", - req->Immutable().requestid(), resp.code(), resp.message()); + req->Immutable().requestid(), fmt::underlying(resp.code()), resp.message()); callback(resp); } else { callback(resp); this->handlers.call(req); YRLOG_DEBUG("send call response , request ID: {}, code {}, message {}", - req->Immutable().requestid(), resp.code(), resp.message()); + req->Immutable().requestid(), fmt::underlying(resp.code()), resp.message()); } } if (resp.code() != common::ERR_NONE) { @@ -249,7 +249,7 @@ void FSIntf::HandleSignalRequest(const SignalRequest &req, SignalCallBack callba { this->signalExecutor.Handle( [this, req, callback]() { - YRLOG_DEBUG("recieve signal req, signal is {}, payload is {}", req.signal(), req.payload()); + YRLOG_DEBUG("receive signal req, signal is {}, payload is {}", req.signal(), req.payload()); auto resp = this->handlers.signal(req); callback(resp); }, diff --git a/src/libruntime/fsclient/fs_intf.h b/src/libruntime/fsclient/fs_intf.h index 3f8033a..3dbe30a 100644 --- a/src/libruntime/fsclient/fs_intf.h +++ b/src/libruntime/fsclient/fs_intf.h @@ -83,7 +83,7 @@ using SubscriptionPayload = ::core_service::SubscriptionPayload; using NotificationPayload = ::core_service::NotificationPayload; using InstanceTermination = ::core_service::InstanceTermination; using FunctionMasterObserve = ::core_service::FunctionMasterObserve; -using KillCallBack = std::function; +using KillCallBack = std::function; using ExitRequest = ::core_service::ExitRequest; using ExitResponse = ::core_service::ExitResponse; @@ -123,7 +123,6 @@ using HeartbeatResponse = ::runtime_service::HeartbeatResponse; using Arg = common::Arg; using Arg_ArgType = common::Arg_ArgType; -const int NOTIFY_THREAD_POOL_SIZE = 2; const int CKPT_RCVR_THREAD_POOL_SIZE = 1; const int SHUTDOWN_THREAD_POOL_SIZE = 1; const int SIGNAL_THREAD_POOL_SIZE = 10; @@ -283,6 +282,7 @@ public: void HandleHeartbeatRequest(const HeartbeatRequest &req, HeartbeatCallBack callback); int WaitRequestEmpty(uint64_t gracePeriodSec); void SetInitialized(); + virtual bool IsHealth() = 0; protected: void Clear(); @@ -369,6 +369,9 @@ private: bool SetShuttingDown() { absl::MutexLock lock(&this->mu); + if (err.first == common::ErrorCode::ERR_NONE) { + err = std::make_pair(common::ErrorCode::ERR_INSTANCE_EXITED, "instance has already exited"); + } if (state != SHUTDOWN) { state = SHUTTING_DOWN; } diff --git a/src/libruntime/fsclient/fs_intf_impl.cpp b/src/libruntime/fsclient/fs_intf_impl.cpp index 856d35c..8deab51 100644 --- a/src/libruntime/fsclient/fs_intf_impl.cpp +++ b/src/libruntime/fsclient/fs_intf_impl.cpp @@ -18,6 +18,7 @@ #include "src/dto/config.h" #include "src/dto/status.h" +#include "src/libruntime/traceadaptor/trace_adapter.h" #include "src/libruntime/utils/utils.h" #include "src/utility/logger/logger.h" #include "src/utility/notification_utility.h" @@ -162,7 +163,8 @@ std::pair, bool> FSIntfImpl::UpdateRetryInterval(c auto wr = it->second; ++wr->retryCount; wr->remainTimeoutSec -= wr->retryIntervalSec; - if (wr->remainTimeoutSec <= 0) { // Current is not need to retry, because there is no time to wait response. + // Current is not need to retry, because there is no time to wait response. + if (wr->remainTimeoutSec <= 0 && !wr->ackReceived) { wiredRequests.erase(it); return std::make_pair(wr, true); } @@ -205,8 +207,9 @@ bool FSIntfImpl::NeedRepeat(const std::string requestId) auto [wr, expired] = UpdateRetryInterval(requestId); if (expired) { if (wr != nullptr && wr->callback != nullptr) { - YRLOG_ERROR("RPC request retry expired. request ID: {}", requestId); + YRLOG_ERROR("RPC request retry expired, request ID: {}, is ack received: {}", requestId, wr->ackReceived); ErrorInfo err(ERR_REQUEST_BETWEEN_RUNTIME_BUS, "Response timeout, request ID is " + requestId); + err.SetIsAckTimeout(true); StreamingMessage fake; wr->callback(fake, err, [this, requestId](bool needEraseWiredReq) { if (needEraseWiredReq) { @@ -216,12 +219,10 @@ bool FSIntfImpl::NeedRepeat(const std::string requestId) } return false; } - if (wr && wr->ackReceived) { YRLOG_DEBUG(" {} has received ack, no need retry", requestId); return false; } - return true; } @@ -232,11 +233,12 @@ void FSIntfImpl::WriteCallback(const std::string requestId, const ErrorInfo &err } if (IsCommunicationError(::common::ErrorCode(err.Code()))) { - YRLOG_ERROR("Communicate fails for request({}) errcode({}), msg({})", requestId, err.Code(), err.Msg()); + YRLOG_ERROR("Communicate fails for request({}) errcode({}), msg({})", requestId, fmt::underlying(err.Code()), + err.Msg()); return; } - YRLOG_DEBUG("send grpc request failed for request: {}, err code is {}, err msg is {}", requestId, err.Code(), - err.Msg()); + YRLOG_DEBUG("send grpc request failed for request: {}, err code is {}, err msg is {}", requestId, + fmt::underlying(err.Code()), err.Msg()); auto wr = EraseWiredRequest(requestId); if (wr != nullptr && wr->callback != nullptr) { StreamingMessage fakeMsg; @@ -248,11 +250,14 @@ void FSIntfImpl::GroupCreateAsync(const CreateRequests &reqs, CreateRespsCallbac CreateCallBack callback, int timeoutSec) { auto reqId = reqs.requestid(); + auto tenantId = reqs.tenantid(); auto traceId = reqs.traceid(); - auto respCallback = [reqId, traceId, createRespCallback](const StreamingMessage &createResps, ErrorInfo status, - std::function needEraseWiredReq) { - YRLOG_DEBUG("Receive group create responses, request ID:{}, trace ID:{}", reqId, traceId); + auto respCallback = [reqId, tenantId, traceId, createRespCallback](const StreamingMessage &createResps, + ErrorInfo status, + std::function needEraseWiredReq) { + YRLOG_DEBUG("Receive group create responses, request ID:{}, tenant id {}, trace ID:{}", reqId, tenantId, + traceId); if (status.OK() && createResps.has_creatersps()) { if (createResps.creatersps().code() == common::ERR_NONE) { createRespCallback(createResps.creatersps()); @@ -272,8 +277,8 @@ void FSIntfImpl::GroupCreateAsync(const CreateRequests &reqs, CreateRespsCallbac return; }; auto notifyCb = [callback](const NotifyRequest &req, const ErrorInfo &err) { - YRLOG_DEBUG("Receive group create notify request, request ID:{}, error code: {}, error message: {}", - req.requestid(), req.code(), req.message()); + YRLOG_DEBUG("Receive group create notify request, request ID:{}, code: {}, message: {}", req.requestid(), + fmt::underlying(req.code()), req.message()); callback(req); }; auto wr = std::make_shared(respCallback, notifyCb, timerWorker); @@ -310,10 +315,18 @@ void FSIntfImpl::CreateAsync(const CreateRequest &req, CreateRespCallback create auto reqId = std::make_shared(req.requestid()); auto funcName = req.function(); auto traceId = std::make_shared(req.traceid()); - auto respCallback = [this, reqId, funcName, traceId, createRespCallback]( + auto designatedInstanceID = req.designatedinstanceid(); + auto span = TraceAdapter::GetInstance().StartSpan( + "Create", {{"requestID", *reqId}, {"funcName", funcName}, {"designatedInstanceID", designatedInstanceID}}); + auto respCallback = [this, reqId, funcName, traceId, createRespCallback, span]( const StreamingMessage &createResp, ErrorInfo status, std::function needEraseWiredReq) { YRLOG_DEBUG("Receive create response, function: {}, request ID:{}, trace ID:{}", funcName, *reqId, *traceId); + if (createResp.has_creatersp()) { + span->SetAttribute("respCode", createResp.creatersp().code()); + span->SetAttribute("respMessage", createResp.creatersp().message()); + span->SetAttribute("respInstanceID", createResp.creatersp().instanceid()); + } if (status.OK() && createResp.has_creatersp()) { if (createResp.creatersp().code() == common::ERR_NONE) { createRespCallback(createResp.creatersp()); @@ -334,9 +347,12 @@ void FSIntfImpl::CreateAsync(const CreateRequest &req, CreateRespCallback create return; }; - auto notifyCallback = [callback](const NotifyRequest &req, const ErrorInfo &err) { - YRLOG_DEBUG("Receive create notify request, request ID:{}, error code: {}, error message: {}", req.requestid(), - req.code(), req.message()); + auto notifyCallback = [callback, span](const NotifyRequest &req, const ErrorInfo &err) { + YRLOG_DEBUG("Receive create notify request, request ID:{}, code: {}, message: {}", req.requestid(), + fmt::underlying(req.code()), req.message()); + span->SetAttribute("notifyCode", req.code()); + span->SetAttribute("notifyMessage", req.message()); + span->End(); callback(req); }; @@ -379,10 +395,17 @@ void FSIntfImpl::InvokeAsync(const std::shared_ptr &req, Invo auto reqId = std::make_shared(req->Immutable().requestid()); auto instanceId = std::make_shared(req->Immutable().instanceid()); auto traceId = std::make_shared(req->Immutable().traceid()); - auto respCallback = [this, callback, reqId, instanceId, traceId](const StreamingMessage &invokeResp, - ErrorInfo status, - std::function needEraseWiredReq) { + auto funcName = std::make_shared(req->Immutable().function()); + auto span = TraceAdapter::GetInstance().StartSpan( + "Invoke", {{"requestID", *reqId}, {"funcName", *funcName}, {"instanceId", *instanceId}}); + auto respCallback = [this, callback, reqId, instanceId, traceId, span]( + const StreamingMessage &invokeResp, ErrorInfo status, + std::function needEraseWiredReq) { YRLOG_DEBUG("Receive invoke response, instance: {}, request ID:{}, trace ID:{}", *instanceId, *reqId, *traceId); + if (invokeResp.has_invokersp()) { + span->SetAttribute("respCode", invokeResp.invokersp().code()); + span->SetAttribute("respMessage", invokeResp.invokersp().message()); + } if (status.OK() && invokeResp.has_invokersp()) { if (invokeResp.invokersp().code() == common::ERR_NONE) { needEraseWiredReq(false); @@ -397,15 +420,25 @@ void FSIntfImpl::InvokeAsync(const std::shared_ptr &req, Invo notifyRequest.set_code(common::ErrorCode(status.Code())); notifyRequest.set_message("invoke response failed, request id: " + (*reqId) + ", msg: " + status.Msg()); notifyRequest.set_requestid(*reqId); - YRLOG_ERROR( - "Receive invoke response, instance: {}, request ID:{}, trace ID:{}, error code: {}, error message: {}", - *instanceId, *reqId, *traceId, status.Code(), status.Msg()); + YRLOG_ERROR("Receive invoke response, instance: {}, request ID:{}, trace ID:{}, code: {}, message: {}", + *instanceId, *reqId, *traceId, fmt::underlying(status.Code()), status.Msg()); needEraseWiredReq(true); - callback(notifyRequest, ErrorInfo()); + + this->HandleNotifyRequest( + notifyRequest, + [callback, status, notifyRequest]() -> NotifyResponse { + callback(notifyRequest, status); + return NotifyResponse(); + }, + [](const NotifyResponse &resp) -> void { return; }); return; }; - auto notifyCallback = [callback](const NotifyRequest &req, const ErrorInfo &err) { - YRLOG_DEBUG("Receive invoke notify request, request ID:{}, code: {}", req.requestid(), req.code()); + auto notifyCallback = [callback, span](const NotifyRequest &req, const ErrorInfo &err) { + YRLOG_DEBUG("Receive invoke notify request, request ID:{}, code: {}", req.requestid(), + fmt::underlying(req.code())); + span->SetAttribute("notifyCode", req.code()); + span->SetAttribute("notifyMessage", req.message()); + span->End(); callback(req, err); }; auto wr = std::make_shared(respCallback, notifyCallback, timerWorker, req->Immutable().instanceid()); @@ -413,9 +446,14 @@ void FSIntfImpl::InvokeAsync(const std::shared_ptr &req, Invo wr->SetRequestID(*reqId); wr = SaveWiredRequest(req->Immutable().requestid(), wr); std::weak_ptr weakWr(wr); - - auto sendMsgHandler = [self(shared_from_this()), reqId, req, weakWr]() { - if (auto wr = weakWr.lock(); wr) { + std::shared_ptr self; + try { + self = shared_from_this(); + } catch (const std::exception &e) { + YRLOG_ERROR("FSIntfImpl has been destructed"); + } + auto sendMsgHandler = [self, reqId, req, weakWr]() { + if (auto wr = weakWr.lock(); wr && self) { auto messageId = YR::utility::IDGenerator::GenMessageId(req->Immutable().requestid(), static_cast(wr->retryCount)); YRLOG_DEBUG("Send invoke message, message id {}", messageId); @@ -475,8 +513,9 @@ void FSIntfImpl::CallResultAsync(const std::shared_ptr re CallResultAck resp; resp.set_code(common::ErrorCode(status.Code())); resp.set_message(status.Msg()); - YRLOG_DEBUG("Receive call result ack, instance: {}, request ID:{}, error code: {}, error message: {}", - req->Immutable().instanceid(), req->Immutable().requestid(), status.Code(), status.Msg()); + YRLOG_DEBUG("Receive call result ack, instance: {}, request ID:{}, code: {}, message: {}", + req->Immutable().instanceid(), req->Immutable().requestid(), fmt::underlying(status.Code()), + status.Msg()); needEraseWiredReq(true); callback(resp); return; @@ -516,12 +555,13 @@ void FSIntfImpl::CallResultAsync(const std::shared_ptr re } if (self->IsCommunicationError(::common::ErrorCode(status.Code()))) { (void)self->SaveWiredRequest(*reqId, wr); - YRLOG_ERROR("Communicate fails for request({}) errcode({}), msg({})", *reqId, status.Code(), - status.Msg()); + YRLOG_ERROR("Communicate fails for request({}) errcode({}), msg({})", *reqId, + fmt::underlying(status.Code()), status.Msg()); return; } YRLOG_DEBUG_IF(!status.OK(), "send grpc call result failed for {}, err code is {}, err msg is {}", - *reqId, status.Code(), status.Msg()); + *reqId, fmt::underlying(status.Code()), + status.Msg()); (void)self->EraseWiredRequest(*reqId); if (wr->callback != nullptr) { wr->callback(CALL_RESULT_ACK, status, [](bool) {}); @@ -535,12 +575,22 @@ void FSIntfImpl::CallResultAsync(const std::shared_ptr re void FSIntfImpl::KillAsync(const KillRequest &req, KillCallBack callback, int timeoutSec) { - auto reqId = YR::utility::IDGenerator::GenRequestId(); - auto respCallback = [callback, reqId](const StreamingMessage &killResp, ErrorInfo status, - std::function needEraseWiredReq) { + std::string reqId = req.requestid(); + if (reqId.empty()) { + reqId = YR::utility::IDGenerator::GenRequestId(); + } + auto span = TraceAdapter::GetInstance().StartSpan( + "Kill", {{"requestID", reqId}, {"instanceID", req.instanceid()}, {"signal", req.signal()}}); + auto respCallback = [callback, reqId, span](const StreamingMessage &killResp, ErrorInfo status, + std::function needEraseWiredReq) { YRLOG_DEBUG("Receive kill response, request ID:{}", reqId); + if (killResp.has_killrsp()) { + span->SetAttribute("respCode", killResp.killrsp().code()); + span->SetAttribute("respMessage", killResp.killrsp().message()); + } + span->End(); if (status.OK() && killResp.has_killrsp()) { - callback(killResp.killrsp(), ErrorInfo()); + callback(killResp.killrsp(), status); needEraseWiredReq(true); return; } @@ -548,7 +598,7 @@ void FSIntfImpl::KillAsync(const KillRequest &req, KillCallBack callback, int ti KillResponse resp; resp.set_code(common::ErrorCode(status.Code())); resp.set_message(status.Msg()); - YRLOG_DEBUG("Receive kill response, request ID:{}, error code: {}, error message: {}", reqId, status.Code(), + YRLOG_DEBUG("Receive kill response, request ID:{}, code: {}, msg: {}", reqId, fmt::underlying(status.Code()), status.Msg()); callback(resp, status); needEraseWiredReq(true); @@ -561,8 +611,8 @@ void FSIntfImpl::KillAsync(const KillRequest &req, KillCallBack callback, int ti std::weak_ptr weak(wr); auto sendMsgHandler = [this, req, reqId, weak]() { if (auto thisPtr = weak.lock(); thisPtr) { - auto messageId = YR::utility::IDGenerator::GenMessageId(reqId, static_cast(thisPtr->retryCount)); - this->Write(GenStreamMsg(messageId, req), std::bind(&FSIntfImpl::WriteCallback, this, reqId, _1)); + auto msgId = YR::utility::IDGenerator::GenMessageId(reqId, static_cast(thisPtr->retryCount)); + this->Write(GenStreamMsg(msgId, req), std::bind(&FSIntfImpl::WriteCallback, this, reqId, _1)); } }; @@ -574,11 +624,11 @@ void FSIntfImpl::KillAsync(const KillRequest &req, KillCallBack callback, int ti if (wiredReq != nullptr) { YRLOG_ERROR("Request timeout, start exec notify callback, request ID : {}", reqId); StreamingMessage fake; + ErrorInfo errorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, ModuleCode::CORE, + "kill request timeout, requestId: " + reqId); + errorInfo.SetIsTimeout(true); if (wiredReq->callback != nullptr) { - auto timeoutErr = ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, ModuleCode::CORE, - "kill request timeout, requestId: " + reqId); - timeoutErr.SetIsTimeout(true); - wiredReq->callback(fake, timeoutErr, [](bool needEraseWiredReq) {}); + wiredReq->callback(fake, errorInfo, [](bool needEraseWiredReq) {}); } EraseWiredRequest(reqId); } @@ -599,8 +649,8 @@ void FSIntfImpl::ExitAsync(const ExitRequest &req, ExitCallBack callback) } ExitResponse resp; - YRLOG_DEBUG("Receive exit response, request ID:{}, error code: {}, error message: {}", reqId, status.Code(), - status.Msg()); + YRLOG_DEBUG("Receive exit response, request ID:{}, code: {}, message: {}", reqId, + fmt::underlying(status.Code()), status.Msg()); needEraseWiredReq(true); callback(resp); return; @@ -636,8 +686,8 @@ void FSIntfImpl::StateSaveAsync(const StateSaveRequest &req, StateSaveCallBack c StateSaveResponse resp; resp.set_code(common::ErrorCode(status.Code())); resp.set_message(status.Msg()); - YRLOG_DEBUG("Receive save response, request ID:{}, error code: {}, error message: {}", reqId, status.Code(), - status.Msg()); + YRLOG_DEBUG("Receive save response, request ID:{}, code: {}, message: {}", reqId, + fmt::underlying(status.Code()), status.Msg()); callback(resp); needEraseWiredReq(true); return; @@ -673,8 +723,8 @@ void FSIntfImpl::StateLoadAsync(const StateLoadRequest &req, StateLoadCallBack c StateLoadResponse resp; resp.set_code(common::ErrorCode(status.Code())); resp.set_message(status.Msg()); - YRLOG_DEBUG("Receive load response, request ID:{}, error code: {}, error message: {}", reqId, status.Code(), - status.Msg()); + YRLOG_DEBUG("Receive load response, request ID:{}, code: {}, message: {}", reqId, + fmt::underlying(status.Code()), status.Msg()); callback(resp); needEraseWiredReq(true); return; @@ -712,7 +762,7 @@ void FSIntfImpl::CreateRGroupAsync(const CreateResourceGroupRequest &req, Create resp.set_code(common::ErrorCode(status.Code())); resp.set_message(status.Msg()); YRLOG_DEBUG("Receive create resource group response, request ID:{}, error code: {}, error message: {}", reqId, - status.Code(), status.Msg()); + fmt::underlying(status.Code()), status.Msg()); callback(resp); needEraseWiredReq(true); return; @@ -795,8 +845,8 @@ void FSIntfImpl::NewRTIntfClient(const std::string &dstInstanceID, const NotifyR .disconnectedCb = std::bind(&FSIntfImpl::NotifyDisconnected, this, _1)}, ProtocolType::GRPC); rtIntf->RegisterMessageHandler(rtMsgHdlrs); - (void)fsInrfMgr->Emplace(dstInstanceID, rtIntf); (void)rtIntf->Start(); + (void)fsInrfMgr->Emplace(dstInstanceID, rtIntf); } void FSIntfImpl::RecvNotifyRequest(const std::string &from, const std::shared_ptr &message) @@ -808,7 +858,7 @@ void FSIntfImpl::RecvNotifyRequest(const std::string &from, const std::shared_pt if (wr != nullptr) { dstInstanceID = wr->dstInstanceID; } - if (dstInstanceID != FUNCTION_PROXY && enableDirectCall && message->notifyreq().has_runtimeinfo() && + if (dstInstanceID != FUNCTION_PROXY && message->notifyreq().has_runtimeinfo() && !message->notifyreq().runtimeinfo().serveripaddr().empty() && wr != nullptr) { NewRTIntfClient(wr->dstInstanceID, message->notifyreq()); } @@ -883,7 +933,7 @@ bool FSIntfImpl::NeedResendReq(const std::shared_ptr &message) return IsCommunicationError(message->rgrouprsp().code()); default: YRLOG_ERROR("grpc body not match, messageid: {}, body case: {}", message->messageid(), - message->body_case()); + fmt::underlying(message->body_case())); return false; } } @@ -893,8 +943,8 @@ void FSIntfImpl::RecvCreateOrInvokeResponse(const std::string &, const std::shar auto reqId = YR::utility::IDGenerator::GetRequestIdFromMsg(message->messageid()); YRLOG_DEBUG("receive create or invoke response, msg id {}, req id {}", message->messageid(), reqId); if (NeedResendReq(message)) { - YRLOG_DEBUG("create or invoke response has communication error, need resend req, meesage id is {}", - message->messageid()); + YRLOG_WARN("create or invoke response has communication error, need resend req, meesage id is {}", + message->messageid()); return; } auto wr = GetWiredRequest(reqId, true); @@ -912,7 +962,7 @@ void FSIntfImpl::RecvResponse(const std::string &, const std::shared_ptrmessageid()); YRLOG_DEBUG("req id {}", reqId); if (NeedResendReq(message)) { - YRLOG_DEBUG("response has communication error, need resend req, meesage id is {}", message->messageid()); + YRLOG_WARN("response has communication error, need resend req, meesage id is {}", message->messageid()); return; } auto wr = EraseWiredRequest(reqId); @@ -1059,13 +1109,14 @@ std::pair FSIntfImpl::NotifyDriv status = stub->DiscoverDriver(&ctx, req, &resp); if (status.error_code() != grpc::StatusCode::OK) { i++; - YRLOG_DEBUG("Discover driver call grpc status code: {}, retry index: {}", status.error_code(), i); + YRLOG_DEBUG("Discover driver call grpc status code: {}, retry index: {}", + fmt::underlying(status.error_code()), i); sleep(retryInternal); } } while (status.error_code() != grpc::StatusCode::OK && i < maxRetryTime); if (status.error_code() != grpc::StatusCode::OK) { - YRLOG_ERROR("Discover driver call grpc status code: {}", status.error_code()); + YRLOG_ERROR("Discover driver call grpc status code: {}", fmt::underlying(status.error_code())); return std::make_pair(resp, ErrorInfo(ErrorCode::ERR_INIT_CONNECTION_FAILED, ModuleCode::RUNTIME, "failed to connect to cluster " + fsIp + ":" + std::to_string(fsPort))); } @@ -1103,7 +1154,7 @@ ErrorInfo FSIntfImpl::Start(const std::string &jobID, const std::string &instanc "POD_IP env should be properly set, while client mode & direct call enabled on cloud"); } listeningIpAddr = Config::Instance().POD_IP(); - selfPort = Config::Instance().DERICT_RUNTIME_SERVER_PORT(); + selfPort = 0; // service listened port is dynamic set by grpc } this->instanceID = instanceID.empty() ? "driver-" + jobID : instanceID; this->runtimeID = runtimeID; @@ -1184,8 +1235,21 @@ void FSIntfImpl::Stop(void) void FSIntfImpl::RemoveInsRtIntf(const std::string &instanceId) { - YRLOG_DEBUG("{} remove rt intf", instanceId); fsInrfMgr->Remove(instanceId); } + +bool FSIntfImpl::IsHealth() +{ + if (!fsInrfMgr || stopped) { + return false; + } + auto fsIntf = fsInrfMgr->GetSystemIntf(); + if (!fsIntf) { + YRLOG_WARN("function system client reader writer is nullptr, return false directly"); + return false; + } + return fsIntf->IsHealth(); +} + } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/fsclient/fs_intf_impl.h b/src/libruntime/fsclient/fs_intf_impl.h index 61a8cec..5bcfd7e 100644 --- a/src/libruntime/fsclient/fs_intf_impl.h +++ b/src/libruntime/fsclient/fs_intf_impl.h @@ -253,6 +253,16 @@ inline std::shared_ptr GenStreamMsg(const std::string &message return streamMsg; } +template <> +inline std::shared_ptr GenStreamMsg(const std::string &messageId, + const CreateResourceGroupResponse &msg) +{ + auto streamMsg = std::make_shared(); + streamMsg->mutable_rgrouprsp()->CopyFrom(msg); + streamMsg->set_messageid(messageId); + return streamMsg; +} + template <> inline std::shared_ptr GenStreamMsg(const std::string &messageId, const CreateRequests &msg) { @@ -318,6 +328,7 @@ inline std::shared_ptr GenStreamMsg(const std::string &message } struct WiredRequest : public std::enable_shared_from_this { + WiredRequest() = default; WiredRequest(std::function)> cb, std::shared_ptr tw) : callback(cb), notifyCallback(nullptr), retryCount(0), timer_(nullptr), timerWorkerWeak(tw) @@ -422,7 +433,7 @@ struct WiredRequest : public std::enable_shared_from_this { int currentRetryInterval = std::min(requestACKTimeout, Config::Instance().REQUEST_ACK_ACC_MAX_SEC()); auto weakThis = weak_from_this(); if (timer_ != nullptr) { - timer_->cancel(); + return; } this->timer_ = timerWorker->CreateTimer(currentRetryInterval * YR::Libruntime::MILLISECOND_UNIT, 1, [weakThis] { if (auto thisPtr = weakThis.lock(); thisPtr) { @@ -520,6 +531,7 @@ public: enableDirectCall = true; } void RemoveInsRtIntf(const std::string &instanceId) override; + bool IsHealth() override; protected: void Write(const std::shared_ptr &msg, std::function callback = nullptr); diff --git a/src/libruntime/fsclient/fs_intf_manager.cpp b/src/libruntime/fsclient/fs_intf_manager.cpp index d55ac99..1e27ac2 100644 --- a/src/libruntime/fsclient/fs_intf_manager.cpp +++ b/src/libruntime/fsclient/fs_intf_manager.cpp @@ -109,8 +109,8 @@ void FSIntfManager::Remove(const std::string &instanceID) absl::WriterMutexLock lock(&this->mu); if (auto iter = this->rtIntfs.find(instanceID); iter != this->rtIntfs.end()) { intfNeedStop = iter->second; + (void)this->rtIntfs.erase(instanceID); } - (void)this->rtIntfs.erase(instanceID); } if (intfNeedStop) { intfNeedStop->Stop(); diff --git a/src/libruntime/fsclient/fs_intf_reader_writer.h b/src/libruntime/fsclient/fs_intf_reader_writer.h index b64ba6a..34c8dcb 100644 --- a/src/libruntime/fsclient/fs_intf_reader_writer.h +++ b/src/libruntime/fsclient/fs_intf_reader_writer.h @@ -50,6 +50,7 @@ public: virtual void Write(const std::shared_ptr &msg, std::function callback = nullptr, std::function preWrite = nullptr) = 0; + virtual bool IsHealth() = 0; void SetDiscoverDriverCb(const DiscoverDriverCb &cb) { discoverDriverCb = cb; @@ -65,7 +66,7 @@ public: if (it != msgHdlrs.end()) { it->second(dstInstance, message); } else { - YRLOG_ERROR("Invalid received message body type {} from {}", static_cast(message->body_case()), + YRLOG_ERROR("Invalid received message body type {} from {}", fmt::underlying(message->body_case()), dstInstance); } } diff --git a/src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.cpp b/src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.cpp index c8dc83e..3558386 100644 --- a/src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.cpp +++ b/src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.cpp @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include #include "fs_intf_grpc_client_reader_writer.h" +#include "src/libruntime/utils/grpc_utils.h" #include "src/utility/string_utility.h" namespace YR::Libruntime { @@ -62,7 +62,10 @@ bool FSIntfGrpcClientReaderWriter::GrpcRead(const std::shared_ptrRead(message.get()); + if (stream_) { + return stream_->Read(message.get()); + } + return false; } bool FSIntfGrpcClientReaderWriter::GrpcWrite(const std::shared_ptr &request) @@ -118,7 +121,7 @@ void FSIntfGrpcClientReaderWriter::ReconnectHandler() if (!StreamEmpty()) { WritesDone(); auto status = Finish(); - YRLOG_INFO("grpc status code: {}, msg: {}", status.error_code(), status.error_message()); + YRLOG_INFO("grpc status code: {}, msg: {}", fmt::underlying(status.error_code()), status.error_message()); // instance id not match if (status.error_code() == grpc::StatusCode::INVALID_ARGUMENT) { abnormal_.store(true); @@ -182,8 +185,8 @@ ErrorInfo FSIntfGrpcClientReaderWriter::Start() ErrorInfo FSIntfGrpcClientReaderWriter::Reconnect() { - YRLOG_INFO("begin to reconnect {}, abnormal_ {}", dstInstance, abnormal_); - clientsMgr->ReleaseFsConn(ip, port); + YRLOG_INFO("begin to reconnect {}, abnormal_ {}", dstInstance, abnormal_.load()); + clientsMgr->ReleaseFsConn(ip, port, dstInstance); return NewGrpcClientWithRetry(1); } @@ -211,9 +214,10 @@ void FSIntfGrpcClientReaderWriter::Stop() } catch (std::exception &e) { YRLOG_ERROR("failed to finalize grpc stream, exception: {}", e.what()); } - auto err = clientsMgr->ReleaseFsConn(ip, port); + auto err = clientsMgr->ReleaseFsConn(ip, port, dstInstance); if (!err.OK()) { - YRLOG_ERROR("failed to release function system conn, code:({}), message({})", err.Code(), err.Msg()); + YRLOG_ERROR("failed to release function system conn, code:({}), message({})", fmt::underlying(err.Code()), + err.Msg()); } YRLOG_DEBUG("connection of {} closed, ip={}, port={}", dstInstance, ip, port); } @@ -221,7 +225,7 @@ void FSIntfGrpcClientReaderWriter::Stop() ErrorInfo FSIntfGrpcClientReaderWriter::NewGrpcClientWithRetry(const int retryTimes) { grpc::EnableDefaultHealthCheckService(true); - auto [channel, error] = clientsMgr->GetFsConn(ip, port); + auto [channel, error] = clientsMgr->GetFsConn(ip, port, dstInstance); if (!error.OK()) { YRLOG_ERROR( "failed to get grpc connection from fsconns to instance({}), " @@ -237,7 +241,7 @@ ErrorInfo FSIntfGrpcClientReaderWriter::NewGrpcClientWithRetry(const int retryTi ErrorInfo FSIntfGrpcClientReaderWriter::BuildStream(std::shared_ptr channel) { stub_ = runtime_rpc::RuntimeRPC::NewStub(channel); - if (dstInstance == FUNCTION_PROXY) { + if (dstInstance == FUNCTION_PROXY || security->IsFsAuthEnable()) { stream_ = stub_->MessageStream(context.get()); if (stream_ == nullptr) { return ErrorInfo(ErrorCode::ERR_CONNECTION_FAILED, "failed to build posix stream"); @@ -259,16 +263,37 @@ ErrorInfo FSIntfGrpcClientReaderWriter::BuildStreamWithRetry(std::shared_ptr(); context->AddMetadata(INSTANCE_ID_META, srcInstance); context->AddMetadata(RUNTIME_ID_META, runtimeID); + SensitiveValue token; + std::string ak; + SensitiveValue sk; + if (security != nullptr) { + security->GetToken(token); + security->GetAKSK(ak, sk); + } + YRLOG_INFO( + "start build grpc stream, src instance is {}, dst instance is {}, token is empty: {}, ak is: {}, sk is empty: " + "{}", + srcInstance, dstInstance, token.Empty(), ak, sk.Empty()); + if (!token.Empty()) { + context->AddMetadata(TOKEN_META, std::string(token.GetData(), token.GetSize())); + } + if (!ak.empty() && !sk.Empty()) { + context->AddMetadata(TENANT_ACCESS_KEY, ak); + auto signature = SignStreamingMessage(ak, sk); + context->AddMetadata(SIGNATURE, signature); + auto timestamp = GetCurrentUTCTime(); + context->AddMetadata(TIMESTAMP, timestamp); + } context->AddMetadata(SOURCE_ID_META, srcInstance); context->AddMetadata(DST_ID_META, dstInstance); ErrorInfo err; for (int32_t retry = 0; retry < retryTimes; ++retry) { if (channel == nullptr) { - auto ret = clientsMgr->NewFsConn(ip, port, security); + auto ret = clientsMgr->NewFsConn(ip, port, security, dstInstance); if (!ret.second.OK()) { err = ret.second; YRLOG_ERROR("get new fs connection err, ip is {}, port is {}, err code is {}, err msg is {}", ip, port, - err.Code(), err.Msg()); + fmt::underlying(err.Code()), err.Msg()); continue; } channel = ret.first; @@ -291,7 +316,7 @@ ErrorInfo FSIntfGrpcClientReaderWriter::BuildStreamWithRetry(std::shared_ptrReleaseFsConn(ip, port); + clientsMgr->ReleaseFsConn(ip, port, dstInstance); } if (!err.OK()) { isConnect_.store(false); @@ -299,4 +324,25 @@ ErrorInfo FSIntfGrpcClientReaderWriter::BuildStreamWithRetry(std::shared_ptrGetFsConn(ip, port, dstInstance); + if (!res.first) { + YRLOG_WARN("there is no channel of dstInstance:{}, address is {}:{}, return false directly", dstInstance, ip, + port); + return false; + } + clientsMgr->ReleaseFsConn(ip, port, dstInstance); + if (res.first->GetState(false) != GRPC_CHANNEL_READY) { + YRLOG_WARN("channel of dstInstance:{}, address is {}:{}, cheeck health failed, return false directly", + dstInstance, ip, port); + return false; + } + return true; +} } // namespace YR::Libruntime diff --git a/src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.h b/src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.h index c397d35..5aa3a4b 100644 --- a/src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.h +++ b/src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.h @@ -30,6 +30,9 @@ public: const std::string INSTANCE_ID_META = "instance_id"; const std::string RUNTIME_ID_META = "runtime_id"; const std::string TOKEN_META = "authorization"; + const std::string TENANT_ACCESS_KEY = "access_key"; + const std::string SIGNATURE = "signature"; + const std::string TIMESTAMP = "timestamp"; const std::string SOURCE_ID_META = "source_id"; const std::string DST_ID_META = "dst_id"; const std::string JOB_ID_META = "job_id"; @@ -60,6 +63,7 @@ public: ErrorInfo BuildStreamWithRetry(std::shared_ptr channel, const int retryTimes = RETRY_TIME); void PreStart() override {} ErrorInfo Start() override; + bool IsHealth() override; ErrorInfo Reconnect(); void Stop() override; bool GrpcRead(const std::shared_ptr &message) override; diff --git a/src/libruntime/fsclient/grpc/fs_intf_grpc_reader_writer.h b/src/libruntime/fsclient/grpc/fs_intf_grpc_reader_writer.h index b88377e..e384b10 100644 --- a/src/libruntime/fsclient/grpc/fs_intf_grpc_reader_writer.h +++ b/src/libruntime/fsclient/grpc/fs_intf_grpc_reader_writer.h @@ -47,6 +47,7 @@ public: virtual void PreStart() = 0; virtual ErrorInfo Start() = 0; void Init(); + virtual bool IsHealth() = 0; void Stop() override; bool Available() override; bool Abnormal() override; diff --git a/src/libruntime/fsclient/grpc/fs_intf_grpc_server_reader_writer.cpp b/src/libruntime/fsclient/grpc/fs_intf_grpc_server_reader_writer.cpp index 779fc03..98754ad 100644 --- a/src/libruntime/fsclient/grpc/fs_intf_grpc_server_reader_writer.cpp +++ b/src/libruntime/fsclient/grpc/fs_intf_grpc_server_reader_writer.cpp @@ -23,14 +23,14 @@ namespace YR { namespace Libruntime { -void FSIntfGrpcServerReaderWriter::PreStart() +bool FSIntfGrpcServerReaderWriter::IsBatched() { - isConnect_.store(true); + return batchStream_ != nullptr; } -bool FSIntfGrpcServerReaderWriter::IsBatched() +void FSIntfGrpcServerReaderWriter::PreStart() { - return batchStream_ != nullptr; + isConnect_.store(true); } ErrorInfo FSIntfGrpcServerReaderWriter::Start() @@ -92,6 +92,11 @@ bool FSIntfGrpcServerReaderWriter::GrpcBatchWrite(const std::shared_ptrWrite(*request.get()); } +bool FSIntfGrpcServerReaderWriter::IsHealth() +{ + return true; +} + void FSIntfGrpcServerReaderWriter::ClearStream() { this->stream_ = nullptr; diff --git a/src/libruntime/fsclient/grpc/fs_intf_grpc_server_reader_writer.h b/src/libruntime/fsclient/grpc/fs_intf_grpc_server_reader_writer.h index 7fb59b7..52afaa8 100644 --- a/src/libruntime/fsclient/grpc/fs_intf_grpc_server_reader_writer.h +++ b/src/libruntime/fsclient/grpc/fs_intf_grpc_server_reader_writer.h @@ -53,6 +53,7 @@ public: bool GrpcBatchRead(const std::shared_ptr &message) override; bool GrpcBatchWrite(const std::shared_ptr &request) override; bool IsBatched() override; + bool IsHealth() override; private: void ClearStream(); diff --git a/src/libruntime/fsclient/grpc/grpc_posix_service.cpp b/src/libruntime/fsclient/grpc/grpc_posix_service.cpp index 9bb6352..f80f3e6 100644 --- a/src/libruntime/fsclient/grpc/grpc_posix_service.cpp +++ b/src/libruntime/fsclient/grpc/grpc_posix_service.cpp @@ -15,14 +15,28 @@ */ #include "grpc_posix_service.h" +#include +#include "src/libruntime/utils/hash_utils.h" #include "src/libruntime/utils/utils.h" #include "fs_intf_grpc_server_reader_writer.h" +#include "src/libruntime/fsclient/grpc/posix_auth_interceptor.h" namespace YR { namespace Libruntime { +const double TIMESTAMP_EXPIRE_DURATION_SECONDS = 60; +using SensitiveValue = datasystem::SensitiveValue; const std::string FUNCTION_PROXY = "function-proxy"; +struct PosixMetaData { + std::string instanceID; + std::string runtimeID; + std::string token; + std::string tenantAccessKey; + std::string timestamp; + std::string signature; +}; + bool GrpcPosixService::CompareInstanceID(grpc::ServerContext *context) const { const std::multimap metadata = context->client_metadata(); @@ -64,12 +78,25 @@ ErrorInfo GrpcPosixService::Start() builder.SetMaxSendMessageSize(maxGrpcSize); builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); builder.SetDefaultCompressionLevel(GRPC_COMPRESS_LEVEL_NONE); + if (security_ != nullptr && security_->IsFsAuthEnable()) { + std::vector> interceptorCreators; + auto interceptorCreator = new PosixServerAuthInterceptorFactory(); + interceptorCreator->RegisterSecurity(this->security_); + // Only StreamingMessage support AKSK. + interceptorCreators.push_back(std::unique_ptr(interceptorCreator)); + builder.experimental().SetInterceptorCreators(std::move(interceptorCreators)); + } + if (security_ != nullptr) { + YRLOG_INFO("start grpc service with auth, fs auth enable: {}", security_->IsFsAuthEnable()); + } server = builder.BuildAndStart(); if (server == nullptr) { YRLOG_ERROR("Failed to start grpc server, errno: {}, listeningIpAddr: {}, selfPort: {}, listeningPort: {}", errno, listeningIpAddr, selfPort, listeningPort); return ErrorInfo(ErrorCode::ERR_INIT_CONNECTION_FAILED, ModuleCode::RUNTIME, "failed to start grpc server"); } + YRLOG_INFO("successful start grpc server, listeningIpAddr: {}, selfPort: {}, listeningPort: {}", listeningIpAddr, + selfPort, listeningPort); return {}; } @@ -100,15 +127,50 @@ void GrpcPosixService::Stop() YRLOG_INFO("service of {}. listening({}:{}) is stopped", instanceID, listeningIpAddr, listeningPort); } +bool GrpcPosixService::VerifySrcWihtAkSk(const std::multimap &metadata) +{ + PosixMetaData data; + for (const auto &it : metadata) { + auto key = std::string(it.first.data(), it.first.length()); + if (key == ACCESS_KEY) { + data.tenantAccessKey = std::string(it.second.data(), it.second.length()); + } + if (key == TIMESTAMP) { + data.timestamp = std::string(it.second.data(), it.second.length()); + } + if (key == SIGNATURE) { + data.signature = std::string(it.second.data(), it.second.length()); + } + } + auto currentTimeStamp = GetCurrentUTCTime(); + if (IsLaterThan(currentTimeStamp, data.timestamp, TIMESTAMP_EXPIRE_DURATION_SECONDS)) { + YRLOG_ERROR("failed to verify timestamp, difference is more than 1 min {} vs {}", currentTimeStamp, + data.timestamp); + return false; + } + std::string ak; + SensitiveValue sk; + security_->GetAKSK(ak, sk); + std::string signKey = ak + ":" + data.timestamp; + if (GetHMACSha256(sk, signKey) != data.signature) { + YRLOG_ERROR("failed to verify timestamp, signature isn't the same, ak is: {}, sk is empty: {}", ak, sk.Empty()); + return false; + } + return true; +} + grpc::Status GrpcPosixService::MessageStream(grpc::ServerContext *context, grpc::ServerReaderWriter *stream) { // guard for this to avoid deconstructed [[maybe_unused]] auto raii = shared_from_this(); if (stopped) { - return grpc::Status(grpc::StatusCode::UNAVAILABLE, "service was already closed"); + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "MessageStream was already closed"); } const std::multimap metadata = context->client_metadata(); + if (this->security_->IsFsAuthEnable() && !VerifySrcWihtAkSk(metadata)) { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "MessageStream: the ak sk is not correct"); + } auto iter = metadata.find("source_id"); bool isDirect = iter != metadata.end(); if (!isDirect) { @@ -155,9 +217,12 @@ grpc::Status GrpcPosixService::BatchMessageStream( // guard for this to avoid deconstructed [[maybe_unused]] auto raii = shared_from_this(); if (stopped) { - return grpc::Status(grpc::StatusCode::UNAVAILABLE, "service was already closed"); + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "BatchMessageStream was already closed"); } const std::multimap metadata = context->client_metadata(); + if (this->security_->IsFsAuthEnable() && !VerifySrcWihtAkSk(metadata)) { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "BatchMessageStream: the ak sk is not correct"); + } auto iter = metadata.find("source_id"); bool isDirect = iter != metadata.end(); if (!isDirect) { diff --git a/src/libruntime/fsclient/grpc/grpc_posix_service.h b/src/libruntime/fsclient/grpc/grpc_posix_service.h index b97b1a5..29765f2 100644 --- a/src/libruntime/fsclient/grpc/grpc_posix_service.h +++ b/src/libruntime/fsclient/grpc/grpc_posix_service.h @@ -105,6 +105,7 @@ private: int disconnectedTimeout); void StartDisconnectTimer(const std::string &remote, int disconnectedTimeout); void StopDisconnectTimer(const std::string &remote); + bool VerifySrcWihtAkSk(const std::multimap &metadata); std::string instanceID; std::string runtimeID; diff --git a/src/libruntime/fsclient/grpc/posix_auth_interceptor.cpp b/src/libruntime/fsclient/grpc/posix_auth_interceptor.cpp index e69de29..f7aef84 100644 --- a/src/libruntime/fsclient/grpc/posix_auth_interceptor.cpp +++ b/src/libruntime/fsclient/grpc/posix_auth_interceptor.cpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/libruntime/fsclient/grpc/posix_auth_interceptor.h" +#include "src/libruntime/utils/grpc_utils.h" +#include "src/utility/logger/logger.h" +namespace YR { +namespace Libruntime { + +void PosixAuthInterceptor::InterceptCommon(::grpc::experimental::InterceptorBatchMethods *methods) +{ + if (stopped) { + return; + } + if (methods->QueryInterceptionHookPoint(::grpc::experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + const auto *message = dynamic_cast( + static_cast(methods->GetSendMessage())); + if (message == nullptr) { + methods->Proceed(); + return; + } + StreamingMessage signedMessage; + if (message->has_heartbeatrsp()) { + methods->Proceed(); + return; + } + signedMessage.CopyFrom(*message); + if (SignWithAKSK(signedMessage)) { + methods->ModifySendMessage(&signedMessage); + methods->Proceed(); + } else { + YRLOG_ERROR("failed to sign message: {}, instance: {}, runtime: {}", message->DebugString(), instanceID_, + runtimeID_); + methods->FailHijackedSendMessage(); + } + return; + } + + if (methods->QueryInterceptionHookPoint(::grpc::experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + auto *message = + dynamic_cast(static_cast<::google::protobuf::Message *>(methods->GetRecvMessage())); + if (message == nullptr) { + methods->Proceed(); + return; + } + if (message->has_heartbeatreq()) { + methods->Proceed(); + return; + } + if (!VerifyAKSK(*message)) { + YRLOG_ERROR("failed to verify message: {}, instance: {}, runtime: {}", message->DebugString(), instanceID_, + runtimeID_); + // clear message, if return directly, the server will be blocked + message->Clear(); + } + } + methods->Proceed(); +} + +void PosixClientAuthInterceptor::Intercept(::grpc::experimental::InterceptorBatchMethods *clientMethods) +{ + if (clientMethods->QueryInterceptionHookPoint( + ::grpc::experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto metadata = clientMethods->GetSendInitialMetadata(); + if (metadata == nullptr) { + clientMethods->Proceed(); + return; + } + + for (const auto &metaIter : *metadata) { + auto clientKey = std::string(metaIter.first.data(), metaIter.first.length()); + if (clientKey == INSTANCE_ID) { + instanceID_ = std::string(metaIter.second.data(), metaIter.second.length()); + } + if (clientKey == RUNTIME_ID) { + runtimeID_ = std::string(metaIter.second.data(), metaIter.second.length()); + } + if (clientKey == ACCESS_KEY) { + tenantAccessKey_ = std::string(metaIter.second.data(), metaIter.second.length()); + } + } + + clientMethods->Proceed(); + return; + } + InterceptCommon(clientMethods); +} + +void PosixServerAuthInterceptor::Intercept(::grpc::experimental::InterceptorBatchMethods *methods) +{ + if (methods->QueryInterceptionHookPoint(::grpc::experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto metadata = methods->GetSendInitialMetadata(); + if (metadata == nullptr) { + methods->Proceed(); + return; + } + + for (const auto &metaIte : *metadata) { + auto key = std::string(metaIte.first.data(), metaIte.first.length()); + if (key == INSTANCE_ID) { + instanceID_ = std::string(metaIte.second.data(), metaIte.second.length()); + } + if (key == RUNTIME_ID) { + runtimeID_ = std::string(metaIte.second.data(), metaIte.second.length()); + } + if (key == ACCESS_KEY) { + tenantAccessKey_ = std::string(metaIte.second.data(), metaIte.second.length()); + } + } + + methods->Proceed(); + return; + } + InterceptCommon(methods); +} + +bool PosixAuthInterceptor::VerifyAKSK(const StreamingMessage &message) +{ + if (security_ != nullptr && !security_->IsFsAuthEnable()) { + return true; + } + + auto tenantAccessKey = message.metadata().find(ACCESS_KEY); + if (tenantAccessKey == message.metadata().end() || tenantAccessKey->second.empty()) { + YRLOG_ERROR("failed to verify message: {}, failed to find access_key in meta-data, instance: {}, runtime: {}", + message.DebugString(), instanceID_, runtimeID_); + return false; + } + + std::string ak; + SensitiveValue sk; + security_->GetAKSK(ak, sk); + + if (ak.empty() || sk.Empty()) { + YRLOG_ERROR( + "failed to verify message {}, ak or sk is emptgy, failed to get cred from security, instance {}, " + "runtime {}", + message.DebugString(), instanceID_, runtimeID_); + return false; + } + + return VerifyStreamingMessage(ak, sk, message); +} + +bool PosixAuthInterceptor::SignWithAKSK(StreamingMessage &message) +{ + if (security_ != nullptr && !security_->IsFsAuthEnable()) { + YRLOG_WARN("is fs auth enable: {}, no need sign with ak sk", security_->IsFsAuthEnable()); + return true; + } + + std::string ak; + SensitiveValue sk; + security_->GetAKSK(ak, sk); + + if (ak.empty() || sk.Empty()) { + YRLOG_ERROR("failed to sign message: {}, failed to get cred from iam, instance: {}, runtime: {}", + message.DebugString(), instanceID_, runtimeID_); + return false; + } + + if (!SignStreamingMessage(ak, sk, message)) { + YRLOG_ERROR("failed to sign message: {}, instance: {}, runtime: {}", message.DebugString(), instanceID_, + runtimeID_); + return false; + } + return true; +} +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/fsclient/grpc/posix_auth_interceptor.h b/src/libruntime/fsclient/grpc/posix_auth_interceptor.h new file mode 100644 index 0000000..6badf38 --- /dev/null +++ b/src/libruntime/fsclient/grpc/posix_auth_interceptor.h @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "src/libruntime/fsclient/fs_intf_manager.h" +#include "src/libruntime/fsclient/protobuf/runtime_rpc.grpc.pb.h" +#include "src/libruntime/utils/security.h" + +namespace YR { +namespace Libruntime { +const std::string ACCESS_KEY = "access_key"; +const std::string TIMESTAMP = "timestamp"; +const std::string SIGNATURE = "signature"; +const std::string INSTANCE_ID = "instance_id"; +const std::string RUNTIME_ID = "runtime_id"; +using StreamingMessage = ::runtime_rpc::StreamingMessage; + +class PosixAuthInterceptor : public ::grpc::experimental::Interceptor { +public: + PosixAuthInterceptor() = default; + virtual ~PosixAuthInterceptor() + { + stopped.store(true); + } + + bool VerifyAKSK(const StreamingMessage &message); + + bool SignWithAKSK(StreamingMessage &message); + + void InterceptCommon(::grpc::experimental::InterceptorBatchMethods *methods); + + void RegisterSecurity(std::shared_ptr security) + { + this->security_ = security; + } + +protected: + std::string runtimeID_; + std::string instanceID_; + std::string tenantAccessKey_; + std::shared_ptr security_; + std::atomic stopped{false}; +}; + +class PosixClientAuthInterceptor : public PosixAuthInterceptor { +public: + explicit PosixClientAuthInterceptor(::grpc::experimental::ClientRpcInfo *info) : info_(info) {} + ~PosixClientAuthInterceptor() override + { + stopped.store(true); + } + void Intercept(::grpc::experimental::InterceptorBatchMethods *methods) override; + +private: + ::grpc::experimental::ClientRpcInfo *info_; +}; + +class PosixServerAuthInterceptor : public PosixAuthInterceptor { +public: + explicit PosixServerAuthInterceptor(::grpc::experimental::ServerRpcInfo *info) : info_(info) {} + ~PosixServerAuthInterceptor() override + { + stopped.store(true); + } + void Intercept(::grpc::experimental::InterceptorBatchMethods *methods) override; + +private: + ::grpc::experimental::ServerRpcInfo *info_; +}; + +class PosixClientAuthInterceptorFactory : public ::grpc::experimental::ClientInterceptorFactoryInterface { +public: + PosixClientAuthInterceptorFactory() = default; + + ~PosixClientAuthInterceptorFactory() override = default; + + void RegisterSecurity(std::shared_ptr security) + { + this->security_ = security; + } + + ::grpc::experimental::Interceptor *CreateClientInterceptor(::grpc::experimental::ClientRpcInfo *info) override + { + auto interceptor = new PosixClientAuthInterceptor(info); + interceptor->RegisterSecurity(security_); + return interceptor; + } + +private: + std::shared_ptr security_; +}; + +class PosixServerAuthInterceptorFactory : public ::grpc::experimental::ServerInterceptorFactoryInterface { +public: + PosixServerAuthInterceptorFactory() = default; + + ~PosixServerAuthInterceptorFactory() override = default; + + void RegisterSecurity(std::shared_ptr security) + { + this->security_ = security; + } + + ::grpc::experimental::Interceptor *CreateServerInterceptor(::grpc::experimental::ServerRpcInfo *info) override + { + auto interceptor = new PosixServerAuthInterceptor(info); + interceptor->RegisterSecurity(security_); + return interceptor; + } + +private: + std::shared_ptr security_; +}; + +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/fsclient/protobuf/bus_service.proto b/src/libruntime/fsclient/protobuf/bus_service.proto index 0f42f4f..8394d5c 100644 --- a/src/libruntime/fsclient/protobuf/bus_service.proto +++ b/src/libruntime/fsclient/protobuf/bus_service.proto @@ -24,6 +24,10 @@ option go_package = "grpc/pb/bus;bus"; // bus service provides APIs to runtime, service BusService { + // notify bus to connect frontend + rpc DiscoverFrontend (DiscoverFrontendRequest) returns (DiscoverFrontendResponse) {} + // query instance info from frontend + rpc QueryInstance (QueryInstanceRequest) returns (QueryInstanceResponse) {} // notify bus to connect driver rpc DiscoverDriver (DiscoverDriverRequest) returns (DiscoverDriverResponse) {} } @@ -43,6 +47,13 @@ message DiscoverDriverResponse { string hostIp = 4; } +message DiscoverFrontendRequest { + string frontendIP = 1; + string frontendPort = 2; +} + +message DiscoverFrontendResponse {} + message QueryInstanceRequest { string instanceID = 1; } diff --git a/src/libruntime/fsclient/protobuf/common.proto b/src/libruntime/fsclient/protobuf/common.proto index 73922a8..069955d 100644 --- a/src/libruntime/fsclient/protobuf/common.proto +++ b/src/libruntime/fsclient/protobuf/common.proto @@ -140,6 +140,7 @@ enum ErrorCode { ERR_RUNTIME_MANAGER_OPERATION_ERROR = 3012; ERR_INSTANCE_MANAGER_OPERATION_ERROR= 3013; ERR_LOCAL_SCHEDULER_ABNORMAL = 3014; + ERR_NPU_FAULT_ERROR = 3016; } enum HealthCheckCode { @@ -168,6 +169,13 @@ message StackTraceElement { map extensions = 5; // extensions for different language } +message TenantCredentials { + bytes accessKey = 1; + bytes secretKey = 2; + bytes dataKey = 3; + bool isCredential = 4; +} + message TLSConfig { bool dsAuthEnable = 1; bool dsEncryptEnable = 2; @@ -185,6 +193,7 @@ message TLSConfig { string salt = 14; string accessKey = 15; // component-level access key string securityKey = 16; // component-level security key + TenantCredentials tenantCredentials = 17; } message HeteroDeviceInfo { diff --git a/src/libruntime/fsclient/protobuf/core_service.proto b/src/libruntime/fsclient/protobuf/core_service.proto index 437667e..5ecfb44 100644 --- a/src/libruntime/fsclient/protobuf/core_service.proto +++ b/src/libruntime/fsclient/protobuf/core_service.proto @@ -23,57 +23,57 @@ import "src/libruntime/fsclient/protobuf/common.proto"; option go_package = "grpc/pb/core;core"; enum AffinityType { - PreferredAffinity = 0; + PreferredAffinity = 0; PreferredAntiAffinity = 1; - RequiredAffinity = 2; - RequiredAntiAffinity = 3; + RequiredAffinity = 2; + RequiredAntiAffinity = 3; } message SchedulingOptions { - int32 priority = 1; + int32 priority = 1; map resources = 2; map extension = 3; // will deprecate in future - map affinity = 4; - common.Affinity scheduleAffinity = 5; - InstanceRange range = 6; - int64 scheduleTimeoutMs = 7; - bool preemptedAllowed = 8; + map affinity = 4; + common.Affinity scheduleAffinity = 5; + InstanceRange range = 6; + int64 scheduleTimeoutMs = 7; + bool preemptedAllowed = 8; // indicated which rgroup submit to string rGroupName = 9; } message InstanceRange { - int32 min = 1; - int32 max = 2; + int32 min = 1; + int32 max = 2; int32 step = 3; } message CreateRequest { - string function = 1; - repeated common.Arg args = 2; + string function = 1; + repeated common.Arg args = 2; SchedulingOptions schedulingOps = 3; - string requestID = 4; - string traceID = 5; - repeated string labels = 6; // "key:value" or "key2" + string requestID = 4; + string traceID = 5; + repeated string labels = 6; // "key:value" or "key2" // optional. if designated instanceID is not empty, the created instance id will be assigned designatedInstanceID string designatedInstanceID = 7; map createOptions = 8; } message CreateResponse { - common.ErrorCode code = 1; - string message = 2; - string instanceID = 3; + common.ErrorCode code = 1; + string message = 2; + string instanceID = 3; } message CreateResponses { - common.ErrorCode code = 1; - string message = 2; + common.ErrorCode code = 1; + string message = 2; repeated string instanceIDs = 3; // used for life cycle management and the unique ID of the corresponding group. // when you want to recycle a group, use signal 4 to send a kill request for the ID. - string groupID = 4; + string groupID = 4; } message CreateRequests { @@ -86,18 +86,18 @@ message CreateRequests { message CreateResourceGroupRequest { common.ResourceGroupSpec rGroupSpec = 1; - string requestID = 2; - string traceID = 3; + string requestID = 2; + string traceID = 3; } message CreateResourceGroupResponse { - common.ErrorCode code = 1; - string message = 2; + common.ErrorCode code = 1; + string message = 2; } message GroupOptions { // group schedule timeout (sec) - int64 timeout = 1; + int64 timeout = 1; // group alias name, this field cannot be used for life cycle management. string groupName = 2; bool sameRunningLifecycle = 3; @@ -107,39 +107,39 @@ message GroupOptions { } message InvokeOptions { - map customTag = 1; + map customTag = 1; } message InvokeRequest { - string function = 1; - repeated common.Arg args = 2; - string instanceID = 3; - string requestID = 4; - string traceID = 5; - repeated string returnObjectIDs = 6; - string spanID = 7; - InvokeOptions invokeOptions = 8; + string function = 1; + repeated common.Arg args = 2; + string instanceID = 3; + string requestID = 4; + string traceID = 5; + repeated string returnObjectIDs = 6; + string spanID = 7; + InvokeOptions invokeOptions = 8; } message InvokeResponse { - common.ErrorCode code = 1; - string message = 2; + common.ErrorCode code = 1; + string message = 2; string returnObjectID = 3; } message CallResult { - common.ErrorCode code = 1; - string message = 2; - string instanceID = 3; - string requestID = 4; - repeated common.SmallObject smallObjects = 5; + common.ErrorCode code = 1; + string message = 2; + string instanceID = 3; + string requestID = 4; + repeated common.SmallObject smallObjects = 5; repeated common.StackTraceInfo stackTraceInfos = 6; - common.RuntimeInfo runtimeInfo = 7; + common.RuntimeInfo runtimeInfo = 7; } message CallResultAck { - common.ErrorCode code = 1; - string message = 2; + common.ErrorCode code = 1; + string message = 2; } message TerminateRequest { @@ -147,38 +147,47 @@ message TerminateRequest { } message TerminateResponse { - common.ErrorCode code = 1; + common.ErrorCode code = 1; string message = 2; } -message ExitRequest {} +message ExitRequest { + common.ErrorCode code = 1; + string message = 2; +} -message ExitResponse {} +message ExitResponse { + common.ErrorCode code = 1; + string message = 2; +} message StateSaveRequest { - bytes state = 1; + bytes state = 1; + string requestID = 2; } message StateSaveResponse { - common.ErrorCode code = 1; - string message = 2; + common.ErrorCode code = 1; + string message = 2; string checkpointID = 3; } message StateLoadRequest { string checkpointID = 1; + string requestID = 2; } message StateLoadResponse { - common.ErrorCode code = 1; + common.ErrorCode code = 1; string message = 2; - bytes state = 3; + bytes state = 3; } message KillRequest { string instanceID = 1; - int32 signal = 2; - bytes payload = 3; + int32 signal = 2; + bytes payload = 3; + string requestID = 4; } message InstanceTermination { @@ -193,15 +202,15 @@ message FunctionMasterEvent { message SubscriptionPayload { oneof Content { - InstanceTermination instanceTermination = 1; // Subscribe to instance termination event - FunctionMasterObserve functionMaster = 2; // Subscribe to function-master election changed + InstanceTermination instanceTermination = 1; // Subscribe to instance termination event + FunctionMasterObserve functionMaster = 2; // Subscribe to function-master election changed } } message UnsubscriptionPayload { oneof Content { - InstanceTermination instanceTermination = 1; // Unsubscribe specified instance's termination event - FunctionMasterObserve functionMaster = 2; // UnSubscribe to function-master election changed + InstanceTermination instanceTermination = 1; // Unsubscribe specified instance's termination event + FunctionMasterObserve functionMaster = 2; // UnSubscribe to function-master election changed } } @@ -213,6 +222,6 @@ message NotificationPayload { } message KillResponse { - common.ErrorCode code = 1; + common.ErrorCode code = 1; string message = 2; } \ No newline at end of file diff --git a/src/libruntime/generator/generator_id_map.h b/src/libruntime/generator/generator_id_map.h new file mode 100644 index 0000000..d4e447d --- /dev/null +++ b/src/libruntime/generator/generator_id_map.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +#include "absl/synchronization/mutex.h" + +namespace YR { +namespace Libruntime { +class GeneratorIdMap { +public: + void AddRecord(const std::string &key, const std::string &value) + { + absl::MutexLock lock(&mu_); + records_.emplace(key, value); + } + + void RemoveRecord(const std::string &key) + { + absl::MutexLock lock(&mu_); + records_.erase(key); + } + + void GetRecord(const std::string &key, std::string &value) + { + absl::MutexLock lock(&mu_); + if (records_.find(key) != records_.end()) { + value = records_[key]; + } + } + + void UpdateRecord(const std::string &key, const std::string &value) + { + absl::MutexLock lock(&mu_); + records_[key] = value; + } + + void GetRecordKeys(std::vector &keys) + { + absl::MutexLock lock(&mu_); + for (auto &record : records_) { + keys.push_back(record.first); + } + } + +private: + absl::Mutex mu_; + std::unordered_map records_ ABSL_GUARDED_BY(mu_); +}; + +class GeneratorIdRecorder { +public: + GeneratorIdRecorder(const std::string &genId, const std::string &srcRuntimeId, std::shared_ptr map) + : genId_(genId), map_(map) + { + if (map_ && !genId_.empty()) { + map_->AddRecord(genId_, srcRuntimeId); + } + } + + virtual ~GeneratorIdRecorder(void) {} + +private: + const std::string genId_; + std::shared_ptr map_; +}; +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/generator/generator_notifier.h b/src/libruntime/generator/generator_notifier.h new file mode 100644 index 0000000..d6c71f3 --- /dev/null +++ b/src/libruntime/generator/generator_notifier.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include "src/dto/data_object.h" +#include "src/libruntime/err_type.h" +#include "src/utility/timer_worker.h" + +namespace YR { +namespace Libruntime { +class GeneratorNotifier { +public: + GeneratorNotifier() = default; + + virtual ~GeneratorNotifier() = default; + + virtual ErrorInfo NotifyResult(const std::string &generatorId, int index, std::shared_ptr resultObj, + const ErrorInfo &resultErr) = 0; + + virtual ErrorInfo NotifyFinished(const std::string &generatorId, int numResults) = 0; + virtual void Initialize(void) = 0; + virtual void Stop(void) = 0; +}; +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/generator/generator_receiver.h b/src/libruntime/generator/generator_receiver.h new file mode 100644 index 0000000..5b43c42 --- /dev/null +++ b/src/libruntime/generator/generator_receiver.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/libruntime/err_type.h" +#include "src/libruntime/utils/utils.h" +#include "src/utility/timer_worker.h" + +namespace YR { +namespace Libruntime { +class GeneratorReceiver { +public: + virtual void Initialize(void) = 0; + virtual void Stop(void) = 0; + virtual void MarkEndOfStream(const std::string &genId, const ErrorInfo &errInfo) = 0; + virtual void AddRecord(const std::string &genId) = 0; +}; +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/generator/stream_generator_notifier.cpp b/src/libruntime/generator/stream_generator_notifier.cpp new file mode 100644 index 0000000..55ffb98 --- /dev/null +++ b/src/libruntime/generator/stream_generator_notifier.cpp @@ -0,0 +1,359 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "stream_generator_notifier.h" + +namespace YR { +namespace Libruntime { + +void StreamGeneratorNotifier::Initialize(void) +{ + std::call_once(flag_, std::bind(&StreamGeneratorNotifier::DoInitializeOnce, this)); +} + +StreamGeneratorNotifier::~StreamGeneratorNotifier(void) +{ + Stop(); + if (timer_ != nullptr) { + timer_->cancel(); + } + if (notifyThread_.joinable()) { + notifyThread_.join(); + } +} + +void StreamGeneratorNotifier::Stop(void) +{ + stopped = true; +} + +void StreamGeneratorNotifier::DoInitializeOnce(void) +{ + notifyThread_ = std::thread(&StreamGeneratorNotifier::Notify, this); + timer_ = timerWorker_->CreateTimer(10000, -1, [this]() { NotifyHeartbeat(); }); // 10s interval +} + +void StreamGeneratorNotifier::Notify() +{ + ErrorInfo err; + while (!stopped) { + { + std::unique_lock lock(mux_); + cv.wait(lock, [this] { return !genQueue_.empty(); }); + } + YRLOG_DEBUG("start process notify info"); + std::vector> datas; + PopBatch(datas); + for (size_t i = 0; i < datas.size(); i++) { + if (datas[i]->isHeartbeat) { + YRLOG_DEBUG("start send heartbeat, generator id: {}", datas[i]->generatorId); + err = NotifyHeartbeatByStream(datas[i]->generatorId); + } else if (datas[i]->finish) { + YRLOG_DEBUG("start send finish {}-{})", datas[i]->generatorId, datas[i]->numberResult); + err = NotifyFinishedByStream(datas[i]->generatorId, datas[i]->numberResult); + } else { + YRLOG_DEBUG("start send result {}-{})", datas[i]->generatorId, datas[i]->index); + err = NotifyResultByStream(datas[i]->generatorId, datas[i]->index, datas[i]->data, datas[i]->err); + } + if (!err.OK()) { + YRLOG_ERROR("failed to notify {} of {}-{}, err code: {}, err message: {}", + datas[i]->isHeartbeat ? "heartbeat" : "result", datas[i]->generatorId, datas[i]->index, + fmt::underlying(err.Code()), err.Msg()); + } + } + } +} + +void StreamGeneratorNotifier::PopBatch(std::vector> &datas) +{ + datas.clear(); + std::lock_guard lk(mux_); + datas.swap(genQueue_); +} + +ErrorInfo StreamGeneratorNotifier::NotifyResult(const std::string &generatorId, int index, + std::shared_ptr resultObj, const ErrorInfo &resultErr) +{ + auto notifyData = std::make_shared(); + notifyData->generatorId = generatorId; + notifyData->index = index; + notifyData->data = resultObj; + notifyData->err = resultErr; + std::lock_guard lk(mux_); + genQueue_.push_back(notifyData); + cv.notify_one(); + return ErrorInfo(); +} + +ErrorInfo StreamGeneratorNotifier::NotifyFinished(const std::string &generatorId, int numResults) +{ + auto notifyData = std::make_shared(); + notifyData->generatorId = generatorId; + notifyData->numberResult = numResults; + notifyData->finish = true; + std::lock_guard lk(mux_); + genQueue_.push_back(notifyData); + cv.notify_one(); + return ErrorInfo(); +} + +ErrorInfo StreamGeneratorNotifier::NotifyHeartbeat() +{ + if (!map_) { + YRLOG_ERROR("null map pointer"); + return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "null pointer"); + } + + std::vector generatorIds; + map_->GetRecordKeys(generatorIds); + + std::lock_guard lk(mux_); + for (auto &generatorId : generatorIds) { + auto notifyData = std::make_shared(); + notifyData->generatorId = generatorId; + notifyData->isHeartbeat = true; + genQueue_.push_back(notifyData); + } + cv.notify_one(); + return ErrorInfo(); +} + +ErrorInfo StreamGeneratorNotifier::NotifyResultByStream(const std::string &generatorId, int index, + std::shared_ptr resultObj, + const ErrorInfo &resultErr) +{ + std::string topic; + auto err = CheckAndGetTopic(generatorId, topic); + if (!err.OK()) { + return err; + } + std::shared_ptr producer; + err = GetOrCreateProducer(topic, producer); + if (!err.OK()) { + YRLOG_ERROR("failed to get producer when notify result, err code: {}, err message: {}", + fmt::underlying(err.Code()), err.Msg()); + return err; + } + + if (index == 0) { + IncreaseProducerReference(topic); + } + + auto res = BuildGeneratorResult(generatorId, index, resultObj, resultErr); + Element ele(reinterpret_cast(res.data()), res.size()); + err = producer->Send(ele); + if (!err.OK()) { + YRLOG_ERROR("failed to send notify result to stream, err code: {}, err message: {}", + fmt::underlying(err.Code()), err.Msg()); + } else { + err = producer->Flush(); + if (!err.OK()) { + YRLOG_ERROR("failed to flush stream, err code: {}, err message: {}", fmt::underlying(err.Code()), + err.Msg()); + } + } + + if (!resultErr.OK()) { + if (!DecreaseProducerReference(topic)) { + RemoveProducer(topic); + } + map_->RemoveRecord(generatorId); + } + + return err; +} + +ErrorInfo StreamGeneratorNotifier::NotifyFinishedByStream(const std::string &generatorId, int numResults) +{ + std::string topic; + auto err = CheckAndGetTopic(generatorId, topic); + if (!err.OK()) { + return err; + } + std::shared_ptr producer; + err = GetOrCreateProducer(topic, producer); + if (!err.OK()) { + YRLOG_ERROR("failed to get producer when notify finished, err code: {}, err message: {}", + fmt::underlying(err.Code()), err.Msg()); + return err; + } + + auto res = BuildGeneratorFinished(generatorId, numResults); + Element ele(reinterpret_cast(res.data()), res.size()); + err = producer->Send(ele); + if (!err.OK()) { + YRLOG_ERROR("failed to send notify finished to stream, err code: {}, err message: {}", + fmt::underlying(err.Code()), err.Msg()); + } else { + err = producer->Flush(); + if (!err.OK()) { + YRLOG_ERROR("failed to flush stream, err code: {}, err message: {}", fmt::underlying(err.Code()), + err.Msg()); + } + } + + if (!DecreaseProducerReference(topic)) { + RemoveProducer(topic); + } + map_->RemoveRecord(generatorId); + return err; +} + +ErrorInfo StreamGeneratorNotifier::NotifyHeartbeatByStream(const std::string &generatorId) +{ + std::string topic; + auto err = CheckAndGetTopic(generatorId, topic); + if (!err.OK()) { + return err; + } + std::shared_ptr producer; + err = GetOrCreateProducer(topic, producer); + if (!err.OK()) { + YRLOG_ERROR("failed to get producer when notify heartbeat, err code: {}, err message: {}", + fmt::underlying(err.Code()), err.Msg()); + return err; + } + IncreaseProducerReference(topic); + + auto res = BuildGeneratorHeartbeat(generatorId); + Element ele(reinterpret_cast(res.data()), res.size()); + err = producer->Send(ele); + if (!err.OK()) { + YRLOG_ERROR("failed to send notify heartbeat to stream, err code: {}, err message: {}", + fmt::underlying(err.Code()), err.Msg()); + } else { + err = producer->Flush(); + if (!err.OK()) { + YRLOG_ERROR("failed to flush stream, err code: {}, err message: {}", fmt::underlying(err.Code()), + err.Msg()); + } + } + + if (!DecreaseProducerReference(topic)) { + RemoveProducer(topic); + } + return err; +} + +std::string StreamGeneratorNotifier::BuildGeneratorResult(const std::string &generatorId, int index, + std::shared_ptr resultObj, + const ErrorInfo &resultErr) +{ + libruntime::NotifyGeneratorResult result; + result.set_genid(generatorId); + result.set_index(index); + result.set_objectid(resultObj->id); + if (resultObj->buffer->IsNative()) { + result.set_data(resultObj->buffer->ImmutableData(), resultObj->buffer->GetSize()); + } + result.set_errorcode(resultErr.Code()); + result.set_errormessage(resultErr.Msg()); + return result.SerializeAsString(); +} + +std::string StreamGeneratorNotifier::BuildGeneratorFinished(const std::string &generatorId, int numResults) +{ + libruntime::NotifyGeneratorResult result; + result.set_genid(generatorId); + result.set_finished(true); + result.set_numresults(numResults); + return result.SerializeAsString(); +} + +std::string StreamGeneratorNotifier::BuildGeneratorHeartbeat(const std::string &generatorId) +{ + libruntime::NotifyGeneratorResult result; + result.set_genid(generatorId); + result.set_isheartbeat(true); + return result.SerializeAsString(); +} + +ErrorInfo StreamGeneratorNotifier::GetOrCreateProducer(const std::string &topic, + std::shared_ptr &producer) +{ + absl::MutexLock lock(&mu_); + if (producers_.find(topic) != producers_.end()) { + producer = producers_[topic]; + return ErrorInfo(); + } + + if (!dsStreamStore_) { + return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "invalid stream store"); + } + + producer = std::make_shared(); + ProducerConf producerConf; + producerConf.retainForNumConsumers = 1; + // for improving stream speed + producerConf.delayFlushTime = YR::Libruntime::Config::Instance().DS_DELAY_FLUSH_TIME(); + auto err = dsStreamStore_->CreateStreamProducer(topic, producer, producerConf); + if (err.OK()) { + producers_[topic] = producer; + } + YRLOG_DEBUG("success create producer of {}", topic); + return err; +} + +ErrorInfo StreamGeneratorNotifier::RemoveProducer(const std::string &topic) +{ + if (!YR::Libruntime::Config::Instance().ENABLE_CLEAN_STREAM_PRODUCER()) { + YRLOG_DEBUG("Don't remove producer {}", topic); + return ErrorInfo(); + } + absl::MutexLock lock(&mu_); + if (producers_.find(topic) != producers_.end()) { + producers_.erase(topic); + } + return ErrorInfo(); +} + +ErrorInfo StreamGeneratorNotifier::CheckAndGetTopic(const std::string &generatorId, std::string &topic) +{ + if (!map_) { + YRLOG_ERROR("null map pointer"); + return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "null pointer"); + } + map_->GetRecord(generatorId, topic); + if (topic.empty()) { + YRLOG_ERROR("topic not found, generator id: {}", generatorId); + return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "invalid topic"); + } + return ErrorInfo(); +} + +void StreamGeneratorNotifier::IncreaseProducerReference(const std::string &topic) +{ + if (producerReferences_.find(topic) != producerReferences_.end()) { + producerReferences_[topic]++; + return; + } + producerReferences_[topic] = 1; +} + +bool StreamGeneratorNotifier::DecreaseProducerReference(const std::string &topic) +{ + if (producerReferences_.find(topic) == producerReferences_.end()) { + return false; + } + producerReferences_[topic]--; + if (producerReferences_[topic] <= 0) { + producerReferences_.erase(topic); + return false; + } + return true; +} +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/generator/stream_generator_notifier.h b/src/libruntime/generator/stream_generator_notifier.h new file mode 100644 index 0000000..04984ab --- /dev/null +++ b/src/libruntime/generator/stream_generator_notifier.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include "generator_id_map.h" +#include "generator_notifier.h" +#include "src/dto/config.h" +#include "src/dto/invoke_options.h" +#include "src/libruntime/streamstore/stream_producer_consumer.h" +#include "src/libruntime/streamstore/stream_store.h" + +namespace YR { +namespace Libruntime { +using namespace YR::utility; + +struct GeneratorNotifyData { + std::string generatorId; + int index = 0; + int numberResult = 0; + std::shared_ptr data; + ErrorInfo err = ErrorInfo(); + bool finish = false; + bool isHeartbeat = false; +}; + +class StreamGeneratorNotifier : public GeneratorNotifier { +public: + StreamGeneratorNotifier(std::shared_ptr store, std::shared_ptr map) + : GeneratorNotifier(), dsStreamStore_(store), map_(map) + { + this->timerWorker_ = std::make_shared(); + } + + virtual ~StreamGeneratorNotifier(); + void Initialize(void) override; + void Stop(void) override; + + virtual ErrorInfo NotifyResult(const std::string &generatorId, int index, std::shared_ptr resultObj, + const ErrorInfo &resultErr) override; + + virtual ErrorInfo NotifyFinished(const std::string &generatorId, int numResults) override; + + ErrorInfo NotifyResultByStream(const std::string &generatorId, int index, std::shared_ptr resultObj, + const ErrorInfo &resultErr); + ErrorInfo NotifyFinishedByStream(const std::string &generatorId, int numResults); + ErrorInfo NotifyHeartbeatByStream(const std::string &generatorId); + +private: + void Notify(); + void DoInitializeOnce(void); + ErrorInfo NotifyHeartbeat(); + std::string BuildGeneratorResult(const std::string &generatorId, int index, std::shared_ptr resultObj, + const ErrorInfo &resultErr); + std::string BuildGeneratorFinished(const std::string &generatorId, int numResults); + std::string BuildGeneratorHeartbeat(const std::string &generatorId); + void PopBatch(std::vector> &datas); + + ErrorInfo GetOrCreateProducer(const std::string &topic, std::shared_ptr &producer); + ErrorInfo RemoveProducer(const std::string &topic); + ErrorInfo CheckAndGetTopic(const std::string &generatorId, std::string &topic); + void IncreaseProducerReference(const std::string &topic); + bool DecreaseProducerReference(const std::string &topic); + + std::unordered_map> producers_; + std::unordered_map producerReferences_; + std::shared_ptr dsStreamStore_; + std::shared_ptr map_; + absl::Mutex mu_; + std::atomic stopped{false}; + std::once_flag flag_; + std::thread notifyThread_; + mutable std::mutex mux_; + std::condition_variable cv; + std::vector> genQueue_; + std::shared_ptr timerWorker_; + std::shared_ptr timer_; +}; +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/generator/stream_generator_receiver.cpp b/src/libruntime/generator/stream_generator_receiver.cpp new file mode 100644 index 0000000..f762d10 --- /dev/null +++ b/src/libruntime/generator/stream_generator_receiver.cpp @@ -0,0 +1,254 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "stream_generator_receiver.h" + +namespace YR { +namespace Libruntime { +const size_t RECEIVE_INTERVAL = 1000; +const uint32_t CONSUMER_TIMEOUT = 5; +const long long HEARTBEAT_TIMEOUT = 30000; +const int HEARTBEAT_TIMER_TIMEOUT = 10000; +void StreamGeneratorReceiver::Initialize(void) +{ + std::call_once(flag_, std::bind(&StreamGeneratorReceiver::DoInitializeOnce, this)); +} + +StreamGeneratorReceiver::~StreamGeneratorReceiver(void) +{ + Stop(); + if (timer_ != nullptr) { + timer_->cancel(); + } + if (receiverThread_.joinable()) { + receiverThread_.join(); + } +} + +void StreamGeneratorReceiver::Stop(void) +{ + stopped = true; + if (consumer_ != nullptr) { + auto err = consumer_->Close(); + if (!err.OK()) { + YRLOG_ERROR("failed to close consumer, err code: {}, err message: {}", fmt::underlying(err.Code()), + err.Msg()); + } + YRLOG_INFO("consummer closed"); + if (!topic_.empty()) { + err = dsStreamStore_->DeleteStream(topic_); + if (!err.OK()) { + YRLOG_ERROR("failed to delete stream {}, err code: {}, err message: {}", topic_, + fmt::underlying(err.Code()), err.Msg()); + } + } + } +} + +void StreamGeneratorReceiver::DoInitializeOnce(void) +{ + receiverThread_ = std::thread(&StreamGeneratorReceiver::Receive, this); + timer_ = timerWorker_->CreateTimer(HEARTBEAT_TIMER_TIMEOUT, -1, [this]() { DetectHeartbeat(); }); +} + +void StreamGeneratorReceiver::DetectHeartbeat(void) +{ + std::vector generatorIds; + map_->GetRecordKeys(generatorIds); + for (auto &genId : generatorIds) { + std::string timestamp; + map_->GetRecord(genId, timestamp); + if (timestamp.empty()) { + YRLOG_WARN("timestamp not found, generator id: {}", genId); + continue; + } + long long now = GetCurrentTimestampMs(); + long long lastTime = std::stoll(timestamp); + if (now - lastTime >= HEARTBEAT_TIMEOUT) { + auto errMsg = + "stream generator receiver detect heartbeat failed, the connection between runtime and ds worker maybe " + "disconnected, generator id: " + genId; + MarkEndOfStream(genId, ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, errMsg)); + map_->RemoveRecord(genId); + } + } +} + +void StreamGeneratorReceiver::Receive(void) +{ + auto topic = libruntimeConfig_->runtimeId; + if (topic == "driver") { + topic = libruntimeConfig_->runtimeId + "_" + libruntimeConfig_->jobId; + } + SubscriptionConfig config; + std::shared_ptr consumer = std::make_shared(); + initErr_ = dsStreamStore_->CreateStreamConsumer(topic, config, consumer, true); + if (!initErr_.OK()) { + YRLOG_ERROR("failed to create stream consumer, err code: {}, err message: {}", fmt::underlying(initErr_.Code()), + initErr_.Msg()); + return; + } + this->consumer_ = consumer; + this->topic_ = topic; + + YRLOG_INFO("begin to receive message from topic: {}", topic); + + while (!stopped) { + std::vector elements; + auto err = consumer->Receive(1, CONSUMER_TIMEOUT, elements); + if (!err.OK()) { + YRLOG_ERROR("failed to receive from consumer, err code: {}, err message: {}", fmt::underlying(err.Code()), + err.Msg()); + if (err.Code() == ErrorCode::ERR_CLIENT_ALREADY_CLOSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(RECEIVE_INTERVAL)); + } + continue; + } + for (auto ele : elements) { + libruntime::NotifyGeneratorResult result; + if (!result.ParseFromArray(ele.ptr, ele.size)) { + YRLOG_ERROR("failed to parse element, topic: {}, id: {}", topic, ele.id); + continue; + } + + if (result.finished()) { + HandleGeneratorFinished(result); + } else if (result.isheartbeat()) { + HandleGeneratorHeartbeat(result.genid()); + } else { + HandleGeneratorResult(result); + } + } + }; + YRLOG_INFO("finishe to receive message from topic: {}", topic); +} + +void StreamGeneratorReceiver::HandleGeneratorFinished(const libruntime::NotifyGeneratorResult &result) +{ + absl::MutexLock lock(&mu_); + auto &genId = result.genid(); + map_->RemoveRecord(genId); + if (generatorResultsCounter_.find(genId) == generatorResultsCounter_.end()) { + generatorResultsCounter_[genId] = 0; + } + auto genObjectID = memoryStore_->GenerateObjectId(genId, result.numresults()); + YRLOG_DEBUG("received finished {}", genObjectID); + SetError(genId, genObjectID, result.numresults(), + ErrorInfo(ErrorCode::ERR_GENERATOR_FINISHED, ModuleCode::RUNTIME, "")); + if (generatorResultsCounter_[genId] == result.numresults()) { + memoryStore_->GeneratorFinished(genId); + ClearCountersByGenId(genId); + } else { + YRLOG_INFO("stream message maybe unordered, gen id: {}, num of results: {}, counter of results: {}", genId, + result.numresults(), generatorResultsCounter_[genId]); + numGeneratorResults_[genId] = result.numresults(); + } +} + +void StreamGeneratorReceiver::HandleGeneratorResult(const libruntime::NotifyGeneratorResult &result) +{ + absl::MutexLock lock(&mu_); + auto &genId = result.genid(); + if (generatorResultsCounter_.find(genId) == generatorResultsCounter_.end()) { + generatorResultsCounter_[genId] = 0; + } + generatorResultsCounter_[genId] += 1; + auto genObjectID = memoryStore_->GenerateObjectId(genId, result.index()); + YRLOG_DEBUG("received result {}", genObjectID); + if (result.errorcode()) { + map_->RemoveRecord(genId); + SetError(genId, genObjectID, result.index(), + ErrorInfo(static_cast(result.errorcode()), ModuleCode::RUNTIME, result.errormessage())); + memoryStore_->GeneratorFinished(genId); + ClearCountersByGenId(genId); + return; + } + if (!result.data().empty()) { + std::shared_ptr buf = std::make_shared(result.data().size()); + memoryStore_->AddReturnObject(genObjectID); + buf->MemoryCopy(result.data().data(), result.data().size()); + auto err = memoryStore_->Put(buf, genObjectID, {}, false); + if (err.OK()) { + memoryStore_->SetReady(genObjectID); + } else { + memoryStore_->AddOutput(genId, genObjectID, result.index(), err); + YRLOG_ERROR("failed to put stream result, err code: {}, err message: {}", fmt::underlying(err.Code()), + err.Msg()); + return; + } + } else { + YRLOG_ERROR("receive empty result data {}", genId); + } + memoryStore_->AddOutput(genId, genObjectID, result.index()); + if (numGeneratorResults_.find(genId) != numGeneratorResults_.end() && + numGeneratorResults_[genId] == generatorResultsCounter_[genId]) { + memoryStore_->GeneratorFinished(genId); + ClearCountersByGenId(genId); + } +} + +void StreamGeneratorReceiver::AddRecord(const std::string &genId) +{ + long long now = GetCurrentTimestampMs(); + map_->AddRecord(genId, std::to_string(now)); +} + +void StreamGeneratorReceiver::HandleGeneratorHeartbeat(const std::string &genId) +{ + long long now = GetCurrentTimestampMs(); + map_->UpdateRecord(genId, std::to_string(now)); +} + +void StreamGeneratorReceiver::MarkEndOfStream(const std::string &genId, const ErrorInfo &errInfo) +{ + absl::MutexLock lock(&mu_); + auto index = 0; + if (generatorResultsCounter_.find(genId) != generatorResultsCounter_.end()) { + index = generatorResultsCounter_[genId]; + } + auto genObjectID = memoryStore_->GenerateObjectId(genId, index); + YRLOG_DEBUG("mark end of stream {}", genObjectID); + SetError(genId, genObjectID, index, errInfo); + memoryStore_->GeneratorFinished(genId); + ClearCountersByGenId(genId); +} + +void StreamGeneratorReceiver::SetError(const std::string &genId, const std::string &genObjectID, uint64_t index, + const ErrorInfo &errInfo) +{ + YRLOG_DEBUG("set error stream result of {}, err code: {}, err message: {}", genObjectID, + fmt::underlying(errInfo.Code()), errInfo.Msg()); + std::shared_ptr buf = std::make_shared(genId.size()); + memoryStore_->AddReturnObject(genObjectID); + auto err = memoryStore_->Put(buf, genObjectID, {}, false, errInfo); + if (err.OK()) { + memoryStore_->SetError(genObjectID, errInfo); + } else { + memoryStore_->AddOutput(genId, genObjectID, index, err); + YRLOG_ERROR("failed to put stream finished result, err code: {}, err message: {}", + fmt::underlying(err.Code()), err.Msg()); + return; + } + memoryStore_->AddOutput(genId, genObjectID, index, errInfo); +} + +void StreamGeneratorReceiver::ClearCountersByGenId(const std::string &genId) +{ + generatorResultsCounter_.erase(genId); + numGeneratorResults_.erase(genId); +} +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/generator/stream_generator_receiver.h b/src/libruntime/generator/stream_generator_receiver.h new file mode 100644 index 0000000..081a36c --- /dev/null +++ b/src/libruntime/generator/stream_generator_receiver.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include "generator_id_map.h" +#include "generator_receiver.h" +#include "src/libruntime/libruntime_config.h" +#include "src/libruntime/objectstore/memory_store.h" +#include "src/libruntime/streamstore/stream_store.h" + +namespace YR { +namespace Libruntime { +using namespace YR::utility; +class StreamGeneratorReceiver : public GeneratorReceiver { +public: + StreamGeneratorReceiver(std::shared_ptr config, std::shared_ptr store, + std::shared_ptr mStore) + : GeneratorReceiver(), libruntimeConfig_(config), dsStreamStore_(store), memoryStore_(mStore) + { + this->timerWorker_ = std::make_shared(); + this->map_ = std::make_shared(); + } + virtual ~StreamGeneratorReceiver(); + void Initialize(void) override; + void Stop(void) override; + void MarkEndOfStream(const std::string &genId, const ErrorInfo &errInfo) override; + void AddRecord(const std::string &genId) override; + +private: + void DoInitializeOnce(void); + void Receive(void); + void HandleGeneratorFinished(const libruntime::NotifyGeneratorResult &result); + void HandleGeneratorResult(const libruntime::NotifyGeneratorResult &result); + void HandleGeneratorHeartbeat(const std::string &genId); + void DetectHeartbeat(void); + void ClearCountersByGenId(const std::string &genId); + void SetError(const std::string &genId, const std::string &genObjectID, uint64_t index, const ErrorInfo &errInfo); + + std::once_flag flag_; + std::shared_ptr libruntimeConfig_; + std::shared_ptr dsStreamStore_; + std::shared_ptr memoryStore_; + std::thread receiverThread_; + std::shared_ptr consumer_; + std::string topic_; + std::atomic stopped{false}; + ErrorInfo initErr_; + std::unordered_map generatorResultsCounter_; + std::unordered_map numGeneratorResults_; + absl::Mutex mu_; + std::shared_ptr map_; + std::shared_ptr timerWorker_; + std::shared_ptr timer_; +}; +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/groupmanager/function_group.cpp b/src/libruntime/groupmanager/function_group.cpp index d4cc4c0..b5ccf1d 100644 --- a/src/libruntime/groupmanager/function_group.cpp +++ b/src/libruntime/groupmanager/function_group.cpp @@ -76,7 +76,8 @@ void FunctionGroup::CreateRespHandler(const CreateResponses &resps) auto errorInfo = memStore_->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease by requestid {}. Code: {}, MCode: {}, Msg: {}", - createSpecs[0]->requestId, errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + createSpecs[0]->requestId, fmt::underlying(errorInfo.Code()), + fmt::underlying(errorInfo.MCode()), errorInfo.Msg()); } } } @@ -92,7 +93,7 @@ void FunctionGroup::CreateNotifyHandler(const NotifyRequest &req) auto errorInfo = memStore_->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease by requestid {}. Code: {}, MCode: {}, Msg: {}", createSpecs[0]->requestId, - errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), errorInfo.Msg()); } } } @@ -169,7 +170,7 @@ void FunctionGroup::InvokeByInstanceIds(const std::shared_ptr &spec, auto errorInfo = memStore_->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease by requestid {}. Code: {}, MCode: {}, Msg: {}", req.requestid(), - errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), errorInfo.Msg()); } this->Terminate(); }); @@ -214,10 +215,12 @@ ErrorInfo FunctionGroup::Accelerate(const AccelerateMsgQueueHandle &handle, Hand auto killPromise = std::make_shared>(); killFutures.emplace_back(killPromise->get_future()); KillRequest killReq; + killReq.set_requestid(YR::utility::IDGenerator::GenRequestId()); killReq.set_instanceid(instanceIdList[i]); killReq.set_payload(payload); killReq.set_signal(libruntime::Signal::Accelerate); - fsClient->KillAsync(killReq, [killPromise](KillResponse rsp, ErrorInfo err) { killPromise->set_value(rsp); }); + fsClient->KillAsync(killReq, + [killPromise](KillResponse rsp, const ErrorInfo &err) { killPromise->set_value(rsp); }); } std::vector handles; std::vector objIds; diff --git a/src/libruntime/groupmanager/group.cpp b/src/libruntime/groupmanager/group.cpp index 38dec4f..7eafb3f 100644 --- a/src/libruntime/groupmanager/group.cpp +++ b/src/libruntime/groupmanager/group.cpp @@ -110,11 +110,13 @@ void Group::Terminate() runFlag = false; YRLOG_DEBUG("start terminate group ins, group name is {}, group id is {}", groupName, groupId); KillRequest killReq; + killReq.set_requestid(YR::utility::IDGenerator::GenRequestId()); killReq.set_instanceid(groupId); killReq.set_payload(""); killReq.set_signal(libruntime::Signal::KillGroupInstance); - this->fsClient->KillAsync(killReq, [](KillResponse resp, ErrorInfo err) -> void { - YRLOG_ERROR("get termiate group ins response, resp code is {}, resp msg is {}", resp.code(), resp.message()); + this->fsClient->KillAsync(killReq, [](KillResponse resp, const ErrorInfo &err) -> void { + YRLOG_ERROR("get termiate group ins response, resp code is {}, resp msg is {}", fmt::underlying(resp.code()), + resp.message()); }); SetTerminateError(); } diff --git a/src/libruntime/groupmanager/named_group.cpp b/src/libruntime/groupmanager/named_group.cpp index f753a43..fbf8d70 100644 --- a/src/libruntime/groupmanager/named_group.cpp +++ b/src/libruntime/groupmanager/named_group.cpp @@ -27,8 +27,8 @@ NamedGroup::NamedGroup(const std::string &name, const std::string &inputTenantId void NamedGroup::CreateRespHandler(const CreateResponses &resps) { - YRLOG_DEBUG("recieve group create response, resp code is {}, message is {}, runflag is {}", resps.code(), - resps.message(), runFlag); + YRLOG_DEBUG("recieve group create response, resp code is {}, message is {}, runflag is {}", + fmt::underlying(resps.code()), resps.message(), runFlag.load()); if (!runFlag) { return; } @@ -50,8 +50,8 @@ void NamedGroup::CreateRespHandler(const CreateResponses &resps) void NamedGroup::CreateNotifyHandler(const NotifyRequest &req) { - YRLOG_DEBUG("recieve group create notify, req code is {}, message is {}, runflag is {}", req.code(), req.message(), - runFlag); + YRLOG_DEBUG("recieve group create notify, req code is {}, message is {}, runflag is {}", + fmt::underlying(req.code()), req.message(), runFlag.load()); if (!runFlag) { return; } diff --git a/src/libruntime/groupmanager/range_group.cpp b/src/libruntime/groupmanager/range_group.cpp index cdc7e3a..2b36086 100644 --- a/src/libruntime/groupmanager/range_group.cpp +++ b/src/libruntime/groupmanager/range_group.cpp @@ -44,8 +44,8 @@ void RangeGroup::CreateNotifyHandler(const NotifyRequest &req) void RangeGroup::HandleCreateResp(const CreateResponses &resps) { - YRLOG_DEBUG("recieve group create response, resp code is {}, message is {}, runflag is {}", resps.code(), - resps.message(), runFlag); + YRLOG_DEBUG("recieve group create response, resp code is {}, message is {}, runflag is {}", + fmt::underlying(resps.code()), resps.message(), runFlag.load()); if (!runFlag) { return; } @@ -77,8 +77,8 @@ void RangeGroup::HandleCreateResp(const CreateResponses &resps) void RangeGroup::HandleCreateNotify(const NotifyRequest &req) { - YRLOG_DEBUG("recieve group create notify, req code is {}, message is {}, runflag is {}", req.code(), req.message(), - runFlag); + YRLOG_DEBUG("recieve group create notify, req code is {}, message is {}, runflag is {}", + fmt::underlying(req.code()), req.message(), runFlag.load()); if (!runFlag) { return; } diff --git a/src/libruntime/gwclient/gw_client.cpp b/src/libruntime/gwclient/gw_client.cpp new file mode 100644 index 0000000..9fef770 --- /dev/null +++ b/src/libruntime/gwclient/gw_client.cpp @@ -0,0 +1,1104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/libruntime/gwclient/gw_client.h" +#include "datasystem/object_client.h" +#include "json.hpp" +#include "src/libruntime/gwclient/gw_datasystem_client_wrapper.h" +#include "src/libruntime/utils/http_utils.h" +#include "src/utility/logger/logger.h" +#include "src/utility/notification_utility.h" +namespace YR { +namespace Libruntime { +const std::string REMOTE_CLIENT_ID_KEY = "remoteClientId"; +const std::string REMOTE_CLIENT_ID_KEY_NEW = "X-Remote-Client-Id"; +const std::string TRACE_ID_KEY = "traceId"; +const std::string TENANT_ID_KEY = "tenantId"; +const std::string TENANT_ID_KEY_NEW = "X-Tenant-Id"; + +using json = nlohmann::json; +using YR::utility::NotificationUtility; +ErrorInfo ClientBuffer::Seal(const std::unordered_set &nestedIds) +{ + std::shared_ptr gwClient = gwClientWeak_.lock(); + if (gwClient == nullptr) { + return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, ModuleCode::RUNTIME, "gwClient not existed"); + } + PutRequest req; + req.set_objectid(this->objectId_); + req.set_objectdata(reinterpret_cast(this->ImmutableData()), this->GetSize()); + for (const auto &id : nestedIds) { + req.add_nestedobjectids(id); + } + req.set_writemode(static_cast(this->createParam_.writeMode)); + req.set_consistencytype(static_cast(this->createParam_.consistencyType)); + req.set_cachetype(static_cast(this->createParam_.cacheType)); + return gwClient->PosixObjPut(req); +} + +ErrorInfo GwClient::Init(std::shared_ptr httpClient, std::int32_t connectTimeout) +{ + if (init_) { + return ErrorInfo(); + } + this->httpClient_ = std::move(httpClient); + return Init("", 0, connectTimeout); +} + +ErrorInfo GwClient::Init(const std::string &ip, int port) +{ + return Init(ip, port); +} + +void GwClient::Init(std::shared_ptr httpClient) +{ + if (init_) { + return; + } + this->httpClient_ = std::move(httpClient); + init_ = true; +} + +ErrorInfo GwClient::Init(const std::string &addr, int port, std::int32_t connectTimeout) +{ + return Init(addr, port, false, false, "", datasystem::SensitiveValue{}, "", datasystem::SensitiveValue{}, "", + datasystem::SensitiveValue{}, connectTimeout); +} + +ErrorInfo GwClient::Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, + const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk, std::int32_t connectTimeout) +{ + std::shared_ptr dsClientWrapper = + std::make_shared(shared_from_this()); + asyncDecreRef_.Init(dsClientWrapper); + init_ = true; + connectTimeout_ = connectTimeout; + return ErrorInfo(); +} + +ErrorInfo GwClient::Start(const std::string &jobID, const std::string &instanceID, const std::string &runtimeID, + const std::string &functionName, const SubscribeFunc &reSubscribeCb) +{ + if (!init_) { + return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, ModuleCode::RUNTIME, "should init before client start"); + } + this->jobId_ = jobID; + if (start_) { + return ErrorInfo(); + } + auto error = Lease(); + if (!error.OK()) { + return error; + } + timer_ = timerWorker_->CreateTimer(KEEPALIVE_INTERVAL, KEEPALIVE_TIMES, [this](void) { + auto error = this->KeepLease(); + if (!error.OK()) { + YRLOG_WARN("Keepalive fails for {} consecutive times.", ++lostLeaseTimes_); + } else { + lostLeaseTimes_ = 0; + } + }); + start_ = true; + return error; +} + +void GwClient::Clear() +{ + std::vector objectIds = refCountMap_.ToArray(); + refCountMap_.Clear(); + if (asyncDecreRef_.Push(objectIds, threadLocalTenantId)) { + asyncDecreRef_.Stop(); + } +} + +void GwClient::Stop(void) +{ + YRLOG_DEBUG("GwClient Stop"); + Clear(); + if (timer_ != nullptr) { + timer_->cancel(); + } + Release(); + init_ = false; + start_ = false; +} + +ErrorInfo GwClient::Lease() +{ + return HandleLease(POSIX_LEASE, PUT); +} + +ErrorInfo GwClient::KeepLease() +{ + return HandleLease(POSIX_LEASE_KEEPALIVE, POST); +} + +std::string VerbToString(const boost::beast::http::verb &v) +{ + switch (v) { + case boost::beast::http::verb::get: + return "GET"; + case boost::beast::http::verb::post: + return "POST"; + case boost::beast::http::verb::put: + return "PUT"; + case boost::beast::http::verb::delete_: + return "DELETE"; + default: + return "UNKNOWN"; + } +} + +ErrorInfo GwClient::HandleLease(const std::string &url, const http::verb &verb) +{ + auto req = this->BuildLeaseRequest(); + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, ""); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("{} lease request, requestId :{}", VerbToString(verb), *requestId); + auto logError = [verb, requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("{} lease Code: {}, MCode: {}, Msg: {}, requestId: {}", VerbToString(verb), + fmt::underlying(err.Code()), fmt::underlying(err.MCode()), err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + verb, url, headers, body, requestId, + [this, asyncNotify, requestId](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseLeaseResponse(result); + } + asyncNotify->Notify(err); + }); + int32_t leaseTimeout = url == POSIX_LEASE ? connectTimeout_ : KEEPALIVE_INTERVAL / S_TO_MS; + std::stringstream ss; + ss << "lease http request timeout s: " << leaseTimeout << ", requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(leaseTimeout), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +ErrorInfo GwClient::Release() +{ + return HandleLease(POSIX_LEASE, DELETE); +} + +void GwClient::CreateAsync(const CreateRequest &req, CreateRespCallback createRespCallback, CreateCallBack callback, + int timeoutSec) +{ + auto requestId = std::make_shared(req.requestid()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + YRLOG_DEBUG("create request, requestId :{}", *requestId); + httpClient_->SubmitInvokeRequest( + POST, POSIX_CREATE, headers, body, requestId, + [requestId, createRespCallback, callback](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + CreateResponse createRsp; + NotifyRequest notifyReq; + std::stringstream ss; + if (errorCode) { + ss << "network error between client and frontend, error_code: " << errorCode.message() + << ", requestId: " << *requestId; + createRsp.set_code(common::ERR_INNER_COMMUNICATION); + createRsp.set_message(ss.str()); + } else if (!IsResponseSuccessful(statusCode)) { + ss << "failed response status_code: " << std::to_string(statusCode) << ", result: " << result + << ", requestId: " << *requestId; + createRsp.set_code(common::ERR_PARAM_INVALID); + createRsp.set_message(ss.str()); + } else { + notifyReq.ParseFromString(result); + createRsp.set_code(notifyReq.code()); + createRsp.set_message(notifyReq.message()); + createRsp.set_instanceid(notifyReq.instanceid()); + } + YRLOG_DEBUG("create response, code: {}, requestId :{}, instanceId : {}, msg: {} ", + fmt::underlying(createRsp.code()), *requestId, createRsp.instanceid(), createRsp.message()); + createRespCallback(createRsp); + if (createRsp.code() != common::ERR_NONE) { + return; + } + notifyReq.set_requestid(*requestId); + callback(notifyReq); + }); +} + +void GwClient::InvokeAsync(const std::shared_ptr &req, InvokeCallBack callback, int timeoutSec) +{ + auto requestId = std::make_shared(req->Immutable().requestid()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req->Immutable().SerializeToString(&body); + YRLOG_DEBUG("invoke request, requestId :{}, instanceId: {}", *requestId, req->Immutable().instanceid()); + httpClient_->SubmitInvokeRequest( + POST, POSIX_INVOKE, headers, body, requestId, + [requestId, callback](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + NotifyRequest notifyReq; + std::stringstream ss; + if (errorCode) { + ss << "network error between client and frontend, error_code: " << errorCode.message() + << ", requestId: " << *requestId; + notifyReq.set_requestid(*requestId); + notifyReq.set_code(common::ERR_INNER_COMMUNICATION); + notifyReq.set_message(ss.str()); + } else if (!IsResponseSuccessful(statusCode)) { + ss << "failed response status_code: " << std::to_string(statusCode) << ", result: " << result + << ", requestId: " << *requestId; + notifyReq.set_requestid(*requestId); + notifyReq.set_code(common::ERR_PARAM_INVALID); + notifyReq.set_message(ss.str()); + } else { + notifyReq.ParseFromString(result); + } + YRLOG_DEBUG("invoke response, code: {}, requestId: {}, msg: {}, small objects size: {}", + fmt::underlying(notifyReq.code()), *requestId, notifyReq.message(), + notifyReq.smallobjects_size()); + callback(notifyReq, ErrorInfo()); + }); +} + +void GwClient::KillAsync(const KillRequest &req, KillCallBack callback, int timeoutSec) +{ + auto requestId = std::make_shared(req.requestid()); + auto instanceId = req.instanceid(); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + YRLOG_DEBUG("kill request, requestId: {}, instanceId: {}", *requestId, instanceId); + httpClient_->SubmitInvokeRequest( + POST, POSIX_KILL, headers, body, requestId, + [requestId, callback](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + KillResponse killRsp; + std::stringstream ss; + if (errorCode) { + ss << "network error between client and frontend, error_code: " << errorCode.message() + << ", requestId: " << *requestId; + killRsp.set_code(common::ERR_INNER_COMMUNICATION); + killRsp.set_message(ss.str()); + } else if (!IsResponseSuccessful(statusCode)) { + ss << "failed response status_code: " << std::to_string(statusCode) << ", result: " << result + << ", requestId: " << *requestId; + killRsp.set_code(common::ERR_PARAM_INVALID); + killRsp.set_message(ss.str()); + } else { + killRsp.ParseFromString(result); + } + YRLOG_DEBUG("kill response, code: {}, requestId :{}, msg: {}", fmt::underlying(killRsp.code()), *requestId, + killRsp.message()); + callback(killRsp, ErrorInfo()); + }); +} + +void GwClient::ExitAsync(const ExitRequest &req, ExitCallBack callback) +{ + STDERR_AND_THROW_EXCEPTION(ERR_INNER_SYSTEM_ERROR, RUNTIME, + "ExitAsync method not implemented when inCluster is false"); +} + +void GwClient::StateSaveAsync(const StateSaveRequest &req, StateSaveCallBack callback) +{ + STDERR_AND_THROW_EXCEPTION(ERR_INNER_SYSTEM_ERROR, RUNTIME, + "StateSaveAsync method not implemented when inCluster is false"); +} + +void GwClient::StateLoadAsync(const StateLoadRequest &req, StateLoadCallBack callback) +{ + STDERR_AND_THROW_EXCEPTION(ERR_INNER_SYSTEM_ERROR, RUNTIME, "StateLoadAsync is not supported with gateway client"); +} + +ErrorInfo GwClient::CreateBuffer(const std::string &objectId, size_t dataSize, std::shared_ptr &dataBuf, + const CreateParam &createParam) +{ + dataBuf = std::make_shared(dataSize, objectId, createParam, weak_from_this()); + return ErrorInfo(); +} + +std::pair>> GwClient::GetBuffers(const std::vector &ids, + int timeoutMS) +{ + auto [err, results] = Get(ids, timeoutMS); + return std::make_pair(err, results); +} + +std::pair>> GwClient::GetBuffersWithoutRetry( + const std::vector &ids, int timeoutMS) +{ + auto [err, results] = Get(ids, timeoutMS); + RetryInfo retryInfo{err, RetryType::NO_RETRY}; + return std::make_pair(retryInfo, results); +} + +ErrorInfo GwClient::Put(std::shared_ptr data, const std::string &objID, + const std::unordered_set &nestedID, const CreateParam &createParam) +{ + auto req = this->BuildObjPutRequest(data, objID, nestedID, createParam); + return PosixObjPut(req); +} + +SingleResult GwClient::Get(const std::string &objID, int timeoutMS) +{ + ErrorInfo err; + std::vector ids = {objID}; + auto multiRes = Get(ids, timeoutMS); + if (multiRes.first.Code() != ErrorCode::ERR_OK) { + return std::make_pair(multiRes.first, std::shared_ptr()); + } + return std::make_pair(err, multiRes.second[0]); +} + +MultipleResult GwClient::Get(const std::vector &ids, int timeoutMS) +{ + auto result = std::make_shared>>(); + result->resize(ids.size()); + ErrorInfo err = PosixObjGet(ids, result, timeoutMS); + return std::make_pair(err, *result); +} + +ErrorInfo GwClient::UpdateToken(datasystem::SensitiveValue token) +{ + return ErrorInfo(); +}; + +ErrorInfo GwClient::UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) +{ + return ErrorInfo(); +}; + +ErrorInfo GwClient::GenerateKey(std::string &key, const std::string &prefix, bool isPut) +{ + key = prefix; + return ErrorInfo(); +} + +ErrorInfo GwClient::GetPrefix(const std::string &key, std::string &prefix) +{ + prefix = key; + return ErrorInfo(); +} + +ErrorInfo GwClient::IncreGlobalReference(const std::vector &objectIds) +{ + auto failedObjectIds = std::make_shared>(); + auto err = PosixGInCreaseRef(objectIds, failedObjectIds); + if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { + YRLOG_ERROR(err.Msg()); + return err; + } + refCountMap_.IncreRefCount(objectIds); + if (!failedObjectIds->empty()) { + YRLOG_WARN("Datasystem failed to increase all objectRefs, fail count: {}", failedObjectIds->size()); + refCountMap_.DecreRefCount(*failedObjectIds); + } + return err; +} + +ErrorInfo GwClient::DecreGlobalReference(const std::vector &objectIds) +{ + ErrorInfo err; + // if the objectId is not in the map, not decrease + std::vector needDecreObjectIds = refCountMap_.DecreRefCount(objectIds); + if (!needDecreObjectIds.empty()) { + bool success = asyncDecreRef_.Push(needDecreObjectIds, threadLocalTenantId); + if (!success) { + err.SetErrCodeAndMsg(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, + YR::Libruntime::ModuleCode::DATASYSTEM, "async decrease thread has exited"); + } + } + return err; +} + +ErrorInfo GwClient::Write(const std::string &key, std::shared_ptr value, SetParam setParam) +{ + return PosixKvSet(key, value, setParam); +} + +ErrorInfo GwClient::MSetTx(const std::vector &keys, const std::vector> &vals, + const MSetParam &mSetParam) +{ + return PosixKvMSetTx(keys, vals, mSetParam); +} + +SingleReadResult GwClient::Read(const std::string &key, int timeoutMS) +{ + std::vector keys = {key}; + auto result = Read(keys, timeoutMS, false); + return std::make_pair(result.first[0], result.second); +} + +MultipleReadResult GwClient::Read(const std::vector &keys, int timeoutMS, bool allowPartial) +{ + auto result = std::make_shared>>(); + result->resize(keys.size()); + auto err = PosixKvGet(keys, result, timeoutMS); + if (err.Code() != ErrorCode::ERR_OK) { + YRLOG_ERROR("GetValueWithTimeout error: Code:{}, MCode:{}, Msg:{}.", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); + } + if (!allowPartial) { + auto errInfo = ProcessKeyPartialResult(keys, *result, err, timeoutMS); + if (errInfo.Code() != ErrorCode::ERR_OK) { + err = errInfo; + } + } + return std::make_pair(*result, err); +} + +ErrorInfo GwClient::Del(const std::string &key) +{ + std::vector keys = {key}; + auto result = Del(keys); + return result.second; +} + +MultipleDelResult GwClient::Del(const std::vector &keys) +{ + auto failedKeys = std::make_shared>(); + auto err = PosixKvDel(keys, failedKeys); + return std::make_pair(*failedKeys, err); +} + +ErrorInfo GwClient::GenRspError(const boost::beast::error_code &errorCode, const uint statusCode, + const std::string &result, const std::shared_ptr requestId) +{ + ErrorInfo err; + std::stringstream ss; + if (errorCode) { + ss << "network error between client and frontend, error_code: " << errorCode.message() + << ", requestId: " << *requestId; + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_COMMUNICATION, ss.str()); + } + if (!IsResponseSuccessful(statusCode)) { + ss << "failed response status_code: " << std::to_string(statusCode) << ", result: " << result + << ", requestId: " << *requestId; + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, ss.str()); + } + return err; +} + +ErrorInfo GwClient::PosixKvGet(const std::vector &keys, + std::shared_ptr>> results, int32_t timeoutMs) +{ + auto req = this->BuildKvGetRequest(keys, timeoutMs); + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("kv get request, requestId :{}", *requestId); + auto logError = [requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("kv get Code: {}, MCode: {}, Msg: {}, requestId: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + POST, POSIX_KV_GET, headers, body, requestId, + [this, asyncNotify, results, requestId](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseKvGetResponse(result, results); + } + asyncNotify->Notify(err); + }); + std::stringstream ss; + ss << "kv get http request timeout s: " << connectTimeout_ << ", requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(connectTimeout_), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +ErrorInfo GwClient::PosixKvSet(const std::string &key, const std::shared_ptr value, const SetParam &setParam) +{ + auto req = this->BuildKvSetRequest(key, value, setParam); + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("kv set request, requestId :{}", *requestId); + auto logError = [requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("kv set Code: {}, MCode: {}, Msg: {}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + POST, POSIX_KV_SET, headers, body, requestId, + [this, asyncNotify, requestId](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseKvSetResponse(result); + } + asyncNotify->Notify(err); + }); + std::stringstream ss; + ss << "kv set http request timeout s: " << connectTimeout_ << ", requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(connectTimeout_), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +ErrorInfo GwClient::PosixKvMSetTx(const std::vector &keys, + const std::vector> &vals, const MSetParam &mSetParam) +{ + auto req = this->BuildKvMSetTxRequest(keys, vals, mSetParam); + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("kv multi set tx request, requestId :{}", *requestId); + auto logError = [requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("kv multi set tx Code: {}, MCode: {}, Msg: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + POST, POSIX_KV_MSET_TX, headers, body, requestId, + [this, asyncNotify, requestId](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseKvMSetTxResponse(result); + } + asyncNotify->Notify(err); + }); + std::stringstream ss; + ss << "kv multi set tx http request timeout: " << connectTimeout_ << "s, requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(connectTimeout_), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +ErrorInfo GwClient::PosixObjPut(const PutRequest &req) +{ + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("obj put request, requestId :{}", *requestId); + auto logError = [requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("obj put Code: {}, MCode: {}, Msg: {}, requestId: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + POST, POSIX_OBJ_PUT, headers, body, requestId, + [this, asyncNotify, requestId](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseObjPutResponse(result); + } + asyncNotify->Notify(err); + }); + std::stringstream ss; + ss << "obj put http request timeout s: " << connectTimeout_ << ", requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(connectTimeout_), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +ErrorInfo GwClient::PosixObjGet(const std::vector &keys, + std::shared_ptr>> results, int32_t timeoutMs) +{ + auto req = this->BuildObjGetRequest(keys, timeoutMs); + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("obj get request, requestId :{}", *requestId); + auto logError = [requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("obj get Code: {}, MCode: {}, Msg: {}, requestId: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + POST, POSIX_OBJ_GET, headers, body, requestId, + [this, asyncNotify, results, requestId](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseObjGetResponse(result, results); + } + asyncNotify->Notify(err); + }); + std::stringstream ss; + ss << "obj get http request timeout s: " << connectTimeout_ << ", requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(connectTimeout_), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +ErrorInfo GwClient::PosixKvDel(const std::vector &keys, + std::shared_ptr> failedKeys) +{ + auto req = this->BuildKvDelRequest(keys); + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("kv del request, requestId :{}", *requestId); + auto logError = [requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("kv del Code: {}, MCode: {}, Msg: {}, requestId: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + POST, POSIX_KV_DEL, headers, body, requestId, + [this, asyncNotify, failedKeys, requestId](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseKvDelResponse(result, failedKeys); + } + asyncNotify->Notify(err); + }); + std::stringstream ss; + ss << "kv del http request timeout s: " << connectTimeout_ << ", requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(connectTimeout_), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +ErrorInfo GwClient::PosixGDecreaseRef(const std::vector &objectIds, + std::shared_ptr> failedObjectIds) +{ + auto req = this->BuildDecreaseRefRequest(objectIds); + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("decrease ref request, requestId :{}", *requestId); + auto logError = [requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("decrease ref Code: {}, MCode: {}, Msg: {}, requestId: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + POST, POSIX_OBJ_DECREASE, headers, body, requestId, + [this, asyncNotify, failedObjectIds, requestId]( + const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseDecreaseRefResponse(result, failedObjectIds); + } + asyncNotify->Notify(err); + }); + std::stringstream ss; + ss << "decrease ref http request timeout s: " << connectTimeout_ << ", requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(connectTimeout_), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +ErrorInfo GwClient::PosixGInCreaseRef(const std::vector &objectIds, + std::shared_ptr> failedObjectIds) +{ + auto req = this->BuildIncreaseRefRequest(objectIds); + auto requestId = std::make_shared(YR::utility::IDGenerator::GenRequestId()); + auto headers = this->BuildHeaders(this->jobId_, *requestId, threadLocalTenantId); + std::string body; + req.SerializeToString(&body); + auto asyncNotify = std::make_shared(); + YRLOG_DEBUG("increase ref request, requestId :{}", *requestId); + auto logError = [requestId](const ErrorInfo &err) -> ErrorInfo { + YRLOG_DEBUG("increase ref Code: {}, MCode: {}, Msg: {}, requestId: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg(), *requestId); + return err; + }; + httpClient_->SubmitInvokeRequest( + POST, POSIX_OBJ_INCREASE, headers, body, requestId, + [this, asyncNotify, failedObjectIds, requestId]( + const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + auto err = GenRspError(errorCode, statusCode, result, requestId); + if (err.OK()) { + err = ParseIncreaseRefResponse(result, failedObjectIds); + } + asyncNotify->Notify(err); + }); + std::stringstream ss; + ss << "increase ref request by http has timed out: " << connectTimeout_ << ", requestId: " << *requestId; + return logError(asyncNotify->WaitForNotificationWithTimeout( + absl::Seconds(connectTimeout_), ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, ModuleCode::RUNTIME, ss.str()))); +} + +void GwClient::SetTenantId(const std::string &tenantId) +{ + threadLocalTenantId = tenantId; +} + +ErrorInfo GwClient::ParseLeaseResponse(const std::string &result) +{ + LeaseResponse leaseRsp; + leaseRsp.ParseFromString(result); + return leaseRsp.code() == common::ERR_NONE + ? ErrorInfo() + : ErrorInfo(static_cast(leaseRsp.code()), ModuleCode::RUNTIME, leaseRsp.message()); +} + +ErrorInfo GwClient::ParseObjGetResponse(const std::string &result, + std::shared_ptr>> results) +{ + GetResponse getRsp; + getRsp.ParseFromString(result); + for (int i = 0; i < getRsp.buffers_size(); i++) { + if (getRsp.buffers(i).size() == 0) { + continue; + } + auto buf = std::make_shared(getRsp.buffers(i).size()); + buf->MemoryCopy(getRsp.buffers(i).data(), getRsp.buffers(i).size()); + (*results)[i] = std::move(buf); + } + return getRsp.code() == common::ERR_NONE + ? ErrorInfo() + : ErrorInfo(static_cast(getRsp.code()), ModuleCode::DATASYSTEM, getRsp.message()); +} + +ErrorInfo GwClient::ParseObjPutResponse(const std::string &result) +{ + PutResponse putRsp; + putRsp.ParseFromString(result); + return putRsp.code() == common::ERR_NONE + ? ErrorInfo() + : ErrorInfo(static_cast(putRsp.code()), ModuleCode::DATASYSTEM, putRsp.message()); +} + +ErrorInfo GwClient::ParseKvGetResponse(const std::string &result, + std::shared_ptr>> results) +{ + KvGetResponse kvRsp; + kvRsp.ParseFromString(result); + for (int i = 0; i < kvRsp.values_size(); i++) { + if (kvRsp.values(i).size() == 0) { + continue; + } + auto buf = std::make_shared(kvRsp.values(i).size()); + buf->MemoryCopy(kvRsp.values(i).data(), kvRsp.values(i).size()); + (*results)[i] = std::move(buf); + } + return kvRsp.code() == common::ERR_NONE + ? ErrorInfo() + : ErrorInfo(static_cast(kvRsp.code()), ModuleCode::DATASYSTEM, kvRsp.message()); +} + +ErrorInfo GwClient::ParseKvSetResponse(const std::string &result) +{ + KvSetResponse kvRsp; + kvRsp.ParseFromString(result); + return kvRsp.code() == common::ERR_NONE + ? ErrorInfo() + : ErrorInfo(static_cast(kvRsp.code()), ModuleCode::DATASYSTEM, kvRsp.message()); +} + +ErrorInfo GwClient::ParseKvMSetTxResponse(const std::string &result) +{ + KvMSetTxResponse mSetRsp; + mSetRsp.ParseFromString(result); + return mSetRsp.code() == common::ERR_NONE + ? ErrorInfo() + : ErrorInfo(static_cast(mSetRsp.code()), ModuleCode::DATASYSTEM, mSetRsp.message()); +} + +ErrorInfo GwClient::ParseKvDelResponse(const std::string &result, std::shared_ptr> failedKeys) +{ + KvDelResponse kvRsp; + kvRsp.ParseFromString(result); + failedKeys->resize(kvRsp.failedkeys_size()); + for (int i = 0; i < kvRsp.failedkeys_size(); i++) { + (*failedKeys)[i] = kvRsp.failedkeys(i); + } + return kvRsp.code() == common::ERR_NONE + ? ErrorInfo() + : ErrorInfo(static_cast(kvRsp.code()), ModuleCode::DATASYSTEM, kvRsp.message()); +} + +ErrorInfo GwClient::ParseIncreaseRefResponse(const std::string &result, + std::shared_ptr> failedObjectIds) +{ + IncreaseRefResponse increaseRefRsp; + increaseRefRsp.ParseFromString(result); + failedObjectIds->resize(increaseRefRsp.failedobjectids_size()); + for (int i = 0; i < increaseRefRsp.failedobjectids_size(); i++) { + (*failedObjectIds)[i] = increaseRefRsp.failedobjectids(i); + } + return increaseRefRsp.code() == common::ERR_NONE ? ErrorInfo() + : ErrorInfo(static_cast(increaseRefRsp.code()), + ModuleCode::DATASYSTEM, increaseRefRsp.message()); +} + +ErrorInfo GwClient::ParseDecreaseRefResponse(const std::string &result, + std::shared_ptr> failedObjectIds) +{ + DecreaseRefResponse decreaseRefRsp; + decreaseRefRsp.ParseFromString(result); + failedObjectIds->resize(decreaseRefRsp.failedobjectids_size()); + for (int i = 0; i < decreaseRefRsp.failedobjectids_size(); i++) { + (*failedObjectIds)[i] = decreaseRefRsp.failedobjectids(i); + } + return decreaseRefRsp.code() == common::ERR_NONE ? ErrorInfo() + : ErrorInfo(static_cast(decreaseRefRsp.code()), + ModuleCode::DATASYSTEM, decreaseRefRsp.message()); +} + +ErrorInfo GwClient::ReleaseGRefs(const std::string &remoteId) +{ + return ErrorInfo(ErrorCode::ERR_PARAM_INVALID, ModuleCode::RUNTIME, "not support out of cluster"); +} + +LeaseRequest GwClient::BuildLeaseRequest() +{ + LeaseRequest leaseReq; + leaseReq.set_remoteclientid(this->jobId_); + return leaseReq; +} + +GetRequest GwClient::BuildObjGetRequest(const std::vector &keys, int32_t timeoutMs) +{ + GetRequest req; + for (const auto &key : keys) { + req.add_objectids(key); + req.set_timeoutms(timeoutMs); + } + return req; +} + +PutRequest GwClient::BuildObjPutRequest(std::shared_ptr data, const std::string &objID, + const std::unordered_set &nestedID, const CreateParam &createParam) +{ + PutRequest req; + req.set_objectid(objID); + req.set_objectdata(data->ImmutableData(), data->GetSize()); + for (const auto &id : nestedID) { + req.add_nestedobjectids(id); + } + req.set_writemode(static_cast(createParam.writeMode)); + req.set_consistencytype(static_cast(createParam.consistencyType)); + req.set_cachetype(static_cast(createParam.cacheType)); + return req; +} + +KvSetRequest GwClient::BuildKvSetRequest(const std::string &key, const std::shared_ptr value, + const SetParam &setParam) +{ + KvSetRequest req; + req.set_key(key); + req.set_value(value->ImmutableData(), value->GetSize()); + req.set_existence(static_cast(setParam.existence)); + req.set_writemode(static_cast(setParam.writeMode)); + req.set_ttlsecond(setParam.ttlSecond); + req.set_cachetype(static_cast(setParam.cacheType)); + return req; +} + +KvMSetTxRequest GwClient::BuildKvMSetTxRequest(const std::vector &keys, + const std::vector> &vals, + const MSetParam &mSetParam) +{ + KvMSetTxRequest req; + for (const auto &key : keys) { + req.add_keys(key); + } + for (const auto &val : vals) { + req.add_values(val->ImmutableData(), val->GetSize()); + } + req.set_existence(static_cast(mSetParam.existence)); + req.set_writemode(static_cast(mSetParam.writeMode)); + req.set_ttlsecond(mSetParam.ttlSecond); + req.set_cachetype(static_cast(mSetParam.cacheType)); + return req; +} + +KvGetRequest GwClient::BuildKvGetRequest(const std::vector &keys, int32_t timeoutMs) +{ + KvGetRequest req; + for (const auto &key : keys) { + req.add_keys(key); + } + req.set_timeoutms(timeoutMs); + return req; +} + +KvDelRequest GwClient::BuildKvDelRequest(const std::vector &keys) +{ + KvDelRequest req; + for (const auto &key : keys) { + req.add_keys(key); + } + return req; +} + +IncreaseRefRequest GwClient::BuildIncreaseRefRequest(const std::vector &objectIds) +{ + IncreaseRefRequest req; + for (const auto &objId : objectIds) { + req.add_objectids(objId); + } + req.set_remoteclientid(YR::utility::ParseRealJobId(this->jobId_)); + return req; +} + +DecreaseRefRequest GwClient::BuildDecreaseRefRequest(const std::vector &objectIds) +{ + DecreaseRefRequest req; + for (size_t i = 0; i < objectIds.size(); i++) { + req.add_objectids(objectIds[i]); + } + req.set_remoteclientid(YR::utility::ParseRealJobId(this->jobId_)); + return req; +} + +std::unordered_map GwClient::BuildHeaders(const std::string &remoteClientId, + const std::string &traceId, + const std::string &tenantId) +{ + std::unordered_map headers; + if (!remoteClientId.empty()) { + headers.emplace(REMOTE_CLIENT_ID_KEY, remoteClientId); + headers.emplace(REMOTE_CLIENT_ID_KEY_NEW, remoteClientId); + } + if (!traceId.empty()) { + headers.emplace(TRACE_ID_KEY, traceId); + headers.emplace(TRACE_ID_KEY_NEW, traceId); + } + if (!tenantId.empty()) { + headers.emplace(TENANT_ID_KEY, tenantId); + headers.emplace(TENANT_ID_KEY_NEW, tenantId); + } + return headers; +} + +std::pair, std::string> GwClient::BuildRequestWithAkSk( + const std::shared_ptr spec, const std::string &url) +{ + std::string callReq; + const size_t userEventIndex = 1; + if (spec->invokeArgs.size() > userEventIndex) { + auto &invokeArg = spec->invokeArgs[userEventIndex]; + callReq = std::string(static_cast(invokeArg.dataObj->data->ImmutableData()), + invokeArg.dataObj->data->GetSize()); + } else { + YRLOG_ERROR("invoke args size not valid"); + } + std::string event; + try { + json j = json::parse(callReq); + if (j.contains("body")) { + if (j["body"].is_string()) { + event = j["body"].get(); + } else { + event = j["body"].dump(); + } + } else { + YRLOG_ERROR("event is empty, callReq {}", callReq); + } + } catch (const std::exception &e) { + YRLOG_ERROR("{} JSON parse error: {}", callReq, e.what()); + } + std::unordered_map headers; + if (!spec->traceId.empty()) { + headers.emplace(TRACE_ID_KEY_NEW, spec->traceId); + } + headers.emplace(REMOTE_CLIENT_ID_KEY, jobId_); // for llt test + headers.emplace(INSTANCE_CPU_KEY, std::to_string(spec->opts.cpu)); + headers.emplace(INSTANCE_MEMORY_KEY, std::to_string(spec->opts.memory)); + std::string ak; + datasystem::SensitiveValue sk; + security_->GetAKSK(ak, sk); + if (ak.empty() || sk.Empty()) { + YRLOG_WARN("ak or sk is empty"); + return std::make_pair(headers, event); + } + SignHttpRequest(ak, sk, headers, event, url); + return std::make_pair(headers, event); +} + +void GwClient::GroupCreateAsync(const CreateRequests &reqs, CreateRespsCallback respCallback, CreateCallBack callback, + int timeoutSec) +{ + // wait inplement not in cluster +} + +std::string TransformJson(uint statusCode, const std::string &input) +{ + if (IsResponseSuccessful(statusCode)) { + json out; + out["innerCode"] = "0"; + try { + json j = json::parse(input); + out["body"] = j; + } catch (const std::exception &e) { + YRLOG_WARN("json parse error {}", e.what()); + out["body"] = input; + } + return out.dump(); + } + if (IsResponseServerError(statusCode)) { + try { + json j = json::parse(input); + json out; + if (j.contains("code")) { + out["innerCode"] = std::to_string(j["code"].get()); + } + if (j.contains("message")) { + out["body"] = j["message"]; + } + return out.dump(); + } catch (const std::exception &e) { + YRLOG_WARN("{} JSON parse error: {}", input, e.what()); + return "{}"; + } + } + return input; +} + +void GwClient::InvocationAsync(const std::string &url, const std::shared_ptr spec, + const InvocationCallback &callback) +{ + auto requestId = std::make_shared(spec->requestId); + auto [headers, body] = BuildRequestWithAkSk(spec, url); + YRLOG_DEBUG("invocation request, url: {}, requestId: {}", url, *requestId); + httpClient_->SubmitInvokeRequest( + POST, url, headers, body, requestId, + [requestId, callback](const std::string &result, const boost::beast::error_code &errorCode, + const uint statusCode) { + if (errorCode) { + std::stringstream ss; + ss << "invocation network error between client and frontend, error_code: " << errorCode.message() + << ", requestId: " << *requestId; + YRLOG_ERROR(ss.str()); + callback(*requestId, ErrorCode::ERR_INNER_COMMUNICATION, ss.str()); + } else if (!IsResponseSuccessful(statusCode)) { + YRLOG_ERROR("invocation response, status_code: {}, result: {}, requestId: {}", statusCode, result, + *requestId); + if (IsResponseServerError(statusCode)) { + callback(*requestId, ErrorCode::ERR_OK, TransformJson(statusCode, result)); + } else { + callback(*requestId, ErrorCode::ERR_INNER_SYSTEM_ERROR, result); + } + } else { + YRLOG_DEBUG("invocation response, http status code: {}, requestId: {}", statusCode, *requestId); + callback(*requestId, ErrorCode::ERR_OK, TransformJson(statusCode, result)); + } + }); +} +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/gwclient/gw_client.h b/src/libruntime/gwclient/gw_client.h new file mode 100644 index 0000000..944ffd4 --- /dev/null +++ b/src/libruntime/gwclient/gw_client.h @@ -0,0 +1,364 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/libruntime/fsclient/fs_intf.h" +#include "src/libruntime/gwclient/http/client_manager.h" +#include "src/libruntime/heterostore/hetero_store.h" +#include "src/libruntime/invoke_spec.h" +#include "src/libruntime/objectstore/async_decre_ref.h" +#include "src/libruntime/objectstore/object_store.h" +#include "src/libruntime/statestore/state_store.h" +#include "src/libruntime/streamstore/stream_store.h" +#include "src/utility/timer_worker.h" +namespace YR { +namespace Libruntime { +using YR::utility::TimerWorker; +extern const std::string REMOTE_CLIENT_ID_KEY; +extern const std::string REMOTE_CLIENT_ID_KEY_NEW; +extern const std::string TRACE_ID_KEY; +extern const std::string TRACE_ID_KEY_NEW; +extern const std::string TENANT_ID_KEY; +extern const std::string TENANT_ID_KEY_NEW; + +const std::string POSIX_CREATE = "/serverless/v1/posix/instance/create"; +const std::string POSIX_INVOKE = "/serverless/v1/posix/instance/invoke"; +const std::string POSIX_KILL = "/serverless/v1/posix/instance/kill"; + +const std::string POSIX_LEASE = "/client/v1/lease"; +const std::string POSIX_LEASE_KEEPALIVE = "/client/v1/lease/keepalive"; +const std::string POSIX_OBJ_PUT = "/datasystem/v1/obj/put"; +const std::string POSIX_OBJ_GET = "/datasystem/v1/obj/get"; +const std::string POSIX_OBJ_INCREASE = "/datasystem/v1/obj/increaseref"; +const std::string POSIX_OBJ_DECREASE = "/datasystem/v1/obj/decreaseref"; +const std::string POSIX_KV_SET = "/datasystem/v1/kv/set"; +const std::string POSIX_KV_MSET_TX = "/datasystem/v1/kv/msettx"; +const std::string POSIX_KV_GET = "/datasystem/v1/kv/get"; +const std::string POSIX_KV_DEL = "/datasystem/v1/kv/del"; +const int KEEPALIVE_INTERVAL = 60 * 1000; // 1 min +const int KEEPALIVE_TIMES = -1; // unlimited retry + +using InvocationCallback = + std::function; + +thread_local static std::string threadLocalTenantId; + +class GwClient : public FSIntf, + public ObjectStore, + public StateStore, + public StreamStore, + public HeteroStore, + public std::enable_shared_from_this { + friend class ClientBuffer; + +public: + GwClient(const std::string &funcId, FSIntfHandlers handlers, std::shared_ptr security = nullptr) + : FSIntf(handlers), funcId_(funcId), security_(security) + { + this->funcName_ = funcId.substr(funcId.find('/') + 1, funcId.find_last_of('/')); + this->funcVersion_ = funcId.substr(funcId.find_last_of('/') + 1, funcId.size()); + this->timerWorker_ = std::make_shared(); + } + ~GwClient() = default; + ErrorInfo Start(const std::string &jobID, const std::string &instanceID = "", const std::string &runtimeID = "", + const std::string &functionName = "", const SubscribeFunc &reSubscribeCb = nullptr) override; + void Stop(void) override; + void Clear() override; + void GroupCreateAsync(const CreateRequests &reqs, CreateRespsCallback respCallback, CreateCallBack callback, + int timeoutSec = -1) override; + void CreateAsync(const CreateRequest &req, CreateRespCallback respCallback, CreateCallBack callback, + int timeoutSec = -1) override; + void InvokeAsync(const std::shared_ptr &req, InvokeCallBack callback, + int timeoutSec = -1) override; + void InvocationAsync(const std::string &url, const std::shared_ptr spec, + const InvocationCallback &callback); + void CallResultAsync(const std::shared_ptr req, CallResultCallBack callback) + { + STDERR_AND_THROW_EXCEPTION(ERR_INNER_SYSTEM_ERROR, RUNTIME, + "CallResultAsync method not implemented when inCluster is false"); + } + void KillAsync(const KillRequest &req, KillCallBack callback, int timeoutSec = -1) override; + void ExitAsync(const ExitRequest &req, ExitCallBack callback) override; + void StateSaveAsync(const StateSaveRequest &req, StateSaveCallBack callback) override; + void StateLoadAsync(const StateLoadRequest &req, StateLoadCallBack callback) override; + void CreateRGroupAsync(const CreateResourceGroupRequest &req, CreateResourceGroupCallBack callback, + int timeoutSec = -1) + { + STDERR_AND_THROW_EXCEPTION(ERR_INNER_SYSTEM_ERROR, RUNTIME, + "CreateRGroupAsync method not implemented when inCluster is false"); + } + ErrorInfo Init(std::shared_ptr httpClient, std::int32_t connectTimeout); + void Init(std::shared_ptr httpClient); + ErrorInfo Init(const std::string &ip, int port) override; + ErrorInfo Init(const std::string &addr, int port, std::int32_t connectTimeout); + ErrorInfo Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, + const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, + "Init method with the nine params is not supported when inCluster is false"); + } + + ErrorInfo Init(datasystem::ConnectOptions &options) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, + "Init method with the ConnectOptions is not supported when inCluster is false"); + } + ErrorInfo Init(datasystem::ConnectOptions &options, std::shared_ptr dsStateStore) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, + "Init method with the ConnectOptions is not supported when inCluster is false"); + } + ErrorInfo Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, + const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk, std::int32_t connectTimeout) override; + + ErrorInfo CreateBuffer(const std::string &objectId, size_t dataSize, std::shared_ptr &dataBuf, + const CreateParam &createParam) override; + std::pair>> GetBuffers(const std::vector &ids, + int timeoutMS) override; + std::pair>> GetBuffersWithoutRetry( + const std::vector &ids, int timeoutMS) override; + ErrorInfo Put(std::shared_ptr data, const std::string &objID, + const std::unordered_set &nestedID, const CreateParam &createParam) override; + SingleResult Get(const std::string &objID, int timeoutMS) override; + MultipleResult Get(const std::vector &ids, int timeoutMS) override; + ErrorInfo IncreGlobalReference(const std::vector &objectIds) override; + ErrorInfo DecreGlobalReference(const std::vector &objectIds) override; + ErrorInfo UpdateToken(datasystem::SensitiveValue token) override; + ErrorInfo UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) override; + std::vector QueryGlobalReference(const std::vector &objectIds) override + { + STDERR_AND_THROW_EXCEPTION(ERR_INNER_SYSTEM_ERROR, RUNTIME, + "QueryGlobalReference method is not supported when inCluster is false"); + } + ErrorInfo ReleaseGRefs(const std::string &remoteId) override; + ErrorInfo GenerateKey(std::string &key, const std::string &prefix, bool isPut) override; + ErrorInfo GetPrefix(const std::string &key, std::string &prefix) override; + ErrorInfo Write(const std::string &key, std::shared_ptr value, SetParam setParam) override; + ErrorInfo MSetTx(const std::vector &keys, const std::vector> &vals, + const MSetParam &mSetParam) override; + SingleReadResult Read(const std::string &key, int timeoutMS) override; + MultipleReadResult Read(const std::vector &keys, int timeoutMS, bool allowPartial) override; + ErrorInfo QuerySize(const std::vector &keys, std::vector &outSizes) override + { + STDERR_AND_THROW_EXCEPTION(ERR_INNER_SYSTEM_ERROR, RUNTIME, + "QuerySize method is not supported when inCluster is false"); + } + ErrorInfo Del(const std::string &key) override; + MultipleDelResult Del(const std::vector &keys) override; + MultipleExistResult Exist(const std::vector &keys) override + { + auto result = std::make_shared>(); + result->resize(keys.size()); + ErrorInfo err(ERR_INNER_SYSTEM_ERROR, "Exist method is not supported when inCluster is false"); + return std::make_pair(*result, err); + } + ErrorInfo PosixGDecreaseRef(const std::vector &objectIds, + std::shared_ptr> failedObjectIds); + void Shutdown() override {} + void SetTenantId(const std::string &tenantId) override; + ErrorInfo Init(const DsConnectOptions &options) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, + "Init method with DsConnectOptions not implemented when inCluster is false"); + } + ErrorInfo GenerateKey(std::string &returnKey) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "GenerateKey method is not supported when inCluster is false"); + } + ErrorInfo StartHealthCheck() override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "HealthCheck method is not supported when inCluster is false"); + } + ErrorInfo Write(std::shared_ptr value, SetParam setParam, std::string &returnKey) + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "Write method is not supported when inCluster is false"); + } + ErrorInfo CreateStreamProducer(const std::string &streamName, std::shared_ptr &producer, + ProducerConf producerConf = {}) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, + "CreateStreamProducer method is not supported when inCluster is false"); + } + ErrorInfo CreateStreamConsumer(const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck = false) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, + "CreateStreamConsumer method is not supported when inCluster is false"); + } + ErrorInfo DeleteStream(const std::string &streamName) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "DeleteStream method is not supported when inCluster is false"); + } + ErrorInfo QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, + "QueryGlobalProducersNum method is not supported when inCluster is false"); + } + ErrorInfo QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, + "QueryGlobalConsumersNum method is not supported when inCluster is false"); + } + + ErrorInfo DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "Delete method is not supported when inCluster is false"); + } + + ErrorInfo DevLocalDelete(const std::vector &objectIds, + std::vector &failedObjectIds) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "DevLocalDelete method is not supported when inCluster is false"); + } + + ErrorInfo DevSubscribe(const std::vector &keys, const std::vector &blob2dList, + std::vector> &futureVec) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "DevSubscribe method is not supported when inCluster is false"); + } + + ErrorInfo DevPublish(const std::vector &keys, const std::vector &blob2dList, + std::vector> &futureVec) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "DevPublish method is not supported when inCluster is false"); + } + + ErrorInfo DevMSet(const std::vector &keys, const std::vector &blob2dList, + std::vector &failedKeys) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "DevMSet method is not supported when inCluster is false"); + } + + ErrorInfo DevMGet(const std::vector &keys, const std::vector &blob2dList, + std::vector &failedKeys, int32_t timeoutMs) override + { + return ErrorInfo(ERR_INNER_SYSTEM_ERROR, "DevMGet method is not supported when inCluster is false"); + } + MultipleReadResult GetWithParam(const std::vector &keys, const GetParams ¶ms, + int timeoutMs) override + { + auto result = std::make_shared>>(); + result->resize(keys.size()); + ErrorInfo err(ERR_INNER_SYSTEM_ERROR, "GetWithParam method is not supported when inCluster is false"); + return std::make_pair(*result, err); + } + + bool IsHealth() override + { + return true; + }; + + ErrorInfo HealthCheck() override + { + return ErrorInfo(); + } + +private: + ErrorInfo Lease(); + ErrorInfo Release(); + ErrorInfo KeepLease(); + + std::unordered_map BuildHeaders(const std::string &remoteClientId, + const std::string &traceId, const std::string &tenantId); + std::pair, std::string> BuildRequestWithAkSk( + const std::shared_ptr spec, const std::string &url); + LeaseRequest BuildLeaseRequest(); + PutRequest BuildObjPutRequest(std::shared_ptr data, const std::string &objID, + const std::unordered_set &nestedID, const CreateParam &createParam); + GetRequest BuildObjGetRequest(const std::vector &keys, int32_t timeoutMs); + KvGetRequest BuildKvGetRequest(const std::vector &keys, int32_t timeoutMs); + KvSetRequest BuildKvSetRequest(const std::string &key, const std::shared_ptr value, + const SetParam &setParam); + KvMSetTxRequest BuildKvMSetTxRequest(const std::vector &keys, + const std::vector> &vals, const MSetParam &mSetParam); + KvDelRequest BuildKvDelRequest(const std::vector &keys); + IncreaseRefRequest BuildIncreaseRefRequest(const std::vector &objectIds); + DecreaseRefRequest BuildDecreaseRefRequest(const std::vector &objectIds); + + ErrorInfo ParseLeaseResponse(const std::string &result); + ErrorInfo ParseObjGetResponse(const std::string &result, + std::shared_ptr>> results); + ErrorInfo ParseObjPutResponse(const std::string &result); + ErrorInfo ParseKvGetResponse(const std::string &result, + std::shared_ptr>> results); + ErrorInfo ParseKvSetResponse(const std::string &result); + ErrorInfo ParseKvMSetTxResponse(const std::string &result); + ErrorInfo ParseKvDelResponse(const std::string &result, std::shared_ptr> failedKeys); + + ErrorInfo ParseIncreaseRefResponse(const std::string &result, + std::shared_ptr> failedObjectIds); + + ErrorInfo ParseDecreaseRefResponse(const std::string &result, + std::shared_ptr> failedObjectIds); + + ErrorInfo PosixObjPut(const PutRequest &req); + ErrorInfo PosixObjGet(const std::vector &keys, + std::shared_ptr>> results, int32_t timeoutMs = 0); + ErrorInfo PosixKvSet(const std::string &key, const std::shared_ptr value, const SetParam &setParam); + ErrorInfo PosixKvMSetTx(const std::vector &keys, const std::vector> &vals, + const MSetParam &mSetParam); + ErrorInfo PosixKvGet(const std::vector &keys, + std::shared_ptr>> results, int32_t timeoutMs = 0); + ErrorInfo PosixKvDel(const std::vector &keys, std::shared_ptr> failedKeys); + + ErrorInfo PosixGInCreaseRef(const std::vector &objectIds, + std::shared_ptr> failedObjectIds); + + ErrorInfo GenRspError(const boost::beast::error_code &errorCode, const uint statusCode, const std::string &result, + const std::shared_ptr requestId); + + ErrorInfo HandleLease(const std::string &url, const http::verb &verb); + +private: + std::shared_ptr httpClient_; + bool init_ = false; + bool start_ = false; + std::string funcName_; + std::string funcVersion_; + std::string funcId_; + std::string jobId_; + AsyncDecreRef asyncDecreRef_; + RefCountMap refCountMap_; + std::shared_ptr timerWorker_; + std::shared_ptr timer_; + int lostLeaseTimes_ = 0; + std::int32_t connectTimeout_ = DS_CONNECT_TIMEOUT; + std::shared_ptr security_; +}; + +class ClientBuffer : public NativeBuffer { +public: + ClientBuffer(uint64_t size, const std::string &objectId, const CreateParam &createParam, std::weak_ptr c) + : NativeBuffer(size), gwClientWeak_(c), objectId_(objectId), createParam_(createParam) + { + } + virtual ~ClientBuffer() = default; + + virtual ErrorInfo Seal(const std::unordered_set &nestedIds) override; + +private: + std::weak_ptr gwClientWeak_; + std::string objectId_; + CreateParam createParam_; +}; + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/gwclient/gw_datasystem_client_wrapper.h b/src/libruntime/gwclient/gw_datasystem_client_wrapper.h new file mode 100644 index 0000000..f985bd8 --- /dev/null +++ b/src/libruntime/gwclient/gw_datasystem_client_wrapper.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "datasystem/object_client.h" + +#include "src/libruntime/objectstore/datasystem_client_wrapper.h" + +#include "src/libruntime/gwclient/gw_client.h" + +namespace YR { +namespace Libruntime { +datasystem::StatusCode oneOfNoRetryCode = datasystem::StatusCode::K_RUNTIME_ERROR; +class GwDatasystemClientWrapper : public DatasystemClientWrapper { +public: + GwDatasystemClientWrapper(std::shared_ptr client) + { + gwClient = client; + } + + datasystem::Status GDecreaseRef(const std::vector &objectIds, + std::vector &failedObjectIds) + { + if (auto locked = gwClient.lock()) { + auto failedIdsPtr = std::make_shared>(); + auto err = locked->PosixGDecreaseRef(objectIds, failedIdsPtr); + if (err.OK()) { + failedObjectIds.assign(failedIdsPtr->begin(), failedIdsPtr->end()); + return datasystem::Status(datasystem::StatusCode::K_OK, err.Msg()); + } else { + return datasystem::Status(oneOfNoRetryCode, err.Msg()); + } + } else { + YRLOG_DEBUG("gw client pointer is expired."); + } + return {}; + } + + void SetTenantId(const std::string &tenantId) + { + if (auto locked = gwClient.lock()) { + locked->SetTenantId(tenantId); + } + } + +private: + std::weak_ptr gwClient; +}; + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/gwclient/http/async_http_client.cpp b/src/libruntime/gwclient/http/async_http_client.cpp index 3f91c25..1b52f5a 100644 --- a/src/libruntime/gwclient/http/async_http_client.cpp +++ b/src/libruntime/gwclient/http/async_http_client.cpp @@ -19,10 +19,9 @@ namespace YR { namespace Libruntime { +const char *HTTP_CONNECTION_ERROR_MSG = "connection error"; const int DEFAULT_HTTP_VERSION = 11; -int g_defaultIdleTime = 600; const uint HTTP_CONNECTION_ERROR_CODE = 999; -const char *HTTP_CONNECTION_ERROR_MSG = "connection error"; AsyncHttpClient::AsyncHttpClient(std::shared_ptr ctx) : resolver_(asio::make_strand(*ctx)), stream_(asio::make_strand(*ctx)) { @@ -30,20 +29,20 @@ AsyncHttpClient::AsyncHttpClient(std::shared_ptr ctx) AsyncHttpClient::~AsyncHttpClient() { - if (isConnectionAlive_) { + if (IsConnActive()) { GracefulExit(); } } void AsyncHttpClient::GracefulExit() noexcept { + SetConnInActive(); beast::error_code ec; stream_.socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); if (ec) { YRLOG_WARN("failed to shutdown stream: {}", ec.message().c_str()); } stream_.close(); - isConnectionAlive_ = false; } void AsyncHttpClient::SubmitInvokeRequest(const http::verb &method, const std::string &target, @@ -51,8 +50,8 @@ void AsyncHttpClient::SubmitInvokeRequest(const http::verb &method, const std::s const std::string &body, const std::shared_ptr requestId, const HttpCallbackFunction &receiver) { - retried_ = false; callback_ = receiver; + retried_ = false; req_ = {method, target, DEFAULT_HTTP_VERSION}; for (auto &iter : headers) { req_.set(iter.first, iter.second); @@ -71,13 +70,20 @@ ErrorInfo AsyncHttpClient::Init(const ConnectionParam ¶m) { YRLOG_DEBUG("Http init, serverAddr = {}:{}", param.ip, param.port); connParam_ = param; + idleTime_ = param.idleTime; // sync connection try { auto const resolveRes = resolver_.resolve(param.ip, param.port); + if (param.timeoutSec != CONNECTION_NO_TIMEOUT) { + stream_.expires_after(std::chrono::seconds(param.timeoutSec)); + } stream_.connect(resolveRes); + if (param.timeoutSec != CONNECTION_NO_TIMEOUT) { + stream_.expires_never(); + } } catch (const std::exception &e) { std::stringstream ss; - ss << "failed to connect to all addresses, target: "; + ss << "failed to connect to cluster, target: "; ss << param.ip; ss << ":"; ss << param.port; @@ -86,9 +92,7 @@ ErrorInfo AsyncHttpClient::Init(const ConnectionParam ¶m) YRLOG_DEBUG(ss.str()); return ErrorInfo(ErrorCode::ERR_INIT_CONNECTION_FAILED, ModuleCode::RUNTIME, ss.str()); } - lastActiveTime_ = std::chrono::high_resolution_clock::now(); - isConnectionAlive_ = true; - isUsed_ = false; + ResetConnActive(); return ErrorInfo(); } @@ -97,63 +101,47 @@ void AsyncHttpClient::OnRead(const std::shared_ptr requestId, const { boost::ignore_unused(bytesTransferred); if (ec) { - YRLOG_ERROR("requestId {} failed to read response , err message: {}, this client disconnect", *requestId, + YRLOG_ERROR("requestId {} failed to read response , err message: {}, client disconnect", *requestId, ec.message().c_str()); - isConnectionAlive_ = false; + SetConnInActive(); } if (callback_) { callback_(resParser_->get().body(), ec, resParser_->get().result_int()); } - resParser_.reset(); - req_.clear(); - buf_.clear(); - isUsed_ = false; - lastActiveTime_ = std::chrono::high_resolution_clock::now(); -} - -ErrorInfo AsyncHttpClient::ReInit() -{ - GracefulExit(); - isConnectionAlive_ = false; - return Init(connParam_); + CheckResponseHeaderAndReset(); } void AsyncHttpClient::OnWrite(const std::shared_ptr requestId, beast::error_code &ec, std::size_t bytesTransferred) { boost::ignore_unused(bytesTransferred); - if (ec) { - YRLOG_ERROR("requestId {} failed to write, err message: {}, this client disconnect", *requestId, - ec.message().c_str()); - if (!retried_) { - YRLOG_DEBUG("requestId {} start to retry once", *requestId); - retried_ = true; - if (ReInit().OK()) { - http::async_write(stream_, req_, - beast::bind_front_handler(&AsyncHttpClient::OnWrite, shared_from_this(), requestId)); - return; - } - } - isConnectionAlive_ = false; - if (callback_) { - callback_(HTTP_CONNECTION_ERROR_MSG, ec, HTTP_CONNECTION_ERROR_CODE); - } - isUsed_ = false; + if (!ec) { + http::async_read(stream_, buf_, *resParser_, + beast::bind_front_handler(&AsyncHttpClient::OnRead, shared_from_this(), requestId)); return; } - http::async_read(stream_, buf_, *resParser_, - beast::bind_front_handler(&AsyncHttpClient::OnRead, shared_from_this(), requestId)); -} - -void AsyncHttpClient::Cancel() -{ - boost::system::error_code ec; - stream_.socket().cancel(ec); + YRLOG_ERROR("requestId {} failed to write, err message: {}, this client disconnect", *requestId, + ec.message().c_str()); + if (!retried_) { + YRLOG_DEBUG("requestId {} start to retry once", *requestId); + retried_ = true; + if (ReInit().OK()) { + http::async_write(stream_, req_, + beast::bind_front_handler(&AsyncHttpClient::OnWrite, shared_from_this(), requestId)); + return; + } + } + if (callback_) { + callback_(HTTP_CONNECTION_ERROR_MSG, ec, HTTP_CONNECTION_ERROR_CODE); + } + SetConnInActive(); + SetAvailable(); + return; } void AsyncHttpClient::Stop() { - if (isConnectionAlive_) { + if (IsConnActive()) { GracefulExit(); } } diff --git a/src/libruntime/gwclient/http/async_http_client.h b/src/libruntime/gwclient/http/async_http_client.h index 8055908..8a345ac 100644 --- a/src/libruntime/gwclient/http/async_http_client.h +++ b/src/libruntime/gwclient/http/async_http_client.h @@ -50,13 +50,11 @@ public: std::size_t bytesTransferred); void OnWrite(const std::shared_ptr requestId, beast::error_code &ec, std::size_t bytesTransferred); - ErrorInfo ReInit() override; + void Stop() override; - void Cancel() override; + void GracefulExit() noexcept override; - void Stop() override; private: - void GracefulExit() noexcept; asio::ip::tcp::resolver resolver_; beast::tcp_stream stream_; }; diff --git a/src/libruntime/gwclient/http/async_https_client.cpp b/src/libruntime/gwclient/http/async_https_client.cpp index 1922bf4..157414d 100644 --- a/src/libruntime/gwclient/http/async_https_client.cpp +++ b/src/libruntime/gwclient/http/async_https_client.cpp @@ -27,7 +27,7 @@ AsyncHttpsClient::AsyncHttpsClient(const std::shared_ptr &ioc, AsyncHttpsClient::~AsyncHttpsClient() { - if (isConnectionAlive_) { + if (IsConnActive()) { GracefulExit(); } } @@ -60,20 +60,29 @@ ErrorInfo AsyncHttpsClient::Init(const ConnectionParam ¶m) // Set SNI Hostname (hosts need this to handshake successfully) YRLOG_INFO("Https init, serverAddr = {}:{}", param.ip, param.port); connParam_ = param; + idleTime_ = connParam_.idleTime; std::string msg; - if (!SSL_set_tlsext_host_name(stream_->native_handle(), serverName_.c_str())) { - YRLOG_ERROR("failed to set servername: {}", serverName_); - msg = "failed to set servername during initing invoke client, serverName:" + serverName_; - return ErrorInfo(ErrorCode::ERR_INIT_CONNECTION_FAILED, ModuleCode::RUNTIME, msg); + if (!serverName_.empty()) { + if (!SSL_set_tlsext_host_name(stream_->native_handle(), serverName_.c_str())) { + YRLOG_ERROR("failed to set servername: {}", serverName_); + msg = "failed to set servername during initing invoke client, serverName:" + serverName_; + return ErrorInfo(ErrorCode::ERR_INIT_CONNECTION_FAILED, ModuleCode::RUNTIME, msg); + } } - try { // sync connect auto const results = resolver_.resolve(param.ip, param.port); - (void)beast::get_lowest_layer(*stream_).connect(results); + auto &lowgest = beast::get_lowest_layer(*stream_); + if (param.timeoutSec != CONNECTION_NO_TIMEOUT) { + lowgest.expires_after(std::chrono::seconds(param.timeoutSec)); + } + (void)lowgest.connect(results); + if (param.timeoutSec != CONNECTION_NO_TIMEOUT) { + lowgest.expires_never(); + } } catch (const std::exception &e) { std::stringstream ss; - ss << "failed to connect to all addresses, target: " << param.ip << ":" << param.port; + ss << "failed to connect to cluster, target: " << param.ip << ":" << param.port; ss << ", exception: " << e.what(); YRLOG_DEBUG(ss.str()); return ErrorInfo(ErrorCode::ERR_INIT_CONNECTION_FAILED, ModuleCode::RUNTIME, ss.str()); @@ -86,9 +95,7 @@ ErrorInfo AsyncHttpsClient::Init(const ConnectionParam ¶m) msg = "failed to handshake error during initing invoke client, err:" + serverName_; return ErrorInfo(ErrorCode::ERR_INIT_CONNECTION_FAILED, ModuleCode::RUNTIME, msg); } - lastActiveTime_ = std::chrono::high_resolution_clock::now(); - isConnectionAlive_ = true; - isUsed_ = false; + ResetConnActive(); return ErrorInfo(); } @@ -115,9 +122,8 @@ void AsyncHttpsClient::OnWrite(const std::shared_ptr requestId, con if (callback_) { callback_(HTTP_CONNECTION_ERROR_MSG, ec, HTTP_CONNECTION_ERROR_CODE); } - isConnectionAlive_ = false; - isUsed_ = false; - return; + SetConnInActive(); + SetAvailable(); } void AsyncHttpsClient::OnRead(const std::shared_ptr requestId, const beast::error_code &ec, @@ -127,35 +133,37 @@ void AsyncHttpsClient::OnRead(const std::shared_ptr requestId, cons if (ec) { YRLOG_ERROR("requestId {} failed to read response , err message: {}, this client disconnect", *requestId, ec.message().c_str()); - this->isConnectionAlive_ = false; + SetConnInActive(); } if (callback_) { callback_(resParser_->get().body(), ec, resParser_->get().result_int()); } - resParser_.reset(); - buf_.clear(); - req_.clear(); - isUsed_ = false; - lastActiveTime_ = std::chrono::high_resolution_clock::now(); + CheckResponseHeaderAndReset(); } void AsyncHttpsClient::GracefulExit() noexcept { + SetConnInActive(); boost::system::error_code ec = {}; if (stream_) { stream_->shutdown(ec); + if (ec) { + YRLOG_WARN("SSL shutdown failed: {}", ec.message().c_str()); + return; + } + auto &sock = stream_->next_layer().socket(); + sock.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); + if (ec) { + YRLOG_WARN("Socket shutdown failed: {}", ec.message().c_str()); + return; + } + sock.close(ec); + if (ec) { + YRLOG_WARN("Socket close failed: {}", ec.message().c_str()); + return; + } } - if (ec) { - YRLOG_WARN("shutdown fail {}", ec.message().c_str()); - } - isConnectionAlive_ = false; -} - -ErrorInfo AsyncHttpsClient::ReInit() -{ - GracefulExit(); - return Init(connParam_); } void AsyncHttpsClient::Stop() diff --git a/src/libruntime/gwclient/http/async_https_client.h b/src/libruntime/gwclient/http/async_https_client.h index ee36a5b..2cedd8a 100644 --- a/src/libruntime/gwclient/http/async_https_client.h +++ b/src/libruntime/gwclient/http/async_https_client.h @@ -59,11 +59,11 @@ public: void OnWrite(const std::shared_ptr requestId, const beast::error_code &ec, std::size_t bytesTransferred); - ErrorInfo ReInit() override; - void Stop() override; + + void GracefulExit() noexcept override; + private: - void GracefulExit() noexcept; std::shared_ptr ioc_; std::shared_ptr ctx_; asio::ip::tcp::resolver resolver_; diff --git a/src/libruntime/gwclient/http/client_manager.cpp b/src/libruntime/gwclient/http/client_manager.cpp index 3f4f6f9..b0ab6be 100644 --- a/src/libruntime/gwclient/http/client_manager.cpp +++ b/src/libruntime/gwclient/http/client_manager.cpp @@ -25,7 +25,6 @@ #include "src/utility/notification_utility.h" namespace { -const uint32_t MAX_CONN_SIZE = 10000; const int RETRY_TIME = 3; const int INTERVAL_TIME = 2; } // namespace @@ -37,20 +36,25 @@ using YR::utility::NotificationUtility; ClientManager::ClientManager(const std::shared_ptr &libruntimeConfig) : librtCfg(libruntimeConfig) { this->ioc = std::make_shared(); - this->work = std::make_unique(*ioc); + this->work = std::make_unique>( + boost::asio::make_work_guard(*ioc)); this->maxIocThread = libruntimeConfig->httpIocThreadsNum; this->enableMTLS = libruntimeConfig->enableMTLS; + this->maxConnSize_ = libruntimeConfig->maxConnSize; + this->enableTLS_ = libruntimeConfig->enableTLS; } ClientManager::~ClientManager() { - this->work.reset(); + this->work->reset(); ioc->stop(); for (auto client : clients) { client->Stop(); } for (uint32_t i = 0; i < asyncRunners.size(); i++) { - this->asyncRunners[i]->join(); + if (this->asyncRunners[i]->get_id() != std::this_thread::get_id()) { + this->asyncRunners[i]->join(); + } } } @@ -65,8 +69,11 @@ ErrorInfo ClientManager::InitCtxAndIocThread() ctx->set_verify_mode(ssl::verify_peer); ctx->load_verify_file(librtCfg->verifyFilePath); ctx->use_certificate_chain_file(librtCfg->certificateFilePath); + ctx->set_password_callback([&](std::size_t max_length, ssl::context_base::password_purpose purpose) { + return std::string(librtCfg->privateKeyPaaswd); + }); ctx->use_private_key_file(librtCfg->privateKeyPath, ssl::context::pem); - for (uint32_t i = 0; i < MAX_CONN_SIZE; i++) { + for (uint32_t i = 0; i < maxConnSize_; i++) { this->clients.emplace_back(std::make_shared(this->ioc, ctx, librtCfg->serverName)); } } catch (const std::exception &e) { @@ -79,8 +86,21 @@ ErrorInfo ClientManager::InitCtxAndIocThread() "caught unknown exception when init context"); return err; } + } else if (enableTLS_) { + auto ctx = std::make_shared(ssl::context::tlsv12_client); + ctx->set_options(ssl::context::default_workarounds | ssl::context::no_sslv2 | ssl::context::no_sslv3 | + ssl::context::no_tlsv1 | ssl::context::no_tlsv1_1); + if (librtCfg->verifyFilePath.empty()) { + ctx->set_default_verify_paths(); + } else { + ctx->load_verify_file(librtCfg->verifyFilePath); + } + ctx->set_verify_mode(ssl::verify_peer); + for (uint32_t i = 0; i < maxConnSize_; i++) { + this->clients.emplace_back(std::make_shared(this->ioc, ctx, librtCfg->serverName)); + } } else { - for (uint32_t i = 0; i < MAX_CONN_SIZE; i++) { + for (uint32_t i = 0; i < maxConnSize_; i++) { this->clients.emplace_back(std::make_shared(this->ioc)); } } @@ -100,15 +120,12 @@ ErrorInfo ClientManager::Init(const ConnectionParam ¶m) return error; } this->connParam = param; - connectedClientsCnt = YR::Libruntime::Config::Instance().YR_HTTP_CONNECTION_NUM(); - YRLOG_INFO("http initial connection num {}", connectedClientsCnt); - if (connectedClientsCnt > MAX_CONN_SIZE) { - YRLOG_WARN("Requested connections exceed maximum allowed ({}). Clamping to maximum.", MAX_CONN_SIZE); - connectedClientsCnt = MAX_CONN_SIZE; - } - for (uint32_t i = 0; i < connectedClientsCnt; i++) { + connectedClientsCnt_ = YR::Libruntime::Config::Instance().YR_HTTP_CONNECTION_NUM(); + YRLOG_INFO("http initial connection num {}", connectedClientsCnt_); + for (uint32_t i = 0; i < connectedClientsCnt_; i++) { for (int j = 0; j < RETRY_TIME; j++) { error = clients[i]->Init(param); + clients[i]->SetAvailable(); if (error.OK()) { break; } @@ -126,48 +143,57 @@ void ClientManager::SubmitInvokeRequest(const http::verb &method, const std::str const std::string &body, const std::shared_ptr requestId, const HttpCallbackFunction &receiver) { - std::unique_lock lk(connMtx); - if (clients.empty()) { - YRLOG_ERROR("Clients are not initialized."); - receiver(HTTP_CONNECTION_ERROR_MSG, boost::asio::error::make_error_code(boost::asio::error::connection_reset), - HTTP_CONNECTION_ERROR_CODE); - return; - } for (;;) { - for (uint32_t i = 0; i < connectedClientsCnt; i++) { - if (!this->clients[i]->Available()) { - continue; + if (SubmitRequest(method, target, headers, body, requestId, receiver)) { + break; + } + std::this_thread::yield(); + } +} + +bool ClientManager::SubmitRequest(const http::verb &method, const std::string &target, + const std::unordered_map &headers, const std::string &body, + const std::shared_ptr requestId, const HttpCallbackFunction &receiver) +{ + for (uint32_t i = 0;; i++) { + { + absl::ReaderMutexLock l(&connCntMu_); + if (i >= connectedClientsCnt_) { + break; } - YRLOG_DEBUG("httpclient {} is available, will use this client", i); - // while the connection idletime exceed setup timeout, the server may close the connection - // in this situation, client should try to reconnect - if (!this->clients[i]->IsActive()) { - YRLOG_DEBUG("httpclient {} is not active, reinit now", i); - auto err = this->clients[i]->ReInit(); - if (!err.OK()) { - YRLOG_DEBUG("httpclient {} is reInit failed, err: {}", i, err.CodeAndMsg()); - receiver(HTTP_CONNECTION_ERROR_MSG, - boost::asio::error::make_error_code(boost::asio::error::connection_reset), - HTTP_CONNECTION_ERROR_CODE); - return; - } + } + if (!this->clients[i]->SetUnavailable()) { + continue; + } + YRLOG_DEBUG("httpclient {} is available, will use this client", i); + // while the connection idletime exceed setup timeout, the server may close the connection + // in this situation, client should try to reconnect + if (!this->clients[i]->IsConnActive()) { + YRLOG_DEBUG("httpclient {} is not active, reinit now", i); + auto err = this->clients[i]->ReInit(); + if (!err.OK()) { + YRLOG_DEBUG("httpclient {} is reInit failed, err: {}", i, err.CodeAndMsg()); + receiver(err.CodeAndMsg(), boost::asio::error::make_error_code(boost::asio::error::connection_reset), + HTTP_CONNECTION_ERROR_CODE); + this->clients[i]->SetAvailable(); + return true; } - this->clients[i]->SetUnavailable(); - lk.unlock(); - this->clients[i]->SubmitInvokeRequest(method, target, headers, body, requestId, receiver); - return; } - - if (connectedClientsCnt < MAX_CONN_SIZE) { - connectedClientsCnt++; - this->clients[connectedClientsCnt - 1]->Init(this->connParam); - this->clients[connectedClientsCnt - 1]->SetUnavailable(); - this->clients[connectedClientsCnt - 1]->SubmitInvokeRequest(method, target, headers, body, requestId, - receiver); - return; + this->clients[i]->SubmitInvokeRequest(method, target, headers, body, requestId, receiver); + return true; + } + uint32_t newClientIdx = 0; + { + absl::WriterMutexLock l(&connCntMu_); + if (connectedClientsCnt_ >= maxConnSize_) { + return false; } - std::this_thread::yield(); + newClientIdx = connectedClientsCnt_++; } + YRLOG_DEBUG("init httpclient {}", newClientIdx); + this->clients[newClientIdx]->Init(this->connParam); + this->clients[newClientIdx]->SubmitInvokeRequest(method, target, headers, body, requestId, receiver); + return true; } } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/gwclient/http/client_manager.h b/src/libruntime/gwclient/http/client_manager.h index 3603528..6d9fbc9 100644 --- a/src/libruntime/gwclient/http/client_manager.h +++ b/src/libruntime/gwclient/http/client_manager.h @@ -32,25 +32,30 @@ public: ClientManager(const std::shared_ptr &librtCfg); ~ClientManager() override; - ErrorInfo Init(const ConnectionParam ¶m) override; - void SubmitInvokeRequest(const http::verb &method, const std::string &target, - const std::unordered_map &headers, const std::string &body, - const std::shared_ptr requestId, - const HttpCallbackFunction &receiver) override; + virtual ErrorInfo Init(const ConnectionParam ¶m) override; + virtual void SubmitInvokeRequest(const http::verb &method, const std::string &target, + const std::unordered_map &headers, + const std::string &body, const std::shared_ptr requestId, + const HttpCallbackFunction &receiver) override; private: + bool SubmitRequest(const http::verb &method, const std::string &target, + const std::unordered_map &headers, const std::string &body, + const std::shared_ptr requestId, const HttpCallbackFunction &receiver); ErrorInfo InitCtxAndIocThread(); std::shared_ptr ioc; - std::unique_ptr work; + std::unique_ptr> work; std::vector> asyncRunners; ConnectionParam connParam; std::vector> clients; - uint32_t connectedClientsCnt; + uint32_t connectedClientsCnt_ ABSL_GUARDED_BY(connCntMu); + mutable absl::Mutex connCntMu_; std::shared_ptr librtCfg; - std::mutex connMtx; uint32_t maxIocThread; bool enableMTLS; + bool enableTLS_{false}; + uint32_t maxConnSize_; }; } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/gwclient/http/http_client.h b/src/libruntime/gwclient/http/http_client.h index 5035532..b306121 100644 --- a/src/libruntime/gwclient/http/http_client.h +++ b/src/libruntime/gwclient/http/http_client.h @@ -20,12 +20,12 @@ #include #include -#include "absl/synchronization/mutex.h" #include #include #include +#include "absl/synchronization/mutex.h" +#include "src/dto/config.h" #include "src/libruntime/err_type.h" - namespace http = boost::beast::http; namespace ssl = boost::asio::ssl; namespace beast = boost::beast; @@ -38,12 +38,14 @@ const http::verb DELETE = http::verb::delete_; const http::verb GET = http::verb::get; const http::verb PUT = http::verb::put; extern const int DEFAULT_HTTP_VERSION; -extern int g_defaultIdleTime; extern const uint HTTP_CONNECTION_ERROR_CODE; +const int CONNECTION_NO_TIMEOUT = -1; extern const char *HTTP_CONNECTION_ERROR_MSG; struct ConnectionParam { std::string ip; std::string port; + int idleTime{120}; + int timeoutSec = CONNECTION_NO_TIMEOUT; }; class HttpClient { @@ -54,28 +56,47 @@ public: const std::unordered_map &headers, const std::string &body, const std::shared_ptr requestId, const HttpCallbackFunction &receiver) = 0; - virtual void RegisterHeartbeat(const std::string &jobID, int timeout) {}; - virtual bool Available() const + virtual ErrorInfo ReInit() { - absl::ReaderMutexLock l(&mu_); - return !this->isUsed_; - }; + GracefulExit(); + const int totalRetryCount = YR::Libruntime::Config::Instance().MAX_HTTP_RETRY_TIME(); + const int maxTimeoutSec = YR::Libruntime::Config::Instance().MAX_HTTP_TIMEOUT_SEC(); + int timeoutSec = YR::Libruntime::Config::Instance().INITIAL_HTTP_CONNECT_SEC(); + int retryCount = 0; + int backoffFactor = 2; + ErrorInfo err; + while (retryCount < totalRetryCount) { + connParam_.timeoutSec = timeoutSec; + err = Init(connParam_); + if (err.OK()) { + YRLOG_DEBUG("client reinit success"); + connParam_.timeoutSec = CONNECTION_NO_TIMEOUT; + return err; + } + retryCount++; + if (timeoutSec != CONNECTION_NO_TIMEOUT) { + timeoutSec = std::min(timeoutSec * backoffFactor, maxTimeoutSec); + } + YRLOG_DEBUG("retry count {}, init err: {}", retryCount, err.Msg()); + } + connParam_.timeoutSec = CONNECTION_NO_TIMEOUT; + return err; + } - virtual bool IsActive() const - { - auto current = std::chrono::high_resolution_clock::now(); - auto idle = std::chrono::duration_cast(current - this->lastActiveTime_).count(); - return isConnectionAlive_ && idle < g_defaultIdleTime; - }; + virtual void Stop() {} + + virtual void GracefulExit() noexcept {} - virtual bool IsConnActive() const + bool SetUnavailable() { - absl::ReaderMutexLock l(&mu_); - auto current = std::chrono::high_resolution_clock::now(); - auto idle = std::chrono::duration_cast(current - this->lastActiveTime_).count(); - return isConnectionAlive_ && idle < idleTime_; - }; + absl::WriterMutexLock l(&mu_); + if (isUsed_) { + return false; + } + isUsed_ = true; + return true; + } void SetAvailable() { @@ -83,30 +104,64 @@ public: isUsed_ = false; } - virtual ErrorInfo ReInit() + void ResetConnActive() { - return ErrorInfo(); + absl::WriterMutexLock l(&mu_); + lastActiveTime_ = std::chrono::high_resolution_clock::now(); + isConnectionAlive_ = true; } - virtual void Cancel() {} - virtual void Stop() {} - void SetUnavailable() + void ResetConnActiveTime() { - isUsed_ = true; + absl::WriterMutexLock l(&mu_); + lastActiveTime_ = std::chrono::high_resolution_clock::now(); + } + + void SetConnInActive() + { + absl::WriterMutexLock l(&mu_); + isConnectionAlive_ = false; } + bool Available() const + { + absl::ReaderMutexLock l(&mu_); + return !this->isUsed_; + }; + + bool IsConnActive() const + { + absl::ReaderMutexLock l(&mu_); + auto current = std::chrono::high_resolution_clock::now(); + auto idle = std::chrono::duration_cast(current - this->lastActiveTime_).count(); + return isConnectionAlive_ && idle < idleTime_; + }; + + void CheckResponseHeaderAndReset() + { + auto headers = resParser_->get().base(); + if (auto it = headers.find("Connection"); it != headers.end() && it->value() == "close") { + SetConnInActive(); + } + resParser_.reset(); + buf_.clear(); + req_.clear(); + ResetConnActiveTime(); + SetAvailable(); + }; + protected: ConnectionParam connParam_; HttpCallbackFunction callback_; beast::flat_buffer buf_; std::shared_ptr> resParser_; http::request req_; - std::atomic isUsed_{false}; - std::atomic isConnectionAlive_{false}; - std::chrono::time_point lastActiveTime_; + bool isUsed_{true} ABSL_GUARDED_BY(mu_); + bool isConnectionAlive_{false} ABSL_GUARDED_BY(mu_); + std::chrono::time_point lastActiveTime_ ABSL_GUARDED_BY(mu_); bool retried_{false}; - mutable absl::Mutex mu_; int idleTime_{120}; + mutable absl::Mutex mu_; }; inline bool IsResponseSuccessful(const uint statusCode) @@ -115,5 +170,19 @@ inline bool IsResponseSuccessful(const uint statusCode) const uint SUCCESS_CODE_MAX = 299; return (statusCode >= SUCCESS_CODE_MIN && statusCode <= SUCCESS_CODE_MAX); } + +inline bool IsResponseServerError(const uint statusCode) +{ + const uint SUCCESS_CODE_MIN = 500; + const uint SUCCESS_CODE_MAX = 599; + return (statusCode >= SUCCESS_CODE_MIN && statusCode <= SUCCESS_CODE_MAX); +} + +inline bool IsResponseClientError(const uint statusCode) +{ + const uint SUCCESS_CODE_MIN = 400; + const uint SUCCESS_CODE_MAX = 499; + return (statusCode >= SUCCESS_CODE_MIN && statusCode <= SUCCESS_CODE_MAX); +} } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/heterostore/datasystem_hetero_store.cpp b/src/libruntime/heterostore/datasystem_hetero_store.cpp index 6577964..c12ff15 100644 --- a/src/libruntime/heterostore/datasystem_hetero_store.cpp +++ b/src/libruntime/heterostore/datasystem_hetero_store.cpp @@ -77,21 +77,22 @@ void DatasystemHeteroStore::Shutdown() datasystem::Status status = dsHeteroClient->ShutDown(); std::string msg = "shutdown hetero client failed, errMsg:" + status.ToString(); if (!status.IsOk()) { - YRLOG_WARN("hetero object client Shutdown fail. Status code: {}, Msg: {}", status.GetCode(), status.ToString()); + YRLOG_WARN("hetero object client Shutdown fail. Status code: {}, Msg: {}", fmt::underlying(status.GetCode()), + status.ToString()); } } -ErrorInfo DatasystemHeteroStore::Delete(const std::vector &objectIds, +ErrorInfo DatasystemHeteroStore::DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) { HETERO_STORE_INIT_ONCE(); - datasystem::Status status = dsHeteroClient->Delete(objectIds, failedObjectIds); + datasystem::Status status = dsHeteroClient->DevDelete(objectIds, failedObjectIds); std::string msg = "delete hetero object failed, errMsg:" + status.ToString(); RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); return ErrorInfo(); } -ErrorInfo DatasystemHeteroStore::LocalDelete(const std::vector &objectIds, +ErrorInfo DatasystemHeteroStore::DevLocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) { HETERO_STORE_INIT_ONCE(); diff --git a/src/libruntime/heterostore/datasystem_hetero_store.h b/src/libruntime/heterostore/datasystem_hetero_store.h index 00c96a9..8985a39 100644 --- a/src/libruntime/heterostore/datasystem_hetero_store.h +++ b/src/libruntime/heterostore/datasystem_hetero_store.h @@ -27,9 +27,9 @@ public: ~DatasystemHeteroStore() override = default; void Shutdown() override; ErrorInfo Init(datasystem::ConnectOptions &connectOptions) override; - ErrorInfo Delete(const std::vector &objectIds, std::vector &failedObjectIds) override; - ErrorInfo LocalDelete(const std::vector &objectIds, - std::vector &failedObjectIds) override; + ErrorInfo DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) override; + ErrorInfo DevLocalDelete(const std::vector &objectIds, + std::vector &failedObjectIds) override; ErrorInfo DevSubscribe(const std::vector &keys, const std::vector &blob2dList, std::vector> &futureVec) override; ErrorInfo DevPublish(const std::vector &keys, const std::vector &blob2dList, diff --git a/src/libruntime/heterostore/hetero_future.cpp b/src/libruntime/heterostore/hetero_future.cpp index cdf3807..d03403d 100644 --- a/src/libruntime/heterostore/hetero_future.cpp +++ b/src/libruntime/heterostore/hetero_future.cpp @@ -19,10 +19,27 @@ namespace YR { namespace Libruntime { + +YR::Libruntime::AsyncResult ConverDsAsyncResultToLib(datasystem::AsyncResult dsResult) +{ + YR::Libruntime::AsyncResult result; + YRLOG_DEBUG("convert async result, code is {}, msg is {}, failedList size is {}", + fmt::underlying(dsResult.status.GetCode()), dsResult.status.GetMsg(), result.failedList.size()); + if (dsResult.status.IsOk()) { + result.error = YR::Libruntime::ErrorInfo(); + } else { + result.error = ErrorInfo(YR::Libruntime::ConvertDatasystemErrorToCore(dsResult.status.GetCode()), + dsResult.status.GetMsg()); + } + result.failedList = dsResult.failedList; + return result; +} + YR::Libruntime::AsyncResult ConverDsStatusToAsyncRes(datasystem::Status dsStatus) { YR::Libruntime::AsyncResult result; - YRLOG_DEBUG("convert async result from status, code is {}, msg is {}", dsStatus.GetCode(), dsStatus.GetMsg()); + YRLOG_DEBUG("convert async result from status, code is {}, msg is {}", fmt::underlying(dsStatus.GetCode()), + dsStatus.GetMsg()); if (dsStatus.IsOk()) { result.error = YR::Libruntime::ErrorInfo(); } else { @@ -32,6 +49,11 @@ YR::Libruntime::AsyncResult ConverDsStatusToAsyncRes(datasystem::Status dsStatus return result; } +HeteroFuture::HeteroFuture(std::shared_ptr> dsFuture) +{ + this->sharedFuture_ = dsFuture; +} + HeteroFuture::HeteroFuture(std::shared_ptr dsFuture) { this->dsFuture_ = dsFuture; @@ -45,6 +67,10 @@ bool HeteroFuture::IsDsFuture() YR::Libruntime::AsyncResult HeteroFuture::Get() { + if (!this->isDsFuture_) { + auto dsResult = this->sharedFuture_->get(); + return ConverDsAsyncResultToLib(dsResult); + } datasystem::Status status = this->dsFuture_->Get(); return ConverDsStatusToAsyncRes(status); } diff --git a/src/libruntime/heterostore/hetero_future.h b/src/libruntime/heterostore/hetero_future.h index 2287db7..c150cf3 100644 --- a/src/libruntime/heterostore/hetero_future.h +++ b/src/libruntime/heterostore/hetero_future.h @@ -26,11 +26,13 @@ class HeteroFuture { public: HeteroFuture() = default; ~HeteroFuture() = default; + explicit HeteroFuture(std::shared_ptr> dsFuture); explicit HeteroFuture(std::shared_ptr dsFuture); YR::Libruntime::AsyncResult Get(); bool IsDsFuture(); private: + std::shared_ptr> sharedFuture_; std::shared_ptr dsFuture_; bool isDsFuture_ = false; }; diff --git a/src/libruntime/heterostore/hetero_store.h b/src/libruntime/heterostore/hetero_store.h index f96eb67..f159754 100644 --- a/src/libruntime/heterostore/hetero_store.h +++ b/src/libruntime/heterostore/hetero_store.h @@ -42,13 +42,29 @@ public: */ virtual void Shutdown() = 0; + /** + * @brief For device object, to async get multiple objects + * @param[in] objectIds multiple keys support + * @param[out] devBlobList vector of blobs, only modify the data pointed to by the pointer. + * @param[in] timeoutMs max waiting time of getting data + * @return future of AsyncResult, describe get ErrorInfo and failed list. + */ + + /** + * @brief For device object Async set multiple objects, and return before publish rpc called. + * @param[in] objIds multiple keys support + * @param[in] devBlobList vector of blobs + * @return future of AsyncResult, describe set ErrorInfo and failed list. + */ + /** * @brief Invoke worker client to delete all the given objectId. * @param[in] objectIds The vector of the objId. * @param[out] failedObjectIds The failed delete objIds. * @return ERR_OK on any key success; the error code otherwise. */ - virtual ErrorInfo Delete(const std::vector &objectIds, std::vector &failedObjectIds) = 0; + virtual ErrorInfo DevDelete(const std::vector &objectIds, + std::vector &failedObjectIds) = 0; /** * @brief LocalDelete interface. After calling this interface, the data replica stored in the data system by the @@ -57,9 +73,19 @@ public: * @param[out] failedObjectIds Partial failures will be returned through this parameter. * @return ERR_OK on when return sucesss; the error code otherwise. */ - virtual ErrorInfo LocalDelete(const std::vector &objectIds, + virtual ErrorInfo DevLocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) = 0; + /** + * @brief Initialize multipath transfer for a given target device (current context) + * @param[in] devices The device ids participating in the multipath transfer + * @return ErrorInfo of the call. + */ + + /** + * @brief Destroys multipath transfer for the current target device (context) + * @return ErrorInfo of the call. + */ /** * @brief Subscribe data from device. * @param[in] keys A list of keys corresponding to the blob2dList. @@ -91,6 +117,14 @@ public: virtual ErrorInfo DevMSet(const std::vector &keys, const std::vector &blob2dList, std::vector &failedKeys) = 0; + /** + * @brief Retrieves Device data through prefetching, completing the operation and returning immediately, requiring + * combination with DevMGet + * @param[in] keys Keys corresponding to blob2dList + * @param[in] blob2dList List describing the structure of Device memory + * @return ERR_OK on when return sucesssfully; the error code otherwise. + */ + /** * @brief Retrieves data from the Device through the data system, storing it in the corresponding DeviceBlobList * @param[in] keys Keys corresponding to blob2dList diff --git a/src/libruntime/invoke_order_manager.cpp b/src/libruntime/invoke_order_manager.cpp index 1701d57..093776e 100644 --- a/src/libruntime/invoke_order_manager.cpp +++ b/src/libruntime/invoke_order_manager.cpp @@ -30,15 +30,13 @@ void InvokeOrderManager::CreateInstance(std::shared_ptr spec) if (instanceId.empty()) { return; } - YRLOG_DEBUG("instanceid is {}, function meta name is {}, function meta ns is {}", instanceId, - spec->functionMeta.name.value_or("NONE"), spec->functionMeta.ns.value_or("NONE")); absl::MutexLock lock(&mu); if (instances.find(instanceId) == instances.end()) { - YRLOG_DEBUG("insert instance for ordering, instance id: {}", instanceId); auto [it, inserted] = instances.insert({instanceId, ConstuctInstOrder()}); (void)inserted; spec->invokeSeqNo = it->second->orderingCounter++; + YRLOG_DEBUG("current order of instance {} is : {}", instanceId, instances[instanceId]->orderingCounter.load()); } else { YRLOG_DEBUG("insert instance for ordering, instance already exists, instance id: {}", instanceId); } @@ -93,6 +91,21 @@ void InvokeOrderManager::RegisterInstance(const std::string &instanceId) } } +void InvokeOrderManager::RegisterInstanceAndUpdateOrder(const std::string &instanceId) +{ + if (instanceId.empty()) { + return; + } + absl::MutexLock lock(&mu); + if (instances.find(instanceId) == instances.end()) { + instances.insert({instanceId, ConstuctInstOrder()}); + instances[instanceId]->orderingCounter++; + YRLOG_DEBUG("current order of instance {} is : {}", instanceId, instances[instanceId]->orderingCounter.load()); + } else { + YRLOG_DEBUG("register instance for ordering, instance already exists, instance id: {}", instanceId); + } +} + void InvokeOrderManager::RemoveInstance(std::shared_ptr spec) { if (!spec->opts.needOrder || spec->returnIds.empty()) { @@ -106,9 +119,6 @@ void InvokeOrderManager::RemoveInstance(std::shared_ptr spec) return; } - YRLOG_DEBUG( - "start Romove instanceid from order manager, id is {}, function meta name is {}, function meta ns is {}", - instanceId, spec->functionMeta.name.value_or("NONE"), spec->functionMeta.ns.value_or("NONE")); absl::MutexLock lock(&mu); if (instances.find(instanceId) != instances.end()) { YRLOG_DEBUG("remove instance for ordering, instance id: {}", instanceId); @@ -131,8 +141,11 @@ void InvokeOrderManager::Invoke(std::shared_ptr spec) if (instances.find(instanceId) != instances.end()) { auto instOrder = instances[instanceId]; spec->invokeSeqNo = instOrder->orderingCounter++; - YRLOG_DEBUG("instance invoke with order, instance id: {}, request id: {}, sequence No.: {}, unfinished: {}", - instanceId, spec->requestId, spec->invokeSeqNo, instOrder->unfinishedSeqNo); + YRLOG_DEBUG( + "instance id: {}, request id: {}, invoke sequence No.: {}, ordered count: {}, " + "unfinished: {}", + instanceId, spec->requestId, spec->invokeSeqNo, instOrder->orderingCounter.load(), + instOrder->unfinishedSeqNo); } else { if (spec->opts.isGetInstance) { YRLOG_DEBUG("when inovke type is get named instance, need insert instance for ordering, instance id: {}", @@ -202,22 +215,25 @@ void InvokeOrderManager::NotifyInvokeSuccess(std::shared_ptr spec) if (instanceId.empty()) { return; } + this->UpdateFinishReqSeqNo(instanceId, spec->invokeSeqNo); +} +void InvokeOrderManager::UpdateFinishReqSeqNo(const std::string &instanceId, int64_t invokeSeqNo) +{ absl::MutexLock lock(&mu); if (instances.find(instanceId) != instances.end()) { auto instOrder = instances[instanceId]; - instOrder->finishedUnorderedInvokeSpecs.insert({spec->invokeSeqNo, spec}); - auto it = instOrder->finishedUnorderedInvokeSpecs.begin(); - while (it != instOrder->finishedUnorderedInvokeSpecs.end()) { - if (it->first == instOrder->unfinishedSeqNo) { + instOrder->finishedOrderedReqSeqNo.emplace(invokeSeqNo); + auto it = instOrder->finishedOrderedReqSeqNo.begin(); + while (it != instOrder->finishedOrderedReqSeqNo.end()) { + if (*it == instOrder->unfinishedSeqNo) { instOrder->unfinishedSeqNo++; - it = instOrder->finishedUnorderedInvokeSpecs.erase(it); + it = instOrder->finishedOrderedReqSeqNo.erase(it); } else { break; } } - YRLOG_DEBUG("current unfinished sequence No. is {}, instance id: {}, finished unordered spec size: {}", - instOrder->unfinishedSeqNo, instanceId, instOrder->finishedUnorderedInvokeSpecs.size()); + YRLOG_DEBUG("current unfinished sequence No. is {}, instance id: {}", instOrder->unfinishedSeqNo, instanceId); } } diff --git a/src/libruntime/invoke_order_manager.h b/src/libruntime/invoke_order_manager.h index 3feb97a..0b9efa9 100644 --- a/src/libruntime/invoke_order_manager.h +++ b/src/libruntime/invoke_order_manager.h @@ -27,10 +27,12 @@ public: virtual ~InvokeOrderManager() = default; void CreateInstance(std::shared_ptr spec); void RegisterInstance(const std::string &instanceId); + void RegisterInstanceAndUpdateOrder(const std::string &instanceId); void RemoveInstance(std::shared_ptr spec); void Invoke(std::shared_ptr spec); void UpdateUnfinishedSeq(std::shared_ptr spec); void NotifyInvokeSuccess(std::shared_ptr spec); + void UpdateFinishReqSeqNo(const std::string &instanceId, int64_t invokeSeqNo); void ClearInsOrderMsg(const std::string &insId, int signal); void RemoveGroupInstance(const std::string &instanceId); void CreateGroupInstance(const std::string &instanceId); @@ -49,6 +51,7 @@ struct InstanceOrdering { std::atomic orderingCounter{0}; int64_t unfinishedSeqNo = 0; std::map> finishedUnorderedInvokeSpecs; + std::set finishedOrderedReqSeqNo; }; } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/invoke_spec.cpp b/src/libruntime/invoke_spec.cpp index b407c20..e496290 100644 --- a/src/libruntime/invoke_spec.cpp +++ b/src/libruntime/invoke_spec.cpp @@ -23,11 +23,14 @@ const char *STORAGE_TYPE = "storage_type"; const char *CODE_PATH = "code_path"; const char *WORKING_DIR = "working_dir"; const char *DELEGATE_DOWNLOAD = "DELEGATE_DOWNLOAD"; +const char *ENABLE_DEBUG_KEY = "enable"; +const char *ENABLE_DEBUG = "true"; +const char *DEBUG_CONFIG_KEY = "debug_config"; InvokeSpec::InvokeSpec(const std::string &jobId, const FunctionMeta &functionMeta, const std::vector &returnObjs, std::vector invokeArgs, const libruntime::InvokeType invokeType, std::string traceId, std::string requestId, - const std::string &instanceId, const InvokeOptions &opts) + const std::string &instanceId, InvokeOptions opts) : jobId(jobId), functionMeta(functionMeta), returnIds(returnObjs), @@ -39,6 +42,7 @@ InvokeSpec::InvokeSpec(const std::string &jobId, const FunctionMeta &functionMet opts(std::move(opts)), requestInvoke(std::make_shared()) { + schedulerInstanceIdMtx_ = std::make_shared(); } void InvokeSpec::ConsumeRetryTime() @@ -48,6 +52,18 @@ void InvokeSpec::ConsumeRetryTime() } } +bool InvokeSpec::ExceedMaxRetryTime() +{ + if (opts.maxRetryTime == -1) { + return false; + } + if (opts.maxRetryTime == 0) { + return true; + } + --opts.maxRetryTime; + return false; +} + void InvokeSpec::IncrementSeq() { seq++; @@ -60,11 +76,11 @@ std::string InvokeSpec::ConstructRequestID() std::string InvokeSpec::GetNamedInstanceId() { - if (functionMeta.name && !functionMeta.name.value().empty()) { - if (functionMeta.ns && !functionMeta.ns.value().empty()) { - return functionMeta.ns.value() + "-" + functionMeta.name.value(); + if (!functionMeta.name.empty()) { + if (!functionMeta.ns.empty()) { + return functionMeta.ns + "-" + functionMeta.name; } else { - return DEFAULT_YR_NAMESPACE + "-" + functionMeta.name.value(); + return DEFAULT_YR_NAMESPACE + "-" + functionMeta.name; } } return ""; @@ -86,16 +102,41 @@ std::string InvokeSpec::GetInstanceId(std::shared_ptr config) void InvokeSpec::InitDesignatedInstanceId(const LibruntimeConfig &config) { - if (functionMeta.name && !functionMeta.name.value().empty()) { - auto ns = (!functionMeta.ns.has_value() || functionMeta.ns->empty()) ? config.ns : functionMeta.ns.value(); + if (!functionMeta.name.empty()) { + auto ns = functionMeta.ns.empty() ? config.ns : functionMeta.ns; if (ns.empty()) { - designatedInstanceID = *functionMeta.name; + designatedInstanceID = functionMeta.name; } else { - designatedInstanceID = ns + "-" + *functionMeta.name; + designatedInstanceID = ns + "-" + functionMeta.name; } } } +std::vector InvokeSpec::GetSchedulerInstanceIds() +{ + absl::ReaderMutexLock lock(schedulerInstanceIdMtx_.get()); + return this->opts.schedulerInstanceIds; +} + +std::string InvokeSpec::GetSchedulerInstanceId() +{ + absl::ReaderMutexLock lock(schedulerInstanceIdMtx_.get()); + if (this->opts.schedulerInstanceIds.empty()) { + return ""; + } + return this->opts.schedulerInstanceIds[0]; +} + +void InvokeSpec::SetSchedulerInstanceId(const std::string &schedulerId) +{ + absl::WriterMutexLock lock(schedulerInstanceIdMtx_.get()); + if (this->opts.schedulerInstanceIds.empty()) { + this->opts.schedulerInstanceIds.push_back(schedulerId); + } else { + this->opts.schedulerInstanceIds[0] = schedulerId; + } +} + bool InvokeSpec::IsStaleDuplicateNotify(uint8_t sequence) { if (sequence != this->seq) { // stale duplicate notify @@ -173,6 +214,7 @@ void InvokeSpec::BuildRequestPbCreateOptions(InvokeOptions &opts, const Librunti CreateRequest &request) { auto *createOptions = request.mutable_createoptions(); + createOptions->insert({"DATA_AFFINITY_ENABLED", opts.isDataAffinity ? "true" : "false"}); for (auto &opt : opts.createOptions) { createOptions->insert({opt.first, opt.second}); } @@ -234,6 +276,11 @@ void InvokeSpec::BuildRequestPbCreateOptions(InvokeOptions &opts, const Librunti : HIGH_RELIABILITY_TYPE; createOptions->insert({RELIABILITY_TYPE, reliability}); } + if (opts.debug.enable) { + nlohmann::json debugJson; + debugJson[std::string(ENABLE_DEBUG_KEY)] = std::string(ENABLE_DEBUG); + createOptions->insert({std::string(DEBUG_CONFIG_KEY), debugJson.dump()}); + } } void InvokeSpec::BuildInstanceCreateRequest(const LibruntimeConfig &config) @@ -306,10 +353,12 @@ std::string InvokeSpec::BuildCreateMetaData(const LibruntimeConfig &config) funcMeta->set_language(this->functionMeta.languageType); funcMeta->set_modulename(this->functionMeta.moduleName); funcMeta->set_signature(this->functionMeta.signature); - funcMeta->set_name(this->functionMeta.name.value_or("")); - funcMeta->set_ns((!this->functionMeta.ns.has_value() || this->functionMeta.ns->empty()) - ? config.ns - : this->functionMeta.ns.value_or("")); + if (!this->functionMeta.name.empty()) { + funcMeta->set_name(this->functionMeta.name); + } + if (!this->functionMeta.ns.empty()) { + funcMeta->set_ns(this->functionMeta.ns); + } auto metaConfig = meta.mutable_config(); config.BuildMetaConfig(*metaConfig); if (!this->opts.codePaths.empty()) { @@ -339,6 +388,7 @@ std::string InvokeSpec::BuildCreateMetaData(const LibruntimeConfig &config) metaConfig->add_schedulerinstanceids(id); } } + auto invocationMeta = meta.mutable_invocationmeta(); if (config.runtimeId == "driver") { invocationMeta->set_invokerruntimeid(config.runtimeId + "_" + config.jobId); @@ -347,7 +397,7 @@ std::string InvokeSpec::BuildCreateMetaData(const LibruntimeConfig &config) } invocationMeta->set_invocationsequenceno(this->invokeSeqNo); invocationMeta->set_minunfinishedsequenceno(this->invokeUnfinishedSeqNo); - YRLOG_DEBUG("create meta data is {}", meta.DebugString()); + YRLOG_DEBUG("create meta data: >{}<", meta.DebugString()); return meta.SerializeAsString(); } @@ -403,11 +453,16 @@ bool RequestResource::operator==(const RequestResource &r) const std::size_t h2 = (*it)->GetAffinityHash(); affinityHash = affinityHash ^ h2; } - if (opts.instanceSession != r.opts.instanceSession) { - return false; - } - if (opts.instanceSession != nullptr && opts.instanceSession->sessionID != r.opts.instanceSession->sessionID) { - return false; + if (opts.instanceSession || r.opts.instanceSession) { + if (!opts.instanceSession || !r.opts.instanceSession) { + return false; + } + if (opts.instanceSession->sessionID != r.opts.instanceSession->sessionID) { + return false; + } + if (opts.instanceSession->sessionTTL != r.opts.instanceSession->sessionTTL) { + return false; + } } if (r.opts.invokeLabels.size() != opts.invokeLabels.size()) { return false; @@ -439,9 +494,22 @@ bool RequestResource::operator==(const RequestResource &r) const opts.resourceGroupOpts.resourceGroupName == r.opts.resourceGroupOpts.resourceGroupName) && affinityHash == 0; } + +std::size_t FaasInfoForBatchRenewFn::operator()(const FaasInfoForBatchRenew &i) const +{ + return std::hash{}(i.schedulerFunctionID) ^ std::hash{}(i.schedulerInstanceID) ^ + std::hash{}(i.functionId) ^ std::hash{}(i.batchIndex); +} + +bool FaasInfoForBatchRenew::operator==(const FaasInfoForBatchRenew &i) const +{ + return schedulerInstanceID == i.schedulerInstanceID && schedulerFunctionID == i.schedulerFunctionID && + functionId == i.functionId && batchIndex == i.batchIndex; +} + void RequestResource::Print(void) const { - YRLOG_DEBUG("function meta: {} {}", functionMeta.languageType, functionMeta.functionId); + YRLOG_DEBUG("function meta: {} {}", fmt::underlying(functionMeta.languageType), functionMeta.functionId); } std::size_t HashFn::operator()(const RequestResource &r) const @@ -469,13 +537,15 @@ std::size_t HashFn::operator()(const RequestResource &r) const } if (r.opts.instanceSession) { std::size_t h10 = std::hash()(r.opts.instanceSession->sessionID); - result = result ^ h10; + std::size_t h11 = std::hash()(r.opts.instanceSession->sessionTTL); + result = result ^ h10 ^ h11; } for (const auto &envPair: r.opts.envVars) { - std::size_t h11 = std::hash()(envPair.first); - std::size_t h12 = std::hash()(envPair.second); - result = result ^ h11 ^ h12; + std::size_t h12 = std::hash()(envPair.first); + std::size_t h13 = std::hash()(envPair.second); + result = result ^ h12 ^ h13; } + result = result ^ std::hash()(r.opts.debug); return result; } } // namespace Libruntime diff --git a/src/libruntime/invoke_spec.h b/src/libruntime/invoke_spec.h index 4ff9929..ffe2c8b 100644 --- a/src/libruntime/invoke_spec.h +++ b/src/libruntime/invoke_spec.h @@ -20,6 +20,7 @@ #include #include +#include "absl/synchronization/mutex.h" #include "src/dto/acquire_options.h" #include "src/dto/config.h" @@ -29,6 +30,7 @@ #include "src/libruntime/fsclient/fs_intf.h" #include "src/libruntime/fsclient/protobuf/common.grpc.pb.h" #include "src/libruntime/invokeadaptor/report_record.h" +#include "src/libruntime/invokeadaptor/scheduler_instance_info.h" #include "src/libruntime/libruntime_config.h" #include "src/libruntime/utils/serializer.h" #include "src/libruntime/utils/utils.h" @@ -54,13 +56,21 @@ extern const char *CODE_PATH; extern const char *WORKING_DIR; extern const char *DELEGATE_DOWNLOAD; const std::string NEED_ORDER = "need_order"; +const std::string FAAS_INVOKE_TIMEOUT = "INVOKE_TIMEOUT"; const std::string RECOVER_RETRY_TIMES = "RecoverRetryTimes"; +extern const char *ENABLE_DEBUG_KEY; +extern const char *ENABLE_DEBUG; +extern const char *DEBUG_CONFIG_KEY; struct InvokeSpec { InvokeSpec(const std::string &jobId, const FunctionMeta &functionMeta, const std::vector &returnObjs, - std::vector invokeArgs, const libruntime::InvokeType invokeType, std::string traceId, - std::string requestId, const std::string &instanceId, const InvokeOptions &opts); - InvokeSpec() : requestInvoke(std::make_shared()) {} + std::vector invokeArgs, const libruntime::InvokeType invokeType, + std::string traceId, std::string requestId, const std::string &instanceId, + InvokeOptions opts); + InvokeSpec() : requestInvoke(std::make_shared()) + { + schedulerInstanceIdMtx_ = std::make_shared(); + } ~InvokeSpec() = default; std::string jobId; FunctionMeta functionMeta; @@ -82,9 +92,9 @@ struct InvokeSpec { std::shared_ptr requestInvoke; uint8_t seq = 0; std::string instanceRoute = ""; - + bool downgradeFlag_{false}; void ConsumeRetryTime(void); - + bool ExceedMaxRetryTime(); void IncrementSeq(); std::string ConstructRequestID(); @@ -114,6 +124,8 @@ struct InvokeSpec { std::string BuildInvokeMetaData(const LibruntimeConfig &config); + std::shared_ptr schedulerInfos = std::make_shared(); + template void BuildRequestPbArgs(const LibruntimeConfig &config, T &request, bool isCreate) { @@ -152,9 +164,15 @@ struct InvokeSpec { } } + std::vector GetSchedulerInstanceIds(); + std::string GetSchedulerInstanceId(); + void SetSchedulerInstanceId(const std::string &schedulerId); + private: void BuildRequestPbScheduleOptions(InvokeOptions &opts, const LibruntimeConfig &config, CreateRequest &request); void BuildRequestPbCreateOptions(InvokeOptions &opts, const LibruntimeConfig &config, CreateRequest &request); + + std::shared_ptr schedulerInstanceIdMtx_; }; struct FaasAllocationInfo { @@ -166,6 +184,25 @@ struct FaasAllocationInfo { // ReportRecord record; // metrics, faas scheduler弹性依据 }; +struct FaasInfoForBatchRenew { + std::string schedulerInstanceID; + std::string schedulerFunctionID; + std::string functionId; + int64_t batchIndex; + bool operator==(const FaasInfoForBatchRenew &r) const; + FaasInfoForBatchRenew(const FaasAllocationInfo &faasInfo, int64_t index) + : schedulerInstanceID(faasInfo.schedulerInstanceID), + schedulerFunctionID(faasInfo.schedulerFunctionID), + functionId(faasInfo.functionId), + batchIndex(index) + { + } +}; + +struct FaasInfoForBatchRenewFn { + std::size_t operator()(const FaasInfoForBatchRenew &i) const; +}; + struct InstanceInfo { std::string instanceId = "" ABSL_GUARDED_BY(mtx); std::string leaseId = "" ABSL_GUARDED_BY(mtx); @@ -178,6 +215,7 @@ struct InstanceInfo { std::string stateId = "" ABSL_GUARDED_BY(mtx); std::shared_ptr scaleDownTimer ABSL_GUARDED_BY(mtx); int64_t claimTime = 0 ABSL_GUARDED_BY(mtx); + bool needReacquire = false ABSL_GUARDED_BY(mtx); mutable absl::Mutex mtx; }; @@ -188,20 +226,9 @@ struct CreatingInsInfo { CreatingInsInfo(const std::string &id = "", int64_t time = 0) : instanceId(id), startTime(time) {} }; -struct RequestResourceInfo { - std::unordered_map> instanceInfos ABSL_GUARDED_BY(mtx); - std::unordered_map> avaliableInstanceInfos ABSL_GUARDED_BY(mtx); - std::vector> creatingIns ABSL_GUARDED_BY(mtx); - int createFailInstanceNum ABSL_GUARDED_BY(mtx); - // The time to create an instance. when cancle a creating instance - // the waiting time should not be less than the createTime. - int createTime ABSL_GUARDED_BY(mtx); - mutable absl::Mutex mtx; -}; - struct RequestResource { FunctionMeta functionMeta; - size_t concurrency; + size_t concurrency = 1; InvokeOptions opts; bool operator==(const RequestResource &r) const; void Print(void) const; @@ -211,6 +238,18 @@ struct HashFn { std::size_t operator()(const RequestResource &r) const; }; +struct RequestResourceInfo { + std::unordered_map> instanceInfos ABSL_GUARDED_BY(mtx); + std::unordered_map> avaliableInstanceInfos ABSL_GUARDED_BY(mtx); + std::vector> creatingIns ABSL_GUARDED_BY(mtx); + int createFailInstanceNum ABSL_GUARDED_BY(mtx); + // The time to create an instance. when cancle a creating instance + // the waiting time should not be less than the createTime. + int createTime ABSL_GUARDED_BY(mtx); + int tLeaseInterval ABSL_GUARDED_BY(mtx); + mutable absl::Mutex mtx; +}; + struct ConcurrencyGroup { std::string name; uint32_t maxConcurrency; diff --git a/src/libruntime/invokeadaptor/alias_element.cpp b/src/libruntime/invokeadaptor/alias_element.cpp new file mode 100644 index 0000000..97048e1 --- /dev/null +++ b/src/libruntime/invokeadaptor/alias_element.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "alias_element.h" + +namespace YR { +namespace Libruntime { +using nlohmann::json; +using namespace YR::utility; +void from_json(const json &j, std::vector &schedulerInstanceList) +{ + for (auto &ele : j) { + AliasElement aliasEle; + JsonGetTo(ele, "aliasUrn", aliasEle.aliasUrn); + JsonGetTo(ele, "functionUrn", aliasEle.functionUrn); + JsonGetTo(ele, "functionVersionUrn", aliasEle.functionVersionUrn); + JsonGetTo(ele, "name", aliasEle.name); + JsonGetTo(ele, "functionVersion", aliasEle.functionVersion); + JsonGetTo(ele, "revisionId", aliasEle.revisionId); + JsonGetTo(ele, "description", aliasEle.description); + JsonGetTo(ele, "routingType", aliasEle.routingType); + if (aliasEle.routingType == "rule") { + JsonGetTo(ele, "routingRules", aliasEle.routingRules); + } else { + JsonGetTo(ele, "routingconfig", aliasEle.routingConfig); + } + schedulerInstanceList.push_back(aliasEle); + } +} + +void from_json(const json &j, RoutingRules &routingRules) +{ + JsonGetTo(j, "ruleLogic", routingRules.ruleLogic); + JsonGetTo(j, "rules", routingRules.rules); + JsonGetTo(j, "grayVersion", routingRules.grayVersion); +} + +void from_json(const json &j, std::vector &routingConfig) +{ + for (auto &ele : j) { + RoutingConfig conf; + JsonGetTo(ele, "functionVersionUrn", conf.functionVersionUrn); + JsonGetTo(ele, "weight", conf.weight); + routingConfig.push_back(conf); + } +} + +void to_json(nlohmann::json &j, const RoutingRules &routingRules) +{ + j = json{{"ruleLogic", routingRules.ruleLogic}, + {"rules", routingRules.rules}, + {"grayVersion", routingRules.grayVersion}}; +} + +void to_json(nlohmann::json &j, const RoutingConfig &routingConfig) +{ + j = json{{"functionVersionUrn", routingConfig.functionVersionUrn}, {"weight", routingConfig.weight}}; +} + +void to_json(nlohmann::json &j, const AliasElement &aliasElement) +{ + j = json{{"aliasUrn", aliasElement.aliasUrn}, + {"functionUrn", aliasElement.functionUrn}, + {"functionVersionUrn", aliasElement.functionVersionUrn}, + {"name", aliasElement.name}, + {"functionVersion", aliasElement.functionVersion}, + {"revisionId", aliasElement.revisionId}, + {"description", aliasElement.description}, + {"routingType", aliasElement.routingType}, + {"routingRules", aliasElement.routingRules}, + {"routingConfig", aliasElement.routingConfig}}; +} + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/alias_element.h b/src/libruntime/invokeadaptor/alias_element.h new file mode 100644 index 0000000..001fda3 --- /dev/null +++ b/src/libruntime/invokeadaptor/alias_element.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include "src/utility/json_utility.h" + +namespace YR { +namespace Libruntime { +struct RoutingRules { + std::string ruleLogic; + std::vector rules; + std::string grayVersion; +}; + +struct RoutingConfig { + std::string functionVersionUrn; + double weight; +}; + +struct AliasElement { + std::string aliasUrn; + std::string functionUrn; + std::string functionVersionUrn; + std::string name; + std::string functionVersion; + std::string revisionId; + std::string description; + std::string routingType; + RoutingRules routingRules; + std::vector routingConfig; +}; + +void from_json(const nlohmann::json &j, std::vector &schedulerInstanceList); + +void from_json(const nlohmann::json &j, RoutingRules &routingRules); + +void from_json(const nlohmann::json &j, std::vector &schedulerInstanceList); + +void to_json(nlohmann::json &j, const RoutingRules &routingRules); + +void to_json(nlohmann::json &j, const RoutingConfig &routingConfig); + +void to_json(nlohmann::json &j, const AliasElement &aliasElement); + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/alias_routing.cpp b/src/libruntime/invokeadaptor/alias_routing.cpp new file mode 100644 index 0000000..1b9e28a --- /dev/null +++ b/src/libruntime/invokeadaptor/alias_routing.cpp @@ -0,0 +1,290 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "alias_routing.h" +#include "re2/re2.h" +#include "load_balancer.h" +#include "src/utility/logger/logger.h" +#include "src/utility/string_utility.h" + +namespace YR { +namespace Libruntime { +using namespace YR::utility; + +const std::string RoutingTypeRule = "rule"; +const size_t SplitedRuleNum = 3; +const int WEIGHT_MULTIPLIER = 100; + +const std::string ALIAS_PATTERN_STRING = "^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$"; +const re2::RE2 ALIAS_PATTERN(ALIAS_PATTERN_STRING); +const int ALIAS_LENGTH_LIMIT = 32; + +void AliasRouting::UpdateAliasInfo(const std::vector &aliasInfo) +{ + YRLOG_INFO("recv aliases info"); + std::unordered_map> newAliasInfo; + for (auto &alias : aliasInfo) { + if (newAliasInfo.find(alias.aliasUrn) == newAliasInfo.end()) { + std::unique_ptr lb(LoadBalancer::Factory(LoadBalancerType::WeightedRoundRobin)); + auto entry = std::make_shared(alias, std::move(lb)); + newAliasInfo.emplace(alias.aliasUrn, entry); + } else { + auto entry = newAliasInfo[alias.aliasUrn]; + entry->Update(alias); + } + } + + absl::WriterMutexLock lock(&mu_); + aliasInfo_ = newAliasInfo; +} + +std::string AliasRouting::GetFuncVersionUrnWithParams(const std::string &aliasUrn, + const std::unordered_map ¶ms) +{ + std::shared_ptr aliasEntry; + { + absl::ReaderMutexLock lock(&mu_); + if (aliasInfo_.find(aliasUrn) == aliasInfo_.end()) { + return aliasUrn; + } + aliasEntry = aliasInfo_[aliasUrn]; + } + + if (aliasEntry == nullptr) { + YRLOG_ERROR("empty alias entry for alias URN: {}", aliasUrn); + return aliasUrn; + } + + if (aliasEntry->IsRoutingByRules()) { + return aliasEntry->GetFuncVersionUrnByRule(params); + } + + return aliasEntry->GetFuncVersionUrn(); +} + +static std::vector ParseRules(const AliasElement &alias) +{ + std::vector parsed(alias.routingRules.rules.size()); + for (size_t i = 0; i < alias.routingRules.rules.size(); ++i) { + auto &rule = alias.routingRules.rules[i]; + std::vector result; + Split(rule, result, ':'); + if (result.size() != SplitedRuleNum) { + YRLOG_ERROR("rule is not splited to size: {}, rule: {}", SplitedRuleNum, rule); + return std::vector(); + } + + parsed[i].left = result[0]; // left + parsed[i].op = result[1]; // operation + parsed[i].right = result[2]; // right + } + return parsed; +} + +std::tuple parseFunctionId(const std::string &functionId) +{ + std::vector result; + Split(functionId, result, '/'); + if (result.size() != 3) { // {tenantId}/0@default@{functionName}/{aliasOrVersion} + return {false, "", "", ""}; + } + + return {true, result.at(0), result.at(1), result.at(2)}; +}; + +bool AliasRouting::CheckAlias(const std::string &functionId) +{ + auto [parseOk, tenantId, functionName, aliasOrVersion] = parseFunctionId(functionId); + if (!parseOk) { + return false; + } + YRLOG_DEBUG("functionId {}, tenantId {}, functionName {}, aliasOrVersion {}", + functionId, tenantId, functionName, aliasOrVersion); + if (aliasOrVersion == "latest" || aliasOrVersion == "$latest") { + return false; + } + return RE2::FullMatch(aliasOrVersion, ALIAS_PATTERN); +} + +std::tuple functionIdToAliasUrn(const std::string &functionId) +{ + auto [parseOk, tenantId, functionFullName, alias] = parseFunctionId(functionId); + if (!parseOk) { + YRLOG_ERROR("functionId format error {}", functionId); + return {false, ""}; + } + + std::ostringstream oss; + oss << "sn:cn:yrk:" << tenantId << ":function:" << functionFullName << ":" << alias; + return {true, oss.str()}; +} + +std::tuple functionVersionUrnToFunctionId(const std::string &functionVersionUrn) +{ + std::vector result; + + // example: sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest + Split(functionVersionUrn, result, ':'); + if (result.size() != 7) { // + return {false, ""}; + } + std::ostringstream oss; + oss << result.at(3) << "/" << result.at(5) << "/" << result.at(6); // {tenantId}/{functionName}/{version} + return {true, oss.str()}; +} + +std::string AliasRouting::ParseAlias(const std::string &functionId, + const std::unordered_map ¶ms) +{ + auto [ok, aliasUrn] = functionIdToAliasUrn(functionId); + if (!ok) { + YRLOG_ERROR("empty alias entry for alias URN: {}", aliasUrn); + return functionId; + } + std::string functionVersionUrn = this->GetFuncVersionUrnWithParams(aliasUrn, params); + + auto [ok1, parsedFunctionId] = functionVersionUrnToFunctionId(functionVersionUrn); + if (!ok1) { + YRLOG_WARN("functionVersionUrnToFunctionId failed, urn: {}", functionVersionUrn); // just for clean check + } + + YRLOG_INFO("parse alias {} to {}", functionId, parsedFunctionId); + return parsedFunctionId; +} + +static bool MatchRuleString(const std::string &val, const std::string &inStrings) +{ + std::vector result; + Split(inStrings, result, ','); + for (auto &inStr : result) { + if (!inStr.empty() && TrimSpace(val) == TrimSpace(inStr)) { + return true; + } + } + return false; +} + +static bool MatchOneRule(const AliasRule &rule, const std::unordered_map ¶ms) +{ + if (params.find(rule.left) == params.end()) { + YRLOG_ERROR("cannot find rule left {} in params", rule.left); + return false; + } + + const auto &val = params.at(rule.left); + if (rule.op == "=") { + return TrimSpace(val) == TrimSpace(rule.right); + } else if (rule.op == "!=") { + return TrimSpace(val) != TrimSpace(rule.right); + } else if (rule.op == ">") { + auto [result, err] = CompareIntFromString(val, rule.right); + return err.empty() && result == 1; + } else if (rule.op == "<") { + auto [result, err] = CompareIntFromString(val, rule.right); + return err.empty() && result == -1; + } else if (rule.op == ">=") { + auto [result, err] = CompareIntFromString(val, rule.right); + return err.empty() && (result == 1 || result == 0); + } else if (rule.op == "<=") { + auto [result, err] = CompareIntFromString(val, rule.right); + return err.empty() && (result == -1 || result == 0); + } else if (rule.op == "in") { + return MatchRuleString(val, rule.right); + } else { + YRLOG_ERROR("unknown operator in alias rule, op: {}", rule.op); + return false; + } +} + +static bool MergeMatchResults(const std::list &results, const std::string &ruleLogic) +{ + if (results.empty()) { + return false; + } + + if (ruleLogic == "or") { + auto it = std::find(results.begin(), results.end(), true); + if (it != results.end()) { + return true; + } + return false; + } else if (ruleLogic == "and") { + auto it = std::find(results.begin(), results.end(), false); + if (it != results.end()) { + return false; + } + return true; + } + YRLOG_ERROR("unknown alias rule logic: {}", ruleLogic); + return false; +} + +static bool MatchRules(const std::unordered_map ¶ms, const std::vector &rules, + const std::string &ruleLogic) +{ + std::list results; + for (auto &rule : rules) { + results.push_back(MatchOneRule(rule, params)); + } + return MergeMatchResults(results, ruleLogic); +} + +AliasEntry::AliasEntry(const AliasElement &alias, std::unique_ptr &&lb) : lb_(std::move(lb)) +{ + Update(alias); +} + +void AliasEntry::Update(const AliasElement &alias) +{ + absl::WriterMutexLock lock(&mu_); + aliasElement_ = alias; + aliasParsedRules_ = ParseRules(alias); + lb_->RemoveAll(); + for (auto &conf : alias.routingConfig) { + // rate + lb_->Add(conf.functionVersionUrn, int(conf.weight * WEIGHT_MULTIPLIER)); + } +} + +bool AliasEntry::IsRoutingByRules(void) +{ + return aliasElement_.routingType == RoutingTypeRule; +} + +std::string AliasEntry::GetFuncVersionUrnByRule(const std::unordered_map ¶ms) +{ + absl::ReaderMutexLock lock(&mu_); + if (params.empty() || aliasElement_.routingRules.rules.empty() || aliasParsedRules_.empty()) { + YRLOG_ERROR("params or alias rules is empty, params size: {}, alias rules size: {}, parsed rule size: {}", + params.size(), aliasElement_.routingRules.rules.size(), aliasParsedRules_.size()); + return aliasElement_.functionVersionUrn; + } + + if (MatchRules(params, aliasParsedRules_, aliasElement_.routingRules.ruleLogic)) { + return aliasElement_.routingRules.grayVersion; + } + + return aliasElement_.functionVersionUrn; +} + +std::string AliasEntry::GetFuncVersionUrn() +{ + absl::ReaderMutexLock lock(&mu_); + return lb_->Next("", false); +} + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/alias_routing.h b/src/libruntime/invokeadaptor/alias_routing.h new file mode 100644 index 0000000..8289998 --- /dev/null +++ b/src/libruntime/invokeadaptor/alias_routing.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include "absl/synchronization/mutex.h" + +#include "alias_element.h" +#include "load_balancer.h" + +namespace YR { +namespace Libruntime { + +class AliasEntry; +class AliasRouting { +public: + AliasRouting() = default; + virtual ~AliasRouting() = default; + void UpdateAliasInfo(const std::vector &aliasInfo); + std::string GetFuncVersionUrnWithParams(const std::string &aliasUrn, + const std::unordered_map ¶ms); + std::string ParseAlias(const std::string &functionId, const std::unordered_map ¶ms); + bool CheckAlias(const std::string &functionId); + +private: + absl::Mutex mu_; + std::unordered_map> aliasInfo_ ABSL_GUARDED_BY(mu_); +}; + +struct AliasRule { + std::string left; + std::string op; + std::string right; +}; + +class AliasEntry { +public: + AliasEntry(const AliasElement &alias, std::unique_ptr &&lb); + virtual ~AliasEntry() = default; + + void Update(const AliasElement &alias); + bool IsRoutingByRules(void); + std::string GetFuncVersionUrnByRule(const std::unordered_map ¶ms); + std::string GetFuncVersionUrn(); + +private: + absl::Mutex mu_; + AliasElement aliasElement_; + std::vector aliasParsedRules_; + std::unique_ptr lb_; +}; +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/faas_instance_manager.cpp b/src/libruntime/invokeadaptor/faas_instance_manager.cpp new file mode 100644 index 0000000..0f7744c --- /dev/null +++ b/src/libruntime/invokeadaptor/faas_instance_manager.cpp @@ -0,0 +1,1042 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "faas_instance_manager.h" +#include +#include "src/dto/acquire_options.h" + +namespace YR { +namespace Libruntime { +const std::string INSTANCE_REQUIREMENT_RESOURKEY = "resourcesData"; +const std::string INSTANCE_REQUIREMENT_INSKEY = "designateInstanceID"; +const std::string INSTANCE_REQUIREMENT_POOLLABELKEY = "poolLabel"; +const std::string INSTANCE_REQUIREMENT_INVOKE_LABEL = "instanceInvokeLabel"; +const std::string META_PREFIX = "0000000000000000"; +const std::string INSTANCETRAFFICLIMITED = "instanceTrafficLimited"; +const std::string HOST_NAME_ENV = "HOSTNAME"; +const std::string POD_NAME_ENV = "POD_NAME"; +const std::string INSTANCE_CALLER_POD_NAME = "instanceCallerPodName"; +const std::string INSTANCE_SESSION_CONFIG = "instanceSessionConfig"; +const int64_t BEFOR_RETAIN_TIME = 30; // millisecond +const int64_t RETAIN_TIME_RATE = 2; +const int RELEASE_DELAYTIME = 100; // millisecond +const int FAAS_INS_REQ_SUCCESS_CODE = 6030; +const int FAAS_USER_CODE_THRESHOLD = 10000; +const int ERR_FUNC_META_NOT_FOUND = 150424; +const int ERR_TARGET_INSTANCE_NOT_FOUND = 150425; +const int ERR_RENEW_INSTANCE_LEASE_NOT_FOUND = 150463; +const int ERR_NO_INSTANCE_AVAILABLE = 150431; +const int INSTANCE_LABEL_NOT_FOUND = 150444; +const int ERR_USER_FUNC_ENTRY_NOT_FOUND = 4001; +const int WISECLOUD_NUWA_INVOKE_ERROR = 161915; +const int64_t BATCH_RENEW_LEASE_NUM = 1000; +const std::string SPEC_ERR_MSG = "invoke request timeout"; + +const static std::set userErrSet = + std::set{ERR_FUNC_META_NOT_FOUND, INSTANCE_LABEL_NOT_FOUND, ERR_USER_FUNC_ENTRY_NOT_FOUND, + WISECLOUD_NUWA_INVOKE_ERROR, ERR_TARGET_INSTANCE_NOT_FOUND}; + +void FaasInsManager::UpdateConfig(int recycleTimeMs) {} + +void FaasInsManager::RecordRequest(const RequestResource &resource, const std::shared_ptr spec, + bool isInstanceNormal) +{ + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; + } + std::shared_ptr faasInfo; + { + absl::ReaderMutexLock lock(&info->mtx); + if (info->instanceInfos.find(spec->invokeLeaseId) == info->instanceInfos.end()) { + return; + } + faasInfo = info->instanceInfos[spec->invokeLeaseId]; + } + absl::WriterMutexLock insLk(&faasInfo->mtx); + // Except for user option errors, ins will only be immediately deleted in case of instance abnormal + if (!isInstanceNormal) { + faasInfo->reporter->RecordAbnormal(); + faasInfo->claimTime = 0; + return; + } + long long claimTime = 0; + if (faasInfo->claimTime != 0) { + claimTime = faasInfo->claimTime; + auto currentMillTime = std::chrono::time_point_cast(std::chrono::steady_clock::now()) + .time_since_epoch() + .count(); + faasInfo->claimTime = currentMillTime; + faasInfo->reporter->RecordRequest(currentMillTime - claimTime); + } +} + +void FaasInsManager::DelRelatedLease(const std::string &insId, const RequestResource &resource) +{ + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; + } + absl::WriterMutexLock lock(&info->mtx); + auto &insInfos = info->instanceInfos; + for (auto it = insInfos.begin(); it != insInfos.end();) { + std::string leaseId; + std::string instanceId; + { + absl::ReaderMutexLock lock(&it->second->mtx); + leaseId = it->second->leaseId; + instanceId = it->second->instanceId; + } + YRLOG_DEBUG("lease id is {}, instance id is {}, input ins id is {}", leaseId, instanceId, insId); + if (instanceId != insId) { + ++it; + continue; + } + ReleaseInstanceAsync(it->second); + it = insInfos.erase(it); + info->avaliableInstanceInfos.erase(leaseId); + DecreaseCreatedInstanceNum(); + } +} + +void FaasInsManager::ScaleCancel(const RequestResource &resource, size_t reqNum, bool cleanAll) +{ + return; +} + +void FaasInsManager::ScaleDown(const std::shared_ptr spec, bool isInstanceNormal) +{ + auto resource = GetRequestResource(spec); + this->RecordRequest(resource, spec, isInstanceNormal); + auto id = spec->invokeLeaseId; + auto instanceId = spec->invokeInstanceId; + if (isInstanceNormal) { + return StartReleaseTimer(resource, id); + } + this->DelRelatedLease(instanceId, resource); + EraseResourceInfoMap(resource); +} + +bool FaasInsManager::ScaleUp(std::shared_ptr invokeSpec, size_t reqNum) +{ + AddRequestResourceInfo(invokeSpec); + return this->AcquireFaasIns(invokeSpec, reqNum); +} + +bool FaasInsManager::AcquireFaasIns(const std::shared_ptr invokeSpec, size_t reqNum) +{ + auto resource = GetRequestResource(invokeSpec); + auto needCreatePair = NeedCreateNewIns(resource, reqNum, false); + if (!needCreatePair.first) { + YRLOG_INFO("No need create new ins for reqid: {}, trace id: {}", invokeSpec->requestId, invokeSpec->traceId); + return false; + } + SendAcquireReq(invokeSpec, needCreatePair.second); + return true; +} + +void FaasInsManager::SendAcquireReq(const std::shared_ptr invokeSpec, size_t delayTime) +{ + auto resource = GetRequestResource(invokeSpec); + this->AddCreatingInsInfo(resource, std::make_shared("", 0)); + auto weak_this = weak_from_this(); + tw_->CreateTimer(delayTime * MILLISECOND_UNIT, 1, [weak_this, invokeSpec]() { + if (auto this_ptr = weak_this.lock(); this_ptr) { + this_ptr->AcquireInstanceAsync(invokeSpec); + } + }); +} + +void FaasInsManager::ProcessInstanceInfo(std::shared_ptr spec, const InstanceAllocation &inst) +{ + auto resource = GetRequestResource(spec); + auto faasInsInfo = std::make_shared(); + faasInsInfo->instanceId = inst.instanceId; + faasInsInfo->leaseId = inst.leaseId; + faasInsInfo->idleTime = 0; + faasInsInfo->unfinishReqNum = 0; + faasInsInfo->available = true; + faasInsInfo->traceId = spec->traceId; + faasInsInfo->faasInfo = FaasAllocationInfo{inst.functionId, inst.funcSig, inst.tLeaseInterval, + spec->GetSchedulerInstanceId(), spec->opts.schedulerFunctionId}; + faasInsInfo->reporter = std::make_shared(); + this->HandleFaasInsInfo(faasInsInfo, resource); +} + +std::pair GetFaasInstanceRsp(const NotifyRequest ¬ifyReq) +{ + InstanceResponse instanceResp; + if (notifyReq.smallobjects_size() == 0) { + return std::make_pair(instanceResp, + ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, + "invalid acquire notify result, req id is " + notifyReq.requestid())); + } + auto &smallObj = notifyReq.smallobjects(0); + if (smallObj.value().size() <= MetaDataLen) { + return std::make_pair(instanceResp, + ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, + "invalid acquire notify result, req id is " + notifyReq.requestid())); + } + const std::string &bufStr = smallObj.value().substr(MetaDataLen); + YRLOG_DEBUG("get acquire notify result, req id is {}, value is {}", notifyReq.requestid(), bufStr); + auto err = ConvertStringToInsResp(instanceResp, bufStr); + if (!err.OK()) { + YRLOG_ERROR("failed to convert acquire notify req to instance info, req id is {}, err msg is {}", + notifyReq.requestid(), err.Msg()); + } + return std::make_pair(instanceResp, err); +} + +std::pair GetFaasBatchInstanceRsp(const NotifyRequest ¬ifyReq) +{ + BatchInstanceResponse instanceResp; + if (notifyReq.smallobjects_size() == 0) { + return std::make_pair(instanceResp, + ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, + "invalid acquire notify result, req id is " + notifyReq.requestid())); + } + auto &smallObj = notifyReq.smallobjects(0); + if (smallObj.value().size() <= MetaDataLen) { + return std::make_pair(instanceResp, + ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, + "invalid acquire notify result, req id is " + notifyReq.requestid())); + } + const std::string &bufStr = smallObj.value().substr(MetaDataLen); + YRLOG_DEBUG("get retain notify result, req id is {}, value is {}", notifyReq.requestid(), bufStr); + ConvertStringToBatchInsResp(instanceResp, bufStr); + return std::make_pair(instanceResp, ErrorInfo()); +} + +void FaasInsManager::ProcessAsynAcquireResult(const NotifyRequest ¬ifyReq, std::shared_ptr acquireSpec, + const ErrorInfo &errInput, std::shared_ptr invokeSpec) +{ + auto err = ErrorInfo(ErrorCode(notifyReq.code()), notifyReq.message(), errInput.IsAckTimeout()); + if (errInput.IsTimeout()) { + err.SetErrorCode(ErrorCode::ERR_ACQUIRE_TIMEOUT); + err.SetErrorMsg("acquire instance timeout"); + } + if (!err.OK()) { + AcquireCallback(acquireSpec, err, InstanceResponse{}, invokeSpec); + return; + } + auto [instanceResp, errInfo] = GetFaasInstanceRsp(notifyReq); + AcquireCallback(acquireSpec, errInfo, instanceResp, invokeSpec); +} + +void FaasInsManager::ProcecssAcquireResult( + const NotifyRequest ¬ifyReq, std::shared_ptr spec, + std::shared_ptr>> acquirePromise) +{ + auto err = ErrorInfo(ErrorCode(notifyReq.code()), notifyReq.message()); + if (notifyReq.code() != common::ERR_NONE) { + acquirePromise->set_value(std::make_pair(InstanceAllocation{}, err)); + return; + } + + auto [instanceResp, errInfo] = GetFaasInstanceRsp(notifyReq); + if (!errInfo.OK()) { + acquirePromise->set_value(std::make_pair(InstanceAllocation{}, errInfo)); + return; + } + + ProcessInstanceInfo(spec, instanceResp.info); + acquirePromise->set_value(std::make_pair(instanceResp.info, err)); +} + +std::pair FaasInsManager::AcquireInstance(const std::string &stateId, + std::shared_ptr spec) +{ + YRLOG_DEBUG("start acquire instance, functon id is {}, request id is {}, state id is {}", + spec->functionMeta.functionId, spec->requestId, stateId); + auto acquireSpec = BuildAcquireRequest(spec, stateId); + auto acquirePromise = std::make_shared>>(); + auto acquireFuture = acquirePromise->get_future().share(); + auto weak_this = weak_from_this(); + this->fsClient->InvokeAsync( + acquireSpec->requestInvoke, + [weak_this, spec, acquirePromise](const NotifyRequest &req, const ErrorInfo &err) -> void { + if (auto this_ptr = weak_this.lock(); this_ptr) { + this_ptr->ProcecssAcquireResult(req, spec, acquirePromise); + } + }); + auto status = acquireFuture.wait_for(std::chrono::seconds( + spec->opts.acquireTimeout == 0 ? FAAS_DEFALUT_ACQUIRE_TIMEOUT : spec->opts.acquireTimeout)); + if (status != std::future_status::ready) { + YRLOG_ERROR("acquire instance timeout, state id is {}, req id is {}", stateId, spec->requestId); + return std::make_pair(InstanceAllocation{}, ErrorInfo(ErrorCode::ERR_ACQUIRE_TIMEOUT, ModuleCode::RUNTIME, + "acquire instance timeout, req id is " + spec->requestId + + " , state id is " + stateId)); + } + return acquireFuture.get(); +} + +void FaasInsManager::AcquireInstanceAsync(std::shared_ptr invokeSpec) +{ + if (invokeSpec->opts.acquireTimeout == 0) { + invokeSpec->opts.acquireTimeout = FAAS_DEFALUT_ACQUIRE_TIMEOUT; + } + YRLOG_DEBUG("start acquire instance async, function: {}, timeout: {}, request: {}, trace: {}", + invokeSpec->functionMeta.functionId, invokeSpec->opts.acquireTimeout, invokeSpec->requestId, + invokeSpec->traceId); + auto acquireSpec = BuildAcquireRequest(invokeSpec); + YRLOG_INFO("acquire instance to {} for {}, trace: {}, acquire req id :{}, invoke req id : {}", + acquireSpec->invokeInstanceId, invokeSpec->functionMeta.functionId, invokeSpec->traceId, + acquireSpec->requestId, invokeSpec->requestId); + auto weak_this = weak_from_this(); + this->fsClient->InvokeAsync( + acquireSpec->requestInvoke, + [weak_this, acquireSpec, invokeSpec](const NotifyRequest ¬ifyReq, const ErrorInfo &err) -> void { + if (auto this_ptr = weak_this.lock(); this_ptr) { + this_ptr->ProcessAsynAcquireResult(notifyReq, acquireSpec, err, invokeSpec); + } + }, + acquireSpec->opts.timeout); +} + +/*********************************************************************************************************************** +The acquire request can fail for the following reasons: + 1. The function system reports an error with error codes 1003 or 1007. This is considered a FaaSScheduler failure, + and the scheduler ID needs to be the next from the Hash ring. + 2. The function system reports an error with an error code other than 1003 or 1007. No action is taken. + 3. The FaaS reports an error with an error code less than 10000. This is considered a user error. + Scheduling is interrupted, and the error is throw. + 4. The FaaS reports an error with an error code greater than 10000. No action is taken. +***********************************************************************************************************************/ +void FaasInsManager::AcquireCallback(const std::shared_ptr acquireSpec, const ErrorInfo &errInfo, + const InstanceResponse &resp, const std::shared_ptr invokeSpec) +{ + YRLOG_DEBUG("finished to acquire, acquire req id: {}, invoke req id: {}, invoke trace id: {}", + acquireSpec->requestId, invokeSpec->requestId, invokeSpec->traceId); + auto resource = GetRequestResource(invokeSpec); + auto returnErr = errInfo; + + if (!errInfo.OK() || resp.errorCode != FAAS_INS_REQ_SUCCESS_CODE) { + auto code = errInfo.OK() ? resp.errorCode : errInfo.Code(); + auto msg = errInfo.OK() ? resp.errorMessage : errInfo.Msg(); + YRLOG_WARN( + "failed to acquire, acquire req id: {}, invoke trace id: {}, invoke req id: {}, invoke scheduler id: {}, " + "code: {}, msg: {}", + acquireSpec->requestId, acquireSpec->traceId, invokeSpec->requestId, acquireSpec->invokeInstanceId, code, + msg); + this->EraseCreatingInsInfo(resource, ""); + this->ChangeCreateFailNum(resource, true); + } + + if (!errInfo.OK()) { + // When the caller specifies a scheduler ID, but the ID is invalid, obtain the new scheduler ID in the hash + // ring. + if (errInfo.Code() == YR::Libruntime::ErrorCode::ERR_INSTANCE_EXITED || + errInfo.Code() == YR::Libruntime::ErrorCode::ERR_INSTANCE_NOT_FOUND || + (errInfo.Code() == YR::Libruntime::ERR_REQUEST_BETWEEN_RUNTIME_BUS && errInfo.IsAckTimeout()) || + errInfo.Code() == YR::Libruntime::ERR_FINALIZED) { + UpdateSpecSchedulerIds(invokeSpec, acquireSpec->invokeInstanceId); + auto schedulerId = + csHash->NextRetry(invokeSpec->functionMeta.functionId, invokeSpec->schedulerInfos, true); + if (schedulerId == ALL_SCHEDULER_UNAVAILABLE || schedulerId.empty()) { + returnErr = + ErrorInfo(ErrorCode::ERR_ALL_SCHEDULER_UNAVALIABLE, "all scheduler instance is unavailable"); + } else { + invokeSpec->SetSchedulerInstanceId(schedulerId); + returnErr = ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, + "send acquire to faas scheduler failed, id: " + acquireSpec->invokeInstanceId); + } + } + return scheduleInsCb(resource, returnErr, IsRemainIns(resource)); + } + if (resp.errorCode != FAAS_INS_REQ_SUCCESS_CODE) { + nlohmann::json faasMessageJson; + faasMessageJson["code"] = resp.errorCode; + faasMessageJson["message"] = resp.errorMessage; + std::string faasMessage = faasMessageJson.dump(); + if (resp.errorCode < FAAS_USER_CODE_THRESHOLD || userErrSet.find(resp.errorCode) != userErrSet.end()) { + returnErr = ErrorInfo(ErrorCode::ERR_USER_FUNCTION_EXCEPTION, faasMessage); + } else if (resp.errorCode == ERR_NO_INSTANCE_AVAILABLE) { + UpdateSpecSchedulerIds(invokeSpec, acquireSpec->invokeInstanceId); + auto schedulerId = csHash->Next(invokeSpec->functionMeta.functionId, invokeSpec->schedulerInfos, true); + if (schedulerId == ALL_SCHEDULER_UNAVAILABLE || schedulerId.empty()) { + returnErr = + ErrorInfo(ErrorCode::ERR_ALL_SCHEDULER_UNAVALIABLE, "all scheduler instance is unavailable"); + } else { + invokeSpec->SetSchedulerInstanceId(schedulerId); + returnErr = ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, faasMessage); + } + } else { + returnErr = ErrorInfo(ErrorCode::ERR_INNER_COMMUNICATION, faasMessage); + } + return scheduleInsCb(resource, returnErr, IsRemainIns(resource)); + } + + auto faasInsInfo = std::make_shared(); + faasInsInfo->instanceId = resp.info.instanceId; + faasInsInfo->leaseId = resp.info.leaseId; + faasInsInfo->idleTime = 0; + faasInsInfo->unfinishReqNum = 0; + faasInsInfo->available = true; + faasInsInfo->traceId = invokeSpec->traceId; + faasInsInfo->faasInfo = FaasAllocationInfo{resp.info.functionId, resp.info.funcSig, resp.info.tLeaseInterval, + acquireSpec->invokeInstanceId, acquireSpec->functionMeta.functionId}; + faasInsInfo->reporter = std::make_shared(); + this->HandleFaasInsInfo(faasInsInfo, resource); + if (faasInsInfo->faasInfo.tLeaseInterval > 0) { + this->StartReleaseTimer(resource, faasInsInfo->leaseId); + } + this->ChangeCreateFailNum(resource, false); + YRLOG_INFO("succeed to acquire, lease: {} req: {} trace: {}", faasInsInfo->leaseId, acquireSpec->requestId, + acquireSpec->traceId); + return scheduleInsCb(resource, returnErr, IsRemainIns(resource)); +} + +void FaasInsManager::UpdateSpecSchedulerIds(std::shared_ptr spec, const std::string &schedulerId) +{ + absl::WriterMutexLock insLk(&spec->schedulerInfos->schedulerMtx); + if (schedulerId.empty()) { + YRLOG_WARN("scheduler id for req: {} is empty, no need update, trace id is {}", spec->requestId, spec->traceId); + return; + } + bool updated = false; + for (auto &scheduler : spec->schedulerInfos->schedulerInstanceList) { + if (scheduler->InstanceID == schedulerId) { + scheduler->isAvailable = false; + scheduler->updateTime = YR::GetCurrentTimestampNs(); + updated = true; + YRLOG_WARN("success update scheduler info, update status of scheduler:{} to unavailable, trace id is {}", + schedulerId, spec->traceId); + break; + } + } + if (!updated) { + spec->schedulerInfos->schedulerInstanceList.push_back(std::make_shared( + SchedulerInstance{"", schedulerId, YR::GetCurrentTimestampNs(), false})); + YRLOG_WARN("success update scheduler info, update status of scheduler:{} to unavailable, trace id is {}", + schedulerId, spec->traceId); + } +} + +void FaasInsManager::HandleFaasInsInfo(std::shared_ptr &faasInsInfo, const RequestResource &resource) +{ + YRLOG_DEBUG("start handler fass acquire info, instance id: {}, lease id: {}", faasInsInfo->instanceId, + faasInsInfo->leaseId); + auto info = GetOrAddRequestResourceInfo(resource); + // ensure atomicity of erase creating and add instances. + // avoid creating unnecessary instances when judge NeedCreateNewIns + { + absl::WriterMutexLock lock(&info->mtx); + this->EraseCreatingInsInfoBare(info, ""); + this->AddInsInfoBare(info, faasInsInfo); + } + { + if (faasInsInfo->faasInfo.tLeaseInterval <= 0) { + return; + } + absl::WriterMutexLock lock(&this->leaseMtx); + this->globalLeases[faasInsInfo->leaseId] = resource; + } +} + +void FaasInsManager::ReleaseHandler(const RequestResource &resource, const std::string &leaseId) +{ + std::shared_ptr faasInfoDel; + bool needDel = false; + + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; + } + std::shared_ptr faasInfo; + { + absl::ReaderMutexLock lock(&info->mtx); + if (info->instanceInfos.find(leaseId) == info->instanceInfos.end()) { + YRLOG_DEBUG("lease has already release, id: {}", leaseId); + return; + } + faasInfo = info->instanceInfos[leaseId]; + } + + std::string functionId = ""; + { + absl::WriterMutexLock resInfoLock(&info->mtx); + absl::ReaderMutexLock faasInfoLock(&faasInfo->mtx); + if (faasInfo->unfinishReqNum >= 1) { + YRLOG_DEBUG("instance is not available, do not release, lease id{}", leaseId); + return; + } + info->avaliableInstanceInfos.erase(leaseId); + functionId = faasInfo->faasInfo.functionId; + } + faasInfoDel = faasInfo; + needDel = true; + { + absl::WriterMutexLock lock(&info->mtx); + info->instanceInfos.erase(leaseId); + } + DecreaseCreatedInstanceNum(); + EraseResourceInfoMap(resource, REQUEST_RESOURCE_USE_COUNT); + + if (needDel) { + YRLOG_DEBUG("start send release instance req, function id {}, lease id {}", functionId, leaseId); + this->ReleaseInstanceAsync(faasInfoDel); + } +} + +void FaasInsManager::StartReleaseTimer(const RequestResource &resource, const std::string &leaseId) +{ + YRLOG_DEBUG("start release timer, lease id: {}", leaseId); + auto weakThis = weak_from_this(); + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; + } + int tLeaseInterval; + { + absl::ReaderMutexLock lock(&info->mtx); + tLeaseInterval = info->tLeaseInterval; + } + if (tLeaseInterval <= 0) { + return ReleaseHandler(resource, leaseId); + } + tw_->CreateTimer(RELEASE_DELAYTIME, 1, [weakThis, this, leaseId, resource]() { + auto thisPtr = weakThis.lock(); + if (!thisPtr) { + return; + } + ReleaseHandler(resource, leaseId); + }); +} + +void FaasInsManager::ReleaseInstanceAsync(const std::shared_ptr &ins) +{ + std::string leaseId; + { + absl::ReaderMutexLock lock(&ins->mtx); + if (ins->faasInfo.tLeaseInterval <= 0) { + return; + } + leaseId = ins->leaseId; + YRLOG_DEBUG("start aysnc release instance, leaseId id is {}", leaseId); + } + auto req = BuildReleaseReq(ins); + auto messageSpec = std::make_shared(std::move(req)); + this->fsClient->InvokeAsync(messageSpec, [leaseId](const NotifyRequest ¬ifyReq, const ErrorInfo &err) -> void { + if (notifyReq.code() != common::ERR_NONE) { + YRLOG_ERROR("release instance async failed, leaseId id is {}, req id is {}, msg is {}", leaseId, + notifyReq.requestid(), notifyReq.message()); + } else { + YRLOG_DEBUG("release instance async success, lease id is {}", leaseId); + } + }); +} + +ErrorInfo FaasInsManager::ReleaseInstance(const std::string &leaseId, const std::string &stateId, bool abnormal, + std::shared_ptr spec) +{ + ErrorInfo err; + if (!stateId.empty()) { + auto insInfo = std::make_shared(); + insInfo->faasInfo.schedulerInstanceID = spec->GetSchedulerInstanceId(); + insInfo->faasInfo.schedulerFunctionID = spec->opts.schedulerFunctionId; + insInfo->reporter = std::make_shared(); + insInfo->stateId = stateId; + + err = SendReleaseInstanceReq(insInfo, spec); + return err; + } + + auto resource = GetRequestResource(spec); + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return err; + } + absl::WriterMutexLock lock(&info->mtx); + auto &insInfos = info->instanceInfos; + for (auto it = insInfos.begin(); it != insInfos.end();) { + YRLOG_DEBUG("state id is {}, lease id is {}, instance id is {}", it->second->stateId, it->second->leaseId, + it->second->instanceId); + if (it->second->leaseId == leaseId) { + err = SendReleaseInstanceReq(it->second, spec); + it = insInfos.erase(it); + info->avaliableInstanceInfos.erase(leaseId); + DecreaseCreatedInstanceNum(); + break; + } else { + ++it; + } + } + return err; +} + +ErrorInfo FaasInsManager::SendReleaseInstanceReq(const std::shared_ptr &ins, + std::shared_ptr spec) +{ + std::string leaseId; + std::string stateId; + { + absl::ReaderMutexLock lock(&ins->mtx); + leaseId = ins->leaseId; + stateId = ins->stateId; + } + auto req = BuildReleaseReq(ins); + auto messageSpec = std::make_shared(std::move(req)); + auto releasePromise = std::make_shared>(); + auto releaseFuture = releasePromise->get_future().share(); + YRLOG_DEBUG("start release instance, leaseId id is {}, state id is {}, req id is {}", leaseId, stateId, + req.requestid()); + this->fsClient->InvokeAsync( + messageSpec, [releasePromise, stateId](const NotifyRequest ¬ifyReq, const ErrorInfo &err) -> void { + if (notifyReq.code() != common::ERR_NONE) { + YRLOG_ERROR("receive release notify, req id is {}, state id is {} code is {}, msg is {}", + notifyReq.requestid(), stateId, fmt::underlying(notifyReq.code()), notifyReq.message()); + } + releasePromise->set_value(ErrorInfo(ErrorCode(notifyReq.code()), notifyReq.message())); + }); + auto status = releaseFuture.wait_for( + std::chrono::milliseconds(spec->opts.timeout == 0 ? FAAS_DEFALUT_INVOKE_TIMEOUT : spec->opts.timeout)); + if (status != std::future_status::ready) { + YRLOG_ERROR("release instance timeout, state id is {}, lease id is {}, req id is {}", stateId, leaseId, + spec->requestId); + return ErrorInfo(ErrorCode::ERR_INIT_CONNECTION_FAILED, + "release instance timeout, req id is " + spec->requestId + " , state id is " + stateId); + } + return releaseFuture.get(); +} + +InvokeRequest FaasInsManager::BuildReleaseReq(const std::shared_ptr &ins) +{ + absl::ReaderMutexLock lock(&ins->mtx); + YRLOG_DEBUG("start build release instance req, instance id {}, lease id {}", ins->instanceId, ins->leaseId); + InvokeRequest req; + auto requestId = YR::utility::IDGenerator::GenRequestId(); + req.set_requestid(requestId); + req.set_traceid(ins->traceId); + req.set_instanceid(ins->faasInfo.schedulerInstanceID); + req.set_function(ins->faasInfo.schedulerFunctionID); + req.add_returnobjectids(YR::utility::IDGenerator::GenObjectId(requestId, 0)); + std::string acquireOps; + if (ins->stateId.empty()) { + acquireOps = "release#" + ins->leaseId; + } else { + acquireOps = "release#" + ins->stateId; + } + YRLOG_DEBUG("requst id is {}, instance id is {}, function is {}, acquire ops is {}", req.requestid(), + req.instanceid(), req.function(), acquireOps); + Arg *pbArg = req.add_args(); + pbArg->set_type(Arg_ArgType::Arg_ArgType_VALUE); + pbArg->set_value(META_PREFIX + acquireOps); + auto report = ins->reporter->Report(true); + nlohmann::json instanceReport; + instanceReport["ProcReqNum"] = report.procReqNum; + instanceReport["AvgProcTime"] = report.avgProcTime; + instanceReport["MaxProcTime"] = report.maxProcTime; + instanceReport["IsAbnormal"] = report.isAbnormal; + YRLOG_DEBUG( + "lease id is {}, instance id is {}, procReqNum is {}, avgProcTime is {}, maxProcTime is {}, isAbnormal: {}", + ins->leaseId, ins->instanceId, report.procReqNum, report.avgProcTime, report.maxProcTime, report.isAbnormal); + pbArg = req.add_args(); + pbArg->set_type(Arg_ArgType::Arg_ArgType_VALUE); + pbArg->set_value(META_PREFIX + instanceReport.dump()); + pbArg = req.add_args(); + pbArg->set_type(Arg_ArgType::Arg_ArgType_VALUE); + pbArg->set_value(META_PREFIX + ins->traceId); + return req; +} + +std::shared_ptr FaasInsManager::BuildAcquireRequest(std::shared_ptr invokeSpec, + const std::string &stateId) +{ + auto acquireSpec = std::make_shared(); + acquireSpec->jobId = invokeSpec->jobId; + acquireSpec->requestId = YR::utility::IDGenerator::GenRequestId(); + auto dsObjectId = YR::utility::IDGenerator::GenObjectId(acquireSpec->requestId, 0); + acquireSpec->returnIds.push_back(DataObject(dsObjectId)); + acquireSpec->traceId = invokeSpec->traceId; + + acquireSpec->functionMeta.functionId = + invokeSpec->opts.schedulerFunctionId.empty() ? this->schedulerFuncKey : invokeSpec->opts.schedulerFunctionId; + acquireSpec->functionMeta.apiType = libruntime::ApiType::Posix; + std::string schedulerId = invokeSpec->GetSchedulerInstanceId(); + if (!schedulerId.empty()) { + acquireSpec->invokeInstanceId = schedulerId; + } else { + schedulerId = csHash->NextRetry(invokeSpec->functionMeta.functionId); + acquireSpec->invokeInstanceId = schedulerId; + } + + std::string acquireFunction; + if (stateId == "") { + acquireFunction = "acquire#" + invokeSpec->functionMeta.functionId; + } else { + acquireFunction = "acquire#" + invokeSpec->functionMeta.functionId + ";" + stateId; + } + InvokeArg acquireFunctionArg; + acquireFunctionArg.dataObj = std::make_shared(0, acquireFunction.size()); + acquireFunctionArg.dataObj->data->MemoryCopy(acquireFunction.data(), acquireFunction.size()); + acquireSpec->invokeArgs.push_back(acquireFunctionArg); + + std::unordered_map> instanceRequirement; + if (invokeSpec->designatedInstanceID.empty()) { + std::unordered_map resourceSpecs; + resourceSpecs.insert({"CPU", invokeSpec->opts.cpu}); + resourceSpecs.insert({"Memory", invokeSpec->opts.memory}); + for (auto &resource : invokeSpec->opts.customResources) { + resourceSpecs.insert({resource.first, static_cast(resource.second)}); + } + nlohmann::json resourceMapJson = resourceSpecs; + std::string resourceMapString = resourceMapJson.dump(); + YRLOG_DEBUG("instance resource : {}", resourceMapString); + std::vector vec(resourceMapString.begin(), resourceMapString.end()); + instanceRequirement[INSTANCE_REQUIREMENT_RESOURKEY] = vec; + } else { + std::vector vec(invokeSpec->designatedInstanceID.begin(), + invokeSpec->designatedInstanceID.end()); + instanceRequirement[INSTANCE_REQUIREMENT_INSKEY] = vec; + } + if (invokeSpec->opts.trafficLimited) { + std::string strTrue = "true"; + std::vector vec(strTrue.begin(), strTrue.end()); + instanceRequirement[INSTANCETRAFFICLIMITED] = vec; + } + std::string podName = + Config::Instance().POD_NAME().empty() ? Config::Instance().HOSTNAME() : Config::Instance().POD_NAME(); + if (!podName.empty()) { + std::vector vec(podName.begin(), podName.end()); + instanceRequirement[INSTANCE_CALLER_POD_NAME] = vec; + } + + if (!invokeSpec->functionMeta.poolLabel.empty()) { + std::vector vec(invokeSpec->functionMeta.poolLabel.begin(), + invokeSpec->functionMeta.poolLabel.end()); + instanceRequirement[INSTANCE_REQUIREMENT_POOLLABELKEY] = vec; + } + if (!invokeSpec->opts.invokeLabels.empty()) { + nlohmann::json invokeLabelsJson = invokeSpec->opts.invokeLabels; + std::string invokeLabelsString = invokeLabelsJson.dump(); + std::vector vec(invokeLabelsString.begin(), invokeLabelsString.end()); + instanceRequirement[INSTANCE_REQUIREMENT_INVOKE_LABEL] = vec; + } + if (invokeSpec->opts.instanceSession) { + nlohmann::json instanceSessionJson; + instanceSessionJson["sessionID"] = invokeSpec->opts.instanceSession->sessionID; + instanceSessionJson["sessionTTL"] = invokeSpec->opts.instanceSession->sessionTTL; + instanceSessionJson["concurrency"] = invokeSpec->opts.instanceSession->concurrency; + std::string instanceSessionJsonString = instanceSessionJson.dump(); + std::vector vec(instanceSessionJsonString.begin(), instanceSessionJsonString.end()); + instanceRequirement[INSTANCE_SESSION_CONFIG] = vec; + } + nlohmann::json instanceRequirementJson = instanceRequirement; + std::string instanceRequirementStr = instanceRequirementJson.dump(); + InvokeArg instanceRequirementArg; + instanceRequirementArg.dataObj = std::make_shared(0, instanceRequirementStr.size()); + instanceRequirementArg.dataObj->data->MemoryCopy(instanceRequirementStr.data(), instanceRequirementStr.size()); + acquireSpec->invokeArgs.push_back(instanceRequirementArg); + + InvokeArg traceIdArg; + traceIdArg.dataObj = std::make_shared(0, invokeSpec->traceId.size()); + traceIdArg.dataObj->data->MemoryCopy(invokeSpec->traceId.data(), invokeSpec->traceId.size()); + acquireSpec->invokeArgs.push_back(traceIdArg); + + acquireSpec->opts.timeout = invokeSpec->opts.acquireTimeout; + + acquireSpec->BuildInstanceInvokeRequest(*libRuntimeConfig); + + return acquireSpec; +} + +void FaasInsManager::AddInsInfoBare(std::shared_ptr info, + std::shared_ptr &faasInsInfo) +{ + bool needIncreaseInsNum = true; + if (auto it = info->instanceInfos.find(faasInsInfo->leaseId); it != info->instanceInfos.end()) { + YRLOG_DEBUG("already exist, instace id {}, current unfinish req num {}, lease id {}", faasInsInfo->instanceId, + faasInsInfo->unfinishReqNum, faasInsInfo->leaseId); + needIncreaseInsNum = false; + } + info->instanceInfos[faasInsInfo->leaseId] = faasInsInfo; + info->tLeaseInterval = faasInsInfo->faasInfo.tLeaseInterval; + info->avaliableInstanceInfos[faasInsInfo->leaseId] = faasInsInfo; + if (needIncreaseInsNum) { + IncreaseCreatedInstanceNum(); + } +} + +void FaasInsManager::StartBatchRenewTimer() +{ + absl::ReaderMutexLock lock(&this->leaseMtx); + if (!this->leaseTimer) { + YRLOG_DEBUG("start batch renew timer"); + auto timer = CreateBatchRenewTimer(); + this->leaseTimer = timer; + } +} + +void FaasInsManager::ProcessBatchRenewResult(const NotifyRequest ¬ifyReq, const std::string &functionId, + const ErrorInfo &err, std::vector leaseIds) +{ + bool failedFlag = false; + if (!err.OK()) { + YRLOG_WARN("failed to batch renew instance, req id {}, {}", notifyReq.requestid(), err.CodeAndMsg()); + failedFlag = true; + } + if (notifyReq.code() != common::ERR_NONE) { + YRLOG_WARN("failed to batch renew instance, req id {}, code {}, msg {}", notifyReq.requestid(), + fmt::underlying(notifyReq.code()), notifyReq.message()); + if (notifyReq.code() == common::ERR_INSTANCE_NOT_FOUND || notifyReq.code() == common::ERR_INSTANCE_EXITED) { + ChangeInstanceSchedulerId(functionId, leaseIds); + return; + } + failedFlag = true; + } + auto [instanceResp, errInfo] = GetFaasBatchInstanceRsp(notifyReq); + if (!errInfo.OK()) { + YRLOG_WARN("failed to batch renew instance, req id {}, {}", notifyReq.requestid(), errInfo.CodeAndMsg()); + failedFlag = true; + } + std::vector decreaseLeaseIds; + std::vector reacquireLeaseIds; + if (failedFlag) { + decreaseLeaseIds.resize(leaseIds.size()); + std::copy(leaseIds.begin(), leaseIds.end(), decreaseLeaseIds.begin()); + } else { + for (const auto &[leaseId, allocErrInfo] : instanceResp.instanceAllocFailed) { + YRLOG_WARN("failed to renew instance, lease id {}, req id {}, errCode: {}, errMsg: {}", leaseId, + notifyReq.requestid(), allocErrInfo.errorCode, allocErrInfo.errorMessage); + if (allocErrInfo.errorCode == ERR_RENEW_INSTANCE_LEASE_NOT_FOUND) { + reacquireLeaseIds.push_back(leaseId); + } else { + decreaseLeaseIds.push_back(leaseId); + } + } + } + absl::WriterMutexLock lock(&this->leaseMtx); + for (std::string &leaseId : reacquireLeaseIds) { + if (this->globalLeases.find(leaseId) != this->globalLeases.end()) { + auto info = GetRequestResourceInfo(this->globalLeases.find(leaseId)->second); + if (info == nullptr) { + continue; + } + absl::ReaderMutexLock infoLock(&info->mtx); + if (info->instanceInfos.find(leaseId) != info->instanceInfos.end()) { + auto insInfo = info->instanceInfos[leaseId]; + absl::WriterMutexLock instanceLock(&insInfo->mtx); + insInfo->needReacquire = true; + } + } + } + for (std::string &leaseId : decreaseLeaseIds) { + if (this->globalLeases.find(leaseId) != this->globalLeases.end()) { + DelInsInfo(leaseId, this->globalLeases.find(leaseId)->second); + this->globalLeases.erase(leaseId); + } + } + if (!failedFlag) { + this->tLeaseInterval = instanceResp.tLeaseInterval; + } +} + +void FaasInsManager::ChangeInstanceSchedulerId(const std::string &functionId, std::vector &leaseIds) +{ + auto otherSchedulerInstanceId = this->csHash->Next(functionId, true); + for (std::string &leaseId : leaseIds) { + auto it = this->globalLeases.find(leaseId); + if (it != this->globalLeases.end()) { + auto instanceInfo = GetInstanceInfo(it->second, it->first); + if (instanceInfo == nullptr) { + it = this->globalLeases.erase(it); + continue; + } + { + YRLOG_WARN("failed to renew instance {}, scheduler {} change to {}", leaseId, + instanceInfo->faasInfo.schedulerInstanceID, otherSchedulerInstanceId); + absl::WriterMutexLock insInfoLock(&instanceInfo->mtx); + instanceInfo->faasInfo.schedulerInstanceID = otherSchedulerInstanceId; + } + } + } +} + +std::vector BuildReacquireInstanceData(const RequestResource &resource) +{ + nlohmann::json reacquireData; + std::unordered_map resourceSpecs; + resourceSpecs.insert({"CPU", resource.opts.cpu}); + resourceSpecs.insert({"Memory", resource.opts.memory}); + for (auto &cr : resource.opts.customResources) { + resourceSpecs.insert({cr.first, static_cast(cr.second)}); + } + nlohmann::json resourceMapJson = resourceSpecs; + std::string resourceMapString = resourceMapJson.dump(); + std::vector resourceKeyVec(resourceMapString.begin(), resourceMapString.end()); + reacquireData[INSTANCE_REQUIREMENT_RESOURKEY] = resourceKeyVec; + if (resource.opts.instanceSession) { + nlohmann::json instanceSessionJson; + instanceSessionJson["sessionID"] = resource.opts.instanceSession->sessionID; + instanceSessionJson["sessionTTL"] = resource.opts.instanceSession->sessionTTL; + instanceSessionJson["concurrency"] = resource.opts.instanceSession->concurrency; + std::string instanceSessionJsonString = instanceSessionJson.dump(); + std::vector vec(instanceSessionJsonString.begin(), instanceSessionJsonString.end()); + reacquireData["instanceSessionConfig"] = vec; + } + if (!resource.opts.invokeLabels.empty()) { + nlohmann::json invokeLabelsJson = resource.opts.invokeLabels; + std::string invokeLabelsString = invokeLabelsJson.dump(); + std::vector vec(invokeLabelsString.begin(), invokeLabelsString.end()); + reacquireData[INSTANCE_REQUIREMENT_INVOKE_LABEL] = vec; + } + if (!resource.functionMeta.poolLabel.empty()) { + std::vector vec(resource.functionMeta.poolLabel.begin(), resource.functionMeta.poolLabel.end()); + reacquireData[INSTANCE_REQUIREMENT_POOLLABELKEY] = vec; + } + std::string reacquireDataStr = reacquireData.dump(); + std::vector reacquireDataVec(reacquireDataStr.begin(), reacquireDataStr.end()); + return reacquireDataVec; +} + +void FaasInsManager::BatchRenewHandler() +{ + std::unordered_map instanceReportMap; + std::unordered_map targetNamesMap; + std::unordered_map, FaasInfoForBatchRenewFn> leaseIdsMap; + { + absl::ReaderMutexLock lock(&this->leaseMtx); + if (this->leaseTimer) { + auto weakPtr = weak_from_this(); + tw_->ExecuteByTimer(this->leaseTimer, this->tLeaseInterval / RETAIN_TIME_RATE, [weakPtr]() { + if (auto thisPtr = weakPtr.lock(); thisPtr) { + thisPtr->BatchRenewHandler(); + } + }); + } else { + YRLOG_DEBUG("batch renew is cancelled"); + return; + } + auto it = this->globalLeases.begin(); + while (it != this->globalLeases.end()) { + auto instanceInfo = GetInstanceInfo(it->second, it->first); + if (instanceInfo == nullptr) { + it = this->globalLeases.erase(it); + continue; + } + int i = 0; + FaasInfoForBatchRenew faasInfo = FaasInfoForBatchRenew(instanceInfo->faasInfo, i); + while (leaseIdsMap[faasInfo].size() >= BATCH_RENEW_LEASE_NUM) { + i++; + faasInfo = FaasInfoForBatchRenew(instanceInfo->faasInfo, i); + } + auto report = instanceInfo->reporter->Report(false); + nlohmann::json instanceReport; + instanceReport["procReqNum"] = report.procReqNum; + instanceReport["avgProcTime"] = report.avgProcTime; + instanceReport["maxProcTime"] = report.maxProcTime; + instanceReport["isAbnormal"] = report.isAbnormal; + if (instanceInfo->needReacquire) { + instanceReport["reacquireData"] = BuildReacquireInstanceData(it->second); + instanceReport["functionKey"] = instanceInfo->faasInfo.functionId; + absl::WriterMutexLock instanceLock(&instanceInfo->mtx); + instanceInfo->needReacquire = false; + } + nlohmann::json &reportMap = instanceReportMap[faasInfo]; + reportMap[it->first] = instanceReport; + std::string &targetNames = targetNamesMap[faasInfo]; + if (targetNames.empty()) { + targetNames = it->first; + } else { + targetNames += "," + it->first; + } + std::vector &leaseIds = leaseIdsMap[faasInfo]; + leaseIds.push_back(it->first); + it++; + } + } + for (const auto &[faasInfo, targetNames] : targetNamesMap) { + InvokeRequest req; + req.set_requestid(YR::utility::IDGenerator::GenRequestId()); + std::string traceId = YR::utility::IDGenerator::GenTraceId(); + req.set_traceid(traceId); + req.set_instanceid(faasInfo.schedulerInstanceID); + req.set_function(faasInfo.schedulerFunctionID); + req.add_returnobjectids("obj-" + YR::utility::IDGenerator::GenObjectId()); + + auto acquireOps = "batchRetain#" + targetNames; + Arg *pbArg = req.add_args(); + pbArg->set_type(Arg_ArgType::Arg_ArgType_VALUE); + pbArg->set_value(META_PREFIX + acquireOps); + pbArg = req.add_args(); + pbArg->set_type(Arg_ArgType::Arg_ArgType_VALUE); + pbArg->set_value(META_PREFIX + instanceReportMap[faasInfo].dump()); + pbArg = req.add_args(); + pbArg->set_type(Arg_ArgType::Arg_ArgType_VALUE); + pbArg->set_value(META_PREFIX + traceId); + YRLOG_DEBUG("batch renew to {} {}, batchIndex: {}, req: {}, lease:{}", faasInfo.schedulerFunctionID, + faasInfo.schedulerInstanceID, faasInfo.batchIndex, req.requestid(), targetNames); + auto messageSpec = std::make_shared(std::move(req)); + auto weak_this = weak_from_this(); + std::vector &leaseIds = leaseIdsMap[faasInfo]; + auto functionId = faasInfo.functionId; + this->fsClient->InvokeAsync( + messageSpec, + [weak_this, leaseIds, functionId](const NotifyRequest ¬ifyReq, const ErrorInfo &err) -> void { + if (auto this_ptr = weak_this.lock(); this_ptr) { + this_ptr->ProcessBatchRenewResult(notifyReq, functionId, err, leaseIds); + } + }); + } +} + +std::shared_ptr FaasInsManager::CreateBatchRenewTimer() +{ + auto weakPtr = weak_from_this(); + return tw_->CreateTimer(this->tLeaseInterval / RETAIN_TIME_RATE, 1, [weakPtr]() { + if (auto thisPtr = weakPtr.lock(); thisPtr) { + thisPtr->BatchRenewHandler(); + } + }); +} + +void FaasInsManager::UpdateSchedulerInfo(const std::string &schedulerFuncKey, + const std::vector &schedulerInstanceList) +{ + YRLOG_INFO("recv update scheduler info"); + this->csHash->ResetAll(schedulerInstanceList); + absl::WriterMutexLock lock(&this->schedulerFuncKeyMtx); + if (this->schedulerFuncKey.empty()) { + this->schedulerFuncKey = schedulerFuncKey; + } +} + +std::string FaasInsManager::GetSchedulerKey() +{ + absl::ReaderMutexLock lock(&this->schedulerFuncKeyMtx); + return this->schedulerFuncKey; +} + +std::string FaasInsManager::GetFunctionIdWithLabel(const RequestResource &resource) +{ + std::ostringstream functionKey; + functionKey << resource.functionMeta.functionId << "-"; + functionKey << resource.opts.cpu << "-" << resource.opts.memory; + if (!resource.opts.invokeLabels.empty()) { + functionKey << "-{ "; + for (auto it = resource.opts.invokeLabels.begin(); it != resource.opts.invokeLabels.end();) { + functionKey << it->first << ":" << it->second; + if (++it != resource.opts.invokeLabels.end()) { + functionKey << ", "; + } + } + functionKey << " }"; + } + return functionKey.str(); +} + +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/invokeadaptor/faas_instance_manager.h b/src/libruntime/invokeadaptor/faas_instance_manager.h new file mode 100644 index 0000000..57f5ebd --- /dev/null +++ b/src/libruntime/invokeadaptor/faas_instance_manager.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "src/libruntime/invokeadaptor/instance_manager.h" + +namespace YR { +namespace Libruntime { + +std::pair GetFaasInstanceRsp(const NotifyRequest ¬ifyReq); + +class FaasInsManager : public InsManager, public std::enable_shared_from_this { +public: + FaasInsManager() = default; + FaasInsManager(ScheduleInsCallback cb, std::shared_ptr client, std::shared_ptr store, + std::shared_ptr reqMgr, std::shared_ptr config) + : InsManager(cb, client, store, reqMgr, config) + { + std::shared_ptr lb(LoadBalancer::Factory(LoadBalancerType::ConsistantRoundRobin)); + this->csHash = std::make_shared(lb); + } + bool ScaleUp(std::shared_ptr spec, size_t reqNum) override; + void ScaleDown(const std::shared_ptr spec, bool isInstanceNormal = false) override; + void ScaleCancel(const RequestResource &resource, size_t reqNum, bool cleanAll = false) override; + void StartBatchRenewTimer() override; + virtual void UpdateConfig(int recycleTimeMs) override; + void UpdateSchedulerInfo(const std::string &schedulerFuncKey, + const std::vector &schedulerInstanceList) override; + void RecordRequest(const RequestResource &resource, const std::shared_ptr spec, bool isInstanceNormal); + void DelRelatedLease(const std::string &insId, const RequestResource &resource); + std::pair AcquireInstance(const std::string &stateId, + std::shared_ptr spec); + void ProcessInstanceInfo(std::shared_ptr spec, const InstanceAllocation &inst); + ErrorInfo ReleaseInstance(const std::string &leaseId, const std::string &stateId, bool abnormal, + std::shared_ptr spec); + void ProcessBatchRenewResult(const NotifyRequest ¬ifyReq, const std::string &functionId, const ErrorInfo &err, + std::vector leaseIds); + // add instance info without locking; the caller must ensure locking. + void AddInsInfoBare(std::shared_ptr info, std::shared_ptr &faasInsInfo); + void UpdateSpecSchedulerIds(std::shared_ptr spec, const std::string &schedulerId); + void AcquireCallback(const std::shared_ptr acquireSpec, const ErrorInfo &errInfo, + const InstanceResponse &resp, const std::shared_ptr invokeSpec); + +private: + std::shared_ptr CreateBatchRenewTimer(); + void RenewHandler(std::shared_ptr insInfo); + void BatchRenewHandler(); + void SendAcquireReq(const std::shared_ptr spec, size_t delayTime); + std::shared_ptr BuildAcquireRequest(std::shared_ptr invokeSpec, + const std::string &stateId = ""); + InvokeRequest BuildReleaseReq(const std::shared_ptr &ins); + bool AcquireFaasIns(const std::shared_ptr spec, size_t reqNum); + void HandleFaasInsInfo(std::shared_ptr &faasInsInfo, const RequestResource &resource); + void ProcessAsynAcquireResult(const NotifyRequest ¬ifyReq, std::shared_ptr acquireSpec, + const ErrorInfo &errInput, std::shared_ptr invokeSpec); + void ProcecssAcquireResult(const NotifyRequest &req, std::shared_ptr spec, + std::shared_ptr>> acquirePromise); + void AcquireInstanceAsync(std::shared_ptr spec); + void StartReleaseTimer(const RequestResource &resource, const std::string &leaseId); + void ReleaseHandler(const RequestResource &resource, const std::string &leaseId); + void ReleaseInstanceAsync(const std::shared_ptr &ins); + ErrorInfo SendReleaseInstanceReq(const std::shared_ptr &ins, std::shared_ptr spec); + std::string GetSchedulerKey(); + std::string GetFunctionIdWithLabel(const RequestResource &resource); + void ChangeInstanceSchedulerId(const std::string &functionId, std::vector &leaseIds); + mutable absl::Mutex schedulerFuncKeyMtx; + std::string schedulerFuncKey; +}; +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/instance_manager.cpp b/src/libruntime/invokeadaptor/instance_manager.cpp index e404e53..ae7f341 100644 --- a/src/libruntime/invokeadaptor/instance_manager.cpp +++ b/src/libruntime/invokeadaptor/instance_manager.cpp @@ -18,6 +18,60 @@ namespace YR { namespace Libruntime { +using json = nlohmann::json; +ErrorInfo ConvertStringToInsResp(InstanceResponse &resp, const std::string &bufStr) +{ + /* + bufstr pattern + {"funcKey":"","funcSig":"","instanceID":"","threadID":"","leaseInterval":0, + "errorCode":150428,"errorMessage":"","schedulerTime":30} + */ + ErrorInfo err; + try { + json jsonData = json::parse(bufStr); + InstanceAllocation info; + info.functionId = jsonData["funcKey"].get(); + info.funcSig = jsonData["funcSig"].get(); + info.instanceId = jsonData["instanceID"].get(); + info.leaseId = jsonData["threadID"].get(); + info.tLeaseInterval = jsonData["leaseInterval"].get(); + resp.info = info; + resp.errorCode = jsonData["errorCode"].get(); + resp.errorMessage = jsonData["errorMessage"].get(); + resp.schedulerTime = jsonData["schedulerTime"].get(); + } catch (const std::exception &e) { + err.SetErrCodeAndMsg(ErrorCode::ERR_INNER_SYSTEM_ERROR, ModuleCode::RUNTIME, + "failed parse acquire notify request"); + } + return err; +} + +void ConvertStringToBatchInsResp(BatchInstanceResponse &resp, const std::string &bufStr) +{ + try { + json jsonData = json::parse(bufStr); + resp.tLeaseInterval = jsonData["leaseInterval"].get(); + resp.schedulerTime = jsonData["schedulerTime"].get(); + auto allocSucceed = jsonData["instanceAllocSucceed"].get>(); + auto allocFailed = jsonData["instanceAllocFailed"].get>(); + for (const auto &item : allocSucceed) { + InstanceAllocation info; + info.functionId = item.second["funcKey"].get(); + info.funcSig = item.second["funcSig"].get(); + info.instanceId = item.second["instanceID"].get(); + info.leaseId = item.second["threadID"].get(); + resp.instanceAllocSucceed[item.first] = info; + } + for (const auto &item : allocFailed) { + InstanceAllocationFailedRsp info; + info.errorCode = item.second["errorCode"].get(); + info.errorMessage = item.second["errorMessage"].get(); + resp.instanceAllocFailed[item.first] = info; + } + } catch (const std::exception &e) { + YRLOG_WARN("failed to convert renew or release notify req to instance info, err msg is {}", e.what()); + } +} size_t GetDelayTime(size_t failedCnt) { @@ -51,6 +105,19 @@ void CancelScaleDownTimer(std::shared_ptr insInfo) } } +std::shared_ptr InsManager::GetOrAddRequestResourceInfo(const RequestResource &resource) +{ + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + absl::WriterMutexLock lock(&insMtx); + if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { + requestResourceInfoMap[resource] = std::make_shared(); + } + info = requestResourceInfoMap[resource]; + } + return info; +} + std::shared_ptr InsManager::GetRequestResourceInfo(const RequestResource &resource) { std::shared_ptr info; @@ -60,19 +127,15 @@ std::shared_ptr InsManager::GetRequestResourceInfo(const Re info = requestResourceInfoMap[resource]; } } - if (info == nullptr) { - absl::WriterMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - requestResourceInfoMap[resource] = std::make_shared(); - } - info = requestResourceInfoMap[resource]; - } return info; } std::shared_ptr InsManager::GetInstanceInfo(const RequestResource &resource, const std::string &insId) { auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return nullptr; + } std::shared_ptr insInfo; { absl::ReaderMutexLock lock(&info->mtx); @@ -87,7 +150,7 @@ std::shared_ptr InsManager::GetInstanceInfo(const RequestResource void InsManager::AddRequestResourceInfo(std::shared_ptr spec) { auto resource = GetRequestResource(spec); - GetRequestResourceInfo(resource); + GetOrAddRequestResourceInfo(resource); } std::pair InsManager::ScheduleInsWithDevice(const RequestResource &resource, @@ -123,19 +186,14 @@ std::pair InsManager::ScheduleInsWithDevice(const Requ } // return std::pair -std::pair InsManager::ScheduleIns(const RequestResource &resource) +std::pair InsManager::GetAvailableIns(const RequestResource &resource) { if (!runFlag) { return std::make_pair("", ""); } - std::shared_ptr resourceInfo; - { - absl::ReaderMutexLock lock(&insMtx); - auto pair = requestResourceInfoMap.find(resource); - if (pair == requestResourceInfoMap.end()) { - return std::make_pair("", ""); - } - resourceInfo = pair->second; + auto resourceInfo = GetRequestResourceInfo(resource); + if (resourceInfo == nullptr) { + return std::make_pair("", ""); } if (!resource.opts.device.name.empty()) { return ScheduleInsWithDevice(resource, resourceInfo); @@ -165,13 +223,9 @@ std::pair> InsManager::NeedCancelCreatingIns(cons size_t reqNum, bool cleanAll) { auto cancelIns = std::vector(); - std::shared_ptr info; - { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return std::make_pair(false, cancelIns); - } - info = requestResourceInfoMap[resource]; + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return std::make_pair(false, cancelIns); } { absl::ReaderMutexLock lock(&info->mtx); @@ -226,15 +280,12 @@ std::pair> InsManager::NeedCancelCreatingIns(cons return std::make_pair(true, cancelIns); } -std::pair InsManager::NeedCreateNewIns(const RequestResource &resource, size_t reqNum) +std::pair InsManager::NeedCreateNewIns(const RequestResource &resource, size_t reqNum, + bool considerWithFailNum) { - std::shared_ptr resourceInsInfo; - { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return std::make_pair(false, 0); - } - resourceInsInfo = requestResourceInfoMap[resource]; + auto resourceInsInfo = GetRequestResourceInfo(resource); + if (resourceInsInfo == nullptr) { + return std::make_pair(false, 0); } size_t creatingInsNum = 0; int createFailNum = 0; @@ -253,22 +304,24 @@ std::pair InsManager::NeedCreateNewIns(const RequestResource &reso YRLOG_DEBUG("ins info: required({}) creating({}) available({}) total({}) current resource({}) ", requiredInsNum, creatingInsNum, availableInsNum, totalInsNum, currentResourceInsNum); - if (createFailNum > 0 && creatingInsNum > 0) { - YRLOG_INFO("current createfailnum is {}, creating num is {}, no need create more ins", createFailNum, - creatingInsNum); + if (considerWithFailNum && (createFailNum > 0 && creatingInsNum > 0)) { + YRLOG_INFO("current createfailnum is {}, creating num is {}, no need create more ins for function: {}", + createFailNum, creatingInsNum, resource.functionMeta.functionId); return std::make_pair(false, 0); } if (requiredInsNum <= static_cast(creatingInsNum) + availableInsNum) { - YRLOG_INFO("required ({}) < creating ({}) + available ({}); no need to create more", requiredInsNum, - creatingInsNum, availableInsNum); + YRLOG_INFO("required ({}) < creating ({}) + available ({}); no need to create more for function: {}", + requiredInsNum, creatingInsNum, availableInsNum, resource.functionMeta.functionId); return std::make_pair(false, 0); } - if (GetCreatingInstanceNum() >= libRuntimeConfig->maxConcurrencyCreateNum) { + auto totalCreatingNum = GetCreatingInstanceNum(); + if (totalCreatingNum >= libRuntimeConfig->maxConcurrencyCreateNum) { YRLOG_INFO( - "total creating ins num : {} is more than max concurrency create num : {}, should not create more ins", - creatingInsNum, libRuntimeConfig->maxConcurrencyCreateNum); + "total creating ins num : {} is more than max concurrency create num : {}, should not create more ins for " + "function: {}", + totalCreatingNum, libRuntimeConfig->maxConcurrencyCreateNum, resource.functionMeta.functionId); return std::make_pair(false, 0); } @@ -279,9 +332,11 @@ std::pair InsManager::NeedCreateNewIns(const RequestResource &reso : false; if (exceedMaxTaskInsNum) { YRLOG_INFO( - "creating ins num : {} is more than max concurrency create num: {} or resource ins num limit: {}, should " - "not create more ins", - creatingInsNum, libRuntimeConfig->maxTaskInstanceNum, resource.opts.maxInstances); + "total ins num : {} is more than max concurrency create num: {} or {} is more than resource ins num limit: " + "{}, should " + "not create more ins for function: {}", + totalInsNum, libRuntimeConfig->maxTaskInstanceNum, currentResourceInsNum, resource.opts.maxInstances, + resource.functionMeta.functionId); return std::make_pair(false, 0); } return std::make_pair(true, GetDelayTime(createFailNum)); @@ -297,14 +352,7 @@ int InsManager::GetRequiredInstanceNum(int reqNum, int concurrency) const void InsManager::AddCreatingInsInfo(const RequestResource &resource, std::shared_ptr insInfo) { - std::shared_ptr resourceInfo; - { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return; - } - resourceInfo = requestResourceInfoMap[resource]; - } + auto resourceInfo = GetOrAddRequestResourceInfo(resource); IncreaseCreatingInstanceNum(); absl::WriterMutexLock lock(&resourceInfo->mtx); auto &creatingInfo = resourceInfo->creatingIns; @@ -321,16 +369,18 @@ void InsManager::AddCreatingInsInfo(const RequestResource &resource, std::shared bool InsManager::EraseCreatingInsInfo(const RequestResource &resource, const std::string &instanceId, bool createSuccess) { - std::shared_ptr info; + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return false; + } + bool isErase; { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return false; - } - info = requestResourceInfoMap[resource]; + absl::WriterMutexLock lock(&info->mtx); + isErase = EraseCreatingInsInfoBare(info, instanceId, createSuccess); } - absl::WriterMutexLock lock(&info->mtx); - return EraseCreatingInsInfoBare(info, instanceId, createSuccess); + info.reset(); + EraseResourceInfoMap(resource); + return isErase; } bool InsManager::EraseCreatingInsInfoBare(std::shared_ptr info, const std::string &instanceId, @@ -380,13 +430,9 @@ bool InsManager::EraseCreatingInsInfoBare(std::shared_ptr i void InsManager::ChangeCreateFailNum(const RequestResource &resource, bool isIncreaseOps) { - std::shared_ptr info; - { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return; - } - info = requestResourceInfoMap[resource]; + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; } absl::WriterMutexLock lock(&info->mtx); if (isIncreaseOps) { @@ -398,20 +444,59 @@ void InsManager::ChangeCreateFailNum(const RequestResource &resource, bool isInc void InsManager::DelInsInfo(const std::string &insId, const RequestResource &resource) { - std::shared_ptr info; + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; + } { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return; + absl::WriterMutexLock infoLock(&info->mtx); + YRLOG_DEBUG("start delete ins info, ins id is {}, current totol ins num is {}", insId, GetCreatedInstanceNum()); + if (info->instanceInfos.find(insId) != info->instanceInfos.end()) { + auto &insInfo = info->instanceInfos[insId]; + CancelScaleDownTimer(insInfo); + DelInsInfoBare(insId, info); } - info = requestResourceInfoMap[resource]; } + info.reset(); + EraseResourceInfoMap(resource); +} + +void InsManager::EraseResourceInfoMap() +{ + auto resources = GetScheduleResources(); + for (const auto &resource : resources) { + EraseResourceInfoMap(resource); + } +} + +std::vector InsManager::GetScheduleResources() +{ + absl::ReaderMutexLock lock(&insMtx); + std::vector resources; + for (const auto &pair : requestResourceInfoMap) { + resources.push_back(pair.first); + } + return resources; +} + +void InsManager::EraseResourceInfoMap(const RequestResource &resource, int currentCount) +{ + // yield avoid frequent cleaning requestResourceInfoMap + std::this_thread::yield(); + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; + } + absl::WriterMutexLock lock(&insMtx); absl::WriterMutexLock infoLock(&info->mtx); - YRLOG_DEBUG("start delete ins info, ins id is {}, current totol ins num is {}", insId, GetCreatedInstanceNum()); - if (info->instanceInfos.find(insId) != info->instanceInfos.end()) { - auto &insInfo = info->instanceInfos[insId]; - CancelScaleDownTimer(insInfo); - DelInsInfoBare(insId, info); + if (info->instanceInfos.empty() && info->creatingIns.empty()) { + // If the reference count is greater than the specified value, + // it indicates that other threads are operating on RequestResourceInfo, + // and the deletion operation should not be performed. + if (info.use_count() <= currentCount) { + YRLOG_DEBUG("remove resource info"); + requestResourceInfoMap.erase(resource); + } } } @@ -427,23 +512,21 @@ void InsManager::DelInsInfoBare(const std::string &insId, std::shared_ptr spec, bool isInstanceNormal) { auto resource = GetRequestResource(spec); - std::shared_ptr info; - { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return; - } - info = requestResourceInfoMap[resource]; + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; } absl::WriterMutexLock lock(&info->mtx); - auto id = spec->invokeInstanceId; + auto id = spec->functionMeta.IsServiceApiType() ? spec->invokeLeaseId : spec->invokeInstanceId; if (info->instanceInfos.find(id) == info->instanceInfos.end()) { return; } absl::WriterMutexLock insLock(&info->instanceInfos[id]->mtx); auto insInfo = info->instanceInfos[id]; insInfo->unfinishReqNum--; - if (insInfo->unfinishReqNum < static_cast(resource.concurrency) && isInstanceNormal) { + if (!insInfo->leaseId.empty() && insInfo->faasInfo.tLeaseInterval <= 0) { + insInfo->available = false; + } else if (insInfo->unfinishReqNum < static_cast(resource.concurrency) && isInstanceNormal) { insInfo->available = true; if (info->avaliableInstanceInfos.find(id) == info->avaliableInstanceInfos.end()) { info->avaliableInstanceInfos.emplace(id, insInfo); @@ -455,13 +538,21 @@ void InsManager::DecreaseUnfinishReqNum(const std::shared_ptr spec, void InsManager::Stop() { runFlag = false; + { + absl::WriterMutexLock lock(&leaseMtx); + if (leaseTimer) { + leaseTimer->cancel(); + leaseTimer.reset(); + } + } + if (tw_ != nullptr) { + tw_->Stop(); + } + tw_.reset(); absl::WriterMutexLock lock(&insMtx); for (auto &pair : requestResourceInfoMap) { auto &requestResourceInfo = pair.second; absl::WriterMutexLock infoLock(&requestResourceInfo->mtx); - for (auto &insInfo : requestResourceInfo->instanceInfos) { - CancelScaleDownTimer(insInfo.second); - } requestResourceInfo->instanceInfos.clear(); requestResourceInfo->avaliableInstanceInfos.clear(); } @@ -497,13 +588,9 @@ std::vector InsManager::GetCreatingInsIds() bool InsManager::IsRemainIns(const RequestResource &resource) { - std::shared_ptr resourceInsInfo; - { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return false; - } - resourceInsInfo = requestResourceInfoMap[resource]; + auto resourceInsInfo = GetRequestResourceInfo(resource); + if (resourceInsInfo == nullptr) { + return false; } absl::ReaderMutexLock lock(&resourceInsInfo->mtx); if (resourceInsInfo->creatingIns.size() > 0 || resourceInsInfo->instanceInfos.size() > 0) { @@ -516,5 +603,22 @@ void InsManager::SetDeleleInsCallback(const DeleteInsCallback &cb) { deleteInsCallback_ = cb; } + +void InsManager::UpdateSchdulerInfo(const std::string &schedulerName, const std::string &schedulerId, + const std::string &option) +{ + UpdateSchedulerOption opt = stringToOption(option); + switch (opt) { + case UpdateSchedulerOption::ADD: + this->csHash->Add(schedulerName, schedulerId); + break; + case UpdateSchedulerOption::REMOVE: + this->csHash->Remove(schedulerName); + break; + case UpdateSchedulerOption::UNKNOWN: + YRLOG_ERROR("option: {} is not correct, do nothing about scheudler: {}", option, schedulerName); + break; + } +} } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/invokeadaptor/instance_manager.h b/src/libruntime/invokeadaptor/instance_manager.h index d0aba09..6fe2175 100644 --- a/src/libruntime/invokeadaptor/instance_manager.h +++ b/src/libruntime/invokeadaptor/instance_manager.h @@ -19,6 +19,7 @@ #include "src/libruntime/err_type.h" #include "src/libruntime/fsclient/fs_client.h" #include "src/libruntime/invoke_spec.h" +#include "src/libruntime/invokeadaptor/limiter_consistant_hash.h" #include "src/libruntime/invokeadaptor/request_manager.h" #include "src/libruntime/invokeadaptor/request_queue.h" #include "src/libruntime/objectstore/memory_store.h" @@ -29,6 +30,8 @@ namespace YR { namespace Libruntime { const int DEFAULT_INVOKE_DURATION = 1000; const int DEFAULT_CREATE_DURATION = 1000; // ms +const int64_t DEFAULT_LEASE_INTERVAL = 500; +const int REQUEST_RESOURCE_USE_COUNT = 3; using YR::utility::TimeMeasurement; using ScheduleInsCallback = std::function; @@ -37,40 +40,44 @@ void CancelScaleDownTimer(std::shared_ptr insInfo); using DeleteInsCallback = std::function; class InsManager { public: - enum class UpdateSchedulerOption { - ADD, - REMOVE, - UNKNOWN - }; - UpdateSchedulerOption stringToOption(const std::string& s) + enum class UpdateSchedulerOption { ADD, REMOVE, UNKNOWN }; + UpdateSchedulerOption stringToOption(const std::string &s) { static const std::unordered_map mapping = { - {"ADD", UpdateSchedulerOption::ADD}, - {"REMOVE", UpdateSchedulerOption::REMOVE} - }; + {"ADD", UpdateSchedulerOption::ADD}, {"REMOVE", UpdateSchedulerOption::REMOVE}}; auto it = mapping.find(s); return (it != mapping.end()) ? it->second : UpdateSchedulerOption::UNKNOWN; } - InsManager() = default; + InsManager() + { + tw_ = std::make_shared(); + }; InsManager(ScheduleInsCallback cb, std::shared_ptr client, std::shared_ptr store, std::shared_ptr reqMgr, std::shared_ptr config) : scheduleInsCb(cb), fsClient(client), memoryStore(store), requestManager(reqMgr), libRuntimeConfig(config) { + tw_ = std::make_shared(); } ~InsManager() = default; void DelInsInfo(const std::string &insId, const RequestResource &resource); - std::pair ScheduleIns(const RequestResource &resource); + std::pair GetAvailableIns(const RequestResource &resource); virtual bool ScaleUp(std::shared_ptr spec, size_t reqNum) = 0; virtual void ScaleDown(const std::shared_ptr spec, bool isInstanceNormal = false) = 0; virtual void ScaleCancel(const RequestResource &resource, size_t reqNum, bool cleanAll = false) = 0; - virtual void StartRenewTimer(const RequestResource &resource, const std::string &insId) = 0; + virtual void StartBatchRenewTimer() = 0; virtual void UpdateConfig(int recycleTimeMs) = 0; + virtual void UpdateSchedulerInfo(const std::string &schedulerFuncKey, + const std::vector &schedulerInfoList) + { + } void DecreaseUnfinishReqNum(const std::shared_ptr spec, bool isInstanceNormal = true); void Stop(); std::vector GetInstanceIds(); std::vector GetCreatingInsIds(); void SetDeleleInsCallback(const DeleteInsCallback &cb); - + void UpdateSchdulerInfo(const std::string &schedulerName, const std::string &schedulerId, + const std::string &option); + void EraseResourceInfoMap(); protected: int recycleTimeMs; ScheduleInsCallback scheduleInsCb; @@ -83,7 +90,8 @@ protected: void DelInsInfoBare(const std::string &insId, std::shared_ptr info); std::pair> NeedCancelCreatingIns(const RequestResource &resource, size_t reqNum, bool cleanAll); - std::pair NeedCreateNewIns(const RequestResource &resource, size_t reqNum); + std::pair NeedCreateNewIns(const RequestResource &resource, size_t reqNum, + bool considerWithFailNum = true); void AddCreatingInsInfo(const RequestResource &resource, std::shared_ptr insInfo); bool EraseCreatingInsInfo(const RequestResource &resource, const std::string &instanceId, bool createSuccess = true); @@ -94,6 +102,7 @@ protected: int GetRequiredInstanceNum(int reqNum, int concurrency) const; bool IsRemainIns(const RequestResource &resource); std::shared_ptr GetRequestResourceInfo(const RequestResource &resource); + std::shared_ptr GetOrAddRequestResourceInfo(const RequestResource &resource); std::shared_ptr GetInstanceInfo(const RequestResource &resource, const std::string &insId); int GetCreatedInstanceNum() { @@ -115,6 +124,11 @@ protected: absl::WriterMutexLock lock(&createInstanceNumMutex); totalCreatedInstanceNum_--; } + void DecreaseCreatedInstanceNums(int instanceNum) + { + absl::WriterMutexLock lock(&createInstanceNumMutex); + totalCreatedInstanceNum_ -= instanceNum; + } void DecreaseCreatingInstanceNum() { absl::WriterMutexLock lock(&createInstanceNumMutex); @@ -130,9 +144,11 @@ protected: absl::WriterMutexLock lock(&createInstanceNumMutex); totalCreatingInstanceNum_++; } - + std::vector GetScheduleResources(); + void EraseResourceInfoMap(const RequestResource &resource, int currentCount = 2); mutable absl::Mutex insMtx; std::atomic runFlag{true}; + std::shared_ptr csHash; int totalCreatedInstanceNum_{0} ABSL_GUARDED_BY(createInstanceNumMutex); int totalCreatingInstanceNum_{0} ABSL_GUARDED_BY(createInstanceNumMutex); mutable absl::Mutex createInstanceNumMutex; @@ -142,6 +158,12 @@ protected: mutable absl::Mutex invokeCostMtx; std::unordered_map, HashFn> requestResourceInfoMap ABSL_GUARDED_BY(insMtx); + std::unordered_map globalLeases ABSL_GUARDED_BY(leaseMtx); + int64_t tLeaseInterval{DEFAULT_LEASE_INTERVAL} ABSL_GUARDED_BY(leaseMtx); + std::shared_ptr leaseTimer ABSL_GUARDED_BY(leaseMtx); + mutable absl::Mutex leaseMtx; + mutable absl::Mutex schedulerMtx; + std::shared_ptr tw_; private: std::pair ScheduleInsWithDevice(const RequestResource &resource, @@ -150,6 +172,8 @@ private: RequestResource GetRequestResource(std::shared_ptr spec); size_t GetDelayTime(size_t failedCnt); +ErrorInfo ConvertStringToInsResp(InstanceResponse &resp, const std::string &bufStr); +void ConvertStringToBatchInsResp(BatchInstanceResponse &resp, const std::string &bufStr); } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/invokeadaptor/invoke_adaptor.cpp b/src/libruntime/invokeadaptor/invoke_adaptor.cpp index b487da8..b7e2e52 100644 --- a/src/libruntime/invokeadaptor/invoke_adaptor.cpp +++ b/src/libruntime/invokeadaptor/invoke_adaptor.cpp @@ -77,8 +77,9 @@ libruntime::FunctionMeta convertFuncMetaToProto(std::shared_ptr spec meta.set_language(spec->functionMeta.languageType); meta.set_modulename(spec->functionMeta.moduleName); meta.set_signature(spec->functionMeta.signature); - meta.set_name(spec->functionMeta.name.value_or("")); - meta.set_ns(spec->functionMeta.ns.value_or("")); + meta.set_needorder(spec->opts.needOrder); + meta.set_name(spec->functionMeta.name); + meta.set_ns(spec->functionMeta.ns); return meta; } @@ -142,58 +143,59 @@ bool ParseMetaData(const CallRequest &request, bool isPosix, libruntime::MetaDat return true; } -bool ParseRequest(const CallRequest &request, std::vector> &rawArgs, - std::shared_ptr memStore, bool isPosix) -{ - int argStart = isPosix ? METADATA_INDEX : ARGS_INDEX; - for (int i = argStart; i < request.args_size(); i++) { - std::shared_ptr rawArg; - if (request.args(i).type() == common::Arg::OBJECT_REF) { - // get arg by argid from ds - std::string argId = std::string(request.args(i).value().data(), request.args(i).value().size()); - auto [err, argBuf] = memStore->GetBuffer(argId, NO_TIMEOUT); - if (err.Code() != ErrorCode::ERR_OK || argBuf == nullptr) { - YRLOG_ERROR("Get arg {} from DS err! Code {}, MCode {}, info {}.", argId, err.Code(), err.MCode(), - err.Msg()); - return false; - } - rawArg = std::make_shared(argId, argBuf); - } else { - auto argBuf = - std::make_shared(request.args(i).value().data(), request.args(i).value().size()); - rawArg = std::make_shared("", argBuf); - } - rawArgs.emplace_back(rawArg); - } - return true; -} - -InvokeAdaptor::InvokeAdaptor(std::shared_ptr config, - std::shared_ptr dependencyResolver, - std::shared_ptr &fsClient, std::shared_ptr memStore, - std::shared_ptr rtCtx, FinalizeCallback cb, - std::shared_ptr waitManager, - std::shared_ptr invokeOrderMgr, - std::shared_ptr clientsMgr, std::shared_ptr metricsAdaptor) +InvokeAdaptor::InvokeAdaptor( + std::shared_ptr config, std::shared_ptr dependencyResolver, + std::shared_ptr &fsClient, std::shared_ptr memStore, std::shared_ptr rtCtx, + FinalizeCallback cb, std::shared_ptr waitManager, + std::shared_ptr invokeOrderMgr, std::shared_ptr clientsMgr, + std::shared_ptr inputMetricsAdaptor, std::shared_ptr genIdMapper, + std::shared_ptr generatorReceiver, std::shared_ptr generatorNotifier, + std::shared_ptr downgrade) : dependencyResolver(dependencyResolver), runtimeContext(rtCtx), finalizeCb_(cb), invokeOrderMgr(invokeOrderMgr), clientsMgr(clientsMgr), - metricsAdaptor(metricsAdaptor) + metricsAdaptor(inputMetricsAdaptor), + map_(genIdMapper) { + ar = std::make_shared(); this->fsClient = fsClient; this->librtConfig = config; this->memStore = memStore; this->requestManager = std::make_shared(); - this->taskSubmitter = - std::make_shared(config, memStore, fsClient, requestManager, - std::bind(&InvokeAdaptor::KillAsync, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3)); + this->taskSubmitter = std::make_shared( + config, memStore, fsClient, requestManager, + std::bind(&InvokeAdaptor::KillAsync, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3), + ar, inputMetricsAdaptor, downgrade); this->groupManager = std::make_shared(); this->waitingObjectManager = waitManager; + this->generatorReceiver_ = generatorReceiver; + this->generatorNotifier_ = generatorNotifier; this->functionMasterClient_ = std::make_shared(); this->functionMasterClient_->SetSubscribeActiveMasterCb(std::bind(&InvokeAdaptor::SubscribeActiveMaster, this)); + + // for debug instance, a built-in breakpoint is set before executing init call, so that the user can set their + // own breakpoint before executing any user code. However, there's no guarantee that init call response is sent + // to proxy, since it's handled by TryDirectWrite in an async threadpool. If proxy receives no init call + // response, it repeats until failure. So we did a dirty trick here. Hopefully after 1 second, the response has + // been sent to proxy + this->setDebugBreakpoint_ = []() { + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_INTERVAL_BEFORE_TRACEPOINT_MS)); + // SIGSTOP requires two "c" in gdb to continue process execution + std::raise(SIGSTOP); + }; +} + +void InvokeAdaptor::CheckAndSetDebugBreakpoint(const std::shared_ptr &req) +{ + const auto &createOpts = req->Immutable().createoptions(); + if (createOpts.find(std::string(DEBUG_CONFIG_KEY)) != createOpts.end()) { + if (setDebugBreakpoint_ != nullptr) { + YRLOG_INFO("debug instance is stopped, waiting for remote debug connection"); + setDebugBreakpoint_(); + } + } } void InvokeAdaptor::SetRGroupManager(std::shared_ptr rGroupManager) @@ -214,7 +216,8 @@ void InvokeAdaptor::InitHandler(const std::shared_ptr &req) callResult.set_instanceid(req->Immutable().senderid()); auto callResultCallback = [](const CallResultAck &resp) { if (resp.code() != common::ERR_NONE) { - YRLOG_WARN("failed to send CallResult, code: {}, message: {}", resp.code(), resp.message()); + YRLOG_WARN("failed to send CallResult, code: {}, message: {}", fmt::underlying(resp.code()), + resp.message()); } }; @@ -242,6 +245,11 @@ void InvokeAdaptor::InitHandler(const std::shared_ptr &req) this->fiberPool_ = std::make_shared(FIBER_STACK_SIZE, YR::Libruntime::Config::Instance().YR_ASYNCIO_MAX_CONCURRENCY()); } + YRLOG_DEBUG("enable metrics is {}, api type is {}", Config::Instance().ENABLE_METRICS(), + fmt::underlying(this->librtConfig->selfApiType)); + if (Config::Instance().ENABLE_METRICS() && !isPosix) { + InitMetricsAdaptor(metaData.config().enablemetrics()); + } if (this->librtConfig->selfApiType != libruntime::ApiType::Posix) { auto res = InitCall(req->Immutable(), metaData); if (res.code() != common::ERR_NONE) { @@ -251,6 +259,7 @@ void InvokeAdaptor::InitHandler(const std::shared_ptr &req) } librtConfig->InitFunctionGroupRunningInfo(runningInfo); } + CheckAndSetDebugBreakpoint(req); librtConfig->funcMeta = metaData.functionmeta(); librtConfig->funcMeta.set_needorder(librtConfig->needOrder); YRLOG_DEBUG("update instance function meta, req id is {}, value is {}", req->Immutable().requestid(), @@ -264,7 +273,8 @@ void InvokeAdaptor::InitHandler(const std::shared_ptr &req) result->Mutable() = std::move(res); fsClient->ReturnCallResult(result, true, [](const CallResultAck &resp) { if (resp.code() != common::ERR_NONE) { - YRLOG_WARN("failed to send CallResult, code: {}, message: {}", resp.code(), resp.message()); + YRLOG_WARN("failed to send CallResult, code: {}, message: {}", fmt::underlying(resp.code()), + resp.message()); } }); }, @@ -279,7 +289,8 @@ void InvokeAdaptor::CallHandler(const std::shared_ptr &req) callResult.set_instanceid(req->Immutable().senderid()); auto callResultCallback = [](const CallResultAck &resp) { if (resp.code() != common::ERR_NONE) { - YRLOG_WARN("failed to send CallResult, code: {}, message: {}", resp.code(), resp.message()); + YRLOG_WARN("failed to send CallResult, code: {}, message: {}", fmt::underlying(resp.code()), + resp.message()); } }; if (!this->execMgr) { @@ -294,11 +305,11 @@ void InvokeAdaptor::CallHandler(const std::shared_ptr &req) libruntime::MetaData metaData; bool isPosix = this->librtConfig->selfApiType == libruntime::ApiType::Posix; if (!ParseMetaData(req->Immutable(), isPosix, metaData)) { + callResult.set_code(common::ERR_INNER_SYSTEM_ERROR); std::string errMsg = "Invalid request, requestid:" + req->Immutable().requestid() + ", traceid:" + req->Immutable().traceid() + ", senderid:" + req->Immutable().senderid() + ", function:" + req->Immutable().function(); callResult.set_message(errMsg); - callResult.set_code(common::ERR_INNER_SYSTEM_ERROR); fsClient->ReturnCallResult(result, true, callResultCallback); return; } @@ -319,25 +330,30 @@ void InvokeAdaptor::CallHandler(const std::shared_ptr &req) this->fsClient->ReturnCallResult(result, false, callResultCallback); return; } + this->CreateCallTimer(req->Immutable().requestid(), req->Immutable().senderid(), + this->librtConfig->invokeTimeoutSec); this->execMgr->Handle( metaData.invocationmeta(), [this, req, metaData]() { std::function handler = [this, req, metaData]() { + threadLocalTraceId = req->Immutable().traceid(); auto startTime = std::chrono::high_resolution_clock::now(); std::vector objectsInDs; auto res = Call(req->Immutable(), metaData, librtConfig->libruntimeOptions, objectsInDs); auto endTime = std::chrono::high_resolution_clock::now(); auto durationCast = std::chrono::duration_cast(endTime - startTime).count(); - YRLOG_INFO("funcname: {}, call elapsed time: {}ms, requestid: {}, traceid: {}", + YRLOG_INFO("func name: {}, call elapsed time: {}ms, request id: {}, trace id: {}", metaData.functionmeta().functionname(), durationCast, req->Immutable().requestid(), req->Immutable().traceid()); + this->EraseCallTimer(req->Immutable().requestid()); ReportMetrics(req->Immutable().requestid(), req->Immutable().traceid(), durationCast); auto result = std::make_shared(); result->Mutable() = std::move(res); result->existObjInDs = !objectsInDs.empty(); fsClient->ReturnCallResult(result, false, [this, objectsInDs](const CallResultAck &resp) { if (resp.code() != common::ERR_NONE) { - YRLOG_WARN("failed to send CallResult, code: {}, message: {}", resp.code(), resp.message()); + YRLOG_WARN("failed to send CallResult, code: {}, message: {}", fmt::underlying(resp.code()), + resp.message()); } this->memStore->DecreGlobalReference(objectsInDs); return; @@ -461,7 +477,9 @@ std::pair InvokeAdaptor::Init(RuntimeContext &runtimeCon if (librtConfig->libruntimeOptions.healthCheckCallback) { handlers.heartbeat = std::bind(&InvokeAdaptor::HeartbeatHandler, this, _1); } - this->librtConfig->enableServerMode = true; + if (Config::Instance().ENABLE_SERVER_MODE()) { + this->librtConfig->enableServerMode = Config::Instance().ENABLE_SERVER_MODE(); + } YRLOG_DEBUG("when start fsclient isDriver {}, enableServerMode {}", this->librtConfig->isDriver, this->librtConfig->enableServerMode); // If this process is pulled up by function system, server listening address is specified by runtime-manager; @@ -498,6 +516,48 @@ std::pair InvokeAdaptor::Init(RuntimeContext &runtimeCon return std::make_pair("", err); } +bool ParseFaasController(const CallRequest &request, std::vector> &rawArgs, bool isPosix) +{ + int argStart = isPosix ? METADATA_INDEX : ARGS_INDEX; + for (int i = argStart; i < request.args_size(); i++) { + auto rawArg = std::make_shared(); + auto argBuf = + std::make_shared(request.args(i).value().data(), request.args(i).value().size()); + rawArg->SetDataBuf(argBuf); + rawArgs.emplace_back(rawArg); + } + return true; +} + +bool InvokeAdaptor::ParseRequest(const CallRequest &request, std::vector> &rawArgs, + bool isPosix) +{ + int argStart = isPosix ? METADATA_INDEX : ARGS_INDEX; + for (int i = argStart; i < request.args_size(); i++) { + std::shared_ptr rawArg; + if (request.args(i).type() == common::Arg::OBJECT_REF) { + if (setTenantIdCb_) { + setTenantIdCb_(); + } + // get arg by argid from ds + std::string argId = std::string(request.args(i).value().data(), request.args(i).value().size()); + auto [err, argBuf] = memStore->GetBuffer(argId, NO_TIMEOUT); + if (err.Code() != ErrorCode::ERR_OK || argBuf == nullptr) { + YRLOG_ERROR("Get arg {} from DS err! Code {}, MCode {}, info {}.", argId, fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); + return false; + } + rawArg = std::make_shared(argId, argBuf); + } else { + auto argBuf = + std::make_shared(request.args(i).value().data(), request.args(i).value().size()); + rawArg = std::make_shared("", argBuf); + } + rawArgs.emplace_back(rawArg); + } + return true; +} + CallResult InvokeAdaptor::Call(const CallRequest &req, const libruntime::MetaData &metaData, const LibruntimeOptions &options, std::vector &objectsInDs) { @@ -508,10 +568,14 @@ CallResult InvokeAdaptor::Call(const CallRequest &req, const libruntime::MetaDat std::vector> rawArgs; bool isPosix = this->librtConfig->selfApiType == libruntime::ApiType::Posix; bool returnByMsg = req.returnobjectids_size() == 0; - if (setTenantIdCb_) { - setTenantIdCb_(); + bool isFaasController = req.requestid().find("faascontroller") != std::string::npos; + bool ok; + if (isFaasController) { + YRLOG_DEBUG("receive request about faas controller {}", req.requestid()); + ok = ParseFaasController(req, rawArgs, isPosix); + } else { + ok = ParseRequest(req, rawArgs, isPosix); } - bool ok = ParseRequest(req, rawArgs, memStore, isPosix); if (!ok) { callResult.set_code(common::ERR_NONE); callResult.set_message(ErrMsgMap.at(common::ERR_NONE)); @@ -519,6 +583,11 @@ CallResult InvokeAdaptor::Call(const CallRequest &req, const libruntime::MetaDat } std::string genId; + if (metaData.functionmeta().isgenerator() && req.returnobjectids_size() > 0) { + generatorNotifier_->Initialize(); + genId = req.returnobjectids(0); + } + GeneratorIdRecorder r(genId, metaData.invocationmeta().invokerruntimeid(), map_); size_t returnIdsize = req.returnobjectids_size() == 0 ? 1 : req.returnobjectids_size(); @@ -548,6 +617,24 @@ CallResult InvokeAdaptor::Call(const CallRequest &req, const libruntime::MetaDat returnObjects[0]->alwaysNative = true; } + if (functionMeta.IsServiceApiType() && rawArgs.size() > SCHEDULER_DATA_INDEX && req.iscreate()) { + auto schedulerDataStr = + std::string(reinterpret_cast(rawArgs[SCHEDULER_DATA_INDEX]->data->MutableData()), + rawArgs[SCHEDULER_DATA_INDEX]->data->GetSize()); + SchedulerInfo schedulerInfo; + + auto err = ParseSchedulerInfo(schedulerDataStr, schedulerInfo); + if (!err.OK()) { + YRLOG_WARN( + "parse schedulerInfo failed, schedulerDataStr is {}, err is {}, in init_handler call another " + "function will failed", + schedulerDataStr, err.Msg()); + } else { + this->taskSubmitter->UpdateFaaSSchedulerInfo(schedulerInfo.schedulerFuncKey, + schedulerInfo.schedulerInstanceList); + } + } + auto err = options.functionExecuteCallback(functionMeta, metaData.invoketype(), rawArgs, returnObjects); for (size_t i = 0; i < returnObjects.size(); i++) { if (returnObjects[i]->buffer != nullptr && returnObjects[i]->buffer->IsNative() && @@ -608,10 +695,6 @@ CallResult InvokeAdaptor::InitCall(const CallRequest &req, const libruntime::Met } else { librtConfig->InitConfig(metaData.config()); taskSubmitter->UpdateConfig(); - if (Config::Instance().ENABLE_METRICS()) { - InitMetricsAdaptor(metaData.config().enablemetrics()); - return callResult; - } } return callResult; } @@ -638,7 +721,10 @@ std::pair InvokeAdaptor::PrepareCallExecutor(con YRLOG_ERROR("{}, concurrency: {}", err, concurrency); return std::make_pair(common::ERR_PARAM_INVALID, err); } - + if (req.createoptions().find(FAAS_INVOKE_TIMEOUT) != req.createoptions().end()) { + this->librtConfig->invokeTimeoutSec = + static_cast(std::stoull(req.createoptions().at(FAAS_INVOKE_TIMEOUT))); + } YRLOG_INFO("Call executor pool size: {}, need order: {}", concurrency, this->librtConfig->needOrder); if (this->librtConfig->needOrder) { this->execMgr = std::make_shared(concurrency, librtConfig->funcExecSubmitHook); @@ -653,11 +739,14 @@ SignalResponse InvokeAdaptor::SignalHandler(const SignalRequest &req) { YRLOG_DEBUG("receive signal {}", req.signal()); SignalResponse resp; + if (!isRunning) { + return resp; + } switch (req.signal()) { case libruntime::Signal::Cancel: { auto objIds = requestManager->GetObjIds(); Cancel(objIds, true, true); - Exit(); + Exit(0, ""); break; } case libruntime::Signal::ErasePendingThread: { @@ -667,6 +756,19 @@ SignalResponse InvokeAdaptor::SignalHandler(const SignalRequest &req) } break; } + case libruntime::Signal::UpdateAlias: { + std::vector aliasInfo; + auto err = ParseAliasInfo(req, aliasInfo); + if (!err.OK()) { + YRLOG_INFO("recv alias update signal, but parse failed, {}", err.Msg()); + resp.set_code(static_cast(err.Code())); + resp.set_message(err.Msg()); + } else { + ar->UpdateAliasInfo(aliasInfo); + } + resp = ExecSignalCallback(req); + break; + } case libruntime::Signal::Update: { const std::string &payload = req.payload(); NotificationPayload notifyPayload; @@ -674,7 +776,9 @@ SignalResponse InvokeAdaptor::SignalHandler(const SignalRequest &req) if (notifyPayload.has_instancetermination()) { this->RemoveInsMetaInfo(notifyPayload.instancetermination().instanceid()); } else if (notifyPayload.has_functionmasterevent()) { - this->functionMasterClient_->UpdateActiveMaster(notifyPayload.functionmasterevent().address()); + if (functionMasterClient_) { + this->functionMasterClient_->UpdateActiveMaster(notifyPayload.functionmasterevent().address()); + } } break; } @@ -697,13 +801,32 @@ SignalResponse InvokeAdaptor::SignalHandler(const SignalRequest &req) AccelerateMsgQueueHandle::FromJson(payload), outputHandle); if (!err.OK()) { resp.set_code(static_cast(err.Code())); - YRLOG_WARN("execute accelerate callback err code: {}, msg: {}", err.Code(), err.Msg()); + YRLOG_WARN("execute accelerate callback err code: {}, msg: {}", fmt::underlying(err.Code()), err.Msg()); resp.set_message(err.Msg()); } else { resp.set_message(outputHandle.ToJson()); } break; } + case libruntime::Signal::UpdateScheduler: { + YRLOG_INFO("recv faascheduler update signal"); + resp = ExecSignalCallback(req); + break; + } + case libruntime::Signal::UpdateSchedulerHash: { + YRLOG_INFO("recv faaschedulerHash update signal"); + SchedulerInfo schedulerInfo; + auto err = ParseSchedulerInfo(req.payload(), schedulerInfo); + if (!err.OK()) { + YRLOG_INFO("recv faascheduler update signal, but parse failed"); + resp.set_code(static_cast(err.Code())); + resp.set_message(err.Msg()); + } else { + this->taskSubmitter->UpdateFaaSSchedulerInfo(schedulerInfo.schedulerFuncKey, + schedulerInfo.schedulerInstanceList); + } + break; + } case libruntime::Signal::GetInstance: { std::string serializedMeta; if (!this->librtConfig->funcMeta.SerializeToString(&serializedMeta)) { @@ -739,6 +862,17 @@ HeartbeatResponse InvokeAdaptor::HeartbeatHandler(const HeartbeatRequest &req) return resp; } +ErrorInfo InvokeAdaptor::ParseAliasInfo(const SignalRequest &req, std::vector &aliasInfo) +{ + try { + json j = json::parse(req.payload()); + aliasInfo = j.get>(); + } catch (std::exception &e) { + return ErrorInfo(ErrorCode::ERR_PARAM_INVALID, std::string("parse alias info: ") + e.what()); + } + return ErrorInfo(); +} + SignalResponse InvokeAdaptor::ExecSignalCallback(const SignalRequest &req) { SignalResponse resp; @@ -771,7 +905,7 @@ ShutdownResponse InvokeAdaptor::ShutdownHandler(const ShutdownRequest &req) ErrorInfo InvokeAdaptor::ExecShutdownCallback(uint64_t gracePeriodSec) { - YRLOG_DEBUG("graceful shutdown period is {}", gracePeriodSec); + YRLOG_INFO("graceful shutdown period is {}", gracePeriodSec); auto notification = std::make_shared(); auto thread = std::thread(&InvokeAdaptor::ExecUserShutdownCallback, this, gracePeriodSec, notification); @@ -816,8 +950,8 @@ void InvokeAdaptor::ExecUserShutdownCallback(uint64_t gracePeriodSec, YRLOG_DEBUG("Start to call user shutdown callback, graceful shutdown time: {}", gracePeriodSec); err = librtConfig->libruntimeOptions.shutdownCallback(gracePeriodSec); if (!err.OK()) { - YRLOG_ERROR("Failed to call user shutdown callback, error: {}, error code: {}, error message: {}", - err.Msg(), err.Code(), static_cast(err.Code())); + YRLOG_ERROR("Failed to call user shutdown callback, code: {}, message: {}", fmt::underlying(err.Code()), + err.Msg()); } else { YRLOG_DEBUG("Succeeded to call user shutdown callback"); } @@ -901,6 +1035,13 @@ void InvokeAdaptor::InvokeInstanceFunction(std::shared_ptr spec) void InvokeAdaptor::SubmitFunction(std::shared_ptr spec) { + if (ar->CheckAlias(spec->functionMeta.functionId)) { + std::string functionId = ar->ParseAlias(spec->functionMeta.functionId, spec->opts.aliasParams); + spec->downgradeFlag_ = (spec->functionMeta.functionId == functionId); + spec->functionMeta.functionId = functionId; + YRLOG_INFO("functionId contain alias, parseAlias to {}, requestId {}", functionId, spec->requestId); + } + if (FunctionGroupEnabled(spec->opts.functionGroupOpts)) { YRLOG_DEBUG("Begin to create instances by function group scheduling, request ID: {}, group name is {}", spec->requestId, spec->opts.groupName); @@ -939,11 +1080,11 @@ void InvokeAdaptor::CreateInstanceRaw(std::shared_ptr reqRaw, RawCallbac this->fsClient->CreateAsync( req, [insId, cb](const CreateResponse &resp) -> void { - YRLOG_DEBUG("recieve create raw response, code is {}, instance id is {}, msg is {}", resp.code(), - resp.instanceid(), resp.message()); + YRLOG_DEBUG("recieve create raw response, code is {}, instance id is {}, msg is {}", + fmt::underlying(resp.code()), resp.instanceid(), resp.message()); if (resp.code() != common::ERR_NONE) { YRLOG_ERROR("start handle failed raw create response, code is {}, instance id is {}, msg is {}", - resp.code(), resp.instanceid(), resp.message()); + fmt::underlying(resp.code()), resp.instanceid(), resp.message()); NotifyRequest notify; notify.set_code(resp.code()); notify.set_message(resp.message()); @@ -957,8 +1098,8 @@ void InvokeAdaptor::CreateInstanceRaw(std::shared_ptr reqRaw, RawCallbac *insId = resp.instanceid(); }, [insId, cb](const NotifyRequest &req) -> void { - YRLOG_DEBUG("recieve create raw notify, code is {}, req id is {}, msg is {}, instance id is {}", req.code(), - req.requestid(), req.message(), *insId); + YRLOG_DEBUG("recieve create raw notify, code is {}, req id is {}, msg is {}, instance id is {}", + fmt::underlying(req.code()), req.requestid(), req.message(), *insId); NotifyRequest notify; notify.set_code(req.code()); notify.set_message(req.message()); @@ -985,8 +1126,8 @@ void InvokeAdaptor::InvokeByInstanceIdRaw(std::shared_ptr reqRaw, RawCal } auto messageSpec = std::make_shared(std::move(req)); this->fsClient->InvokeAsync(messageSpec, [this, cb](const NotifyRequest &req, const ErrorInfo &err) -> void { - YRLOG_DEBUG("recieve invoke raw notify, code is {}, req id is {}, msg is {}", req.code(), req.requestid(), - req.message()); + YRLOG_DEBUG("recieve invoke raw notify, code is {}, req id is {}, msg is {}", fmt::underlying(req.code()), + req.requestid(), req.message()); size_t size = req.ByteSizeLong(); auto respRaw = std::make_shared(size); req.SerializeToArray(respRaw->MutableData(), size); @@ -997,9 +1138,11 @@ void InvokeAdaptor::InvokeByInstanceIdRaw(std::shared_ptr reqRaw, RawCal void InvokeAdaptor::KillRaw(std::shared_ptr reqRaw, RawCallback cb) { KillRequest req; + req.set_requestid(YR::utility::IDGenerator::GenRequestId()); req.ParseFromString(std::string(static_cast(reqRaw->MutableData()), reqRaw->GetSize())); - this->fsClient->KillAsync(req, [this, cb](const KillResponse &resp, ErrorInfo err) -> void { - YRLOG_DEBUG("recieve kill raw response, code is {}", resp.code()); + EraseFsIntf(req.instanceid()); + this->fsClient->KillAsync(req, [this, cb](const KillResponse &resp, const ErrorInfo &err) -> void { + YRLOG_DEBUG("recieve kill raw response, code is {}", fmt::underlying(resp.code())); size_t size = resp.ByteSizeLong(); auto respRaw = std::make_shared(size); resp.SerializeToArray(respRaw->MutableData(), size); @@ -1063,8 +1206,8 @@ void InvokeAdaptor::CreateNotifyHandler(const NotifyRequest &req) if (req.code() != common::ERR_NONE) { bool isConsumeRetryTime = false; if (!NeedRetry(static_cast(req.code()), spec, isConsumeRetryTime)) { - YRLOG_ERROR("Failed to create instance, request ID: {}, code: {}, message: {}", req.requestid(), req.code(), - req.message()); + YRLOG_ERROR("Failed to create instance, request ID: {}, code: {}, message: {}", req.requestid(), + fmt::underlying(req.code()), req.message()); auto isCreate = spec->invokeType == libruntime::InvokeType::CreateInstanceStateless || spec->invokeType == libruntime::InvokeType::CreateInstance; std::vector stackTraceInfos = GetStackTraceInfos(req); @@ -1072,7 +1215,7 @@ void InvokeAdaptor::CreateNotifyHandler(const NotifyRequest &req) stackTraceInfos)); } else { YRLOG_ERROR("Failed to create instance, need retry, request ID: {}, code: {}, message: {}", req.requestid(), - req.code(), req.message()); + fmt::underlying(req.code()), req.message()); RetryCreateInstance(spec, isConsumeRetryTime); return; } @@ -1094,7 +1237,7 @@ void InvokeAdaptor::CreateNotifyHandler(const NotifyRequest &req) auto errorInfo = memStore->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease by requestid {}. Code: {}, MCode: {}, Msg: {}", req.requestid(), - errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), errorInfo.Msg()); } (void)requestManager->RemoveRequest(rawRequestId); @@ -1132,7 +1275,7 @@ void InvokeAdaptor::HandleReturnedObject(const NotifyRequest &req, const std::sh auto err = memStore->IncreDSGlobalReference(dsObjs); if (!err.OK()) { YRLOG_WARN("failed to increase obj ref [{},...] by requestid {}, Code: {}, Msg: {}", dsObjs[0], - req.requestid(), err.Code(), err.Msg()); + req.requestid(), fmt::underlying(err.Code()), err.Msg()); } } memStore->SetReady(spec->returnIds); @@ -1188,13 +1331,16 @@ void InvokeAdaptor::InvokeNotifyHandler(const NotifyRequest &req, const ErrorInf auto errorInfo = memStore->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease by requestid {}. Code: {}, MCode: {}, Msg: {}", req.requestid(), - errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), errorInfo.Msg()); } (void)requestManager->RemoveRequest(rawRequestId); } void InvokeAdaptor::ProcessErr(const std::shared_ptr &spec, const ErrorInfo &errInfo) { + if (spec->functionMeta.isGenerator && generatorReceiver_) { + generatorReceiver_->MarkEndOfStream(spec->returnIds[0].id, errInfo); + } memStore->SetError(spec->returnIds, errInfo); } @@ -1237,10 +1383,13 @@ ErrorInfo InvokeAdaptor::Cancel(const std::vector &objids, bool isF return taskSubmitter->CancelStatelessRequest(objids, f, isForce, isRecursive); } -void InvokeAdaptor::Exit(void) +void InvokeAdaptor::Exit(const int code, const std::string &message) { absl::Notification notification; ExitRequest req; + req.set_code(static_cast(code)); + req.set_message(message); + YRLOG_DEBUG("exit with{}, {}", fmt::underlying(req.code()), req.message()); fsClient->ExitAsync(req, [¬ification](const ExitResponse &resp) { notification.Notify(); }); // default to wait 30s notification.WaitForNotificationWithTimeout(absl::Seconds(30)); @@ -1312,6 +1461,7 @@ void InvokeAdaptor::Finalize(bool isDriver) YRLOG_WARN("Failed to kill all instance, msg: {}", err.Msg()); } } + isRunning = false; if (groupManager != nullptr) { groupManager->Stop(); } @@ -1321,13 +1471,17 @@ void InvokeAdaptor::Finalize(bool isDriver) if (functionMasterClient_) { functionMasterClient_->Stop(); } - isRunning = false; taskSubmitter->Finalize(); if (isDriver) { fsClient->Stop(); } } +void InvokeAdaptor::EraseFsIntf(const std::string &id) +{ + fsClient->EraseIntf(id); +} + void InvokeAdaptor::PushInvokeSpec(std::shared_ptr spec) { this->requestManager->PushRequest(spec); @@ -1342,13 +1496,15 @@ ErrorInfo InvokeAdaptor::Kill(const std::string &instanceId, const std::string & } YRLOG_DEBUG("start kill instance, instance id is {}, signal is {}", instanceId, signal); KillRequest killReq; + killReq.set_requestid(YR::utility::IDGenerator::GenRequestId()); killReq.set_instanceid(instanceId); killReq.set_payload(payload); killReq.set_signal(signal); auto killPromise = std::make_shared>(); std::shared_future killFuture = killPromise->get_future().share(); - this->fsClient->KillAsync(killReq, [killPromise](KillResponse rsp, ErrorInfo err) -> void { killPromise->set_value(rsp); }); + this->fsClient->KillAsync( + killReq, [killPromise](KillResponse rsp, const ErrorInfo &err) -> void { killPromise->set_value(rsp); }); ErrorInfo errInfo; if (signal == libruntime::Signal::killInstanceSync) { errInfo = WaitAndCheckResp(killFuture, instanceId, NO_TIMEOUT); @@ -1359,18 +1515,34 @@ ErrorInfo InvokeAdaptor::Kill(const std::string &instanceId, const std::string & } void InvokeAdaptor::KillAsync(const std::string &instanceId, const std::string &payload, int signal) +{ + this->KillAsyncCB(instanceId, payload, signal, [instanceId, signal](const ErrorInfo &err) -> void { + if (!err.OK()) { + YRLOG_WARN("kill request failed, ins id is {}, signal is {}, err: ", instanceId, signal, err.CodeAndMsg()); + } + }); +} + +void InvokeAdaptor::KillAsyncCB(const std::string &instanceId, const std::string &payload, int signal, + std::function cb) { YRLOG_DEBUG("start kill instance async, instance id is {}, signal is {}, payload is {}", instanceId, signal, payload); KillRequest killReq; + killReq.set_requestid(YR::utility::IDGenerator::GenRequestId()); killReq.set_instanceid(instanceId); killReq.set_payload(payload); killReq.set_signal(signal); - this->fsClient->KillAsync(killReq, [killReq](KillResponse rsp, ErrorInfo err) -> void { + this->fsClient->KillAsync(killReq, [killReq, cb](KillResponse rsp, const ErrorInfo &err) -> void { + if (!err.OK()) { + cb(err); + return; + } if (rsp.code() != common::ERR_NONE) { - YRLOG_WARN("kill request failed, ins id is {}, signal is {}, err code is {}, err msg is {}", - killReq.instanceid(), killReq.signal(), rsp.code(), rsp.message()); + cb(ErrorInfo(static_cast(rsp.code()), ModuleCode::RUNTIME, rsp.message())); + return; } + cb({}); }); } @@ -1490,7 +1662,6 @@ ErrorInfo InvokeAdaptor::ReadDataFromState(const std::string &instanceId, const std::shared_ptr &data) { // deserialize state buffer. format: [uint_8(size of buf1)|buf1|buf2] - YRLOG_DEBUG("Start to read instance state, instance ID: {}", instanceId); const char *statePtr = state.data(); size_t stateSize = state.size(); if (stateSize == 0) { @@ -1501,6 +1672,10 @@ ErrorInfo InvokeAdaptor::ReadDataFromState(const std::string &instanceId, const size_t headerSize = sizeof(size_t); size_t bufInstanceSize = *reinterpret_cast(statePtr); size_t bufMetaSize = stateSize - headerSize - bufInstanceSize; + YRLOG_INFO( + "Start to read instance state, instance ID: {}, state size is {}, buf ins size is {}, header size is {}, buf " + "meta size is {}", + instanceId, stateSize, bufInstanceSize, headerSize, bufMetaSize); auto bufInstance = std::make_shared(bufInstanceSize); bufInstance->MemoryCopy(static_cast(statePtr + headerSize), bufInstanceSize); data = bufInstance; @@ -1566,7 +1741,7 @@ void InvokeAdaptor::ReportMetrics(const std::string &requestId, const std::strin data.labels["requestid"] = requestId; data.labels["traceid"] = traceId; data.value = value; - auto err = metricsAdaptor->ReportMetrics(data); + auto err = MetricsAdaptor::GetInstance()->ReportMetrics(data); if (!err.OK()) { YRLOG_WARN("failed to report metrics, requestid: {}, traceid: {}, value: {}", requestId, traceId, value); } @@ -1593,7 +1768,7 @@ void InvokeAdaptor::InitMetricsAdaptor(bool userEnable) { if (!Config::Instance().METRICS_CONFIG().empty()) { try { - metricsAdaptor->Init(nlohmann::json::parse(Config::Instance().METRICS_CONFIG()), userEnable); + MetricsAdaptor::GetInstance()->Init(nlohmann::json::parse(Config::Instance().METRICS_CONFIG()), userEnable); return; } catch (std::exception &e) { YRLOG_ERROR("parse config json failed, error: {}", e.what()); @@ -1610,11 +1785,53 @@ void InvokeAdaptor::InitMetricsAdaptor(bool userEnable) return; } nlohmann::json config; - f >> config; - metricsAdaptor->Init(config, userEnable); + try { + f >> config; + } catch (const nlohmann::json::parse_error &e) { + YRLOG_ERROR("JSON parse error: {}", e.what()); + } + MetricsAdaptor::GetInstance()->Init(config, userEnable); f.close(); } +std::pair InvokeAdaptor::AcquireInstance(const std::string &stateId, + const FunctionMeta &functionMeta, + InvokeOptions &opts) +{ + auto spec = std::make_shared(); + spec->requestId = YR::utility::IDGenerator::GenRequestId(); + spec->functionMeta = functionMeta; + if (!functionMeta.name.empty()) { + spec->designatedInstanceID = functionMeta.name; + } + if (opts.schedulerInstanceIds.size() == 0) { + opts.schedulerInstanceIds.push_back(librtConfig->schedulerInstanceIds[0]); + } + + spec->opts = opts; + spec->traceId = opts.traceId; + auto [allocation, err] = taskSubmitter->AcquireInstance(stateId, spec); + if (err.OK()) { + requestManager->PushFaasRequest(allocation.leaseId, spec); + } + return std::make_pair(allocation, err); +} + +ErrorInfo InvokeAdaptor::ReleaseInstance(const std::string &leaseId, const std::string &stateId, bool abnormal, + InvokeOptions &opts) +{ + auto spec = std::make_shared(); + if (!stateId.empty()) { + spec->opts = opts; + } else { + if (!requestManager->PopRequest(leaseId, spec)) { + YRLOG_DEBUG("spec of leaseid: {} not exist, do not need release instance.", leaseId); + return ErrorInfo(); + } + } + return taskSubmitter->ReleaseInstance(leaseId, stateId, abnormal, spec); +} + void InvokeAdaptor::CreateResourceGroup(std::shared_ptr spec) { this->rGroupManager_->StoreRGDetail(spec->rGroupSpec.name, spec->requestId, spec->rGroupSpec.bundles.size()); @@ -1628,6 +1845,7 @@ void InvokeAdaptor::CreateResourceGroup(std::shared_ptr this_ptr->rGroupManager_->SetRGCreateErrInfo(spec->rGroupSpec.name, spec->requestId, err); } }; + fsClient->CreateRGroupAsync(spec->requestCreateRGroup, rspHandler); YRLOG_DEBUG("Create resource group request has been sent, req id is {}, Details: {}", spec->requestId, spec->requestCreateRGroup.DebugString()); @@ -1651,27 +1869,36 @@ std::pair InvokeAdaptor::GetInstance(co return std::make_pair(convertProtoToFuncMeta(metaCached), ErrorInfo()); } KillRequest killReq; + killReq.set_requestid(YR::utility::IDGenerator::GenRequestId()); killReq.set_instanceid(insId); killReq.set_signal(libruntime::Signal::GetInstance); auto promise = std::promise>(); auto future = promise.get_future(); - this->fsClient->KillAsync(killReq, [&promise](const KillResponse &rsp, const ErrorInfo &err) -> void { - if (rsp.code() != common::ERR_NONE) { - YR::Libruntime::ErrorInfo errInfo(static_cast(rsp.code()), ModuleCode::RUNTIME, - rsp.message()); - errInfo.SetIsTimeout(err.IsTimeout()); - promise.set_value(std::make_pair(libruntime::FunctionMeta{}, errInfo)); - } else { - libruntime::FunctionMeta funcMeta; - funcMeta.ParseFromString(rsp.message()); - promise.set_value(std::make_pair(funcMeta, YR::Libruntime::ErrorInfo())); - } - }, timeoutSec); + this->invokeOrderMgr->RegisterInstanceAndUpdateOrder(insId); + this->fsClient->KillAsync( + killReq, + [insId, &promise](const KillResponse &response, const ErrorInfo &err) -> void { + if (response.code() != common::ERR_NONE) { + YRLOG_ERROR("get instance failed, instance id is {}, errcode is {}, err msg is {}", insId, + fmt::underlying(response.code()), response.message()); + YR::Libruntime::ErrorInfo errInfo(static_cast(response.code()), ModuleCode::RUNTIME, + response.message()); + errInfo.SetIsTimeout(err.IsTimeout()); + promise.set_value(std::make_pair(libruntime::FunctionMeta{}, errInfo)); + } else { + libruntime::FunctionMeta funcMeta; + funcMeta.ParseFromString(response.message()); + promise.set_value(std::make_pair(funcMeta, YR::Libruntime::ErrorInfo())); + } + }, + timeoutSec); auto [funcMeta, errorInfo] = future.get(); - YRLOG_DEBUG("get instance finished, err code is {}, err msg is {}, function meta is {}", errorInfo.Code(), - errorInfo.Msg(), funcMeta.DebugString()); + YRLOG_DEBUG("get instance finished, err code is {}, err msg is {}, function meta is {}", + fmt::underlying(errorInfo.Code()), errorInfo.Msg(), funcMeta.DebugString()); if (errorInfo.OK()) { this->UpdateAndSubcribeInsStatus(insId, funcMeta); + // for get instance req, invoke seq no is always 0 or no seq no + this->invokeOrderMgr->UpdateFinishReqSeqNo(insId, 0); } else { this->RemoveInsMetaInfo(insId); } @@ -1698,7 +1925,8 @@ void InvokeAdaptor::UpdateAndSubcribeInsStatus(const std::string &insId, librunt YRLOG_DEBUG( "start add ins meta into metamap, ins id is: {}, class name is {}, module name is {}, function id is {}, " "language is {}", - insId, funcMeta.classname(), funcMeta.modulename(), funcMeta.functionid(), funcMeta.language()); + insId, funcMeta.classname(), funcMeta.modulename(), funcMeta.functionid(), + fmt::underlying(funcMeta.language())); if (!funcMeta.name().empty() && funcMeta.ns().empty()) { funcMeta.set_ns(DEFAULT_YR_NAMESPACE); } @@ -1726,12 +1954,12 @@ void InvokeAdaptor::Subscribe(const std::string &insId) std::string serializedPayload; subscription.SerializeToString(&serializedPayload); killReq.set_payload(serializedPayload); + killReq.set_requestid(YR::utility::IDGenerator::GenRequestId()); auto weakThis = weak_from_this(); YRLOG_DEBUG("start send subscribe req of instance: {}", insId); - this->fsClient->KillAsync(killReq, [insId, weakThis](KillResponse rsp, ErrorInfo err) -> void { + this->fsClient->KillAsync(killReq, [insId, weakThis](KillResponse rsp, const ErrorInfo &err) -> void { if (rsp.code() != common::ERR_NONE) { - YRLOG_WARN("subcribe ins status failed, ins id is : {}, code is {}, msg is {},", insId, rsp.code(), - rsp.message()); + YRLOG_WARN("subscribe ins status failed, ins id is : {}, code is {},", insId, fmt::underlying(rsp.code())); } if (rsp.code() == common::ERR_SCHEDULE_PLUGIN_CONFIG || rsp.code() == common::ERR_SUB_STATE_INVALID) { if (auto thisPtr = weakThis.lock(); thisPtr) { @@ -1783,6 +2011,7 @@ void InvokeAdaptor::SubscribeActiveMaster() auto insId = Config::Instance().INSTANCE_ID(); auto instanceId = insId.empty() ? "driver-" + runtimeContext->GetJobId() : insId; KillRequest killReq; + killReq.set_requestid(YR::utility::IDGenerator::GenRequestId()); killReq.set_instanceid(instanceId); killReq.set_signal(libruntime::Signal::Subsribe); SubscriptionPayload subscription; @@ -1793,9 +2022,74 @@ void InvokeAdaptor::SubscribeActiveMaster() killReq.set_payload(serializedPayload); auto weakThis = weak_from_this(); YRLOG_DEBUG("start send subscribe function master req of instance: {}", instanceId); - this->fsClient->KillAsync(killReq, [instanceId](KillResponse rsp, ErrorInfo err) -> void { - YRLOG_DEBUG("get subcribe function master response, ins id is : {}, code is {},", instanceId, rsp.code()); + this->fsClient->KillAsync(killReq, [instanceId, weakThis](KillResponse rsp, const ErrorInfo &err) -> void { + YRLOG_DEBUG("get subcribe function master response, ins id is : {}, code is {},", instanceId, + fmt::underlying(rsp.code())); }); } + +void InvokeAdaptor::UpdateSchdulerInfo(const std::string &scheduleName, const std::string &schedulerId, + const std::string &option) +{ + this->taskSubmitter->UpdateSchdulerInfo(scheduleName, schedulerId, option); +} + +void InvokeAdaptor::CreateCallTimer(const std::string &reqId, const std::string &instanceId, int64_t invokeTimeoutSec) +{ + absl::MutexLock lock(&callTimerMtx_); + YRLOG_DEBUG("start add call req exec timer, invoke timeout is {}, req id is: {}, sender id is: {}", + invokeTimeoutSec, reqId, instanceId); + if (invokeTimeoutSec <= 0) { + return; + } + auto weakThis = weak_from_this(); + callTimeoutTimerMap_[reqId] = YR::utility::ExecuteByGlobalTimer( + [weakThis, reqId, instanceId, invokeTimeoutSec]() { + auto thisPtr = weakThis.lock(); + if (!thisPtr) { + return; + } + YRLOG_WARN("call req execution exceeds the set time: {}, req id is {}, sender id is {}", invokeTimeoutSec, + reqId, instanceId); + auto result = std::make_shared(); + auto &callResult = result->Mutable(); + callResult.set_requestid(reqId); + callResult.set_instanceid(instanceId); + callResult.set_code(common::ERR_INNER_SYSTEM_ERROR); + std::string errMsg = "call req execution exceeds the set time: " + std::to_string(invokeTimeoutSec) + + " req id : " + reqId + ", senderid: " + instanceId; + callResult.set_message(errMsg); + auto callResultCallback = [reqId](const CallResultAck &resp) { + if (resp.code() != common::ERR_NONE) { + YRLOG_WARN("failed to send CallResult, req id: {}, code: {}, message: {}", reqId, + fmt::underlying(resp.code()), resp.message()); + } + }; + thisPtr->fsClient->ReturnCallResult(result, false, callResultCallback); + thisPtr->EraseCallTimer(reqId); + }, + invokeTimeoutSec * MILLISECOND_UNIT, 1); +} + +void InvokeAdaptor::EraseCallTimer(const std::string &reqId) +{ + absl::MutexLock lock(&callTimerMtx_); + auto it = callTimeoutTimerMap_.find(reqId); + if (it != callTimeoutTimerMap_.end()) { + YRLOG_DEBUG("start remove call timer of req: {}", reqId); + it->second->cancel(); + it->second.reset(); + callTimeoutTimerMap_.erase(it); + } +} + +bool InvokeAdaptor::IsHealth() +{ + if (!fsClient) { + return false; + } + return fsClient->IsHealth(); +} + } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/invokeadaptor/invoke_adaptor.h b/src/libruntime/invokeadaptor/invoke_adaptor.h index 11b4472..adc06a7 100644 --- a/src/libruntime/invokeadaptor/invoke_adaptor.h +++ b/src/libruntime/invokeadaptor/invoke_adaptor.h @@ -21,17 +21,23 @@ #include #include +#include "alias_element.h" +#include "alias_routing.h" #include "execution_manager.h" #include "request_manager.h" #include "src/dto/config.h" #include "src/dto/status.h" +#include "src/libruntime/fiber.h" #include "src/libruntime/dependency_resolver.h" #include "src/libruntime/err_type.h" -#include "src/libruntime/fiber.h" -#include "src/libruntime/fsclient/fs_client.h" #include "src/libruntime/fmclient/fm_client.h" +#include "src/libruntime/fsclient/fs_client.h" +#include "src/libruntime/generator/generator_id_map.h" +#include "src/libruntime/generator/generator_notifier.h" +#include "src/libruntime/generator/generator_receiver.h" #include "src/libruntime/groupmanager/function_group.h" #include "src/libruntime/groupmanager/group_manager.h" +#include "src/libruntime/gwclient/gw_client.h" #include "src/libruntime/invoke_order_manager.h" #include "src/libruntime/invoke_spec.h" #include "src/libruntime/invokeadaptor/task_submitter.h" @@ -39,17 +45,19 @@ #include "src/libruntime/metricsadaptor/metrics_adaptor.h" #include "src/libruntime/objectstore/memory_store.h" #include "src/libruntime/objectstore/object_store.h" +#include "src/libruntime/rgroupmanager/resource_group_create_spec.h" +#include "src/libruntime/rgroupmanager/resource_group_manager.h" #include "src/libruntime/runtime_context.h" #include "src/libruntime/utils/constants.h" #include "src/libruntime/utils/exception.h" #include "src/libruntime/utils/utils.h" -#include "src/libruntime/rgroupmanager/resource_group_create_spec.h" -#include "src/libruntime/rgroupmanager/resource_group_manager.h" #include "src/utility/notification_utility.h" namespace YR { namespace Libruntime { +extern thread_local std::string threadLocalTraceId; using FinalizeCallback = std::function; +using DebugBreakpointHook = std::function; using SetTenantIdCallback = std::function; using RawCallback = std::function resultRaw)>; @@ -60,7 +68,10 @@ public: std::shared_ptr &fsClient, std::shared_ptr memStore, std::shared_ptr rtCtx, FinalizeCallback cb, std::shared_ptr waitManager, std::shared_ptr invokeOrderMgr, - std::shared_ptr clientsMgr, std::shared_ptr metricsAdaptor); + std::shared_ptr clientsMgr, std::shared_ptr metricsAdaptor, + std::shared_ptr genIdMapper, std::shared_ptr generatorReceiver, + std::shared_ptr generatorNotifier, + std::shared_ptr downgrade = nullptr); std::pair Init(RuntimeContext &runtimeContext, std::shared_ptr security); @@ -94,13 +105,15 @@ public: virtual ErrorInfo Cancel(const std::vector &objids, bool isForce, bool isRecursive); - virtual void Exit(void); + virtual void Exit(const int code, const std::string &message); virtual void Finalize(bool isDriver = true); virtual ErrorInfo Kill(const std::string &instanceId, const std::string &payload, int signal); virtual void KillAsync(const std::string &instanceId, const std::string &payload, int signal); + virtual void KillAsyncCB(const std::string &instanceId, const std::string &payload, int signal, + std::function cb); CallResponse CallReqProcess(const CallRequest &req); @@ -120,13 +133,20 @@ public: virtual void GroupTerminate(const std::string &groupName); virtual std::pair, ErrorInfo> GetInstanceIds(const std::string &objId, - const std::string &groupName); + const std::string &groupName); virtual ErrorInfo SaveState(const std::shared_ptr data, const int &timeout); virtual ErrorInfo LoadState(std::shared_ptr &data, const int &timeout); virtual ErrorInfo ExecShutdownCallback(uint64_t gracePeriodSec); + + virtual std::pair AcquireInstance(const std::string &stateId, + const FunctionMeta &functionMeta, + InvokeOptions &opts); + + virtual ErrorInfo ReleaseInstance(const std::string &leaseId, const std::string &stateId, bool abnormal, + InvokeOptions &opts); void InitHandler(const std::shared_ptr &req); void CallHandler(const std::shared_ptr &req); CheckpointResponse CheckpointHandler(const CheckpointRequest &req); @@ -134,20 +154,27 @@ public: void CreateResourceGroup(std::shared_ptr spec); virtual std::pair GetInstance(const std::string &name, - const std::string &nameSpace, int timeoutSec); + const std::string &nameSpace, + int timeoutSec); void SubscribeAll(); void Subscribe(const std::string &insId); + void SubscribeActiveMaster(); ErrorInfo Accelerate(const std::string &groupName, const AccelerateMsgQueueHandle &handle, HandleReturnObjectCallback callback); - + virtual void UpdateSchdulerInfo(const std::string &schedulerName, const std::string &schedulerId, + const std::string &option); + void CreateCallTimer(const std::string &reqId, const std::string &instanceId, int64_t invokeTimeoutSec); + void EraseCallTimer(const std::string &reqId); + virtual void EraseFsIntf(const std::string &id); + void PushInvokeSpec(std::shared_ptr spec); + virtual bool IsHealth(); virtual std::pair GetNodeIpAddress(); virtual std::pair GetNodeId(); virtual std::pair> GetResources(void); virtual std::pair GetResourceGroupTable(const std::string &resourceGroupId);\ virtual std::pair QueryNamedInstances(); - void PushInvokeSpec(std::shared_ptr spec); - void SubscribeActiveMaster(); + private: void CreateResponseHandler(std::shared_ptr spec, const CreateResponse &resp); void CreateNotifyHandler(const NotifyRequest &req); @@ -159,6 +186,7 @@ private: HeartbeatResponse HeartbeatHandler(const HeartbeatRequest &req); void ExecUserShutdownCallback(uint64_t gracePeriodSec, const std::shared_ptr ¬ification); + ErrorInfo ParseAliasInfo(const SignalRequest &req, std::vector &aliasInfo); template static ErrorInfo WaitAndCheckResp(std::shared_future &future, const std::string &instanceId, @@ -172,7 +200,7 @@ private: std::pair PrepareCallExecutor(const RequestType &req); CallResult CallProcess(const CallRequest &req); - + bool ParseRequest(const CallRequest &request, std::vector> &rawArgs, bool isPosix); CallResult Call(const CallRequest &request, const libruntime::MetaData &metaData, const LibruntimeOptions &options, std::vector &objectsInDs); // load user function libraries @@ -185,6 +213,8 @@ private: void ProcessErr(const std::shared_ptr &spec, const ErrorInfo &errInfo); + void CheckAndSetDebugBreakpoint(const std::shared_ptr &req); + void UpdateAndSubcribeInsStatus(const std::string &insId, libruntime::FunctionMeta &funcMeta); void RemoveInsMetaInfo(const std::string &insId); std::pair GetCachedInsMeta(const std::string &insId); @@ -204,14 +234,21 @@ private: std::shared_ptr execMgr; std::shared_ptr clientsMgr; std::shared_ptr metricsAdaptor; + std::shared_ptr ar; std::shared_ptr fiberPool_; + std::shared_ptr map_; + std::shared_ptr generatorReceiver_; + std::shared_ptr generatorNotifier_; std::shared_ptr rGroupManager_; + std::shared_ptr functionMasterClient_; std::mutex finishTaskMtx; + DebugBreakpointHook setDebugBreakpoint_ = nullptr; SetTenantIdCallback setTenantIdCb_; std::mutex metaMapMtx; std::unordered_map metaMap; std::atomic accelerateRunFlag_{false}; - std::shared_ptr functionMasterClient_; + std::unordered_map> callTimeoutTimerMap_; + mutable absl::Mutex callTimerMtx_; }; const static std::unordered_map ErrMsgMap = { diff --git a/src/libruntime/invokeadaptor/limiter_consistant_hash.cpp b/src/libruntime/invokeadaptor/limiter_consistant_hash.cpp new file mode 100644 index 0000000..cac4842 --- /dev/null +++ b/src/libruntime/invokeadaptor/limiter_consistant_hash.cpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "limiter_consistant_hash.h" + +namespace YR { +namespace Libruntime { +using namespace std::chrono; +void LimiterCsHash::Add(const std::string &schedulerName, const std::string &schedulerId, int weight) +{ + std::lock_guard lk(limiterMtx_); + this->AddWithoutLock(schedulerName, schedulerId, weight); +} + +void LimiterCsHash::AddWithoutLock(const std::string &schedulerName, const std::string &schedulerId, int weight) +{ + if (schedulerName.empty() || schedulerId.empty()) { + YRLOG_WARN("scheduler name: {} or scheduler id {} is empty, no need add and return directly", schedulerName, + schedulerId); + return; + } + nodeCount_++; + loadBalancer_->Add(schedulerName, weight); + schedulerInfoMap[schedulerName] = schedulerId; + YRLOG_DEBUG("add scheduler name {}, id: {}, current node count is {}", schedulerName, schedulerId, + nodeCount_.load()); +} + +std::string LimiterCsHash::NextRetry(const std::string &functionId, bool move) +{ + std::vector> vec; + return NextRetry(functionId, nullptr, move); +} + +std::string LimiterCsHash::NextRetry(const std::string &functionId, + std::shared_ptr schedulerInfos, bool move) +{ + auto schedulerInstanceId = Next(functionId, schedulerInfos, move); + if (schedulerInstanceId.empty()) { + return Next(functionId, schedulerInfos, true); + } + return schedulerInstanceId; +} + +std::string LimiterCsHash::Next(const std::string &functionId, bool move) +{ + return Next(functionId, nullptr, move); +} + +std::string LimiterCsHash::Next(const std::string &functionId, std::shared_ptr schedulerInfos, + bool moveFlag) +{ + std::lock_guard lk(limiterMtx_); + auto res = loadBalancer_->Next(functionId, moveFlag); + size_t pos = res.find(":"); + if (pos != std::string::npos) { + std::string schedulerName = res.substr(0, pos); + long long updateTime = std::stoll(res.substr(pos + 1)); + YRLOG_DEBUG("res is {}, schedulerName is {}, update time is {}", res, schedulerName, updateTime); + if (schedulerName.empty()) { + return ""; + } + if (schedulerInfoMap.find(schedulerName) != schedulerInfoMap.end()) { + if (moveFlag && isAllInsUnavailable(schedulerInfoMap[schedulerName], updateTime, schedulerInfos)) { + return ALL_SCHEDULER_UNAVAILABLE; + } + return schedulerInfoMap[schedulerName]; + } + } + return ""; +} + +bool LimiterCsHash::isAllInsUnavailable(const std::string &instanceKey, long long updateTime, + std::shared_ptr schedulerInfos) +{ + if (schedulerInfos == nullptr) { + return false; + } + absl::WriterMutexLock insLk(&schedulerInfos->schedulerMtx); + YRLOG_DEBUG("instanceKey is {}, update time is {}", instanceKey, updateTime); + bool hasAvailableInsInVec = false; + bool hasInsIdMatch = false; + std::shared_ptr matchedIns; + for (auto &scheduler : schedulerInfos->schedulerInstanceList) { + if (scheduler->isAvailable) { + hasAvailableInsInVec = true; + } + if (scheduler->InstanceID == instanceKey) { + hasInsIdMatch = true; + matchedIns = scheduler; + } + } + + if (hasAvailableInsInVec) { + if (!hasInsIdMatch) { + schedulerInfos->schedulerInstanceList.push_back( + std::make_shared(SchedulerInstance{"", instanceKey, updateTime, true})); + } else { + matchedIns->updateTime = updateTime; + } + return false; + } + + if (!hasInsIdMatch) { + schedulerInfos->schedulerInstanceList.push_back( + std::make_shared(SchedulerInstance{"", instanceKey, updateTime, true})); + return false; + } else { + if (updateTime <= matchedIns->updateTime) { + YRLOG_WARN( + "current scheduler vecs has no available ins, next scheduler res is {}, add into hash pool at {}, used " + "at {}", + instanceKey, updateTime, matchedIns->updateTime); + return true; + } + matchedIns->isAvailable = true; + matchedIns->updateTime = updateTime; + } + return false; +} + +void LimiterCsHash::Remove(const std::string &name) +{ + std::lock_guard lk(limiterMtx_); + nodeCount_--; + YRLOG_DEBUG("start remove schedule id: {}, current node count is {}", name, nodeCount_.load()); + loadBalancer_->Remove(name); + schedulerInfoMap.erase(name); +} + +void LimiterCsHash::RemoveAll() +{ + std::lock_guard lk(limiterMtx_); + nodeCount_ = 0; + YRLOG_DEBUG("start remove all scheduler"); + loadBalancer_->RemoveAll(); + schedulerInfoMap.clear(); +} + +bool LimiterCsHash::IsSameWithHashPool(const std::vector &schedulerInfoList) +{ + if (schedulerInfoList.size() != schedulerInfoMap.size()) { + return false; + } + for (auto &schedulerIns : schedulerInfoList) { + if (schedulerInfoMap.find(schedulerIns.InstanceName) == schedulerInfoMap.end() || + schedulerInfoMap[schedulerIns.InstanceName] != schedulerIns.InstanceID) { + return false; + } + } + return true; +} + +void LimiterCsHash::ResetAll(const std::vector &schedulerInfoList, int weight) +{ + std::lock_guard lk(limiterMtx_); + if (auto isSame = IsSameWithHashPool(schedulerInfoList); isSame) { + return; + } + nodeCount_ = 0; + YRLOG_DEBUG("start remove all scheduler"); + loadBalancer_->RemoveAll(); + schedulerInfoMap.clear(); + for (auto info : schedulerInfoList) { + this->AddWithoutLock(info.InstanceName, info.InstanceID, weight); + } +} + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/limiter_consistant_hash.h b/src/libruntime/invokeadaptor/limiter_consistant_hash.h new file mode 100644 index 0000000..e93f3fb --- /dev/null +++ b/src/libruntime/invokeadaptor/limiter_consistant_hash.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "load_balancer.h" +#include "src/libruntime/invokeadaptor/scheduler_instance_info.h" +#include "src/utility/logger/logger.h" + +const int DEFAULT_CH_CONCURRENCY = 100; +namespace YR { +namespace Libruntime { +using namespace std::chrono; + +struct LimiterNode { + std::string schedulerId; + time_point lastTime; + std::shared_ptr next; + LimiterNode() = default; + LimiterNode(std::string &key) : schedulerId(key), lastTime(steady_clock::now()), next(nullptr) {} +}; + +struct ConcurrentLimiter { + std::shared_ptr head; +}; + +class LimiterCsHash { +public: + LimiterCsHash() = default; + LimiterCsHash(std::shared_ptr loadBalancer) : loadBalancer_(loadBalancer) {} + ~LimiterCsHash() = default; + void Add(const std::string &schedulerName, const std::string &schedulerId, int weight = 0); + std::string Next(const std::string &functionId, std::shared_ptr schedulerInfos, + bool move = false); + std::string Next(const std::string &functionId, bool move = false); + std::string NextRetry(const std::string &functionId, bool move = false); + std::string NextRetry(const std::string &functionId, std::shared_ptr schedulerInfos, + bool move = false); + bool isAllInsUnavailable(const std::string &instanceKey, long long instanceHash, + std::shared_ptr schedulerInfos); + void Remove(const std::string &name); + void RemoveAll(); + void ResetAll(const std::vector &schedulerInfoList, int weight = 0); + bool IsSameWithHashPool(const std::vector &schedulerInfoList); + +private: + void AddWithoutLock(const std::string &schedulerName, const std::string &schedulerId, int weight = 0); + std::shared_ptr loadBalancer_; + std::unordered_map schedulerInfoMap; + std::unordered_map> limiters_; + std::mutex limiterMtx_; + std::atomic nodeCount_{0}; + int limiterTime_ = 1; +}; + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/load_balancer.cpp b/src/libruntime/invokeadaptor/load_balancer.cpp new file mode 100644 index 0000000..e19c34e --- /dev/null +++ b/src/libruntime/invokeadaptor/load_balancer.cpp @@ -0,0 +1,242 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "load_balancer.h" + +#include +#include +#include + +#include "src/libruntime/utils/utils.h" +#include "src/utility/logger/logger.h" + +namespace YR { +namespace Libruntime { +std::hash g_hashFunc; +struct WeightedNode { + std::string name; + int weight; + int currentWeight; +}; + +struct AnchorInfo { + size_t instanceHash; + std::string instanceKey; +}; + +class WeightedRoundRobin : public LoadBalancer { +public: + WeightedRoundRobin() = default; + virtual ~WeightedRoundRobin() = default; + + std::string Next(const std::string &name, bool move = false) override + { + if (nodes_.size() == 0) { + return ""; + } + return NextNode().name; + } + + void Add(const std::string &node, int weight) override + { + nodes_.emplace_back(WeightedNode{ + .name = node, + .weight = weight, + .currentWeight = 0, + }); + } + + void RemoveAll(void) override + { + nodes_.clear(); + } + + void Remove(const std::string &name) override {} + +private: + WeightedNode &NextNode(void) + { + int total = 0; + WeightedNode *best = &nodes_[0]; + for (auto &node : nodes_) { + node.currentWeight += node.weight; + total += node.weight; + if (node.currentWeight > best->currentWeight) { + best = &node; + } + } + best->currentWeight -= total; + return *best; + } + + std::vector nodes_; +}; + +class CsHashRoundRobin : public LoadBalancer { +public: + CsHashRoundRobin() = default; + virtual ~CsHashRoundRobin() = default; + + std::string Next(const std::string &name, bool move = false) override + { + if (hashPool_.size() == 0) { + YRLOG_WARN("current hash pool size is empty, return empty res directly"); + return ":0"; + } + std::unique_lock lk(insMtx_); + if (auto it = anchorPoint_.find(name); it == anchorPoint_.end()) { + YRLOG_DEBUG("current anchor point map has no value of id: {}", name); + auto anchor = AddAnchorPoint(name); + lk.unlock(); + if (hashPool_.find(anchor.instanceHash) != hashPool_.end()) { + return anchor.instanceKey + ":" + std::to_string(hashPool_[anchor.instanceHash]); + } + return anchor.instanceKey + ":" + std::to_string(0); + } + auto anchor = anchorPoint_[name]; + if (move) { + YRLOG_DEBUG("id: {} need remove anchor point, current anchor ins hash is {}, ins key is {}", name, + anchor.instanceHash, anchor.instanceKey); + MoveAnchorPoint(name, anchor.instanceHash); + } + if (instanceMap_.find(anchor.instanceHash) == instanceMap_.end()) { + YRLOG_DEBUG( + "scheduler: {} not exist in instance map, need move anchor point, current ins hash is {}, " + "ins key is {}", + name, anchor.instanceHash, anchor.instanceKey); + MoveAnchorPoint(name, anchor.instanceHash); + } + lk.unlock(); + if (hashPool_.find(anchorPoint_[name].instanceHash) != hashPool_.end()) { + return anchorPoint_[name].instanceKey + ":" + std::to_string(hashPool_[anchorPoint_[name].instanceHash]); + } + return anchorPoint_[name].instanceKey + ":" + std::to_string(0); + } + + void Add(const std::string &name, int weight) override + { + std::unique_lock lk(insMtx_); + auto hashKey = g_hashFunc(name); + if (instanceMap_.find(hashKey) != instanceMap_.end()) { + hashPool_[hashKey] = YR::GetCurrentTimestampNs(); + lk.unlock(); + YRLOG_DEBUG("scheduler: {} has already exist in instance map", name); + return; + } + + instanceMap_[hashKey] = name; + + hashPool_.emplace(hashKey, YR::GetCurrentTimestampNs()); + YRLOG_DEBUG("start add scheduler: {}, hashKey {} to hash ring, total nodes is {}", name, hashKey, totalNodes); + totalNodes++; + lk.unlock(); + } + + void RemoveAll(void) override + { + std::unique_lock lk(insMtx_); + anchorPoint_.clear(); + instanceMap_.clear(); + hashPool_.clear(); + totalNodes = 0; + lk.unlock(); + } + + void Remove(const std::string &name) override + { + auto hashKey = g_hashFunc(name); + std::unique_lock lk(insMtx_); + if (instanceMap_.find(hashKey) != instanceMap_.end()) { + instanceMap_.erase(hashKey); + } + auto it = hashPool_.find(hashKey); + if (it != hashPool_.end()) { + YRLOG_DEBUG("remove scheduler : {}, total noedes is {}", name, totalNodes); + hashPool_.erase(it); + totalNodes--; + } + lk.unlock(); + } + +private: + std::unordered_map anchorPoint_ = std::unordered_map(); + std::unordered_map instanceMap_ = std::unordered_map(); + std::map hashPool_ = std::map(); + std::mutex insMtx_; + size_t totalNodes = 0; + + void MoveAnchorPoint(const std::string &name, size_t currentHash) + { + auto instanceHash = GetNextHashKey(currentHash); + if (auto it = anchorPoint_.find(name); it != anchorPoint_.end()) { + if (instanceMap_.find(instanceHash) != instanceMap_.end()) { + it->second.instanceKey = instanceMap_[instanceHash]; + } + it->second.instanceHash = instanceHash; + YRLOG_DEBUG("after move anchor point, instance hash of {} is {}, instance key is {}", name, + it->second.instanceHash, it->second.instanceKey); + } + } + + size_t GetNextHashKey(const size_t hashKey) + { + if (hashPool_.empty()) { + return 0; + } + + auto nextHashKey = hashPool_.begin()->first; + for (auto v : hashPool_) { + if (v.first > hashKey) { + nextHashKey = v.first; + break; + } + } + return nextHashKey; + } + + AnchorInfo AddAnchorPoint(const std::string &key) + { + size_t hashKey = g_hashFunc(key); + + size_t nextHashKey = GetNextHashKey(hashKey); + AnchorInfo anchorInfo{.instanceHash = nextHashKey}; + if (instanceMap_.find(nextHashKey) != instanceMap_.end()) { + anchorInfo.instanceKey = instanceMap_[nextHashKey]; + } + anchorPoint_.emplace(key, anchorInfo); + YRLOG_DEBUG("end add name: {}, instance hash key is {}, instance name is {}", key, anchorInfo.instanceHash, + anchorInfo.instanceKey); + return anchorInfo; + } +}; + +LoadBalancer *LoadBalancer::Factory(LoadBalancerType type) +{ + switch (type) { + case LoadBalancerType::WeightedRoundRobin: { + return new WeightedRoundRobin(); + } + case LoadBalancerType::ConsistantRoundRobin: { + return new CsHashRoundRobin(); + } + default: { + YRLOG_ERROR("unknown load balancer type {}", static_cast(type)); + return new WeightedRoundRobin(); + } + } +} +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/load_balancer.h b/src/libruntime/invokeadaptor/load_balancer.h new file mode 100644 index 0000000..836eaaf --- /dev/null +++ b/src/libruntime/invokeadaptor/load_balancer.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "src/libruntime/invokeadaptor/scheduler_instance_info.h" + +namespace YR { +namespace Libruntime { +const std::string ALL_SCHEDULER_UNAVAILABLE = "AllSchedulerUnavailable"; +enum class LoadBalancerType : int { + WeightedRoundRobin, + ConsistantRoundRobin, +}; + +class LoadBalancer { +public: + static LoadBalancer *Factory(LoadBalancerType type); + + LoadBalancer() = default; + virtual ~LoadBalancer() = default; + + virtual std::string Next(const std::string &name, bool move = false) = 0; + virtual void Add(const std::string &node, int weight) = 0; + virtual void RemoveAll(void) = 0; + virtual void Remove(const std::string &name) = 0; +}; + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/normal_instance_manager.cpp b/src/libruntime/invokeadaptor/normal_instance_manager.cpp index fb3f7fd..5b4a4ac 100644 --- a/src/libruntime/invokeadaptor/normal_instance_manager.cpp +++ b/src/libruntime/invokeadaptor/normal_instance_manager.cpp @@ -28,11 +28,12 @@ void NormalInsManager::UpdateConfig(int inputRecycleTimeMs) void NormalInsManager::SendKillReq(const std::string &insId) { KillRequest killReq; + killReq.set_requestid(YR::utility::IDGenerator::GenRequestId()); killReq.set_instanceid(insId); killReq.set_payload(""); killReq.set_signal(libruntime::Signal::KillInstance); YRLOG_DEBUG("start send kill req, ins id is {}", insId); - this->fsClient->KillAsync(killReq, [insId](KillResponse rsp, ErrorInfo err) -> void { + this->fsClient->KillAsync(killReq, [insId](KillResponse rsp, const ErrorInfo &err) -> void { if (rsp.code() != common::ERR_NONE) { YRLOG_WARN("kill req send failed, instance id is {}", insId); } @@ -51,6 +52,7 @@ void NormalInsManager::ScaleCancel(const RequestResource &resource, size_t reqNu SendKillReq(ins); } } + EraseResourceInfoMap(resource); return; } @@ -111,31 +113,29 @@ void NormalInsManager::SendCreateReq(std::shared_ptr spec, size_t de auto insInfo = std::make_shared("", 0); this->AddCreatingInsInfo(resource, insInfo); auto weakThis = weak_from_this(); - YR::utility::ExecuteByGlobalTimer( - [this, createSpec, insInfo, weakThis]() { + tw_->CreateTimer(deleyTime * MILLISECOND_UNIT, 1, [this, createSpec, insInfo, weakThis]() { + auto thisPtr = weakThis.lock(); + if (thisPtr == nullptr) { + return; + } + YRLOG_DEBUG("send create instance request, req id is {}", createSpec->requestId); + requestManager->PushRequest(createSpec); + auto respCallback = [this, createSpec, insInfo, weakThis](const CreateResponse &resp) -> void { auto thisPtr = weakThis.lock(); if (thisPtr == nullptr) { return; } - YRLOG_DEBUG("send create instance request, req id is {}", createSpec->requestId); - requestManager->PushRequest(createSpec); - auto respCallback = [this, createSpec, insInfo, weakThis](const CreateResponse &resp) -> void { - auto thisPtr = weakThis.lock(); - if (thisPtr == nullptr) { - return; - } - HandleCreateResponse(createSpec, resp, insInfo); - }; - if (!createSpec->opts.device.name.empty()) { - absl::ReaderMutexLock lock(&createCostMtx); - createCostMap[createSpec->opts.device.name] = TimeMeasurement(DEFAULT_CREATE_DURATION); - createCostMap[createSpec->opts.device.name].StartTimer(createSpec->requestId); - YRLOG_DEBUG("start timer for {}, reqID: {}", createSpec->opts.device.name, createSpec->requestId); - } - this->fsClient->CreateAsync(createSpec->requestCreate, respCallback, - std::bind(&NormalInsManager::HandleCreateNotify, this, std::placeholders::_1)); - }, - deleyTime * MILLISECOND_UNIT, 1); + HandleCreateResponse(createSpec, resp, insInfo); + }; + if (!createSpec->opts.device.name.empty()) { + absl::ReaderMutexLock lock(&createCostMtx); + createCostMap[createSpec->opts.device.name] = TimeMeasurement(DEFAULT_CREATE_DURATION); + createCostMap[createSpec->opts.device.name].StartTimer(createSpec->requestId); + YRLOG_DEBUG("start timer for {}, reqID: {}", createSpec->opts.device.name, createSpec->requestId); + } + this->fsClient->CreateAsync(createSpec->requestCreate, respCallback, + std::bind(&NormalInsManager::HandleCreateNotify, this, std::placeholders::_1)); + }); } void NormalInsManager::HandleCreateResponse(std::shared_ptr spec, const CreateResponse &resp, @@ -151,7 +151,7 @@ void NormalInsManager::HandleCreateResponse(std::shared_ptr spec, co YRLOG_ERROR( "start handle fail create response, req id is {}, trace id is {}, instance id is {}, code is {}, message " "is {}", - spec->requestId, spec->traceId, instanceId, resp.code(), resp.message()); + spec->requestId, spec->traceId, instanceId, fmt::underlying(resp.code()), resp.message()); auto resource = GetRequestResource(spec); HandleFailCreateNotify(spec, resource); scheduleInsCb(resource, ErrorInfo(static_cast(resp.code()), ModuleCode::CORE, resp.message(), true), @@ -187,7 +187,7 @@ void NormalInsManager::HandleCreateNotify(const NotifyRequest &req) YRLOG_ERROR( "handle normal function instance create failed notify or response, request id is: {}, instance id is: {}, " "trace id is: {},err code is {}, err msg is {}", - reqId, createSpec->instanceId, createSpec->traceId, errInfo.Code(), errInfo.Msg()); + reqId, createSpec->instanceId, createSpec->traceId, fmt::underlying(errInfo.Code()), errInfo.Msg()); HandleFailCreateNotify(createSpec, resource); } else { HandleSuccessCreateNotify(createSpec, resource, req); @@ -223,7 +223,7 @@ void NormalInsManager::HandleSuccessCreateNotify(const std::shared_ptrChangeCreateFailNum(resource, false); - auto info = GetRequestResourceInfo(resource); + auto info = GetOrAddRequestResourceInfo(resource); // ensure atomicity of erase creating and add instances. // avoid creating unnecessary instances when judge NeedCreateNewIns bool isErased; @@ -242,18 +242,15 @@ void NormalInsManager::HandleSuccessCreateNotify(const std::shared_ptrStartNormalInsScaleDownTimer(resource, createSpec->instanceId); } else { SendKillReq(createSpec->instanceId); + EraseResourceInfoMap(resource, REQUEST_RESOURCE_USE_COUNT); } } void NormalInsManager::ScaleDownHandler(const RequestResource &resource, const std::string &id) { - std::shared_ptr info; - { - absl::ReaderMutexLock lock(&insMtx); - if (requestResourceInfoMap.find(resource) == requestResourceInfoMap.end()) { - return; - } - info = requestResourceInfoMap[resource]; + auto info = GetRequestResourceInfo(resource); + if (info == nullptr) { + return; } std::shared_ptr insInfo; { @@ -271,24 +268,22 @@ void NormalInsManager::ScaleDownHandler(const RequestResource &resource, const s DelInsInfoBare(id, info); } SendKillReq(id); + EraseResourceInfoMap(resource, REQUEST_RESOURCE_USE_COUNT); scheduleInsCb(resource, ErrorInfo(), IsRemainIns(resource)); } void NormalInsManager::StartNormalInsScaleDownTimer(const RequestResource &resource, const std::string &id) { auto weakPtr = weak_from_this(); - auto info = GetRequestResourceInfo(resource); std::shared_ptr insInfo = GetInstanceInfo(resource, id); if (insInfo == nullptr) { return; } - auto timer = YR::utility::ExecuteByGlobalTimer( - [this, weakPtr, resource, id]() { - if (auto thisPtr = weakPtr.lock(); thisPtr) { - ScaleDownHandler(resource, id); - } - }, - libRuntimeConfig->recycleTime * MILLISECOND_UNIT, 1); + auto timer = tw_->CreateTimer(libRuntimeConfig->recycleTime * MILLISECOND_UNIT, 1, [this, weakPtr, resource, id]() { + if (auto thisPtr = weakPtr.lock(); thisPtr) { + ScaleDownHandler(resource, id); + } + }); CancelScaleDownTimer(insInfo); absl::WriterMutexLock insLock(&insInfo->mtx); insInfo->scaleDownTimer = timer; @@ -296,7 +291,7 @@ void NormalInsManager::StartNormalInsScaleDownTimer(const RequestResource &resou void NormalInsManager::AddInsInfo(const std::shared_ptr createSpec, const RequestResource &resource) { - auto info = GetRequestResourceInfo(resource); + auto info = GetOrAddRequestResourceInfo(resource); absl::WriterMutexLock lock(&info->mtx); AddInsInfoBare(createSpec, info); } @@ -314,6 +309,6 @@ void NormalInsManager::AddInsInfoBare(const std::shared_ptr createSp } } -void NormalInsManager::StartRenewTimer(const RequestResource &resource, const std::string &insId) {} +void NormalInsManager::StartBatchRenewTimer() {} } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/normal_instance_manager.h b/src/libruntime/invokeadaptor/normal_instance_manager.h index 5656627..aea1ce6 100644 --- a/src/libruntime/invokeadaptor/normal_instance_manager.h +++ b/src/libruntime/invokeadaptor/normal_instance_manager.h @@ -32,7 +32,7 @@ public: bool ScaleUp(std::shared_ptr spec, size_t reqNum) override; void ScaleDown(const std::shared_ptr spec, bool isInstanceNormal = false) override; void ScaleCancel(const RequestResource &resource, size_t reqNum, bool cleanAll = false) override; - void StartRenewTimer(const RequestResource &resource, const std::string &insId) override; + void StartBatchRenewTimer() override; void StartNormalInsScaleDownTimer(const RequestResource &resource, const std::string &id); void HandleCreateResponse(std::shared_ptr spec, const CreateResponse &resp, std::shared_ptr insInfo); diff --git a/src/libruntime/invokeadaptor/request_manager.cpp b/src/libruntime/invokeadaptor/request_manager.cpp index 605f9f9..a1ef5c0 100644 --- a/src/libruntime/invokeadaptor/request_manager.cpp +++ b/src/libruntime/invokeadaptor/request_manager.cpp @@ -19,6 +19,12 @@ namespace YR { namespace Libruntime { +void RequestManager::PushFaasRequest(const std::string &leaseId, const std::shared_ptr spec) +{ + absl::WriterMutexLock lock(&reqMtx_); + requestMap[leaseId] = spec; +} + void RequestManager::PushRequest(const std::shared_ptr spec) { absl::WriterMutexLock lock(&reqMtx_); diff --git a/src/libruntime/invokeadaptor/request_manager.h b/src/libruntime/invokeadaptor/request_manager.h index eea6fc1..015d66d 100644 --- a/src/libruntime/invokeadaptor/request_manager.h +++ b/src/libruntime/invokeadaptor/request_manager.h @@ -24,6 +24,7 @@ namespace YR { namespace Libruntime { class RequestManager { public: + void PushFaasRequest(const std::string &leaseId, const std::shared_ptr spec); void PushRequest(const std::shared_ptr spec); diff --git a/src/libruntime/invokeadaptor/request_queue.cpp b/src/libruntime/invokeadaptor/request_queue.cpp index acdfd76..2ed9ef4 100644 --- a/src/libruntime/invokeadaptor/request_queue.cpp +++ b/src/libruntime/invokeadaptor/request_queue.cpp @@ -50,5 +50,36 @@ bool PriorityQueue::Empty() const return this->queue.empty(); } +void Queue::Pop() +{ + absl::WriterMutexLock lock(&this->mutex); + this->queue.pop(); +} + +std::shared_ptr Queue::Top() +{ + // 读写锁 + absl::ReaderMutexLock lock(&this->mutex); + return this->queue.front(); +} + +void Queue::Push(std::shared_ptr spec) +{ + absl::WriterMutexLock lock(&this->mutex); + this->queue.push(spec); +} + +size_t Queue::Size() const +{ + absl::ReaderMutexLock lock(&this->mutex); + return this->queue.size(); +} + +bool Queue::Empty() const +{ + absl::ReaderMutexLock lock(&this->mutex); + return this->queue.empty(); +} + } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/invokeadaptor/request_queue.h b/src/libruntime/invokeadaptor/request_queue.h index d345f92..99dfba8 100644 --- a/src/libruntime/invokeadaptor/request_queue.h +++ b/src/libruntime/invokeadaptor/request_queue.h @@ -28,6 +28,7 @@ namespace Libruntime { class BaseQueue { public: + virtual ~BaseQueue() = default; virtual void Pop() = 0; virtual std::shared_ptr Top() = 0; virtual void Push(std::shared_ptr spec) = 0; @@ -57,5 +58,18 @@ private: mutable absl::Mutex mutex; }; +class Queue : public BaseQueue { +public: + void Pop() override; + std::shared_ptr Top() override; + void Push(std::shared_ptr spec) override; + size_t Size() const override; + bool Empty() const override; + +private: + std::queue> queue; + mutable absl::Mutex mutex; +}; + } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/invokeadaptor/scheduler_instance_info.cpp b/src/libruntime/invokeadaptor/scheduler_instance_info.cpp new file mode 100644 index 0000000..cc8390e --- /dev/null +++ b/src/libruntime/invokeadaptor/scheduler_instance_info.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "scheduler_instance_info.h" + +namespace YR { +namespace Libruntime { +using nlohmann::json; +using namespace YR::utility; + + +const static std::string SCHEDULER_FUNC_KEY = "schedulerFuncKey"; +const static std::string SCHEDULER_INSTANCE_LIST = "schedulerInstanceList"; + +void from_json(const json &j, std::vector &schedulerInstanceList) +{ + for (auto &info : j) { + SchedulerInstance instanceInfo; + JsonGetTo(info, "instanceName", instanceInfo.InstanceName); + JsonGetTo(info, "instanceId", instanceInfo.InstanceID); + schedulerInstanceList.push_back(instanceInfo); + } +} + +ErrorInfo ParseSchedulerInfo(const std::string &payload, SchedulerInfo &schedulerInfo) +{ + try { + json j = json::parse(payload); + schedulerInfo.schedulerFuncKey = j[SCHEDULER_FUNC_KEY].get(); + schedulerInfo.schedulerInstanceList = j[SCHEDULER_INSTANCE_LIST].get>(); + } catch (std::exception &e) { + return ErrorInfo(ErrorCode::ERR_PARAM_INVALID, + std::string("parse schedulerInfo info: ") + e.what()); + } + return ErrorInfo(); +} + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/scheduler_instance_info.h b/src/libruntime/invokeadaptor/scheduler_instance_info.h new file mode 100644 index 0000000..4a8da64 --- /dev/null +++ b/src/libruntime/invokeadaptor/scheduler_instance_info.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include "absl/synchronization/mutex.h" + +#include "src/libruntime/err_type.h" +#include "src/utility/json_utility.h" + +namespace YR { +namespace Libruntime { +struct SchedulerInstance { + std::string InstanceName; + std::string InstanceID; + long long updateTime; + bool isAvailable = true; +}; + +struct SchedulerInfo { + std::string schedulerFuncKey; + std::vector schedulerInstanceList; +}; + +struct AvailableSchedulerInfos { + std::vector> schedulerInstanceList; + mutable absl::Mutex schedulerMtx; +}; + +void from_json(const nlohmann::json &j, std::vector &schedulerInstanceList); + +ErrorInfo ParseSchedulerInfo(const std::string &payload, SchedulerInfo &schedulerInfo); + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/task_scheduler.cpp b/src/libruntime/invokeadaptor/task_scheduler.cpp index 2796561..24539e8 100644 --- a/src/libruntime/invokeadaptor/task_scheduler.cpp +++ b/src/libruntime/invokeadaptor/task_scheduler.cpp @@ -20,6 +20,7 @@ namespace YR { namespace Libruntime { void TaskScheduler::Run() { + runFlag_ = true; // It is possible to optimize to have multiple schedulers share a single thread. t = std::thread([this]() { Schedule(); }); pthread_setname_np(t.native_handle(), "task_scheduler"); @@ -35,7 +36,7 @@ void TaskScheduler::Schedule() } if (scheduleFlag_) { scheduleFlag_ = false; - mtx_.unlock(); + lockGuard.unlock(); if (func_) { func_(); } @@ -45,6 +46,9 @@ void TaskScheduler::Schedule() void TaskScheduler::Stop() { + if (!runFlag_) { + return; + } { std::unique_lock lockGuard(mtx_); runFlag_ = false; @@ -54,11 +58,40 @@ void TaskScheduler::Stop() t.join(); } } + void TaskScheduler::Notify() { + if (scheduleFlag_) { + return; + } std::unique_lock lockGuard(mtx_); scheduleFlag_ = true; condVar_.notify_one(); } + +void TaskSchedulerWrapper::SetLastError(const YR::Libruntime::ErrorInfo &err) +{ + std::lock_guard lockGuard(errLock_); + this->lastError_ = YR::Libruntime::ErrorInfo(err.Code(), err.MCode(), err.Msg()); +} + +YR::Libruntime::ErrorInfo TaskSchedulerWrapper::GetLastError() +{ + std::lock_guard lockGuard(errLock_); + return this->lastError_; +} + +bool TaskSchedulerWrapper::IsLastErrorOk() +{ + std::lock_guard lockGuard(errLock_); + return this->lastError_.OK(); +} + +void TaskSchedulerWrapper::Notify() +{ + if (taskScheduler_) { + taskScheduler_->Notify(); + } +} } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/task_scheduler.h b/src/libruntime/invokeadaptor/task_scheduler.h index fbaf28a..00fa6df 100644 --- a/src/libruntime/invokeadaptor/task_scheduler.h +++ b/src/libruntime/invokeadaptor/task_scheduler.h @@ -24,6 +24,7 @@ #include #include "src/libruntime/err_type.h" +#include "src/libruntime/invokeadaptor/request_queue.h" namespace YR { namespace Libruntime { @@ -40,13 +41,45 @@ public: private: void Schedule(); - std::atomic runFlag_{true}; + std::atomic runFlag_{false}; bool scheduleFlag_ = false; std::mutex mtx_; std::condition_variable condVar_; ScheduleFunc func_; std::thread t; +}; + +class TaskSchedulerWrapper { +public: + TaskSchedulerWrapper(bool enablePriority) + { + if (enablePriority) { + queue = std::make_shared(); + } else { + queue = std::make_shared(); + } + } + + explicit TaskSchedulerWrapper(std::shared_ptr taskScheduler, bool enablePriority) + : taskScheduler_(taskScheduler) + { + if (enablePriority) { + queue = std::make_shared(); + } else { + queue = std::make_shared(); + } + } + + ~TaskSchedulerWrapper() = default; + void Notify(); + void SetLastError(const YR::Libruntime::ErrorInfo &err); + YR::Libruntime::ErrorInfo GetLastError(); + bool IsLastErrorOk(); + std::shared_ptr queue; +private: + std::shared_ptr taskScheduler_; YR::Libruntime::ErrorInfo lastError_ = ErrorInfo(); + std::mutex errLock_; }; } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/task_submitter.cpp b/src/libruntime/invokeadaptor/task_submitter.cpp index 9e451ec..62cfce0 100644 --- a/src/libruntime/invokeadaptor/task_submitter.cpp +++ b/src/libruntime/invokeadaptor/task_submitter.cpp @@ -24,11 +24,13 @@ #include "src/dto/config.h" #include "src/dto/constant.h" #include "src/libruntime/fsclient/protobuf/core_service.grpc.pb.h" +#include "src/libruntime/invokeadaptor/faas_instance_manager.h" #include "src/libruntime/invokeadaptor/normal_instance_manager.h" #include "src/libruntime/utils/constants.h" #include "src/libruntime/utils/exception.h" #include "src/libruntime/utils/serializer.h" #include "src/proto/libruntime.pb.h" +#include "src/scene/downgrade.h" #include "src/utility/id_generator.h" #include "src/utility/timer_worker.h" #include "task_submitter.h" @@ -37,11 +39,11 @@ namespace Libruntime { const std::string INSTANCE_REQUIREMENT_RESOURKEY = "resourcesData"; const std::string INSTANCE_REQUIREMENT_INSKEY = "designateInstanceID"; const std::string INSTANCE_REQUIREMENT_POOLLABELKEY = "poolLabel"; -const int64_t BEFOR_RETAIN_TIME = 30; // millisecond +const int64_t BEFOR_RETAIN_TIME = 30; // millisecond const int SECONDS_TO_MILLISECONDS_UNIT = 1000; // millisecond const int64_t IDLE_TIMER_INTERNAL = 10; -const int DEFALUT_CANCEL_DELAY_TIME = 5; // second -const std::string DS_OBJECTID_SEPERATOR = ";"; +const int DEFALUT_CANCEL_DELAY_TIME = 5; // second +const int ERASE_DELAY_TIME = 30; using namespace YR::utility; using json = nlohmann::json; @@ -56,9 +58,19 @@ ErrorInfo PackageNotifyErr(const NotifyRequest ¬ifyReq, bool isCreate) TaskSubmitter::TaskSubmitter(std::shared_ptr config, std::shared_ptr store, std::shared_ptr client, std::shared_ptr reqMgr, - CancelFunc cancelFunc) - : libRuntimeConfig(config), memoryStore(store), fsClient(client), requestManager(reqMgr), cancelCb(cancelFunc) + CancelFunc cancelFunc, std::shared_ptr ar, + std::shared_ptr adaptor, + std::shared_ptr downgrade) + : libRuntimeConfig(config), + memoryStore(store), + fsClient(client), + requestManager(reqMgr), + cancelCb(cancelFunc), + ar_(ar), + metricsAdaptor_(adaptor), + downgrade_(downgrade) { + enablePriority_ = Config::Instance().ENABLE_PRIORITY(); this->Init(); } @@ -69,8 +81,32 @@ void TaskSubmitter::Init() std::placeholders::_2, std::placeholders::_3), fsClient, memoryStore, requestManager, libRuntimeConfig); normalInsManager->SetDeleleInsCallback(std::bind(&TaskSubmitter::DeleteInsCallback, this, std::placeholders::_1)); + auto faasInsManager = + std::make_shared(std::bind(&TaskSubmitter::ScheduleIns, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3), + fsClient, memoryStore, requestManager, libRuntimeConfig); + insManagers[libruntime::ApiType::Faas] = faasInsManager; + insManagers[libruntime::ApiType::Serve] = faasInsManager; insManagers[libruntime::ApiType::Function] = normalInsManager; this->UpdateConfig(); + if (Config::Instance().ENABLE_METRICS()) { + this->metricsEnable_ = true; + } + faasInsManager->StartBatchRenewTimer(); + taskScheduler_ = std::make_shared(std::bind(&TaskSubmitter::ScheduleFunctions, this)); + taskScheduler_->Run(); + if (downgrade_) { + downgrade_->Init(); + } + eraseTimer_ = YR::utility::ExecuteByGlobalTimer([this]() { this->EraseInsResourceMap(); }, + ERASE_DELAY_TIME * MILLISECOND_UNIT, -1); +} + +void TaskSubmitter::EraseInsResourceMap() +{ + for (auto it = insManagers.begin(); it != insManagers.end(); ++it) { + it->second->EraseResourceInfoMap(); + } } void TaskSubmitter::UpdateConfig() @@ -130,6 +166,40 @@ void TaskSubmitter::HandleInvokeNotify(const NotifyRequest &req, const ErrorInfo } } +void TaskSubmitter::DowngradeCallback(const std::string &requestId, Libruntime::ErrorCode code, + const std::string &result) +{ + auto spec = requestManager->GetRequest(requestId); + if (spec == nullptr) { + YRLOG_WARN( + "request id : {} did not exit in request manager, may be the invoke request has been canceled or finished.", + requestId); + return; + } + ErrorInfo errInfo; + if (code != ErrorCode::ERR_OK) { + errInfo = ErrorInfo(code, result); + memoryStore->SetError(spec->returnIds, errInfo); + } else { + auto &dataObj = spec->returnIds.front(); + std::shared_ptr buf = std::make_shared(result.size() + MetaDataLen); + dataObj.SetBuffer(buf); + (void)memset_s(dataObj.meta->MutableData(), dataObj.meta->GetSize(), 0, dataObj.meta->GetSize()); + dataObj.data->MemoryCopy(result.data(), result.size()); + memoryStore->Put(dataObj.buffer, dataObj.id, {}, false); + memoryStore->SetReady(dataObj.id); + } + auto ids = this->memoryStore->UnbindObjRefInReq(spec->requestId); + auto errorInfo = this->memoryStore->DecreGlobalReference(ids); + if (!errorInfo.OK()) { + YRLOG_WARN("failed to decrease obj ref [{},...] by requestid {}. Code: {}, MCode: {}, Msg: {}", ids[0], + spec->requestId, fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), + errorInfo.Msg()); + } + requestManager->RemoveRequest(spec->requestId); + this->UpdateFaasInvokeLog(spec->requestId, errInfo); +} + bool TaskSubmitter::HandleFailInvokeIsDelayScaleDown(const NotifyRequest &req, const ErrorInfo &err) { // If error code is less than 2000, it means user operation error, and the instance is normal, @@ -137,7 +207,8 @@ bool TaskSubmitter::HandleFailInvokeIsDelayScaleDown(const NotifyRequest &req, c // Otherwise, invoke failure is either because of fault or only user code error, // it cannot be distinguished here, so just scale down the instance with no delay, // to avoid instance/node fault. - YRLOG_INFO("check if invoke is abnormal notify request code {} requestid {}", req.code(), req.requestid()); + YRLOG_INFO("check if invoke is abnormal notify request code {} requestid {}", fmt::underlying(req.code()), + req.requestid()); if (req.code() == common::ErrorCode::ERR_INSTANCE_NOT_FOUND || req.code() == common::ErrorCode::ERR_INSTANCE_EXITED || req.code() == common::ErrorCode::ERR_INSTANCE_EVICTED) { return false; @@ -164,25 +235,23 @@ void TaskSubmitter::HandleFailInvokeNotify(const NotifyRequest &req, const std:: YRLOG_ERROR( "normal invoke requet fail, need retry, raw request id is {}, code is: {}, trace id is {}, seq is {}, " "complete request id is {}", - req.requestid(), req.code(), spec->traceId, spec->seq, spec->requestInvoke->Mutable().requestid()); + req.requestid(), fmt::underlying(req.code()), spec->traceId, spec->seq, + spec->requestInvoke->Mutable().requestid()); if (isConsumeRetryTime) { spec->ConsumeRetryTime(); YRLOG_DEBUG("consumed invoke retry time to {}, req id is {}", spec->opts.retryTimes, req.requestid()); } - std::shared_ptr requestQueue; + auto taskScheduler = GetOrAddTaskScheduler(resource); { - absl::ReaderMutexLock lock(&reqMtx_); - requestQueue = waitScheduleReqMap_[resource]; - } - { - std::lock_guard lockGuard(requestQueue->atomicMtx); - requestQueue->Push(spec); + std::lock_guard lockGuard(taskScheduler->queue->atomicMtx); + taskScheduler->queue->Push(spec); } } else { YRLOG_ERROR( "normal invoke requet fail, don't need retry, raw request id is {}, code is: {}, trace id is {}, seq is " "{}, complete request id is {}", - req.requestid(), req.code(), spec->traceId, spec->seq, spec->requestInvoke->Mutable().requestid()); + req.requestid(), fmt::underlying(req.code()), spec->traceId, spec->seq, + spec->requestInvoke->Mutable().requestid()); if (this->libRuntimeConfig->inCluster) { std::vector dsObjs; for (auto it = spec->returnIds.begin(); it != spec->returnIds.end(); ++it) { @@ -195,15 +264,14 @@ void TaskSubmitter::HandleFailInvokeNotify(const NotifyRequest &req, const std:: auto errorInfo = this->memoryStore->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease obj ref [{},...] by requestid {}. Code: {}, MCode: {}, Msg: {}", ids[0], - spec->requestId, errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + spec->requestId, fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), + errorInfo.Msg()); } this->memoryStore->SetError(spec->returnIds, errInfo); (void)requestManager->RemoveRequest(spec->requestId); + this->UpdateFaasInvokeLog(spec->requestId, errInfo); } - absl::ReaderMutexLock lock(&reqMtx_); - if (taskSchedulerMap_.find(resource) != taskSchedulerMap_.end()) { - taskSchedulerMap_[resource]->Notify(); - } + NotifyScheduler(resource); } void TaskSubmitter::HandleSuccessInvokeNotify(const NotifyRequest &req, const std::shared_ptr spec, @@ -248,97 +316,229 @@ void TaskSubmitter::HandleSuccessInvokeNotify(const NotifyRequest &req, const st auto errorInfo = this->memoryStore->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease obj ref [{},...] by requestid {}. Code: {}, MCode: {}, Msg: {}", ids[0], - spec->requestId, errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + spec->requestId, fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), + errorInfo.Msg()); } - absl::ReaderMutexLock lock(&reqMtx_); - if (taskSchedulerMap_.find(resource) != taskSchedulerMap_.end()) { - taskSchedulerMap_[resource]->Notify(); + if (spec->functionMeta.apiType == libruntime::ApiType::Faas) { + this->UpdateFaasInvokeLog(spec->requestId, ErrorInfo()); } + NotifyScheduler(resource); } -void TaskSubmitter::SubmitFunction(std::shared_ptr spec) +void TaskSubmitter::AddFaasCancelTimer(std::shared_ptr spec) { - YRLOG_DEBUG("start submit stateless function, req id is {}, return obj id is {}, trace id is {}", spec->requestId, - spec->returnIds[0].id, spec->traceId); - RequestResource resource = GetRequestResource(spec); - std::shared_ptr queue; - std::shared_ptr taskScheduler; + if (spec->functionMeta.apiType == libruntime::ApiType::Faas) { + auto weakThis = weak_from_this(); + auto reqId = spec->requestId; + auto timeoutSec = + spec->opts.acquireTimeout > 0 ? spec->opts.acquireTimeout : Config::Instance().FASS_SCHEDULE_TIMEOUT(); + // The Purpose of this timer is to remove the corresponding spec from request manager and set the invoke req + // as failed if it do not send invoke req after set timeout seconds(default 120s). + { + absl::WriterMutexLock lock(&cancelTimerMtx_); + faasCancelTimerWorkers[reqId] = YR::utility::ExecuteByGlobalTimer( + [weakThis, reqId, timeoutSec]() { + auto thisPtr = weakThis.lock(); + if (!thisPtr) { + return; + } + thisPtr->CancelFaasScheduleTimeoutReq(reqId, timeoutSec); + }, + (timeoutSec + DEFALUT_CANCEL_DELAY_TIME) * MILLISECOND_UNIT, 1); + } + /** + * Record faas invoke function data in 3 steps: + 1. initializes the relevant data when submit the faas function + 2. Record the time of sending faas invoke request after the success of the acquire request, and record + the end time of the faas request when the acquire fails + 3. After receiving notify (whether successful or failed), record the end time of the faas request + */ + this->RecordFaasInvokeData(spec); + } +} + +std::shared_ptr TaskSubmitter::GetTaskScheduler(const RequestResource &resource) +{ + std::shared_ptr taskScheduler; { absl::ReaderMutexLock lock(&reqMtx_); - if (waitScheduleReqMap_.find(resource) != waitScheduleReqMap_.end()) { - queue = waitScheduleReqMap_[resource]; - } if (taskSchedulerMap_.find(resource) != taskSchedulerMap_.end()) { taskScheduler = taskSchedulerMap_[resource]; } } - if (queue == nullptr && taskScheduler == nullptr) { + return taskScheduler; +} + +std::shared_ptr TaskSubmitter::GetOrAddTaskScheduler(const RequestResource &resource) +{ + auto taskScheduler = GetTaskScheduler(resource); + if (taskScheduler == nullptr) { absl::WriterMutexLock lock(&reqMtx_); - if (waitScheduleReqMap_.find(resource) == waitScheduleReqMap_.end()) { - queue = std::make_shared(); - auto weakPtr = weak_from_this(); - auto cb = [weakPtr, resource, this]() { - auto thisPtr = weakPtr.lock(); - if (!thisPtr) { - return; - } - ScheduleFunction(resource); - }; - taskScheduler = std::make_shared(cb); - taskScheduler->Run(); - waitScheduleReqMap_[resource] = queue; + if (taskSchedulerMap_.find(resource) == taskSchedulerMap_.end()) { + taskScheduler = std::make_shared(taskScheduler_, enablePriority_); taskSchedulerMap_[resource] = taskScheduler; } else { - queue = waitScheduleReqMap_[resource]; taskScheduler = taskSchedulerMap_[resource]; } } + return taskScheduler; +} + +void TaskSubmitter::SubmitFunction(std::shared_ptr spec) +{ + YRLOG_DEBUG("start submit stateless function, req id is {}, return obj id is {}, trace id is {}", spec->requestId, + spec->returnIds[0].id, spec->traceId); + RequestResource resource = GetRequestResource(spec); + auto taskScheduler = GetOrAddTaskScheduler(resource); { - std::lock_guard lockGuard(queue->atomicMtx); - queue->Push(spec); + std::lock_guard lockGuard(taskScheduler->queue->atomicMtx); + taskScheduler->queue->Push(spec); } + AddFaasCancelTimer(spec); taskScheduler->Notify(); } -void TaskSubmitter::ScheduleIns(const RequestResource &resource, const ErrorInfo &errInfo, bool isRemainIns) +void TaskSubmitter::RecordFaasInvokeData(const std::shared_ptr spec) { - if (errInfo.OK()) { - absl::ReaderMutexLock lock(&reqMtx_); - if (taskSchedulerMap_.find(resource) != taskSchedulerMap_.end()) { - taskSchedulerMap_[resource]->Notify(); + YRLOG_DEBUG("start recorde value, function id is {}, req id is {}", spec->functionMeta.functionId, spec->requestId); + absl::WriterMutexLock lock(&invokeDataMtx_); + if (!metricsEnable_) { + return; + } + if (faasInvokeDataMap_.find(spec->requestId) == faasInvokeDataMap_.end()) { + FaasInvokeData data(this->libRuntimeConfig->tenantId, spec->functionMeta.funcName, + ar_ ? this->ar_->ParseAlias(spec->functionMeta.functionId, spec->opts.aliasParams) : "", + spec->traceId, GetCurrentTimestampNs()); + auto pos = data.aliAs.find_last_of('/'); + if (pos != std::string::npos && pos != data.aliAs.length() - 1) { + data.version = data.aliAs.substr(pos + 1); } + faasInvokeDataMap_[spec->requestId] = std::make_shared(data); + } +} + +void TaskSubmitter::CancelFaasScheduleTimeoutReq(const std::string &reqId, int timeoutSeconds) +{ + auto spec = requestManager->GetRequest(reqId); + if (!spec) { + YRLOG_DEBUG("spec of request {} is nullptr, return directly", reqId); return; } - if (NeedRetryCreate(errInfo)) { - YRLOG_INFO("start retry create task instance, code: {}, msg: {}", errInfo.Code(), errInfo.Msg()); - absl::ReaderMutexLock lock(&reqMtx_); - if (taskSchedulerMap_.find(resource) != taskSchedulerMap_.end()) { - taskSchedulerMap_[resource]->Notify(); + if (spec->invokeInstanceId.empty()) { + (void)requestManager->RemoveRequest(reqId); + auto ids = memoryStore->UnbindObjRefInReq(spec->requestId); + auto errorInfo = memoryStore->DecreGlobalReference(ids); + if (!errorInfo.OK()) { + YRLOG_WARN("failed to decrease obj ref [{},...] by requestid {}. Code: {}, MCode: {}, Msg: {}", ids[0], + spec->requestId, fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), + errorInfo.Msg()); } + auto resource = GetRequestResource(spec); + ErrorInfo err = ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, ModuleCode::RUNTIME, + "request is still scheduled within " + std::to_string(timeoutSeconds) + + " seconds, cancel the faas invoke request"); + { + absl::ReaderMutexLock lock(&reqMtx_); + auto it = taskSchedulerMap_.find(resource); + if (it != taskSchedulerMap_.end()) { + auto scheduler = it->second; + if (!scheduler->IsLastErrorOk()) { + err = scheduler->GetLastError(); + } + } + } + memoryStore->SetError(spec->returnIds, err); + YRLOG_ERROR("cancel req {} with {}s no scheduled, code: {}, message: {}, trace: {}", spec->requestId, + timeoutSeconds, fmt::underlying(err.Code()), err.Msg(), spec->traceId); + this->UpdateFaasInvokeLog(reqId, err); + } +} + +void TaskSubmitter::EraseFaasCancelTimer(const std::string &reqId) +{ + // three situations need to erase faas cancel timer + // 1. create ins failed and no need retry, there are no running or creating instance in current resource, then the + // faas req is considered a failure and then need to erase the cancel timer. Refer to TaskSubmitter::ScheduleIns + // method else branch for details. + // 2. RequestManager does not have a spec for reqId, indicating that timer has already been executed or the request + // has been cancelled for other reasons, so need to erase the cancel timer. Refer to TaskSubmitter::ScheduleRequest + // method for details. + // 3. Before send faas invoke req, need to stop and erase the cancel timer. Refer to TaskSubmitter::ScheduleRequest + // method for details. + absl::WriterMutexLock lock(&cancelTimerMtx_); + if (auto it = faasCancelTimerWorkers.find(reqId); it != faasCancelTimerWorkers.end()) { + if (it->second) { + YRLOG_DEBUG("start stop and erase the cancel timer of faas req: {}", reqId); + it->second->cancel(); + it->second.reset(); + faasCancelTimerWorkers.erase(it); + } + } +} + +void TaskSubmitter::NotifyScheduler(const RequestResource &resource, const ErrorInfo &err) +{ + auto taskScheduler = GetTaskScheduler(resource); + if (taskScheduler) { + if (!err.OK()) { + taskScheduler->SetLastError(err); + } + taskScheduler->Notify(); return; } - std::shared_ptr requestQueue; - std::shared_ptr taskScheduler; - { - absl::ReaderMutexLock lock(&reqMtx_); - if (waitScheduleReqMap_.find(resource) == waitScheduleReqMap_.end()) { + auto resources = GetScheduleResources(); + for (const auto &resource : resources) { + taskScheduler = GetTaskScheduler(resource); + if (taskScheduler) { + taskScheduler->Notify(); return; } - requestQueue = waitScheduleReqMap_[resource]; - taskScheduler = taskSchedulerMap_[resource]; + } +} + +void TaskSubmitter::ScheduleIns(const RequestResource &resource, const ErrorInfo &errInfo, bool isRemainIns) +{ + if (errInfo.OK()) { + NotifyScheduler(resource); + return; + } + if (NeedRetryCreate(errInfo)) { + YRLOG_INFO("start retry create task instance, code: {}, msg: {}", fmt::underlying(errInfo.Code()), + errInfo.Msg()); + NotifyScheduler(resource, errInfo); + return; + } + auto taskScheduler = GetTaskScheduler(resource); + if (taskScheduler == nullptr) { + return; } // If there are still other instances existing or being created under resource, equals the // isRemainIns value is true, then all requests under that resource should not be set as failed if (!isRemainIns) { - std::lock_guard lockGuard(requestQueue->atomicMtx); - for (; !requestQueue->Empty(); requestQueue->Pop()) { - auto reqId = requestQueue->Top()->requestId; + std::lock_guard lockGuard(taskScheduler->queue->atomicMtx); + for (; !taskScheduler->queue->Empty(); taskScheduler->queue->Pop()) { + auto reqId = taskScheduler->queue->Top()->requestId; + this->EraseFaasCancelTimer(reqId); + if (downgrade_->ShouldFaultDowngrade()) { + auto spec = requestManager->GetRequest(reqId); + downgrade_->Downgrade(spec, std::bind(&TaskSubmitter::DowngradeCallback, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3)); + continue; + } std::shared_ptr invokeSpecNeedFailed; bool specExits = requestManager->PopRequest(reqId, invokeSpecNeedFailed); if (!specExits) { continue; } + if (invokeSpecNeedFailed) { + YRLOG_ERROR( + "there is no available ins of resource, start set error, req id is {}, trace id is {}, invoke ins " + "id is {}", + invokeSpecNeedFailed->requestId, invokeSpecNeedFailed->traceId, + invokeSpecNeedFailed->invokeInstanceId); + } this->memoryStore->SetError(invokeSpecNeedFailed->returnIds[0].id, errInfo); + this->UpdateFaasInvokeLog(reqId, errInfo); } } if (taskScheduler) { @@ -348,11 +548,8 @@ void TaskSubmitter::ScheduleIns(const RequestResource &resource, const ErrorInfo void TaskSubmitter::SendInvokeReq(const RequestResource &resource, std::shared_ptr invokeSpec) { - YRLOG_DEBUG( - "start send stateless function invoke req, instance id is: {}, lease id is: {}, req id is: {}, return obj " - "id is: {}, function name is: {}, trace id is: {}", - invokeSpec->invokeInstanceId, invokeSpec->invokeLeaseId, invokeSpec->requestId, invokeSpec->returnIds[0].id, - invokeSpec->functionMeta.funcName, invokeSpec->traceId); + YRLOG_INFO("invoke function {}, instance: {}, lease: {}, req: {}, trace: {}", invokeSpec->functionMeta.funcName, + invokeSpec->invokeInstanceId, invokeSpec->invokeLeaseId, invokeSpec->requestId, invokeSpec->traceId); if (!invokeSpec->opts.device.name.empty()) { absl::WriterMutexLock lock(&invokeCostMtx_); @@ -360,6 +557,9 @@ void TaskSubmitter::SendInvokeReq(const RequestResource &resource, std::shared_p YRLOG_DEBUG("start timer for instance: {}, reqID: {}", invokeSpec->invokeInstanceId, invokeSpec->requestId); } + if (invokeSpec->functionMeta.IsServiceApiType()) { + this->UpdateFaasInvokeSendTime(invokeSpec->requestId); + } auto weakThis = weak_from_this(); auto insId = invokeSpec->invokeInstanceId; this->fsClient->InvokeAsync( @@ -389,16 +589,31 @@ bool TaskSubmitter::ScheduleRequest(const RequestResource &resource, std::shared auto invokeSpec = requestManager->GetRequest(requestId); if (invokeSpec == nullptr) { YRLOG_WARN("The request {} has been cancelled", requestId); + this->EraseFaasCancelTimer(requestId); requestQueue->Pop(); return false; } - auto [instanceId, leaseId] = insManagers[invokeSpec->functionMeta.apiType]->ScheduleIns(resource); + if (downgrade_->ShouldDowngrade(invokeSpec)) { + this->EraseFaasCancelTimer(requestId); + requestQueue->Pop(); + atomicLock.unlock(); + auto weakPtr = weak_from_this(); + downgrade_->Downgrade( + invokeSpec, [weakPtr](const std::string &requestId, Libruntime::ErrorCode code, const std::string &result) { + if (auto thisPtr = weakPtr.lock(); thisPtr) { + thisPtr->DowngradeCallback(requestId, code, result); + } + }); + return false; + } + auto [instanceId, leaseId] = insManagers[invokeSpec->functionMeta.apiType]->GetAvailableIns(resource); if (instanceId.empty()) { atomicLock.unlock(); YRLOG_DEBUG("invoke request {} can not be scheduled, instanceId is empty", requestId); bool needCreate = insManagers[invokeSpec->functionMeta.apiType]->ScaleUp(invokeSpec, requestQueueSize); return !needCreate; } + this->EraseFaasCancelTimer(requestId); requestQueue->Pop(); atomicLock.unlock(); invokeSpec->invokeInstanceId = instanceId; @@ -412,24 +627,21 @@ bool TaskSubmitter::ScheduleRequest(const RequestResource &resource, std::shared bool TaskSubmitter::CancelAndScheOtherRes(const RequestResource &resource) { absl::ReaderMutexLock lock(&reqMtx_); - auto it = waitScheduleReqMap_.find(resource); - if (it == waitScheduleReqMap_.end()) { + auto it = taskSchedulerMap_.find(resource); + if (it == taskSchedulerMap_.end()) { return true; } - if (!it->second->Empty()) { + if (!it->second->queue->Empty()) { return false; } - YRLOG_DEBUG( - "current resource req queue is empty, try scheduler other resource req. func name is {}, class " - "name is " - "{}", - resource.functionMeta.funcName, resource.functionMeta.className); + YRLOG_DEBUG("queue is empty, try schedule other queue. func name {}, class name {}", resource.functionMeta.funcName, + resource.functionMeta.className); insManagers[resource.functionMeta.apiType]->ScaleCancel(resource, 0, true); - for (auto &entry : waitScheduleReqMap_) { - if (waitScheduleReqMap_.find(entry.first) == it) { + for (auto &entry : taskSchedulerMap_) { + if (taskSchedulerMap_.find(entry.first) == it) { continue; } - if (!entry.second->Empty()) { + if (!entry.second->queue->Empty()) { taskSchedulerMap_[entry.first]->Notify(); return true; } @@ -437,6 +649,24 @@ bool TaskSubmitter::CancelAndScheOtherRes(const RequestResource &resource) return true; } +void TaskSubmitter::ScheduleFunctions() +{ + auto resources = GetScheduleResources(); + for (const auto &resource : resources) { + ScheduleFunction(resource); + } +} + +std::vector TaskSubmitter::GetScheduleResources() +{ + absl::ReaderMutexLock lock(&reqMtx_); + std::vector resources; + for (const auto &pair : taskSchedulerMap_) { + resources.push_back(pair.first); + } + return resources; +} + void TaskSubmitter::ScheduleFunction(const RequestResource &resource) { YRLOG_DEBUG("schedule resource req. func name is {}, class name is {}", resource.functionMeta.funcName, @@ -445,18 +675,39 @@ void TaskSubmitter::ScheduleFunction(const RequestResource &resource) return; } if (CancelAndScheOtherRes(resource)) { + EraseTaskSchedulerMap(resource); return; } - std::shared_ptr requestQueue; - { - absl::ReaderMutexLock lock(&reqMtx_); - requestQueue = waitScheduleReqMap_[resource]; + auto taskScheduler = GetTaskScheduler(resource); + if (taskScheduler == nullptr) { + return; } - while (!requestQueue->Empty()) { - if (auto finish = ScheduleRequest(resource, requestQueue); finish) { + while (!taskScheduler->queue->Empty()) { + if (auto finish = ScheduleRequest(resource, taskScheduler->queue); finish) { break; } } + EraseTaskSchedulerMap(resource); +} + +void TaskSubmitter::EraseTaskSchedulerMap(const RequestResource &resource) +{ + // yield make SubmitFunction can be executed first avoid frequent cleaning taskSchedulerMap + std::this_thread::yield(); + auto taskScheduler = GetTaskScheduler(resource); + if (taskScheduler == nullptr) { + return; + } + const static int currentCount = 2; + std::unique_lock atomicLock(taskScheduler->queue->atomicMtx); + if (taskScheduler->queue->Empty()) { + absl::WriterMutexLock lock(&reqMtx_); + // if there are still references elsewhere, do not erase. + if (taskScheduler.use_count() <= currentCount) { + YRLOG_DEBUG("remove taskScheduler"); + taskSchedulerMap_.erase(resource); + } + } } bool TaskSubmitter::NeedRetryCreate(const ErrorInfo &errInfo) @@ -476,6 +727,9 @@ bool TaskSubmitter::NeedRetry(const ErrorInfo &errInfo, const std::shared_ptrinvokeType == libruntime::InvokeType::InvokeFunctionStateless && (errCode == ErrorCode::ERR_INSTANCE_EVICTED || errCode == ErrorCode::ERR_INSTANCE_NOT_FOUND || errCode == ErrorCode::ERR_INSTANCE_EXITED)) { + if (spec->ExceedMaxRetryTime()) { + return false; + } isConsumeRetryTime = false; // the only case to retry without consuming return true; } @@ -560,6 +814,7 @@ ErrorInfo TaskSubmitter::CancelStatelessRequest(const std::vector & } if (cancelReq.find(requestId) != cancelReq.end()) { memoryStore->SetError(objids[i], cancelErr); + this->UpdateFaasInvokeLog(requestId, cancelErr); } } return ErrorInfo(); @@ -587,20 +842,56 @@ void TaskSubmitter::Finalize() return; } runFlag = false; - { - absl::ReaderMutexLock lock(&reqMtx_); - for (auto &pair : taskSchedulerMap_) { - pair.second->Stop(); - } - } { absl::WriterMutexLock lock(&reqMtx_); taskSchedulerMap_.clear(); - waitScheduleReqMap_.clear(); } + if (taskScheduler_) { + taskScheduler_->Stop(); + } + for (auto &pair : insManagers) { pair.second->Stop(); } + { + absl::WriterMutexLock lock(&cancelTimerMtx_); + for (auto it = faasCancelTimerWorkers.begin(); it != faasCancelTimerWorkers.end(); ++it) { + if (it->second) { + it->second->cancel(); + it->second.reset(); + } + } + faasCancelTimerWorkers.clear(); + } + if (eraseTimer_) { + eraseTimer_->cancel(); + eraseTimer_.reset(); + } +} + +std::pair TaskSubmitter::AcquireInstance(const std::string &stateId, + std::shared_ptr spec) +{ + return std::static_pointer_cast(insManagers[spec->functionMeta.apiType]) + ->AcquireInstance(stateId, spec); +} + +ErrorInfo TaskSubmitter::ReleaseInstance(const std::string &leaseId, const std::string &stateId, bool abnormal, + std::shared_ptr spec) +{ + return std::static_pointer_cast(insManagers[spec->functionMeta.apiType]) + ->ReleaseInstance(leaseId, stateId, abnormal, spec); +} + +void TaskSubmitter::UpdateFaaSSchedulerInfo(std::string schedulerFuncKey, + const std::vector &schedulerInfoList) +{ + if (insManagers[libruntime::ApiType::Faas]) { + insManagers[libruntime::ApiType::Faas]->UpdateSchedulerInfo(schedulerFuncKey, schedulerInfoList); + } + if (insManagers[libruntime::ApiType::Serve]) { + insManagers[libruntime::ApiType::Serve]->UpdateSchedulerInfo(schedulerFuncKey, schedulerInfoList); + } } void TaskSubmitter::DeleteInsCallback(const std::string &instanceId) @@ -609,5 +900,91 @@ void TaskSubmitter::DeleteInsCallback(const std::string &instanceId) fsClient->RemoveInsRtIntf(instanceId); } } + +void TaskSubmitter::UpdateSchdulerInfo(const std::string &schedulerName, const std::string &schedulerId, + const std::string &option) +{ + if (insManagers[libruntime::ApiType::Faas]) { + YRLOG_INFO("start update scheduler info, scheduler name is {}, schduler id is {}, option is {}", schedulerName, + schedulerId, option); + insManagers[libruntime::ApiType::Faas]->UpdateSchdulerInfo(schedulerName, schedulerId, option); + } +} + +YR::Libruntime::GaugeData TaskSubmitter::ConvertToGaugeData(const std::shared_ptr data, + const std::string &reqId) +{ + YR::Libruntime::GaugeData gaugeData; + gaugeData.name = "report_faas_invoke_data"; + gaugeData.labels["requestId"] = reqId; + gaugeData.labels["businessId"] = data->businessId; + gaugeData.labels["tenantId"] = data->tenantId; + gaugeData.labels["srcAppId"] = data->srcAppId; + gaugeData.labels["functionName"] = data->functionName; + gaugeData.labels["traceId"] = data->traceId; + gaugeData.labels["version"] = data->version; + gaugeData.labels["aliAs"] = data->aliAs; + gaugeData.labels["code"] = data->code; + gaugeData.labels["innerCode"] = data->innerCode; + gaugeData.labels["describeMsg"] = data->describeMsg; + gaugeData.labels["exec"] = + data->endTime - data->sendTime > 0 ? std::to_string(data->endTime - data->sendTime) : "0"; + gaugeData.labels["cost"] = + data->endTime - data->submitTime > 0 ? std::to_string(data->endTime - data->submitTime) : "0"; + return gaugeData; +} + +void TaskSubmitter::UpdateFaasInvokeSendTime(const std::string &reqId) +{ + absl::WriterMutexLock lock(&invokeDataMtx_); + if (!metricsEnable_) { + return; + } + if (faasInvokeDataMap_.find(reqId) != faasInvokeDataMap_.end()) { + faasInvokeDataMap_[reqId]->sendTime = GetCurrentTimestampNs(); + } +} + +void TaskSubmitter::UpdateFaasInvokeLog(const std::string &reqId, const ErrorInfo &err) +{ + { + absl::WriterMutexLock lock(&invokeDataMtx_); + if (!metricsEnable_) { + return; + } + } + std::shared_ptr data; + { + absl::WriterMutexLock lock(&invokeDataMtx_); + auto it = faasInvokeDataMap_.find(reqId); + if (it == faasInvokeDataMap_.end()) { + YRLOG_DEBUG("there is no invoke data of req: {}, no need update", reqId); + return; + } + it->second->endTime = GetCurrentTimestampNs(); + if (err.OK()) { + it->second->code = "200"; + } else { + if (it->second->sendTime > 0) { + it->second->code = "500"; + } else { + it->second->code = "400"; + } + } + it->second->innerCode = std::to_string(err.Code()); + auto errIt = errCodeToString.find(err.Code()); + it->second->describeMsg = errIt == errCodeToString.end() ? "UNKOWN" : errIt->second; + data = it->second; + faasInvokeDataMap_.erase(it); + } + if (this->metricsAdaptor_ && Config::Instance().ENABLE_METRICS()) { + YRLOG_DEBUG("start report faas invoke data, req id is {}, function id is {}", reqId, data->aliAs); + auto reportErr = this->metricsAdaptor_->ReportMetrics(ConvertToGaugeData(data, reqId)); + if (!reportErr.OK()) { + YRLOG_WARN("failed to report metrics, req id is {}, trace id is {}, err code is {}, msg is {}", reqId, + data->traceId, fmt::underlying(reportErr.Code()), reportErr.Msg()); + } + } +} } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/invokeadaptor/task_submitter.h b/src/libruntime/invokeadaptor/task_submitter.h index e6e03d0..8384da7 100644 --- a/src/libruntime/invokeadaptor/task_submitter.h +++ b/src/libruntime/invokeadaptor/task_submitter.h @@ -19,12 +19,15 @@ #include #include "src/dto/acquire_options.h" +#include "src/libruntime/invokeadaptor/alias_routing.h" #include "src/libruntime/invokeadaptor/instance_manager.h" #include "src/libruntime/invokeadaptor/request_manager.h" #include "src/libruntime/invokeadaptor/request_queue.h" #include "src/libruntime/invokeadaptor/task_scheduler.h" +#include "src/libruntime/metricsadaptor/metrics_adaptor.h" #include "src/libruntime/objectstore/memory_store.h" #include "src/libruntime/utils/utils.h" +#include "src/scene/downgrade.h" #include "src/utility/time_measurement.h" namespace YR { namespace Libruntime { @@ -35,10 +38,13 @@ class TaskSubmitter : public std::enable_shared_from_this { public: TaskSubmitter() = default; TaskSubmitter(std::shared_ptr config, std::shared_ptr store, - std::shared_ptr client, std::shared_ptr reqMgr, CancelFunc cancelFunc); + std::shared_ptr client, std::shared_ptr reqMgr, CancelFunc cancelFunc, + std::shared_ptr ar = nullptr, std::shared_ptr adaptor = nullptr, + std::shared_ptr downgrade = nullptr); virtual ~TaskSubmitter(); void SubmitFunction(std::shared_ptr spec); void ScheduleFunction(const RequestResource &resource); + void ScheduleFunctions(); void HandleInvokeNotify(const NotifyRequest &req, const ErrorInfo &err); void HandleFailInvokeNotify(const NotifyRequest &req, std::shared_ptr spec, const RequestResource &resource, const ErrorInfo &err); @@ -53,14 +59,36 @@ public: std::vector GetCreatingInsIds(); void UpdateConfig(); void Init(); + virtual std::pair AcquireInstance(const std::string &stateId, + std::shared_ptr spec); + virtual ErrorInfo ReleaseInstance(const std::string &leaseId, const std::string &stateId, bool abnormal, + std::shared_ptr spec); + virtual void UpdateFaaSSchedulerInfo(std::string schedulerFuncKey, + const std::vector &schedulerInstanceList); + void RecordFaasInvokeData(const std::shared_ptr spec); + void CancelFaasScheduleTimeoutReq(const std::string &reqId, int timeoutSeconds); + void EraseFaasCancelTimer(const std::string &reqId); + virtual void UpdateSchdulerInfo(const std::string &schedulerName, const std::string &schedulerId, + const std::string &option); + YR::Libruntime::GaugeData ConvertToGaugeData(const std::shared_ptr data, const std::string &reqId); + void UpdateFaasInvokeSendTime(const std::string &reqId); + void UpdateFaasInvokeLog(const std::string &reqId, const ErrorInfo &err); private: + void EraseInsResourceMap(); + std::vector GetScheduleResources(); + void AddFaasCancelTimer(std::shared_ptr spec); + void EraseTaskSchedulerMap(const RequestResource &resource); bool CancelAndScheOtherRes(const RequestResource &resource); + std::shared_ptr GetTaskScheduler(const RequestResource &resource); + std::shared_ptr GetOrAddTaskScheduler(const RequestResource &resource); void ScheduleIns(const RequestResource &resource, const ErrorInfo &err, bool isRemainIns); + void NotifyScheduler(const RequestResource &resource, const ErrorInfo &err = ErrorInfo()); bool HandleFailInvokeIsDelayScaleDown(const NotifyRequest &req, const ErrorInfo &err); void DeleteInsCallback(const std::string &instanceId); bool ScheduleRequest(const RequestResource &resource, std::shared_ptr requestQueue); void SendInvokeReq(const RequestResource &resource, std::shared_ptr invokeSpec); + void DowngradeCallback(const std::string &requestId, Libruntime::ErrorCode code, const std::string &result); std::shared_ptr libRuntimeConfig; std::atomic runFlag{true}; mutable absl::Mutex reqMtx_; @@ -72,11 +100,20 @@ private: std::unordered_map> insManagers; std::unordered_map invokeCostMap ABSL_GUARDED_BY(invokeCostMtx_); mutable absl::Mutex invokeCostMtx_; - std::unordered_map, HashFn> waitScheduleReqMap_ - ABSL_GUARDED_BY(reqMtx_); - std::unordered_map, HashFn> taskSchedulerMap_ + std::unordered_map, HashFn> taskSchedulerMap_ ABSL_GUARDED_BY(reqMtx_); + std::unordered_map> faasCancelTimerWorkers + ABSL_GUARDED_BY(cancelTimerMtx_); CancelFunc cancelCb; + std::unordered_map> faasInvokeDataMap_ ABSL_GUARDED_BY(invokeDataMtx_); + mutable absl::Mutex invokeDataMtx_; + std::shared_ptr ar_; + std::shared_ptr metricsAdaptor_; + std::atomic metricsEnable_{false}; + std::shared_ptr taskScheduler_; + bool enablePriority_ = false; + std::shared_ptr downgrade_; + std::shared_ptr eraseTimer_; }; } // namespace Libruntime diff --git a/src/libruntime/libruntime.cpp b/src/libruntime/libruntime.cpp index 600191f..1bbc1d7 100755 --- a/src/libruntime/libruntime.cpp +++ b/src/libruntime/libruntime.cpp @@ -14,17 +14,19 @@ * limitations under the License. */ +#include "src/libruntime/libruntime.h" #include -#include "re2/re2.h" #include "invoke_order_manager.h" +#include "re2/re2.h" #include "src/dto/config.h" #include "src/dto/data_object.h" #include "src/dto/status.h" #include "src/libruntime/err_type.h" -#include "src/libruntime/fmclient/fm_client.h" #include "src/libruntime/fsclient/fs_client.h" +#include "src/libruntime/generator/stream_generator_notifier.h" +#include "src/libruntime/generator/stream_generator_receiver.h" +#include "src/libruntime/gwclient/gw_client.h" #include "src/libruntime/invokeadaptor/request_manager.h" -#include "src/libruntime/libruntime.h" #include "src/libruntime/metricsadaptor/metrics_adaptor.h" #include "src/libruntime/objectstore/memory_store.h" #include "src/libruntime/utils/serializer.h" @@ -38,6 +40,7 @@ const int MAX_INS_ID_LENGTH = 64; const std::string DELEGATE_DIRECTORY_QUOTA = "DELEGATE_DIRECTORY_QUOTA"; const std::string DELEGATE_DIRECTORY_INFO = "DELEGATE_DIRECTORY_INFO"; const std::string DEFALUT_DELEGATE_DIRECTORY_INFO = "/tmp"; +const std::string FAAS_INSTANCE_TYPE = "faas"; const std::string ACTOR_INSTANCE_TYPE = "actor"; const char *DEFAULT_DELEGATE_DIRECTORY_QUOTA = "512"; // 512MB const int MAX_DELEGATE_DIRECTORY_QUOTA = 1024 * 1024; // 1TB @@ -46,6 +49,7 @@ const re2::RE2 POD_LABELS_KEY_REGEX("^[a-zA-Z0-9]([-a-zA-Z0-9]{0,61}[a-zA-Z0-9]) const re2::RE2 POD_LABELS_VALUE_REGEX("^[a-zA-Z0-9]([-a-zA-Z0-9]{0,61}[a-zA-Z0-9])?$|^$"); const std::string DISPATCHER = "dis"; const size_t NUM_DISPATCHER = 2; +thread_local std::string threadLocalTraceId; Libruntime::Libruntime(std::shared_ptr librtCfg, std::shared_ptr clientsMgr, std::shared_ptr metricsAdaptor, std::shared_ptr security, @@ -67,6 +71,7 @@ ErrorInfo Libruntime::Init(std::shared_ptr fsClient, YR::Libruntime::D this->dispatcherThread_ = std::make_shared(); this->dispatcherThread_->Init(NUM_DISPATCHER, config->jobId + "." + DISPATCHER); this->dsClients.dsObjectStore = datasystemClients.dsObjectStore; + this->dsClients.dsStreamStore = datasystemClients.dsStreamStore; this->dsClients.dsStateStore = datasystemClients.dsStateStore; this->dsClients.dsHeteroStore = datasystemClients.dsHeteroStore; this->waitingObjectManager = std::make_shared(config->checkSignals_); @@ -75,18 +80,32 @@ ErrorInfo Libruntime::Init(std::shared_ptr fsClient, YR::Libruntime::D this->waitingObjectManager->SetMemoryStore(this->memStore); this->dependencyResolver = std::make_shared(memStore); this->objectIdPool = std::make_shared(memStore); + auto mapper = std::make_shared(); + this->generatorNotifier_ = std::make_shared(datasystemClients.dsStreamStore, mapper); + this->generatorReceiver_ = + std::make_shared(config, datasystemClients.dsStreamStore, this->memStore); this->rGroupManager_ = std::make_shared(); - this->invokeAdaptor = - std::make_shared(config, dependencyResolver, fsClient, memStore, runtimeContext, cb, - waitingObjectManager, invokeOrderMgr, clientsMgr, metricsAdaptor); + auto functionId = config->functionIds[config->selfLanguage]; + if (functionId.empty()) { + functionId = config->functionName; + } + this->downgrade_ = std::make_shared(functionId, clientsMgr, security_); + this->invokeAdaptor = std::make_shared( + config, dependencyResolver, fsClient, memStore, runtimeContext, cb, waitingObjectManager, invokeOrderMgr, + clientsMgr, metricsAdaptor, mapper, generatorReceiver_, generatorNotifier_, downgrade_); invokeAdaptor->SetRGroupManager(rGroupManager_); auto setTenantIdCb = [this]() { this->SetTenantIdWithPriority(); }; invokeAdaptor->SetCallbackOfSetTenantId(setTenantIdCb); auto [serverVersion, err] = invokeAdaptor->Init(*runtimeContext, security_); - if (err.OK()) { - this->config->serverVersion = serverVersion; + if (!err.OK()) { + return err; } - return err; + this->config->serverVersion = serverVersion; + if (config->logToDriver) { + this->driverLogReceiver_ = std::make_shared(); + this->driverLogReceiver_->Init(datasystemClients.dsStreamStore, config->jobId, config->dedupLogs); + } + return ErrorInfo(); } void Libruntime::FinalizeHandler() @@ -267,7 +286,7 @@ ErrorInfo Libruntime::PreProcessArgs(const std::shared_ptr &spec) if (tmpTotalSize > uint64_t(Config::Instance().MAX_ARGS_IN_MSG_BYTES())) { auto [err, objId] = Put(arg.dataObj, arg.nestedObjects); if (err.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("Put arg, error code: {}, error message: {}", err.Code(), err.Msg()); + YRLOG_ERROR("Put arg, code: {}, message: {}", fmt::underlying(err.Code()), err.Msg()); return err; } YRLOG_DEBUG("Put arg, object ID: {}", objId); @@ -282,6 +301,7 @@ ErrorInfo Libruntime::PreProcessArgs(const std::shared_ptr &spec) } objIds.assign(objIdSet.begin(), objIdSet.end()); if (!objIds.empty()) { + SetTraceId(); err = memStore->IncreaseObjRef(objIds); if (!err.OK()) { YRLOG_ERROR("increase ids[{}, ....] failed", objIds[0]); @@ -319,8 +339,8 @@ std::pair Libruntime::CreateInstance(const YR::Libruntim std::vector returnObjs{DataObject("")}; auto err = GenerateReturnObjectIds(requestId, returnObjs); if (err.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("generate return obj id failed, req id: {}, error code: {}, error message: {}", requestId, - err.Code(), err.Msg()); + YRLOG_ERROR("generate return obj id failed, req id: {}, code: {}, message: {}", requestId, + fmt::underlying(err.Code()), err.Msg()); return std::make_pair(err, ""); } std::string traceId = ConstructTraceId(opts); @@ -329,14 +349,14 @@ std::pair Libruntime::CreateInstance(const YR::Libruntim std::move(traceId), std::move(requestId), "", opts); err = this->CheckSpec(spec); if (err.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("check invoke spec failed, req id: {}, error code: {}, error message: {}", requestId, err.Code(), - err.Msg()); + YRLOG_ERROR("check invoke spec failed, req id: {}, code: {}, message: {}", requestId, + fmt::underlying(err.Code()), err.Msg()); return std::make_pair(err, ""); } err = PreProcessArgs(spec); if (err.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("pre process args failed, req id: {}, error code: {}, error message: {}", requestId, err.Code(), - err.Msg()); + YRLOG_ERROR("pre process args failed, req id: {}, code: {}, message: {}", requestId, + fmt::underlying(err.Code()), err.Msg()); return std::make_pair(err, ""); } @@ -348,8 +368,12 @@ std::pair Libruntime::CreateInstance(const YR::Libruntim memStore->AddReturnObject(spec->returnIds); dependencyResolver->ResolveDependencies(spec, [this, spec, returnObjs](const ErrorInfo &err) { if (err.OK()) { + AddGeneratorReceiver(spec); if (PutRefArgToDs(spec)) { - spec->opts.labels.push_back(ACTOR_INSTANCE_TYPE); + if (std::find(spec->opts.labels.begin(), spec->opts.labels.end(), FAAS_INSTANCE_TYPE) == + spec->opts.labels.end()) { + spec->opts.labels.push_back(ACTOR_INSTANCE_TYPE); + } spec->BuildInstanceCreateRequest(*config); this->invokeAdaptor->CreateInstance(spec); } @@ -362,10 +386,11 @@ std::pair Libruntime::CreateInstance(const YR::Libruntim ProcessErr(spec, dependencyErr); auto ids = memStore->UnbindObjRefInReq(spec->requestId); + SetTraceId(); auto errorInfo = memStore->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease by requestid {}. Code: {}, MCode: {}, Msg: {}", spec->requestId, - errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), errorInfo.Msg()); } invokeOrderMgr->RemoveInstance(spec); }); @@ -379,6 +404,7 @@ bool Libruntime::PutRefArgToDs(std::shared_ptr spec) auto &arg = spec->invokeArgs[i]; if (arg.isRef) { this->SetTenantId(arg.tenantId); + SetTraceId(); errInfo = this->memStore->AlsoPutToDS(arg.objId); if (!errInfo.OK()) { break; @@ -386,6 +412,7 @@ bool Libruntime::PutRefArgToDs(std::shared_ptr spec) } if (!arg.nestedObjects.empty()) { this->SetTenantId(arg.tenantId); + SetTraceId(); errInfo = this->memStore->AlsoPutToDS(arg.nestedObjects); if (!errInfo.OK()) { break; @@ -394,7 +421,7 @@ bool Libruntime::PutRefArgToDs(std::shared_ptr spec) } if (!errInfo.OK()) { YRLOG_ERROR("put ref arg to ds failed, reqid is {}, err code is {}, err msg is {}", spec->requestId, - errInfo.Code(), errInfo.Msg()); + fmt::underlying(errInfo.Code()), errInfo.Msg()); ProcessErr(spec, errInfo); return false; } @@ -409,7 +436,7 @@ ErrorInfo Libruntime::InvokeByInstanceId(const YR::Libruntime::FunctionMeta &fun auto err = GenerateReturnObjectIds(requestId, returnObjs); if (err.Code() != ErrorCode::ERR_OK) { YRLOG_ERROR("generate return obj id failed, req id: {}, error code: {}, error message: {}", requestId, - err.Code(), err.Msg()); + fmt::underlying(err.Code()), err.Msg()); return err; } std::string traceId = ConstructTraceId(opts); @@ -418,10 +445,16 @@ ErrorInfo Libruntime::InvokeByInstanceId(const YR::Libruntime::FunctionMeta &fun auto spec = std::make_shared(runtimeContext->GetJobId(), funcMeta, returnObjs, std::move(invokeArgs), libruntime::InvokeType::InvokeFunction, std::move(traceId), std::move(requestId), instanceId, opts); + if (spec->opts.isGetInstance) { + YRLOG_DEBUG("this is not normal member function invoke, redefine invoke type, name is {}, ns is {}", + funcMeta.name, funcMeta.ns); + spec->invokeType = libruntime::InvokeType::GetNamedInstanceMeta; + spec->invokeInstanceId = instanceId; + } err = PreProcessArgs(spec); if (err.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("pre process args failed, req id: {}, code: {}, message: {}", spec->requestId, err.Code(), - err.Msg()); + YRLOG_ERROR("pre process failed, req id: {}, code: {}, message: {}", spec->requestId, + fmt::underlying(err.Code()), err.Msg()); return err; } @@ -433,18 +466,18 @@ ErrorInfo Libruntime::InvokeByInstanceId(const YR::Libruntime::FunctionMeta &fun objIds.push_back(obj.id); } } - YRLOG_DEBUG("start increase ds global reference, req id is {} , obj ids: [{}, ...]", spec->requestId, - objIds[0]); + YRLOG_DEBUG("start increase ds global reference, req id is {}, obj ids: [{}, ...]", spec->requestId, objIds[0]); auto errInfo = memStore->IncreDSGlobalReference(objIds); if (!errInfo.OK()) { YRLOG_ERROR("failed to increase ds global reference, req id is {}, error code is {}, error msg is {}", - spec->requestId, errInfo.Code(), errInfo.Msg()); + spec->requestId, fmt::underlying(errInfo.Code()), errInfo.Msg()); } } invokeOrderMgr->Invoke(spec); auto func = [this, spec, returnObjs](const ErrorInfo &err) { if (err.OK()) { invokeOrderMgr->UpdateUnfinishedSeq(spec); + AddGeneratorReceiver(spec); if (PutRefArgToDs(spec)) { auto namedId = spec->GetNamedInstanceId(); if (namedId.empty()) { @@ -477,7 +510,7 @@ ErrorInfo Libruntime::InvokeByInstanceId(const YR::Libruntime::FunctionMeta &fun auto errorInfo = memStore->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease by requestid {}. Code: {}, MCode: {}, Msg: {}", spec->requestId, - errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), errorInfo.Msg()); } }; dependencyResolver->ResolveDependencies(spec, [func, dispatcherThread(dispatcherThread_)](const ErrorInfo &err) { @@ -515,7 +548,8 @@ std::string Libruntime::GetGroupInstanceIds(const std::string &objectId, int tim { auto [instanceIds, err] = memStore->GetInstanceIds(objectId, timeout); if (!err.OK()) { - YRLOG_WARN("get group instance ids failed, error code: {}, error message: {}", err.Code(), err.Msg()); + YRLOG_WARN("get group instance ids failed, error code: {}, error message: {}", fmt::underlying(err.Code()), + err.Msg()); return ""; } return YR::utility::Join(instanceIds, ";"); @@ -578,6 +612,7 @@ ErrorInfo Libruntime::InvokeByFunctionName(const YR::Libruntime::FunctionMeta &f this->invokeAdaptor->PushInvokeSpec(spec); auto func = [this, spec](const ErrorInfo &err) { if (err.OK()) { + AddGeneratorReceiver(spec); if (PutRefArgToDs(spec)) { spec->BuildInstanceInvokeRequest(*config); this->invokeAdaptor->SubmitFunction(spec); @@ -593,7 +628,7 @@ ErrorInfo Libruntime::InvokeByFunctionName(const YR::Libruntime::FunctionMeta &f auto errorInfo = memStore->DecreGlobalReference(ids); if (!errorInfo.OK()) { YRLOG_WARN("failed to decrease by requestid {}. Code: {}, MCode: {}, Msg: {}", spec->requestId, - errorInfo.Code(), errorInfo.MCode(), errorInfo.Msg()); + fmt::underlying(errorInfo.Code()), fmt::underlying(errorInfo.MCode()), errorInfo.Msg()); } }; dependencyResolver->ResolveDependencies(spec, [func, dispatcherThread(dispatcherThread_)](const ErrorInfo &err) { @@ -606,9 +641,20 @@ ErrorInfo Libruntime::InvokeByFunctionName(const YR::Libruntime::FunctionMeta &f void Libruntime::ProcessErr(const std::shared_ptr &spec, const ErrorInfo &errInfo) { + if (spec->functionMeta.isGenerator && generatorReceiver_) { + generatorReceiver_->MarkEndOfStream(spec->returnIds[0].id, errInfo); + } memStore->SetError(spec->returnIds, errInfo); } +void Libruntime::AddGeneratorReceiver(std::shared_ptr spec) +{ + if (spec->functionMeta.isGenerator && generatorReceiver_) { + generatorReceiver_->Initialize(); + generatorReceiver_->AddRecord(spec->returnIds[0].id); + } +} + void Libruntime::CreateInstanceRaw(std::shared_ptr reqRaw, RawCallback cb) { this->invokeAdaptor->CreateInstanceRaw(reqRaw, cb); @@ -634,6 +680,7 @@ std::pair Libruntime::Put(std::shared_ptr da if (!err.OK()) { return std::make_pair(err, objId); } + SetTraceId(); err = memStore->Put(dataobj->buffer, objId, nestedIds, createParam); return std::make_pair(err, objId); } @@ -641,6 +688,7 @@ std::pair Libruntime::Put(std::shared_ptr da ErrorInfo Libruntime::Put(const std::string &objId, std::shared_ptr dataObj, const std::unordered_set &nestedId, const CreateParam &createParam) { + SetTraceId(); return memStore->Put(dataObj->buffer, objId, nestedId, createParam); } @@ -648,6 +696,7 @@ ErrorInfo Libruntime::Put(std::shared_ptr data, const std::string &objID const std::unordered_set &nestedID, bool toDataSystem, const CreateParam &createParam) { + SetTraceId(); auto err = memStore->Put(data, objID, nestedID, toDataSystem, createParam); if (!err.OK()) { return err; @@ -662,17 +711,20 @@ ErrorInfo Libruntime::PutRaw(const std::string &objId, std::shared_ptr d if (!dsClients.dsObjectStore) { return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "PutRaw dsClients.dsObjectStore is nullptr!"); } + SetTraceId(); return dsClients.dsObjectStore->Put(data, objId, nestedId, createParam); } -ErrorInfo Libruntime::IncreaseReference(const std::vector &objIds) +ErrorInfo Libruntime::IncreaseReference(const std::vector &objIds, bool toDatasystem) { - return memStore->IncreGlobalReference(objIds); + SetTraceId(); + return memStore->IncreGlobalReference(objIds, toDatasystem); } std::pair> Libruntime::IncreaseReference(const std::vector &objIds, const std::string &remoteId) { + SetTraceId(); return memStore->IncreGlobalReference(objIds, remoteId); } @@ -684,6 +736,7 @@ ErrorInfo Libruntime::IncreaseReferenceRaw(const std::vector &objId if (!dsClients.dsObjectStore) { return ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "IncreaseReferenceRaw dsObjectStore is nullptr!"); } + SetTraceId(); return dsClients.dsObjectStore->IncreGlobalReference(objIds); } @@ -698,6 +751,7 @@ std::pair> Libruntime::IncreaseReferenceRaw( ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "IncreaseReferenceRaw dsObjectStore is nullptr!"), std::vector()); } + SetTraceId(); return dsClients.dsObjectStore->IncreGlobalReference(objIds, remoteId); } @@ -707,9 +761,11 @@ void Libruntime::DecreaseReference(const std::vector &objIds) std::cerr << "Libruntime::DecreaseReference memStore is nullptr." << std::endl; return; } + SetTraceId(); ErrorInfo err = memStore->DecreGlobalReference(objIds); if (err.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("ErrCode: {}, ModuleCode: {}, ErrMsg: {}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("ErrCode: {}, ModuleCode: {}, ErrMsg: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); } return; } @@ -717,6 +773,7 @@ void Libruntime::DecreaseReference(const std::vector &objIds) std::pair> Libruntime::DecreaseReference(const std::vector &objIds, const std::string &remoteId) { + SetTraceId(); return memStore->DecreGlobalReference(objIds, remoteId); } @@ -729,9 +786,11 @@ void Libruntime::DecreaseReferenceRaw(const std::vector &objIds) YRLOG_ERROR("DecreaseReferenceRaw dsObjectStore is nullptr!"); return; } + SetTraceId(); ErrorInfo err = dsClients.dsObjectStore->DecreGlobalReference(objIds); if (err.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("ErrCode: {}, ModuleCode: {}, ErrMsg: {}", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("ErrCode: {}, ModuleCode: {}, ErrMsg: {}", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); } return; } @@ -747,9 +806,16 @@ std::pair> Libruntime::DecreaseReferenceRaw( ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "DecreaseReferenceRaw dsObjectStore is nullptr!"), std::vector()); } + SetTraceId(); return dsClients.dsObjectStore->DecreGlobalReference(objIds, remoteId); } +ErrorInfo Libruntime::ReleaseGRefs(const std::string &remoteId) +{ + SetTraceId(); + return memStore->ReleaseGRefs(remoteId); +} + // timeout < 0 : wait without timeout std::shared_ptr Libruntime::Wait(const std::vector &objs, std::size_t waitNum, int timeoutSec) @@ -842,7 +908,7 @@ std::pair>> Libruntime::Get(c if (!err.OK()) { return std::make_pair(err, std::vector>{}); } - + SetTraceId(); MultipleResult res = memStore->Get(ids, remainingTimePeriod); res = MakeGetResult(res, ids, timeoutMs, allowPartial); std::vector> result(ids.size()); @@ -862,6 +928,7 @@ std::pair>> Libruntime::GetRaw(co return std::make_pair(ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "GetRaw dsObjectStore is nullptr!"), std::vector>()); } + SetTraceId(); MultipleResult res = dsClients.dsObjectStore->Get(ids, timeoutMs); return MakeGetResult(res, ids, timeoutMs, allowPartial); } @@ -886,6 +953,7 @@ ErrorInfo Libruntime::AllocReturnObject(DataObject *returnObj, size_t metaSize, totalNativeBufferSize += bufferSize; dataBuf = std::make_shared(bufferSize); } else { + SetTraceId(); auto err = memStore->IncreGlobalReference({returnObj->id}, true); if (!err.OK()) { return err; @@ -897,9 +965,8 @@ ErrorInfo Libruntime::AllocReturnObject(DataObject *returnObj, size_t metaSize, } err = CreateBuffer(returnObj->id, bufferSize, dataBuf); if (!err.OK()) { - YRLOG_ERROR( - "Failed to create return value, object Id: {}, data size: {}, error code: {}, error message: {}.", - returnObj->id, dataSize, err.Code(), err.Msg()); + YRLOG_ERROR("Failed to create return value, object Id: {}, data size: {}, ec : {}, err message: {}.", + returnObj->id, dataSize, fmt::underlying(err.Code()), err.Msg()); return err; } } @@ -915,6 +982,7 @@ ErrorInfo Libruntime::AllocReturnObject(DataObject *returnObj, size_t metaSize, ErrorInfo Libruntime::CreateBuffer(const std::string &objectId, size_t dataSize, std::shared_ptr &dataBuf) { + SetTraceId(); return memStore->CreateBuffer(objectId, dataSize, dataBuf); } @@ -922,6 +990,7 @@ std::pair Libruntime::CreateBuffer(size_t dataSize, std: { // small data -> MemoryStore // Get an id from pool + SetTraceId(); auto [err, objectId] = objectIdPool->Pop(); if (!err.OK()) { return std::make_pair(err, objectId); @@ -935,11 +1004,12 @@ std::pair>> Libruntime::GetBuffer { auto [errWait, remainingTimePeriod] = WaitBeforeGet(ids, timeoutMs, allowPartial); if (!errWait.OK()) { - YRLOG_ERROR("Failed to WaitBeforeGet, ids: {}, error code: {}, error message: {}", - YR::utility::Join(std::vector{ids[0]}, "..."), errWait.Code(), errWait.Msg()); + YRLOG_DEBUG("Failed to WaitBeforeGet, ids: {}, code: {}, message: {}", + YR::utility::Join(std::vector{ids[0]}, "..."), fmt::underlying(errWait.Code()), + errWait.Msg()); return std::make_pair(errWait, std::vector>{}); } - + SetTraceId(); auto [err, results] = memStore->GetBuffers(ids, remainingTimePeriod); if (err.Code() == YR::Libruntime::ErrorCode::ERR_OK) { YRLOG_DEBUG("Succeeded to GetBuffers, ids:{}, ids size: {}, results size: {}", @@ -953,8 +1023,8 @@ std::pair>> Libruntime::GetBuffer YR::Libruntime::ModuleCode::RUNTIME, checkObjPartialResult.second); } } else { - YRLOG_ERROR("Failed to GetBuffers, ids: {}, error code: {}, error message: {}", - YR::utility::Join(std::vector{ids[0]}, "..."), err.Code(), err.Msg()); + YRLOG_ERROR("Failed to GetBuffers, ids: {}, code: {}, message: {}", + YR::utility::Join(std::vector{ids[0]}, "..."), fmt::underlying(err.Code()), err.Msg()); } return std::make_pair(err, results); @@ -977,6 +1047,7 @@ std::pair>> Libruntime::GetDa std::pair>> Libruntime::GetBuffersWithoutWait( const std::vector &ids, int timeoutMS) { + SetTraceId(); return memStore->GetBuffersWithoutRetry(ids, timeoutMS); } @@ -985,6 +1056,7 @@ std::pair Libruntime::CreateDataObject(size_t metaSize, const std::vector &nestedObjIds, const CreateParam &createParam) { + SetTraceId(); auto [err, objId] = objectIdPool->Pop(); if (!err.OK()) { return std::make_pair(err, objId); @@ -1002,12 +1074,13 @@ ErrorInfo Libruntime::CreateDataObject(const std::string &objId, size_t metaSize return ErrorInfo(ErrorCode::ERR_PARAM_INVALID, "check circular references detected, obj id: " + objId); } } - - auto ret = Wait(nestedObjIds, nestedObjIds.size(), DEFAULT_TIMEOUT_SEC); - if (!ret->unreadyIds.empty() || !ret->exceptionIds.empty()) { - return ErrorInfo(ErrorCode::ERR_USER_FUNCTION_EXCEPTION, "wait nested objects timeout or exception"); + if (nestedObjIds.size() != 0) { + auto ret = Wait(nestedObjIds, nestedObjIds.size(), DEFAULT_TIMEOUT_SEC); + if (!ret->unreadyIds.empty() || !ret->exceptionIds.empty()) { + return ErrorInfo(ErrorCode::ERR_USER_FUNCTION_EXCEPTION, "wait nested objects timeout or exception"); + } } - + SetTraceId(); ErrorInfo dsErr = memStore->AlsoPutToDS(nestedObjIds, createParam); if (dsErr.Code() != ErrorCode::ERR_OK) { YRLOG_ERROR("put nested obj to datasystem error"); @@ -1024,8 +1097,8 @@ ErrorInfo Libruntime::CreateDataObject(const std::string &objId, size_t metaSize } auto err = memStore->CreateBuffer(objId, metaSize + dataSize, buf, createParam); if (!err.OK()) { - YRLOG_ERROR("Failed to create dataObject, object Id: {}, data size: {}, error code: {}, error message: {}.", - dataObj->id, dataSize, err.Code(), err.Msg()); + YRLOG_ERROR("Failed to create dataObject, object Id: {}, data size: {}, code: {}, message: {}.", dataObj->id, + dataSize, fmt::underlying(err.Code()), err.Msg()); return err; } if (buf) { @@ -1040,10 +1113,11 @@ ErrorInfo Libruntime::CreateDataObject(const std::string &objId, size_t metaSize std::pair>> Libruntime::GetDataObjects( const std::vector &ids, int timeoutMs, bool allowPartial) { + SetTraceId(); auto [err, buffers] = GetBuffers(ids, timeoutMs, allowPartial); if (!err.OK()) { - YRLOG_ERROR("Failed to GetDataObjects, ids: {}, error code: {}, error message: {}", - YR::utility::Join(std::vector{ids[0]}, "..."), err.Code(), err.Msg()); + YRLOG_DEBUG("Failed to GetDataObjects, ids: {}, code: {}, message: {}", + YR::utility::Join(std::vector{ids[0]}, "..."), fmt::underlying(err.Code()), err.Msg()); return std::make_pair(err, std::vector>{}); } std::vector> result(ids.size()); @@ -1058,73 +1132,146 @@ std::pair>> Libruntime::GetDa ErrorInfo Libruntime::KVWrite(const std::string &key, std::shared_ptr value, SetParam setParam) { + SetTraceId(); return dsClients.dsStateStore->Write(key, value, setParam); } ErrorInfo Libruntime::KVMSetTx(const std::vector &keys, const std::vector> &vals, const MSetParam &mSetParam) { + SetTraceId(); return dsClients.dsStateStore->MSetTx(keys, vals, mSetParam); } SingleReadResult Libruntime::KVRead(const std::string &key, int timeoutMS) { + SetTraceId(); return dsClients.dsStateStore->Read(key, timeoutMS); } MultipleReadResult Libruntime::KVRead(const std::vector &keys, int timeoutMS, bool allowPartial) { + SetTraceId(); return dsClients.dsStateStore->Read(keys, timeoutMS, allowPartial); } MultipleReadResult Libruntime::KVGetWithParam(const std::vector &keys, const GetParams ¶ms, int timeoutMs) { + SetTraceId(); return dsClients.dsStateStore->GetWithParam(keys, params, timeoutMs); } ErrorInfo Libruntime::KVDel(const std::string &key) { + SetTraceId(); return dsClients.dsStateStore->Del(key); } MultipleDelResult Libruntime::KVDel(const std::vector &keys) { + SetTraceId(); return dsClients.dsStateStore->Del(keys); } -ErrorInfo Libruntime::Delete(const std::vector &objectIds, std::vector &failedObjectIds) +MultipleExistResult Libruntime::KVExist(const std::vector &keys) +{ + return dsClients.dsStateStore->Exist(keys); +} + +ErrorInfo Libruntime::DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) { - return dsClients.dsHeteroStore->Delete(objectIds, failedObjectIds); + SetTraceId(); + return dsClients.dsHeteroStore->DevDelete(objectIds, failedObjectIds); } -ErrorInfo Libruntime::LocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) +ErrorInfo Libruntime::DevLocalDelete(const std::vector &objectIds, + std::vector &failedObjectIds) { - return dsClients.dsHeteroStore->LocalDelete(objectIds, failedObjectIds); + SetTraceId(); + return dsClients.dsHeteroStore->DevLocalDelete(objectIds, failedObjectIds); } ErrorInfo Libruntime::DevSubscribe(const std::vector &keys, const std::vector &blob2dList, std::vector> &futureVec) { + SetTraceId(); return dsClients.dsHeteroStore->DevSubscribe(keys, blob2dList, futureVec); } ErrorInfo Libruntime::DevPublish(const std::vector &keys, const std::vector &blob2dList, std::vector> &futureVec) { + SetTraceId(); return dsClients.dsHeteroStore->DevPublish(keys, blob2dList, futureVec); } ErrorInfo Libruntime::DevMSet(const std::vector &keys, const std::vector &blob2dList, std::vector &failedKeys) { + SetTraceId(); return dsClients.dsHeteroStore->DevMSet(keys, blob2dList, failedKeys); } ErrorInfo Libruntime::DevMGet(const std::vector &keys, const std::vector &blob2dList, std::vector &failedKeys, int32_t timeoutSec) { - return dsClients.dsHeteroStore->DevMGet(keys, blob2dList, failedKeys, ToMs(timeoutSec)); + SetTraceId(); + return dsClients.dsHeteroStore->DevMGet(keys, blob2dList, failedKeys, timeoutSec * S_TO_MS); +} + +ErrorInfo Libruntime::CreateStreamProducer(const std::string &streamName, ProducerConf producerConf, + std::shared_ptr &producer) +{ + producer = std::make_shared(); + if (dsClients.dsStreamStore == nullptr) { + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the datasystem stream store is empty, outcluster operations are not supported by now"); + } + SetTraceId(producerConf.traceId); + return dsClients.dsStreamStore->CreateStreamProducer(streamName, producer, producerConf); +} + +ErrorInfo Libruntime::CreateStreamConsumer(const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck) +{ + consumer = std::make_shared(); + if (dsClients.dsStreamStore == nullptr) { + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the datasystem stream store is empty, outcluster operations are not supported by now"); + } + SetTraceId(config.traceId); + return dsClients.dsStreamStore->CreateStreamConsumer(streamName, config, consumer, autoAck); +} + +ErrorInfo Libruntime::DeleteStream(const std::string &streamName) +{ + if (dsClients.dsStreamStore == nullptr) { + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the datasystem stream store is empty, outcluster operations are not supported by now"); + } + SetTraceId(); + return dsClients.dsStreamStore->DeleteStream(streamName); +} + +ErrorInfo Libruntime::QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) +{ + if (dsClients.dsStreamStore == nullptr) { + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the datasystem stream store is empty, outcluster operations are not supported by now"); + } + SetTraceId(); + return dsClients.dsStreamStore->QueryGlobalProducersNum(streamName, gProducerNum); +} + +ErrorInfo Libruntime::QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) +{ + if (dsClients.dsStreamStore == nullptr) { + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the datasystem stream store is empty, outcluster operations are not supported by now"); + } + SetTraceId(); + return dsClients.dsStreamStore->QueryGlobalConsumersNum(streamName, gConsumerNum); } std::string Libruntime::GetInvokingRequestId(void) @@ -1140,12 +1287,19 @@ ErrorInfo Libruntime::Cancel(const std::vector &objids, bool isForc void Libruntime::Exit(void) { // exit data system - invokeAdaptor->Exit(); + invokeAdaptor->Exit(0, ""); +} + +void Libruntime::Exit(const int code, const std::string &message) +{ + // exit data system + invokeAdaptor->Exit(code, message); } ErrorInfo Libruntime::Kill(const std::string &instanceId, int sigNo) { auto realInsId = memStore->GetInstanceId(instanceId); + invokeAdaptor->EraseFsIntf(realInsId); return invokeAdaptor->Kill(realInsId, "", sigNo); } @@ -1158,6 +1312,16 @@ ErrorInfo Libruntime::Kill(const std::string &instanceId, int sigNo, std::shared void Libruntime::Finalize(bool isDriver) { + if (generatorNotifier_ != nullptr) { + generatorNotifier_.reset(); + } + if (generatorReceiver_ != nullptr) { + generatorReceiver_->Stop(); + generatorReceiver_.reset(); + } + if (driverLogReceiver_ != nullptr) { + driverLogReceiver_.reset(); + } if (memStore) { memStore->Clear(); } @@ -1167,6 +1331,9 @@ void Libruntime::Finalize(bool isDriver) if (dsClients.dsStateStore != nullptr) { dsClients.dsStateStore.reset(); } + if (dsClients.dsStreamStore != nullptr) { + dsClients.dsStreamStore.reset(); + } if (!config->inCluster) { auto err = clientsMgr->ReleaseHttpClient(config->functionSystemIpAddr, config->functionSystemPort); if (!err.OK()) { @@ -1269,7 +1436,12 @@ ErrorInfo Libruntime::CreateStateStore(const DsConnectOptions &opts, std::shared ErrorInfo Libruntime::SetTraceId(const std::string &traceId) { - datasystem::Status rc = datasystem::Context::SetTraceId(traceId); + datasystem::Status rc; + if (!traceId.empty()) { + rc = datasystem::Context::SetTraceId(traceId); + } else if (!threadLocalTraceId.empty()) { + rc = datasystem::Context::SetTraceId(threadLocalTraceId); + } if (rc.IsError()) { return ErrorInfo(ConvertDatasystemErrorToCore(rc.GetCode()), YR::Libruntime::ModuleCode::DATASYSTEM, rc.ToString()); @@ -1294,7 +1466,7 @@ ErrorInfo Libruntime::SetTenantId(const std::string &tenantId, bool isReturnErrW } this->dsClients.dsObjectStore->SetTenantId(tenantId); this->config->tenantId = tenantId; - YRLOG_DEBUG("succeed to set tenant id"); + YRLOG_DEBUG("succeed to set tenant id, tenant id is {}", tenantId); return ErrorInfo(); } @@ -1316,12 +1488,14 @@ std::string Libruntime::GetTenantId() ErrorInfo Libruntime::GenerateKeyByStateStore(std::shared_ptr stateStore, std::string &returnKey) { + SetTraceId(); return stateStore->GenerateKey(returnKey); } ErrorInfo Libruntime::SetByStateStore(std::shared_ptr stateStore, const std::string &key, std::shared_ptr nativeBuffer, SetParam setParam) { + SetTraceId(); return stateStore->Write(key, nativeBuffer, setParam); } @@ -1329,12 +1503,14 @@ ErrorInfo Libruntime::SetValueByStateStore(std::shared_ptr stateStor std::shared_ptr nativeBuffer, SetParam setParam, std::string &returnKey) { + SetTraceId(); return stateStore->Write(nativeBuffer, setParam, returnKey); } SingleReadResult Libruntime::GetByStateStore(std::shared_ptr stateStore, const std::string &key, int timeoutMs) { + SetTraceId(); return stateStore->Read(key, timeoutMs); } @@ -1342,17 +1518,27 @@ MultipleReadResult Libruntime::GetArrayByStateStore(std::shared_ptr const std::vector &keys, int timeoutMs, bool allowPartial) { + SetTraceId(); return stateStore->Read(keys, timeoutMs, allowPartial); } +ErrorInfo Libruntime::QuerySizeByStateStore(std::shared_ptr stateStore, + const std::vector &keys, std::vector &outSizes) +{ + SetTraceId(); + return stateStore->QuerySize(keys, outSizes); +} + ErrorInfo Libruntime::DelByStateStore(std::shared_ptr stateStore, const std::string &key) { + SetTraceId(); return stateStore->Del(key); } MultipleDelResult Libruntime::DelArrayByStateStore(std::shared_ptr stateStore, const std::vector &keys) { + SetTraceId(); return stateStore->Del(keys); } @@ -1419,6 +1605,19 @@ ErrorInfo Libruntime::SetAlarm(const std::string &name, const std::string &descr return err; } +std::pair Libruntime::AcquireInstance(const std::string &stateId, + const FunctionMeta &functionMeta, + InvokeOptions &opts) +{ + return invokeAdaptor->AcquireInstance(stateId, functionMeta, opts); +} + +ErrorInfo Libruntime::ReleaseInstance(const std::string &leaseId, const std::string &stateId, bool abnormal, + InvokeOptions &opts) +{ + return invokeAdaptor->ReleaseInstance(leaseId, stateId, abnormal, opts); +} + ErrorInfo Libruntime::ProcessLog(FunctionLog &functionLog) { functionLog.set_instanceid(Config::Instance().INSTANCE_ID()); @@ -1439,6 +1638,22 @@ void Libruntime::NotifyEvent(FiberEventNotify &event) event.Notify(); } +std::pair Libruntime::PeekObjectRefStream(const std::string &generatorId, bool blocking) +{ + return memStore->GetOutput(generatorId, blocking); +} + +ErrorInfo Libruntime::NotifyGeneratorResult(const std::string &generatorId, int index, + std::shared_ptr resultObj, const ErrorInfo &resultErr) +{ + return generatorNotifier_->NotifyResult(generatorId, index, resultObj, resultErr); +} + +ErrorInfo Libruntime::NotifyGeneratorFinished(const std::string &generatorId, int numResults) +{ + return generatorNotifier_->NotifyFinished(generatorId, numResults); +} + FunctionGroupRunningInfo Libruntime::GetFunctionGroupRunningInfo() { return config->groupRunningInfo; @@ -1468,6 +1683,14 @@ std::pair Libruntime::QueryNamedInstances() return ret; } +Credential Libruntime::GetCredential() +{ + if (!this->security_) { + return Credential{}; + } + return this->security_->GetCredential(); +} + ErrorInfo Libruntime::CheckRGroupName(const std::string &rGroupName) { if (rGroupName == std::string(UNSUPPORTED_RGROUP_NAME) || rGroupName == "") { @@ -1512,7 +1735,8 @@ ErrorInfo Libruntime::CreateResourceGroup(const ResourceGroupSpec &resourceGroup YRLOG_ERROR( "check resource group create options failed, name: {}, bundles size: {}, request id: {}, error code: {}, " "error message: {}.", - resourceGroupSpec.name, resourceGroupSpec.bundles.size(), requestId, err.Code(), err.Msg()); + resourceGroupSpec.name, resourceGroupSpec.bundles.size(), requestId, fmt::underlying(err.Code()), + err.Msg()); return err; } @@ -1550,9 +1774,9 @@ std::pair Libruntime::GetInstance(const const std::string &nameSpace, int timeoutSec) { auto [meta, err] = this->invokeAdaptor->GetInstance(name, nameSpace, timeoutSec); - if (err.OK() && meta.needOrder) { - this->invokeOrderMgr->RegisterInstance( - nameSpace.empty() ? this->config->ns + "-" + name : nameSpace + "-" + name); + if (!err.OK() || !meta.needOrder) { + this->invokeOrderMgr->ClearInsOrderMsg(nameSpace.empty() ? name : nameSpace + "-" + name, + libruntime::Signal::KillInstance); } return std::make_pair<>(meta, err); } @@ -1571,7 +1795,7 @@ bool Libruntime::IsLocalInstances(const std::vector &instanceIds) } auto killErr = this->invokeAdaptor->Kill(instanceId, "", libruntime::Signal::QueryDsAddress); if (!killErr.OK()) { - YRLOG_WARN("kill QueryDsAddress code: {}, msg: {}", err.Code(), err.Msg()); + YRLOG_WARN("kill QueryDsAddress code: {}, msg: {}", fmt::underlying(err.Code()), err.Msg()); promise->set_value(false); return; } @@ -1613,6 +1837,14 @@ bool Libruntime::SetError(const std::string &objId, const ErrorInfo &err) return this->memStore->SetError(objId, err); } +void Libruntime::UpdateSchdulerInfo(const std::string &schedulerName, const std::string &schedulerId, + const std::string &option) +{ + YRLOG_DEBUG("start update scheduler info, scheduler name is {}, scheduler id is {}, options is {}", schedulerName, + schedulerId, option); + return invokeAdaptor->UpdateSchdulerInfo(schedulerName, schedulerId, option); +} + std::string Libruntime::GetInstanceRoute(const std::string &objectId) { return memStore->GetInstanceRoute(objectId); @@ -1623,6 +1855,33 @@ void Libruntime::SaveInstanceRoute(const std::string &objectId, const std::strin memStore->SetInstanceRoute(objectId, instanceRoute); } +bool Libruntime::IsHealth() +{ + if (!invokeAdaptor) { + return false; + } + return invokeAdaptor->IsHealth(); +} + +bool Libruntime::IsDsHealth() +{ + if (!dsClients.dsStreamStore && !dsClients.dsStateStore) { + return true; + } + auto err = dsClients.dsStateStore->HealthCheck(); + if (err.OK()) { + return true; + } + return false; +} + +void Libruntime::KillAsync(const std::string &instanceId, int sigNo, std::function cb) +{ + auto realInsId = memStore->GetInstanceId(instanceId); + invokeAdaptor->EraseFsIntf(realInsId); + this->invokeAdaptor->KillAsyncCB(realInsId, "", sigNo, cb); +} + std::pair Libruntime::GetNodeId() { ErrorInfo err; @@ -1636,5 +1895,6 @@ std::string Libruntime::GetNameSpace() { return this->config->ns; } + } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/libruntime.h b/src/libruntime/libruntime.h index f942e8d..eb0f1c2 100644 --- a/src/libruntime/libruntime.h +++ b/src/libruntime/libruntime.h @@ -25,18 +25,24 @@ #include "src/dto/invoke_arg.h" #include "src/dto/invoke_options.h" #include "src/dto/resource_unit.h" +#include "src/dto/stream_conf.h" #include "src/libruntime/clientsmanager/clients_manager.h" #include "src/libruntime/connect/domain_socket_client.h" #include "src/libruntime/connect/message_coder.h" #include "src/libruntime/dependency_resolver.h" +#include "src/libruntime/driverlog/driverlog_receiver.h" #include "src/libruntime/err_type.h" +#include "src/libruntime/fiber.h" #include "src/libruntime/event_notify.h" #include "src/libruntime/fmclient/fm_client.h" #include "src/libruntime/fsclient/fs_client.h" +#include "src/libruntime/generator/generator_notifier.h" +#include "src/libruntime/generator/generator_receiver.h" #include "src/libruntime/invokeadaptor/invoke_adaptor.h" #include "src/libruntime/libruntime_config.h" #include "src/libruntime/metricsadaptor/metrics_adaptor.h" #include "src/libruntime/objectstore/object_id_pool.h" +#include "src/libruntime/streamstore/stream_producer_consumer.h" #include "src/libruntime/utils/constants.h" #include "src/libruntime/utils/security.h" #include "src/libruntime/waiting_object_manager.h" @@ -214,7 +220,7 @@ public: @param objIds A vector of strings containing the IDs of the objects @return An `ErrorInfo` object indicating the success or failure of the operation */ - virtual ErrorInfo IncreaseReference(const std::vector &objIds); + virtual ErrorInfo IncreaseReference(const std::vector &objIds, bool toDatasystem = true); /*! @brief Increases the reference count of the specified objects with a remote ID @@ -274,6 +280,13 @@ public: virtual std::pair> DecreaseReferenceRaw(const std::vector &objIds, const std::string &remoteId); + /*! + @brief Releases global references associated with a remote ID + @param remoteId The ID of the remote context + @return An `ErrorInfo` object indicating the success or failure of the operation + */ + virtual ErrorInfo ReleaseGRefs(const std::string &remoteId); + /*! @brief Allocates and initializes a return object with specified metadata and data sizes This function allocates a `DataObject` and initializes it with the provided metadata size, data size, @@ -387,6 +400,12 @@ public: */ virtual void Exit(void); + /*! + @brief 退出当前上下文,调用该方法会执行函数系统客户端及数据系统客户端的优雅退出,清理相应的函数实例及数据对象 + @throw Exception if the exit operation fails + */ + virtual void Exit(const int code, const std::string &message); + /*! @brief 向一个函数实例或一组任务发送一个指定的信号 @param instanceId 指定的函数实例 ID 或者任务 ID,当前仅信号2支持传入任务 ID @@ -404,6 +423,14 @@ public: */ virtual ErrorInfo Kill(const std::string &instanceId, int sigNo, std::shared_ptr data); + /*! + @brief 向一个函数实例或一组任务发送一个指定的信号并携带特定的数据,并通过callback返回执行结果 + @param instanceId 指定的函数实例 ID 或者任务 ID + @param sigNo 需要发送的信号,信号1为退出函数实例,详见 @ref Signal + @param cb 执行结果的回调函数 + */ + virtual void KillAsync(const std::string &instanceId, int sigNo, std::function cb); + /*! @brief 结束当前上下文 @param isDriver 如果设置为 True,将会退出当前任务对应的所有函数实例 @@ -538,6 +565,56 @@ public: */ virtual MultipleDelResult KVDel(const std::vector &keys); + /*! + @brief Deletes multiple key-value pairs from the datasystem + @param keys A vector of strings containing the keys to query + @return A `MultipleExistResult` object containing the results of the exists and error information + */ + virtual MultipleExistResult KVExist(const std::vector &keys); + + /*! + @brief Creates a stream producer with the specified configuration + @param streamName The name of the stream + @param producerConf Configuration for the stream producer + @param producer A shared pointer to the `StreamProducer` to be created + @return An `ErrorInfo` object indicating the success or failure of the operation + */ + virtual ErrorInfo CreateStreamProducer(const std::string &streamName, ProducerConf producerConf, + std::shared_ptr &producer); + + /*! + @brief Creates a stream consumer with the specified configuration + @param streamName The name of the stream + @param config Configuration for the stream consumer + @param consumer A shared pointer to the `StreamConsumer` to be created + @param autoAck If true, automatically acknowledges consumed messages (default is false) + @return An `ErrorInfo` object indicating the success or failure of the operation + */ + virtual ErrorInfo CreateStreamConsumer(const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck = false); + + /*! + @brief Deletes a stream by its name + @param streamName The name of the stream to delete + @return An `ErrorInfo` object indicating the success or failure of the operation + */ + virtual ErrorInfo DeleteStream(const std::string &streamName); + + /*! + @brief Query the number of global producers for a specific stream + @param streamName the name of the stream + @param gProducerNum the output parameter to store the number of global producers + @return ErrorInfo indicating the success or failure of the operation + */ + virtual ErrorInfo QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum); + + /*! + @brief Query the number of global consumers for a specific stream + @param streamName the name of the stream + @param gConsumerNum the output parameter to store the number of global consumers + @return ErrorInfo indicating the success or failure of the operation + */ + virtual ErrorInfo QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum); virtual ErrorInfo SaveState(const std::shared_ptr data, const int &timeout); virtual ErrorInfo LoadState(std::shared_ptr &data, const int &timeout); @@ -614,7 +691,7 @@ public: @param traceId the trace ID to set @return ErrorInfo indicating success or failure */ - virtual ErrorInfo SetTraceId(const std::string &traceId); + virtual ErrorInfo SetTraceId(const std::string &traceId = ""); /*! @brief Set the tenant ID for the current context @@ -652,6 +729,8 @@ public: virtual MultipleReadResult GetArrayByStateStore(std::shared_ptr stateStore, const std::vector &keys, int timeoutMs, bool allowPartial = false); + virtual ErrorInfo QuerySizeByStateStore(std::shared_ptr stateStore, + const std::vector &keys, std::vector &outSizes); virtual ErrorInfo DelByStateStore(std::shared_ptr stateStore, const std::string &key); virtual MultipleDelResult DelArrayByStateStore(std::shared_ptr stateStore, const std::vector &keys); @@ -736,6 +815,28 @@ public: virtual ErrorInfo SetAlarm(const std::string &name, const std::string &description, const YR::Libruntime::AlarmInfo &alarmInfo); + /*! + @brief Acquire an instance based on the state ID and function metadata + @param stateId the state ID of the instance + @param functionMeta the metadata of the function + @param opts the invoke options + @return a pair containing the instance allocation and ErrorInfo + */ + virtual std::pair AcquireInstance(const std::string &stateId, + const FunctionMeta &functionMeta, + InvokeOptions &opts); + + /*! + @brief Release an instance based on the lease ID and state ID + @param leaseId the lease ID of the instance + @param stateId the state ID of the instance + @param abnormal whether the release is due to an abnormal condition + @param opts the invoke options + @return ErrorInfo indicating the success or failure of the operation + */ + virtual ErrorInfo ReleaseInstance(const std::string &leaseId, const std::string &stateId, bool abnormal, + InvokeOptions &opts); + /*! @brief Process a function log @param functionLog the function log to process @@ -765,6 +866,34 @@ public: */ virtual void NotifyEvent(FiberEventNotify &event); + /*! + @brief Peek into an ObjectRef stream to get the next value + @param generatorId the ID of the generator + @param blocking whether to block until a value is available + @return a pair containing ErrorInfo and the next value in the stream + */ + virtual std::pair PeekObjectRefStream(const std::string &generatorId, bool blocking); + + /*! + @brief Notify the result of a generator + @param generatorId the ID of the generator + @param index the index of the result + @param resultObj the result object + @param resultErr the error information associated with the result + @return ErrorInfo indicating the success or failure of the operation + */ + virtual ErrorInfo NotifyGeneratorResult(const std::string &generatorId, int index, + std::shared_ptr resultObj, const ErrorInfo &resultErr); + + /*! + @brief Notify that a generator has finished producing results + @param generatorId the ID of the generator + @param numResults the number of results produced by the generator + @return ErrorInfo indicating the success or failure of the operation + @throw Exception if the generator does not exist or the operation fails + */ + virtual ErrorInfo NotifyGeneratorFinished(const std::string &generatorId, int numResults); + /*! @brief Get the running information of a function group @return the running information of the function group @@ -788,22 +917,56 @@ public: */ virtual std::pair GetNodeIpAddress(void); + /*! + @brief Get the credential information + @return the credential information + @throw Exception if the credential cannot be retrieved + */ + virtual Credential GetCredential(); + + /** + * @brief For device object, to async get multiple objects + * @param[in] objIds multiple keys support + * @param[out] devBlobList vector of blobs, only modify the data pointed to by the pointer. + * @param[in] timeoutMs max waiting time of getting data + * @return future of AsyncResult, describe get ErrorInfo and failed list. + */ + + /** + * @brief For device object Async set multiple objects, and return before publish rpc called. + * @param[in] objectIds multiple keys support + * @param[in] devBlobList vector of blobs + * @return future of AsyncResult, describe set ErrorInfo and failed list. + */ + /** * @brief Invoke worker client to delete all the given objectId. * @param[in] objectIds The vector of the objId. * @param[out] failedObjectIds The failed delete objIds. * @return ERR_OK on any key success; the error code otherwise. */ - virtual ErrorInfo Delete(const std::vector &objectIds, std::vector &failedObjectIds); + virtual ErrorInfo DevDelete(const std::vector &objectIds, std::vector &failedObjectIds); /** - * @brief LocalDelete interface. After calling this interface, the data replica stored in the data system by the + * @brief DevLocalDelete interface. After calling this interface, the data replica stored in the data system by the * current client connection will be deleted. * @param[in] objectIds The objectIds of the data expected to be deleted. * @param[out] failedObjectIds Partial failures will be returned through this parameter. * @return ERR_OK on when return success; the error code otherwise. */ - virtual ErrorInfo LocalDelete(const std::vector &objectIds, std::vector &failedObjectIds); + virtual ErrorInfo DevLocalDelete(const std::vector &objectIds, + std::vector &failedObjectIds); + + /** + * @brief Initialize multipath transfer for a given target device (current context) + * @param[in] devices The device ids participating in the multipath transfer + * @return ErrorInfo of the call. + */ + + /** + * @brief Destroys multipath transfer for the current target device (context) + * @return ErrorInfo of the call. + */ /** * @brief Subscribe data from device. @@ -884,6 +1047,9 @@ public: virtual bool SetError(const std::string &objId, const ErrorInfo &err); + virtual void UpdateSchdulerInfo(const std::string &schedulerName, const std::string &schedulerId, + const std::string &option); + /*! @brief Get the instance route of an object @param objectId the ID of the object @@ -899,6 +1065,10 @@ public: */ virtual void SaveInstanceRoute(const std::string &objectId, const std::string &instanceRoute); + virtual bool IsHealth(); + + virtual bool IsDsHealth(); + std::pair GetNodeId(void); std::string GetNameSpace(void); @@ -916,6 +1086,7 @@ private: void ProcessErr(const std::shared_ptr &spec, const ErrorInfo &errInfo); ErrorInfo CheckRGroupName(const std::string &vGroupName); ErrorInfo CheckRGroupSpec(const ResourceGroupSpec &resourceGroupSpec); + void AddGeneratorReceiver(std::shared_ptr spec); ErrorInfo GenerateReturnObjectIds(const std::string &requestId, std::vector &returnObjs); ErrorInfo CreateBuffer(const std::string &objectId, size_t dataSize, std::shared_ptr &dataBuf); @@ -942,10 +1113,14 @@ private: std::shared_ptr security_; std::shared_ptr runtimeContext; std::shared_ptr socketClient_; + std::shared_ptr generatorNotifier_; + std::shared_ptr generatorReceiver_; + std::shared_ptr driverLogReceiver_; std::function checkSignals_ = nullptr; std::shared_ptr messageCoder_; std::shared_ptr rGroupManager_; std::shared_ptr dispatcherThread_; + std::shared_ptr downgrade_; }; } // namespace Libruntime diff --git a/src/libruntime/libruntime_config.h b/src/libruntime/libruntime_config.h index fb8b0b1..78dc6d7 100644 --- a/src/libruntime/libruntime_config.h +++ b/src/libruntime/libruntime_config.h @@ -34,6 +34,7 @@ const int DEFAULT_RECYCLETIME = 2; const int MAX_RECYCLETIME = 3000; const int MIN_RECYCLETIME = 1; const int MAX_PASSWD_LENGTH = 100; +const int HTTP_IDLE_TIME = 120; std::pair GetValidMaxLogFileNum(uint32_t maxLogFileNum); std::pair GetValidMaxLogSizeMb(uint32_t maxLogSizeMb); struct LibruntimeConfig { @@ -46,9 +47,7 @@ struct LibruntimeConfig { this->enableMetrics = config.enablemetrics(); this->threadPoolSize = config.threadpoolsize(); this->localThreadPoolSize = config.localthreadpoolsize(); - if (!config.ns().empty()) { - this->ns = config.ns(); - } + this->ns = config.ns(); this->tenantId = config.tenantid(); for (int i = 0; i < config.functionids_size(); i++) { this->functionIds[config.functionids(i).language()] = config.functionids(i).functionid(); @@ -201,15 +200,15 @@ struct LibruntimeConfig { libruntime::ApiType selfApiType = libruntime::ApiType::Function; std::string logLevel = ""; std::string logDir = "."; - uint32_t logFileSizeMax = 1024; - uint32_t logFileNumMax = 1024; + uint32_t logFileSizeMax = 0; + uint32_t logFileNumMax = 0; int logFlushInterval = 1; bool isLogMerge = false; LibruntimeOptions libruntimeOptions; int recycleTime = 2; int maxTaskInstanceNum = -1; int maxConcurrencyCreateNum = 100; - bool enableMetrics = false; + bool enableMetrics = true; uint32_t threadPoolSize = 0; uint32_t localThreadPoolSize = 0; std::vector loadPaths; @@ -217,10 +216,14 @@ struct LibruntimeConfig { std::string metaConfig = ""; // deprecated? pyx use it. bool enableMTLS = false; + bool enableTLS = false; std::string privateKeyPath = ""; std::string certificateFilePath = ""; std::string verifyFilePath = ""; + char privateKeyPaaswd[MAX_PASSWD_LENGTH] = {0}; + std::shared_ptr tlsContext = nullptr; uint32_t httpIocThreadsNum = 200; + int httpIdleTime = HTTP_IDLE_TIME; std::string serverName = ""; bool inCluster = true; std::string ns = ""; @@ -236,6 +239,9 @@ struct LibruntimeConfig { std::string runtimePublicKey = ""; SensitiveValue runtimePrivateKey = ""; std::string dsPublicKey = ""; + SensitiveValue token = ""; + std::string ak_ = ""; + SensitiveValue sk_ = ""; using SubmitHook = std::function &&)>; SubmitHook funcExecSubmitHook = nullptr; @@ -245,11 +251,19 @@ struct LibruntimeConfig { std::function checkSignals_ = nullptr; std::string workingDir; std::string rtCtx; + bool logToDriver = false; + bool dedupLogs = false; std::string primaryKeyStoreFile; std::string standbyKeyStoreFile; + std::string encryptPrivateKeyPasswd; + std::string encryptDsPublicKeyContext; + std::string encryptRuntimePublicKeyContext; + std::string encryptRuntimePrivateKeyContext; libruntime::FunctionMeta funcMeta; bool needOrder = false; bool enableSigaction = true; + int64_t invokeTimeoutSec = 0; + uint32_t maxConnSize = 10000; std::string nodeId; std::string nodeIp; private: diff --git a/src/libruntime/libruntime_manager.cpp b/src/libruntime/libruntime_manager.cpp index 4cbf800..9768986 100644 --- a/src/libruntime/libruntime_manager.cpp +++ b/src/libruntime/libruntime_manager.cpp @@ -21,6 +21,7 @@ #include "src/libruntime/utils/constants.h" #include "src/utility/logger/log_handler.h" #include "src/utility/timer_worker.h" +#include "src/libruntime/traceadaptor/trace_adapter.h" namespace YR { namespace Libruntime { using YR::utility::CloseGlobalTimer; @@ -86,12 +87,13 @@ ErrorInfo LibruntimeManager::HandleInitialized(const LibruntimeConfig &config, c return ErrorInfo(); } } - YRLOG_INFO("merge config, selfLanguage: {} {}", config.selfLanguage, librtConfig->selfLanguage); + YRLOG_INFO("merge config, selfLanguage: {} {}", fmt::underlying(config.selfLanguage), + fmt::underlying(librtConfig->selfLanguage)); for (auto it = config.functionIds.begin(); it != config.functionIds.end(); it++) { - YRLOG_INFO("merge config, functionId {} : {}", it->first, it->second); + YRLOG_INFO("merge config, functionId {} : {}", fmt::underlying(it->first), it->second); } for (auto it = librtConfig->functionIds.begin(); it != librtConfig->functionIds.end(); it++) { - YRLOG_INFO("merge config, functionId {} : {}", it->first, it->second); + YRLOG_INFO("merge config, functionId {} : {}", fmt::underlying(it->first), it->second); } return librtConfig->MergeConfig(config); } @@ -99,7 +101,7 @@ ErrorInfo LibruntimeManager::HandleInitialized(const LibruntimeConfig &config, c LibruntimeManager::LibruntimeManager() { clientsMgr = std::make_shared(); - metricsAdaptor = std::make_shared(); + metricsAdaptor = MetricsAdaptor::GetInstance(); socketClient_ = std::make_shared(std::string(DEFAULT_SOCKET_PATH)); logManager_ = std::make_shared(); } @@ -108,8 +110,8 @@ ErrorInfo LibruntimeManager::Init(const LibruntimeConfig &config, const std::str { auto err = config.Check(); if (!err.OK()) { - YRLOG_ERROR("config check failed, job id is {}, err code is {}, err msg is {}", config.jobId, err.Code(), - err.Msg()); + YRLOG_ERROR("config check failed, job id is {}, err code is {}, err msg is {}", config.jobId, + fmt::underlying(err.Code()), err.Msg()); return err; } @@ -143,20 +145,24 @@ ErrorInfo LibruntimeManager::Init(const LibruntimeConfig &config, const std::str logParam.logBufSecs = config.logFlushInterval; auto result = GetValidMaxLogSizeMb(config.logFileSizeMax); if (!result.second.OK()) { - YRLOG_ERROR("invalid log file size max: {}, err code is {}, err msg is {}", result.first, result.second.Code(), - result.second.Msg()); + YRLOG_ERROR("invalid log file size max: {}, err code is {}, err msg is {}", result.first, + fmt::underlying(result.second.Code()), result.second.Msg()); return result.second; } logParam.maxSize = result.first; result = GetValidMaxLogFileNum(config.logFileNumMax); if (!result.second.OK()) { - YRLOG_ERROR("invalid log file num: {}, err code is {}, err msg is {}", result.first, result.second.Code(), - result.second.Msg()); + YRLOG_ERROR("invalid log file num: {}, err code is {}, err msg is {}", result.first, + fmt::underlying(result.second.Code()), result.second.Msg()); return result.second; } logParam.maxFiles = result.first; logParam.nodeName = config.jobId; logParam.modelName = config.runtimeId; + logParam.loggerId = YR::Libruntime::Config::Instance().YR_LOG_PREFIX(); + if (!logParam.loggerId.empty()) { + logParam.withLogPrefix = true; + } logParam.isLogMerge = config.isLogMerge; this->isLogMerge = config.isLogMerge; InitLog(logParam); @@ -176,7 +182,9 @@ ErrorInfo LibruntimeManager::Init(const LibruntimeConfig &config, const std::str SetGetLoggerNameFunc(getLoggerNameFunc); YRLOG_INFO("Job ID: {}, runtime ID: {}, log dir: {}, log level is {}, is Driver {}", config.jobId, config.runtimeId, config.logDir, config.logLevel, config.isDriver); - + auto traceServiceName = config.runtimeId == "driver" ? config.jobId : Config::Instance().INSTANCE_ID(); + TraceAdapter::GetInstance().InitTrace(traceServiceName, Config::Instance().ENABLE_TRACE(), + Config::Instance().RUNTIME_TRACE_CONFIG()); if (config.enableSigaction) { InstallSigtermHandler(); } @@ -203,11 +211,11 @@ ErrorInfo LibruntimeManager::Init(const LibruntimeConfig &config, const std::str std::shared_ptr librt; auto initErr = this->CreateLibruntime(librtConfig, librt); if (initErr.OK()) { - YRLOG_INFO("succeed to init libruntime, job ID: {}", config.jobId); + YRLOG_INFO("succeed to init libruntime, job ID: {}, tenant ID: {}", config.jobId, librt->GetTenantId()); librts[rtCtx] = librt; } else { - YRLOG_ERROR("failed to init libruntime, job Id: {}, code: {}, msg: {}", config.jobId, initErr.Code(), - initErr.Msg()); + YRLOG_ERROR("failed to init libruntime, job Id: {}, tenant ID: {}, code: {}, msg: {}", config.jobId, + librt->GetTenantId(), fmt::underlying(initErr.Code()), initErr.Msg()); std::lock_guard rtCfgLK(rtCfgMtx); librtConfigs.erase(rtCtx); } @@ -230,8 +238,8 @@ ErrorInfo LibruntimeManager::CreateLibruntime(std::shared_ptr if (librtConfig->inCluster) { auto err = !librtConfig->isDriver ? security->Init() : security->InitWithDriver(librtConfig); if (!err.OK()) { - YRLOG_ERROR("init security failed, is driver: {}, code is {}, msg is {}", librtConfig->isDriver, err.Code(), - err.Msg()); + YRLOG_ERROR("init security failed, is driver: {}, code is {}, msg is {}", librtConfig->isDriver, + fmt::underlying(err.Code()), err.Msg()); return err; } } @@ -239,19 +247,63 @@ ErrorInfo LibruntimeManager::CreateLibruntime(std::shared_ptr librtConfig->runtimePublicKey, librtConfig->runtimePrivateKey, librtConfig->dsPublicKey); librtConfig->enableAuth = enableDsAuth; librtConfig->encryptEnable = encryptEnable; + security->GetToken(librtConfig->token); + std::string ak = ""; + SensitiveValue sk = ""; + security->GetAKSK(ak, sk); auto finalizeHandler = [this, rtCtx(librtConfig->rtCtx)]() { this->Finalize(rtCtx); }; librt = std::make_shared(librtConfig, clientsMgr, metricsAdaptor, security, socketClient_); if (librtConfig->inCluster) { auto [datasystemClients, err] = - clientsMgr->GetOrNewDsClient(librtConfig, Config::Instance().DS_CONNECT_TIMEOUT_SEC()); + clientsMgr->GetOrNewDsClient(librtConfig, ak, sk, Config::Instance().DS_CONNECT_TIMEOUT_SEC()); if (!err.OK()) { - YRLOG_ERROR("get or new ds client failed, code is {}, msg is {}", err.Code(), err.Msg()); + YRLOG_ERROR("get or new ds client failed, code is {}, msg is {}", fmt::underlying(err.Code()), err.Msg()); return err; } + security->WhenTokenUpdated([datasystemClients](const datasystem::SensitiveValue &token) -> void { + auto errInfo = datasystemClients.dsObjectStore->UpdateToken(token); + if (!errInfo.OK()) { + YRLOG_ERROR("update token failed, code is {}", fmt::underlying(errInfo.Code())); + } + errInfo = datasystemClients.dsStateStore->UpdateToken(token); + if (!errInfo.OK()) { + YRLOG_ERROR("update token failed, code is {}", fmt::underlying(errInfo.Code())); + } + errInfo = datasystemClients.dsStreamStore->UpdateToken(token); + if (!errInfo.OK()) { + YRLOG_ERROR("update token failed, code is {}", fmt::underlying(errInfo.Code())); + } + }); + security->WhenAkSkUpdated( + [datasystemClients](const std::string &ak, const datasystem::SensitiveValue &sk) -> void { + auto errInfo = datasystemClients.dsObjectStore->UpdateAkSk(ak, sk); + if (!errInfo.OK()) { + YRLOG_ERROR("update aksk failed, code is {}", fmt::underlying(errInfo.Code())); + } + errInfo = datasystemClients.dsStateStore->UpdateAkSk(ak, sk); + if (!errInfo.OK()) { + YRLOG_ERROR("update aksk failed, code is {}", fmt::underlying(errInfo.Code())); + } + errInfo = datasystemClients.dsStreamStore->UpdateAkSk(ak, sk); + if (!errInfo.OK()) { + YRLOG_ERROR("update aksk failed, code is {}", fmt::underlying(errInfo.Code())); + } + }); auto fsClient = std::make_shared(); return librt->Init(fsClient, datasystemClients, finalizeHandler); } else { - return ErrorInfo(); + FSIntfHandlers handlers; + auto [httpClient, err] = clientsMgr->GetOrNewHttpClient(librtConfig->functionSystemIpAddr, + librtConfig->functionSystemPort, librtConfig); + if (!err.OK()) { + YRLOG_ERROR("get or new http client failed, code is {}, msg is {}", fmt::underlying(err.Code()), err.Msg()); + return err; + } + auto gwClient = std::make_shared(librtConfig->functionIds[librtConfig->selfLanguage], handlers); + gwClient->Init(httpClient, Config::Instance().DS_CONNECT_TIMEOUT_SEC()); + auto fsClient = std::make_shared(gwClient); + DatasystemClients dsClients{gwClient, gwClient, gwClient, gwClient}; + return librt->Init(fsClient, dsClients); } } @@ -285,6 +337,7 @@ void LibruntimeManager::Finalize(const std::string &rtCtx) librtConfigs.erase(rtCtx); } librt->Finalize(librtConfig->isDriver); + TraceAdapter::GetInstance().ShutDown(); librt.reset(); librtConfig.reset(); { @@ -295,6 +348,7 @@ void LibruntimeManager::Finalize(const std::string &rtCtx) } } YRLOG_INFO("finish to finalize libruntime with context: {}", rtCtx); + YR::utility::SpdLogger::GetInstance().Flush(); } std::shared_ptr LibruntimeManager::GetLibRuntime(const std::string &rtCtx) @@ -345,6 +399,10 @@ void LibruntimeManager::InstallSigtermHandler() YRLOG_ERROR("Failed Install SIGTERM handler"); return; } + if (sigaction(SIGINT, &sa, NULL) == -1) { + YRLOG_ERROR("Failed Install SIGINT handler"); + return; + } YRLOG_INFO("Succeeded to Install SIGTERM handler"); return; } @@ -366,7 +424,7 @@ void LibruntimeManager::ExecShutdownCallbackAsync(int signum) void LibruntimeManager::ExecShutdownCallback(int signum, bool needExit) { uint64_t gracePeriodSec = Config::Instance().GRACEFUL_SHUTDOWN_TIME(); - YRLOG_DEBUG("Start to execute SigtermHandler, graceful shutdown time: {}", gracePeriodSec); + YRLOG_INFO("Start to execute SigtermHandler, graceful shutdown time: {}", gracePeriodSec); std::unordered_map> librtsCopied; { /* The user code executed in the 'ExecShutdownCallback' may invoke the Libruntime @@ -386,9 +444,9 @@ void LibruntimeManager::ExecShutdownCallback(int signum, bool needExit) errInfo.Msg()); continue; } - YRLOG_DEBUG("Succeeded to call ExecShutdownCallback for libruntime with context: {}", iter->first); + YRLOG_INFO("Succeeded to call ExecShutdownCallback for libruntime with context: {}", iter->first); } - YRLOG_DEBUG("End to call SigtermHandler, signum: {}", signum); + YRLOG_INFO("End to call SigtermHandler, signum: {}", signum); if (needExit) { exit(signum); } diff --git a/src/libruntime/metricsadaptor/metrics_adaptor.cpp b/src/libruntime/metricsadaptor/metrics_adaptor.cpp index 49f7a5c..5bb01b7 100644 --- a/src/libruntime/metricsadaptor/metrics_adaptor.cpp +++ b/src/libruntime/metricsadaptor/metrics_adaptor.cpp @@ -23,24 +23,34 @@ #include "metrics/sdk/immediately_export_processor.h" #include "metrics/sdk/meter_provider.h" #include "src/dto/config.h" -#include "src/utility/logger/fileutils.h" #include "src/utility/logger/logger.h" +#include "src/utility/logger/fileutils.h" namespace YR { namespace Libruntime { const char *const IMMEDIATELY_EXPORT = "immediatelyExport"; const char *const FILE_EXPORTER = "fileExporter"; +const char *const PROMETHEUS_PUSH_EXPORTER = "prometheusPushExporter"; +const char *const AOM_ALARM_EXPORTER = "aomAlarmExporter"; +const char *const YR_SSL_PASSPHRASE_KEY = "YR_SSL_PASSPHRASE"; static std::string GetLibraryPath(const std::string &exporterType) { std::string filePath = ""; if (exporterType == FILE_EXPORTER) { - filePath = Config::Instance().SNLIB_PATH() + "/libobservability-metrics-file-exporter.so"; + filePath = Config::Instance().SNUSER_LIB_PATH() + "/libobservability-metrics-file-exporter.so"; + } else if (exporterType == PROMETHEUS_PUSH_EXPORTER) { + filePath = Config::Instance().SNUSER_LIB_PATH() + "/libobservability-prometheus-push-exporter.so"; + } else if (exporterType == AOM_ALARM_EXPORTER) { + filePath = Config::Instance().SNUSER_LIB_PATH() + "/libobservability-aom-alarm-exporter.so"; } YRLOG_INFO("exporter {} get library path: {}", exporterType, filePath); return filePath; } +std::once_flag MetricsAdaptor::initFlag; +std::shared_ptr MetricsAdaptor::instance = nullptr; + MetricsAdaptor::MetricsAdaptor() {} void MetricsAdaptor::Init(const nlohmann::json &json, bool userEnable) @@ -87,7 +97,6 @@ void MetricsAdaptor::InitImmediatelyExport(const std::shared_ptr( std::move(exporter), exportConfigs); mp->AddMetricProcessor(std::move(processor)); + } else if (key == PROMETHEUS_PUSH_EXPORTER || key == AOM_ALARM_EXPORTER) { + auto &&exporter = InitHttpExporter(key, IMMEDIATELY_EXPORT, backendName, value); + if (exporter == nullptr) { + YRLOG_ERROR("Failed to init exporter {}", key); + continue; + } + Initialized_ = true; + auto exportConfigs = BuildExportConfigs(value); + exportConfigs.exporterName = key; + exportConfigs.exportMode = MetricsSdk::ExportMode::IMMEDIATELY; + auto processor = + std::make_shared(std::move(exporter), exportConfigs); + mp->AddMetricProcessor(std::move(processor)); } else { YRLOG_WARN("unknown exporter name: {}", key); } @@ -174,6 +196,14 @@ std::shared_ptr MetricsAdaptor::InitHttpExporter(con initConfigJson["rootCertFile"] = Config::Instance().YR_SSL_ROOT_FILE(); initConfigJson["certFile"] = Config::Instance().YR_SSL_CERT_FILE(); initConfigJson["keyFile"] = Config::Instance().YR_SSL_KEY_FILE(); + auto ret = std::getenv(YR_SSL_PASSPHRASE_KEY); + if (ret != nullptr) { + initConfigJson["passphrase"] = ret; + const int replaceOpt = 1; + setenv(YR_SSL_PASSPHRASE_KEY, "", replaceOpt); + } else { + YRLOG_WARN("can not get metrics passphrase from env."); + } } try { initConfig = initConfigJson.dump(); @@ -210,7 +240,7 @@ const MetricsSdk::ExportConfigs MetricsAdaptor::BuildExportConfigs(const nlohman } if (exporterValue.contains("enabledInstruments")) { for (auto &it : exporterValue.at("enabledInstruments").items()) { - YRLOG_INFO("Enabled instrument: {}", it.value()); + YRLOG_INFO("Enabled instrument: {}", it.value().dump()); exportConfigs.enabledInstruments.insert(it.value().get()); } } @@ -550,6 +580,7 @@ ErrorInfo MetricsAdaptor::ReportAlarm(const std::string &name, const std::string "can not find alarm name"); } MetricsApi::AlarmInfo metricsAlarmInfo; + metricsAlarmInfo.id = alarmInfo.id; metricsAlarmInfo.alarmName = alarmInfo.alarmName; metricsAlarmInfo.alarmSeverity = static_cast(alarmInfo.alarmSeverity); metricsAlarmInfo.locationInfo = alarmInfo.locationInfo; @@ -611,7 +642,7 @@ std::shared_ptr MetricsAdaptor::InitFileExporter( } if (!YR::utility::ExistPath(initConfigJson.at("fileDir")) && !YR::utility::Mkdir(initConfigJson.at("fileDir"))) { - YRLOG_ERROR("failed to mkdir{} for exporter {} for backend {} of {}", initConfigJson.at("fileDir"), + YRLOG_ERROR("failed to mkdir{} for exporter {} for backend {} of {}", initConfigJson.at("fileDir").dump(), FILE_EXPORTER, backendKey, backendName); return nullptr; } @@ -631,6 +662,10 @@ std::shared_ptr MetricsAdaptor::InitFileExporter( std::string MetricsAdaptor::GetMetricsFilesName(const std::string &backendName) { + // file reporting is not supported currently,this function is not implemented. + if (backendName == "ds_alarm") { + return backendName + ".alarm.dat"; + } return backendName + "-metrics.data"; } } // namespace Libruntime diff --git a/src/libruntime/metricsadaptor/metrics_adaptor.h b/src/libruntime/metricsadaptor/metrics_adaptor.h index f98600b..5288cc4 100644 --- a/src/libruntime/metricsadaptor/metrics_adaptor.h +++ b/src/libruntime/metricsadaptor/metrics_adaptor.h @@ -33,6 +33,7 @@ #include "metrics/sdk/metric_processor.h" #include "src/dto/invoke_options.h" #include "src/libruntime/err_type.h" +#include "src/utility/singleton.h" namespace YR { namespace Libruntime { @@ -44,7 +45,13 @@ namespace MetricsPlugin = observability::plugin::metrics; class MetricsAdaptor { public: MetricsAdaptor(); - + MetricsAdaptor(const MetricsAdaptor&) = delete; + MetricsAdaptor& operator=(const MetricsAdaptor&) = delete; + static std::shared_ptr GetInstance() + { + std::call_once(initFlag, []() { instance = std::make_shared(); }); + return instance; + } void Init(const nlohmann::json &json, bool userEnable); void SetContextAttr(const std::string &attr, const std::string &value); std::string GetContextValue(const std::string &attr) const; @@ -110,6 +117,8 @@ private: std::mutex alarm_mutex_{}; std::mutex uint64_counter_mutex_{}; std::mutex double_counter_mutex_{}; + static std::shared_ptr instance; + static std::once_flag initFlag; }; } // namespace Libruntime } // namespace YR diff --git a/src/libruntime/objectstore/datasystem_object_store.cpp b/src/libruntime/objectstore/datasystem_object_store.cpp index 8863513..5e27296 100644 --- a/src/libruntime/objectstore/datasystem_object_store.cpp +++ b/src/libruntime/objectstore/datasystem_object_store.cpp @@ -152,12 +152,14 @@ RetryInfo ObjectStoreGetImplWithoutRetry(const std::vector &ids, in ErrorInfo DSCacheObjectStore::Init(const std::string &addr, int port, std::int32_t connectTimeout) { - return DSCacheObjectStore::Init(addr, port, false, false, "", SensitiveValue{}, "", connectTimeout); + return DSCacheObjectStore::Init(addr, port, false, false, "", SensitiveValue{}, "", "", "", SensitiveValue{}, + connectTimeout); } ErrorInfo DSCacheObjectStore::Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const SensitiveValue &runtimePrivateKey, - const std::string &dsPublicKey, const std::int32_t connectTimeout) + const std::string &dsPublicKey, const SensitiveValue &token, const std::string &ak, + const SensitiveValue &sk, const std::int32_t connectTimeout) { ErrorInfo err; YRLOG_DEBUG("Datasystem object store init, ip = {}, port = {}, connectTimeout is {}", ip, port, connectTimeout); @@ -169,6 +171,9 @@ ErrorInfo DSCacheObjectStore::Init(const std::string &ip, int port, bool enableD connectOpts.clientPrivateKey = runtimePrivateKey; connectOpts.serverPublicKey = dsPublicKey; } + if (enableDsAuth) { + GetAuthConnectOpts(connectOpts, ak, sk, token); + } return err; } @@ -214,6 +219,7 @@ ErrorInfo DSCacheObjectStore::CreateBuffer(const std::string &objectId, size_t d ds::CreateParam param; param.writeMode = static_cast(createParam.writeMode); param.consistencyType = static_cast(createParam.consistencyType); + param.cacheType = static_cast(createParam.cacheType); ds::Status status = dsClient->Create(objectId, dataSize, param, dataBuffer); if (!status.IsOk()) { if (status.GetCode() != ds::StatusCode::K_OC_ALREADY_SEALED) { @@ -268,6 +274,7 @@ ErrorInfo DSCacheObjectStore::Put(std::shared_ptr data, const std::strin ds::CreateParam param; param.writeMode = static_cast(createParam.writeMode); param.consistencyType = static_cast(createParam.consistencyType); + param.cacheType = static_cast(createParam.cacheType); ds::Status status = dsClient->Create(objId, static_cast(data->GetSize()), param, dataBuffer); if (!status.IsOk()) { if (status.GetCode() != ds::StatusCode::K_OC_ALREADY_SEALED) { @@ -295,6 +302,19 @@ ErrorInfo DSCacheObjectStore::Put(std::shared_ptr data, const std::strin return ErrorInfo(); } +ErrorInfo DSCacheObjectStore::UpdateToken(datasystem::SensitiveValue token) +{ + if (!isInit) { + return ErrorInfo(); + } + return ErrorInfo(); +} + +ErrorInfo DSCacheObjectStore::UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) +{ + return ErrorInfo(); +} + ErrorInfo DSCacheObjectStore::GetImpl(const std::vector &ids, int timeoutMS, std::vector> &sbufferList) { @@ -393,6 +413,16 @@ std::vector DSCacheObjectStore::QueryGlobalReference(const std::vectorReleaseGRefs(YR::utility::ParseRealJobId(remoteId)); + auto code = + YR::Libruntime::ConvertDatasystemErrorToCore(status.GetCode(), static_cast(status.GetCode())); + auto msg = status.GetMsg(); + return ErrorInfo(code, ModuleCode::DATASYSTEM, msg); +} + ErrorInfo DSCacheObjectStore::GenerateKey(std::string &key, const std::string &prefix, bool isPut) { // if DS-client is not initialized, do not init here, because it may cause memory occupation @@ -403,12 +433,22 @@ ErrorInfo DSCacheObjectStore::GenerateKey(std::string &key, const std::string &p return ErrorInfo(); } std::string msg; - ds::Status status = dsClient->GenerateObjectKey(prefix, key); + ds::Status status = dsClient->GenerateKey(prefix, key); msg = "failed to GenerateKey, errMsg:" + status.ToString(); RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); return ErrorInfo(); } +ErrorInfo DSCacheObjectStore::GetPrefix(const std::string &key, std::string &prefix) +{ + OBJ_STORE_INIT_ONCE(); + std::string msg; + ds::Status status = dsClient->GetPrefix(key, prefix); + msg = "failed to GetPrefix, errMsg:" + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + void DSCacheObjectStore::SetTenantId(const std::string &tenantId) { (void)datasystem::Context::SetTenantId(tenantId); @@ -434,7 +474,8 @@ void DSCacheObjectStore::Shutdown() } ds::Status status = dsClient->ShutDown(); if (!status.IsOk()) { - YRLOG_WARN("DSCacheObjectStore Shutdown fail. Status code: {}, Msg: {}", status.GetCode(), status.ToString()); + YRLOG_WARN("DSCacheObjectStore Shutdown fail. Status code: {}, Msg: {}", fmt::underlying(status.GetCode()), + status.ToString()); } isInit = false; } diff --git a/src/libruntime/objectstore/datasystem_object_store.h b/src/libruntime/objectstore/datasystem_object_store.h index d9c5a3c..ea496fc 100644 --- a/src/libruntime/objectstore/datasystem_object_store.h +++ b/src/libruntime/objectstore/datasystem_object_store.h @@ -35,7 +35,8 @@ public: ErrorInfo Init(const std::string &addr, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, - const std::string &dsPublicKey, std::int32_t connectTimeout) override; + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk, std::int32_t connectTimeout) override; ErrorInfo Init(datasystem::ConnectOptions &inputConnOpt) override; @@ -59,6 +60,8 @@ public: // Get a list of objects from the datasystem. MultipleResult Get(const std::vector &ids, int timeoutMS) override; + ErrorInfo UpdateToken(datasystem::SensitiveValue token) override; + ErrorInfo UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) override; ErrorInfo IncreGlobalReference(const std::vector &objectIds) override; @@ -72,8 +75,12 @@ public: std::vector QueryGlobalReference(const std::vector &objectIds) override; + ErrorInfo ReleaseGRefs(const std::string &remoteId) override; + ErrorInfo GenerateKey(std::string &key, const std::string &prefix, bool isPut) override; + ErrorInfo GetPrefix(const std::string &key, std::string &prefix) override; + void SetTenantId(const std::string &tenantId) override; void Clear(); diff --git a/src/libruntime/objectstore/memory_store.cpp b/src/libruntime/objectstore/memory_store.cpp index 30cdbe9..4ab47ee 100644 --- a/src/libruntime/objectstore/memory_store.cpp +++ b/src/libruntime/objectstore/memory_store.cpp @@ -45,6 +45,11 @@ ErrorInfo MemoryStore::GenerateReturnObjectIds(const std::string &requestId, return ErrorInfo(); } +ErrorInfo MemoryStore::GetPrefix(const std::string &key, std::string &prefix) +{ + return dsObjectStore->GetPrefix(key, prefix); +} + // Default Put to datasystem ErrorInfo MemoryStore::Put(std::shared_ptr data, const std::string &objID, const std::unordered_set &nestedID, const CreateParam &createParam) @@ -54,7 +59,7 @@ ErrorInfo MemoryStore::Put(std::shared_ptr data, const std::string &objI ErrorInfo MemoryStore::Put(std::shared_ptr data, const std::string &objID, const std::unordered_set &nestedID, bool toDataSystem, - const CreateParam &createParam) + const CreateParam &createParam, const ErrorInfo &err) { std::unique_lock lock(mu); std::unique_lock objectDetailLock; @@ -106,6 +111,7 @@ ErrorInfo MemoryStore::Put(std::shared_ptr data, const std::string &objI // save to memory objDetail->data = data; objDetail->storeInMemory = true; + objDetail->err = err; totalInMemBufSize += data->GetSize(); return ErrorInfo(); } else { @@ -119,6 +125,12 @@ ErrorInfo MemoryStore::Put(std::shared_ptr data, const std::string &objI } } +ErrorInfo MemoryStore::Put(std::shared_ptr data, const std::string &objID, + const std::unordered_set &nestedID, bool toDataSystem, const ErrorInfo &err) +{ + return this->Put(data, objID, nestedID, toDataSystem, {}, err); +} + SingleResult MemoryStore::Get(const std::string &objID, int timeoutMS) { std::unique_lock lock(mu); @@ -208,47 +220,37 @@ std::pair> MemoryStore::IncreaseGRefInMemory lock.unlock(); auto result = std::make_pair(ErrorInfo(), std::vector()); - if (shouldIncreInDS.empty()) { - return result; - } - YRLOG_DEBUG("ds increase id {}..., objs size {}", shouldIncreInDS[0], shouldIncreInDS.size()); - if (!remoteId.empty()) { - result = dsObjectStore->IncreGlobalReference(shouldIncreInDS, remoteId); - } else { - auto err = dsObjectStore->IncreGlobalReference(shouldIncreInDS); - if (!err.OK()) { - YRLOG_ERROR("id [{}, ...] datasystem IncreGlobalReference failed. Code: {}, MCode: {}, Msg: {}", - shouldIncreInDS[0], err.Code(), err.MCode(), err.Msg()); - result.first = err; + if (!shouldIncreInDS.empty()) { + YRLOG_DEBUG("ds increase id {}..., objs size {}", shouldIncreInDS[0], shouldIncreInDS.size()); + if (!remoteId.empty()) { + result = dsObjectStore->IncreGlobalReference(shouldIncreInDS, remoteId); + } else { + auto err = dsObjectStore->IncreGlobalReference(shouldIncreInDS); + if (!err.OK()) { + YRLOG_ERROR("id [{}, ...] datasystem IncreGlobalReference failed. Code: {}, MCode: {}, Msg: {}", + shouldIncreInDS[0], fmt::underlying(err.Code()), fmt::underlying(err.MCode()), err.Msg()); + result.first = err; + } + } + for (auto objDetail : increseObjectDetails) { + std::unique_lock objectDetailLock(objDetail->_mu); + if (!result.first.OK()) { + objDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; + } else { + objDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASE_IN_DS; + } + objDetail->notification.Notify(); + objectDetailLock.unlock(); } - } - - for (auto objDetail : increseObjectDetails) { - std::unique_lock objectDetailLock(objDetail->_mu); if (!result.first.OK()) { - objDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; - } else { - objDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASE_IN_DS; + YRLOG_WARN("increase global reference failed, ds increase id {}..., objs size is {}, remote id is {}", + shouldIncreInDS[0], shouldIncreInDS.size(), remoteId); + return result; } - objDetail->notification.Notify(); - objectDetailLock.unlock(); } - - if (!result.first.OK()) { - YRLOG_WARN("increase global reference failed, ds increase id {}..., objs size is {}, remote id is {}", - shouldIncreInDS[0], shouldIncreInDS.size(), remoteId); - return result; - } - for (auto objDetail : waitObjectDetails) { - std::unique_lock objectDetailLock(objDetail->_mu); - if (!objDetail->notification.WaitForNotificationWithTimeout( - absl::Seconds(Config::Instance().DS_CONNECT_TIMEOUT_SEC()))) { - objDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; - } else { - objDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASE_IN_DS; - } - objectDetailLock.unlock(); + objDetail->notification.WaitForNotificationWithTimeout( + absl::Seconds(Config::Instance().DS_CONNECT_TIMEOUT_SEC())); } return result; } @@ -257,7 +259,7 @@ ErrorInfo MemoryStore::IncreDSGlobalReference(const std::vector &ob { std::unique_lock lock(mu); std::vector shouldIncreInDS; - std::vector> increseObjectDetails; + std::vector> increaseObjectDetails; std::vector> waitObjectDetails; for (const std::string &id : objectIds) { auto [it, insertSuccess] = storeMap.emplace(id, std::make_shared()); @@ -271,45 +273,36 @@ ErrorInfo MemoryStore::IncreDSGlobalReference(const std::vector &ob // Not exist before, should Incre in DS shouldIncreInDS.push_back(id); objDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASING_IN_DS; - increseObjectDetails.push_back(objDetail); + increaseObjectDetails.push_back(objDetail); } objectDetailLock.unlock(); } lock.unlock(); - if (shouldIncreInDS.empty()) { - return ErrorInfo(); - } + if (!shouldIncreInDS.empty()) { + YRLOG_DEBUG("ds increase id {}..., objs size {}", shouldIncreInDS[0], shouldIncreInDS.size()); + auto err = dsObjectStore->IncreGlobalReference(shouldIncreInDS); - YRLOG_DEBUG("ds increase id {}..., objs size {}", shouldIncreInDS[0], shouldIncreInDS.size()); - auto err = dsObjectStore->IncreGlobalReference(shouldIncreInDS); + for (auto objDetail : increaseObjectDetails) { + std::unique_lock objectDetailLock(objDetail->_mu); + if (!err.OK()) { + objDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; + } else { + objDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASE_IN_DS; + } + objDetail->notification.Notify(); + objectDetailLock.unlock(); + } - for (auto objDetail : increseObjectDetails) { - std::unique_lock objectDetailLock(objDetail->_mu); if (!err.OK()) { - objDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; - } else { - objDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASE_IN_DS; + YRLOG_ERROR("id [{}, ...] datasystem IncreGlobalReference failed. Code: {}, MCode: {}, Msg: {}", + shouldIncreInDS[0], fmt::underlying(err.Code()), fmt::underlying(err.MCode()), err.Msg()); + return err; } - objDetail->notification.Notify(); - objectDetailLock.unlock(); } - if (!err.OK()) { - YRLOG_ERROR("id [{}, ...] datasystem IncreGlobalReference failed. Code: {}, MCode: {}, Msg: {}", - shouldIncreInDS[0], err.Code(), err.MCode(), err.Msg()); - return err; - } - - for (auto objDetail : waitObjectDetails) { - std::unique_lock objectDetailLock(objDetail->_mu); - if (!objDetail->notification.WaitForNotificationWithTimeout( - absl::Seconds(Config::Instance().DS_CONNECT_TIMEOUT_SEC()))) { - objDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; - } else { - objDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASE_IN_DS; - } - objectDetailLock.unlock(); + for (auto detail : waitObjectDetails) { + detail->notification.WaitForNotificationWithTimeout(absl::Seconds(Config::Instance().DS_CONNECT_TIMEOUT_SEC())); } return ErrorInfo(); } @@ -428,6 +421,11 @@ std::vector MemoryStore::QueryGlobalReference(const std::vectorReleaseGRefs(remoteId); +} + void MemoryStore::Clear() { std::lock_guard lock(mu); @@ -463,8 +461,8 @@ ErrorInfo MemoryStore::DoPutToDS(const std::string &id, const CreateParam &creat ErrorInfo dsErr = dsObjectStore->IncreGlobalReference({id}); objDetail->notification.Notify(); if (dsErr.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("id {} datasystem IncreGlobalReference failed. Code: {}, MCode: {}, Msg: {}", id, dsErr.Code(), - dsErr.MCode(), dsErr.Msg()); + YRLOG_ERROR("id {} datasystem IncreGlobalReference failed. Code: {}, MCode: {}, Msg: {}", id, + fmt::underlying(dsErr.Code()), fmt::underlying(dsErr.MCode()), dsErr.Msg()); return dsErr; } objDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASE_IN_DS; @@ -480,8 +478,8 @@ ErrorInfo MemoryStore::DoPutToDS(const std::string &id, const CreateParam &creat YRLOG_DEBUG("try put id {} to dsObjectStore", id); ErrorInfo dsErr = dsObjectStore->Put(objDetail->data, id, std::unordered_set(), createParam); if (dsErr.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("id {} datasystem Put failed. Code: {}, MCode: {}, Msg: {}", id, dsErr.Code(), dsErr.MCode(), - dsErr.Msg()); + YRLOG_ERROR("id {} datasystem Put failed. Code: {}, MCode: {}, Msg: {}", id, fmt::underlying(dsErr.Code()), + fmt::underlying(dsErr.MCode()), dsErr.Msg()); dsObjectStore->DecreGlobalReference({id}); return dsErr; } @@ -550,33 +548,26 @@ ErrorInfo MemoryStore::IncreaseObjRef(const std::vector &objectIds) objectDetailLock.unlock(); } lock.unlock(); - if (objectIdsNeedIncre.empty()) { - return ErrorInfo(); - } - ErrorInfo dsErr = dsObjectStore->IncreGlobalReference(objectIdsNeedIncre); - for (auto objectDetail : increseObjectDetails) { - std::unique_lock objectDetailLock(objectDetail->_mu); + if (!objectIdsNeedIncre.empty()) { + ErrorInfo dsErr = dsObjectStore->IncreGlobalReference(objectIdsNeedIncre); + for (auto objectDetail : increseObjectDetails) { + std::unique_lock objectDetailLock(objectDetail->_mu); + if (dsErr.Code() != ErrorCode::ERR_OK) { + objectDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; + } + objectDetail->notification.Notify(); + objectDetailLock.unlock(); + } if (dsErr.Code() != ErrorCode::ERR_OK) { - objectDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; + YRLOG_ERROR("id [{}, ...] datasystem IncreGlobalReference failed. Code: {}, MCode: {}, Msg: {}", + objectIdsNeedIncre[0], fmt::underlying(dsErr.Code()), fmt::underlying(dsErr.MCode()), + dsErr.Msg()); + return dsErr; } - objectDetail->notification.Notify(); - objectDetailLock.unlock(); } - if (dsErr.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("id [{}, ...] datasystem IncreGlobalReference failed. Code: {}, MCode: {}, Msg: {}", - objectIdsNeedIncre[0], dsErr.Code(), dsErr.MCode(), dsErr.Msg()); - return dsErr; - } - for (auto objectDetail : waitObjectDetails) { - std::unique_lock objectDetailLock(objectDetail->_mu); - if (!objectDetail->notification.WaitForNotificationWithTimeout( - absl::Seconds(Config::Instance().DS_CONNECT_TIMEOUT_SEC()))) { - objectDetail->increInDataSystemEnum = IncreInDataSystemEnum::NOT_INCREASE_IN_DS; - } else { - objectDetail->increInDataSystemEnum = IncreInDataSystemEnum::INCREASE_IN_DS; - } - objectDetailLock.unlock(); + objectDetail->notification.WaitForNotificationWithTimeout( + absl::Seconds(Config::Instance().DS_CONNECT_TIMEOUT_SEC())); } return ErrorInfo(); } @@ -762,7 +753,7 @@ bool MemoryStore::SetError(const std::vector &objs, const ErrorInfo bool MemoryStore::SetError(const std::string &id, const ErrorInfo &err) { - YRLOG_DEBUG("set id {}, error {}", id, err.Msg()); + YRLOG_DEBUG("set id {}, msg: {}", id, err.Msg()); std::list callbacks; std::list callbacksWithData; { @@ -813,7 +804,7 @@ void MemoryStore::AddOutput(const std::string &generatorId, const std::string &o std::unique_lock objectDetailLock(detail->_mu); YRLOG_DEBUG( "start add object id into generator res map, id is {}, index is {}, err code is {}, err msg is {}", - objectId, index, errInfo.Code(), errInfo.Msg()); + objectId, index, fmt::underlying(errInfo.Code()), errInfo.Msg()); auto [iterator, insertSuccess] = detail->generatorResMap.emplace(index, GeneratorRes{.objectId = objectId, .err = errInfo}); if (!insertSuccess) { @@ -824,9 +815,8 @@ void MemoryStore::AddOutput(const std::string &generatorId, const std::string &o } detail->cv.notify_all(); } else { - YRLOG_WARN( - "generator id {} does not exist in store map, object id is {}, index is {}, error code is {}, msg is {}", - generatorId, objectId, index, errInfo.Code(), errInfo.Msg()); + YRLOG_WARN("generator id {} does not exist in store map, object id is {}, index is {}, ec is {}, msg is {}", + generatorId, objectId, index, fmt::underlying(errInfo.Code()), errInfo.Msg()); } } @@ -868,7 +858,7 @@ std::pair MemoryStore::GetOutput(const std::string &gene YRLOG_DEBUG( "succeed to get generator res, res object id is {}, err code is {}, err msg is {}, index is {}, " "generator id is {}", - res.objectId, res.err.Code(), res.err.Msg(), detail->getIndex, generatorId); + res.objectId, fmt::underlying(res.err.Code()), res.err.Msg(), detail->getIndex.load(), generatorId); detail->getIndex++; return std::make_pair(res.err, res.objectId); } @@ -992,12 +982,12 @@ bool MemoryStore::AddReturnObject(const std::string &objId) { std::lock_guard lock(mu); auto [it, insertSuccess] = storeMap.emplace(objId, std::make_shared()); - if (!insertSuccess) { - return false; - } std::shared_ptr objDetail = it->second; std::unique_lock objectDetailLock(objDetail->_mu); objDetail->localRefCount++; + if (!insertSuccess) { + return false; + } objDetail->ready = false; } waitingObjectManager->SetUnready(objId); @@ -1093,7 +1083,9 @@ std::string MemoryStore::GetInstanceRoute(const std::string &objId, int timeoutS f = objDetail->instanceRouteFuture; } if (timeoutSec != NO_TIMEOUT && f.wait_for(std::chrono::seconds(timeoutSec)) != std::future_status::ready) { - YRLOG_WARN("get instance route timeout, return empty string as instanceRoute. objectID is: {}.", objId); + if (timeoutSec != ZERO_TIMEOUT) { + YRLOG_WARN("get instance route timeout, return empty string as instanceRoute. objectID is: {}.", objId); + } return retInstanceRoute; } return f.get(); diff --git a/src/libruntime/objectstore/memory_store.h b/src/libruntime/objectstore/memory_store.h index 6b765d6..53ba53f 100644 --- a/src/libruntime/objectstore/memory_store.h +++ b/src/libruntime/objectstore/memory_store.h @@ -83,7 +83,9 @@ public: const std::unordered_set &nestedID, const CreateParam &createParam = {}); ErrorInfo Put(std::shared_ptr data, const std::string &objID, const std::unordered_set &nestedID, bool toDataSystem, - const CreateParam &createParam = {}); + const CreateParam &createParam = {}, const ErrorInfo &err = {}); + ErrorInfo Put(std::shared_ptr data, const std::string &objID, + const std::unordered_set &nestedID, bool toDataSystem, const ErrorInfo &err); SingleResult Get(const std::string &objID, int timeoutMS); MultipleResult Get(const std::vector &ids, int timeoutMS); ErrorInfo IncreGlobalReference(const std::vector &objectIds); @@ -95,9 +97,11 @@ public: std::pair> DecreGlobalReference(const std::vector &objectIds, const std::string &remoteId); std::vector QueryGlobalReference(const std::vector &objectIds); + ErrorInfo ReleaseGRefs(const std::string &remoteId); ErrorInfo GenerateKey(std::string &key, const std::string &prefix, bool isPut = true); ErrorInfo GenerateReturnObjectIds(const std::string &requestId, std::vector &returnObjs); + ErrorInfo GetPrefix(const std::string &key, std::string &prefix); void Clear(); // check whether ids in DS, if not, put to DS. diff --git a/src/libruntime/objectstore/object_store.h b/src/libruntime/objectstore/object_store.h index 1ac4683..28bb1e4 100644 --- a/src/libruntime/objectstore/object_store.h +++ b/src/libruntime/objectstore/object_store.h @@ -59,7 +59,9 @@ public: virtual ErrorInfo Init(const std::string &addr, int port, std::int32_t connectTimeout) = 0; virtual ErrorInfo Init(const std::string &addr, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, - const std::string &dsPublicKey, std::int32_t connectTimeout) = 0; + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, + const std::string &ak, const datasystem::SensitiveValue &sk, + std::int32_t connectTimeout) = 0; virtual ErrorInfo Init(datasystem::ConnectOptions &inputConnOpt) = 0; virtual ErrorInfo CreateBuffer(const std::string &objectId, size_t dataSize, std::shared_ptr &dataBuf, const CreateParam &createParam) = 0; @@ -85,10 +87,14 @@ public: return std::make_pair(ErrorInfo(), std::vector()); } virtual std::vector QueryGlobalReference(const std::vector &objectIds) = 0; + virtual ErrorInfo ReleaseGRefs(const std::string &remoteId) = 0; virtual ErrorInfo GenerateKey(std::string &key, const std::string &prefix, bool isPut) = 0; + virtual ErrorInfo GetPrefix(const std::string &key, std::string &prefix) = 0; virtual void SetTenantId(const std::string &tenantId) = 0; virtual void Clear() = 0; virtual void Shutdown() = 0; + virtual ErrorInfo UpdateToken(datasystem::SensitiveValue token) = 0; + virtual ErrorInfo UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) = 0; }; class MsgpackBuffer : public NativeBuffer { diff --git a/src/libruntime/statestore/datasystem_state_store.cpp b/src/libruntime/statestore/datasystem_state_store.cpp index d8c3160..d947845 100644 --- a/src/libruntime/statestore/datasystem_state_store.cpp +++ b/src/libruntime/statestore/datasystem_state_store.cpp @@ -32,13 +32,15 @@ using datasystem::Status; ErrorInfo DSCacheStateStore::Init(const std::string &ip, int port, std::int32_t connectTimeout) { - return this->Init(ip, port, false, false, "", datasystem::SensitiveValue{}, "", connectTimeout); + return this->Init(ip, port, false, false, "", datasystem::SensitiveValue{}, "", datasystem::SensitiveValue{}, "", + datasystem::SensitiveValue{}, connectTimeout); } ErrorInfo DSCacheStateStore::Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, const std::string &dsPublicKey, - std::int32_t connectTimeout) + const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk, std::int32_t connectTimeout) { ErrorInfo err; YRLOG_DEBUG("Datasystem State store init, ip = {}, port = {}", ip, port); @@ -50,6 +52,9 @@ ErrorInfo DSCacheStateStore::Init(const std::string &ip, int port, bool enableDs connectOpts.clientPrivateKey = runtimePrivateKey; connectOpts.serverPublicKey = dsPublicKey; } + if (enableDsAuth) { + GetAuthConnectOpts(connectOpts, ak, sk, token); + } return err; } @@ -74,6 +79,9 @@ ErrorInfo DSCacheStateStore::DoInitOnce(void) std::string msg = "failed to init state store, errMsg:" + status.ToString(); RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_CONNECTION_FAILED, msg); isInit = true; + if (!tokenUpdated.Empty()) { + (void)UpdateToken(tokenUpdated); + } return ErrorInfo(); } @@ -104,6 +112,62 @@ ErrorInfo DSCacheStateStore::Init(const DsConnectOptions &options) ErrorInfo DSCacheStateStore::GenerateKey(std::string &returnKey) { STATE_STORE_INIT_ONCE(); + returnKey = dsStateClient->GenerateKey(); + return ErrorInfo(); +} + +ErrorInfo DSCacheStateStore::DoHealthCheck() +{ + if (dsStateClient == nullptr) { + YRLOG_INFO("ds client is nullptr"); + return ErrorInfo(); + } + Status status = dsStateClient->HealthCheck(); + if (!status.IsOk()) { + YRLOG_INFO("health check failed, code:{}, msg:{}", fmt::underlying(status.GetCode()), status.ToString()); + return GenerateSetErrorInfo(status); + } + return ErrorInfo(); +} + + +ErrorInfo DSCacheStateStore::StartHealthCheck() +{ + STATE_STORE_INIT_ONCE(); + if (timerWorker_ == nullptr) { + timerWorker_ = std::make_shared(); + YRLOG_INFO("start ds client health check"); + timer_ = timerWorker_->CreateTimer(DS_HEALTHCHECK_INTERVAL, DS_HEALTHCHECK_TIMES, + [this]() { + auto error = this->DoHealthCheck(); + if (!error.OK()) { + YRLOG_WARN("ds object client health check failed for {} times.", ++lostHealthTimes_); + if (lostHealthTimes_ >= DS_HEALTHCHECK_FAILED_LIMIT) { + YRLOG_ERROR("ds object client health check failed reach max limit 10, start exiting."); + timer_->cancel(); + YR::Libruntime::AlarmInfo dsAlarmInfo; + dsAlarmInfo.id = "YuanrongDsWorkerUnhealthy00001"; + dsAlarmInfo.customOptions["site"] = MetricsAdaptor::GetInstance()->GetContextValue("site"); + dsAlarmInfo.customOptions["application_id"] = + MetricsAdaptor::GetInstance()->GetContextValue("application_id"); + dsAlarmInfo.customOptions["service_id"] = + MetricsAdaptor::GetInstance()->GetContextValue("service_id"); + dsAlarmInfo.customOptions["tenant_id"] = + MetricsAdaptor::GetInstance()->GetContextValue("tenant_id"); + dsAlarmInfo.customOptions["clear_type"] = "ADAC"; + dsAlarmInfo.customOptions["op_type"] = "firing"; + dsAlarmInfo.alarmName = "yr_ds_alarm"; + dsAlarmInfo.alarmSeverity = AlarmSeverity::MAJOR; + dsAlarmInfo.cause = error.Msg(); + MetricsAdaptor::GetInstance()->SetAlarm("yr_ds_alarm", + "ds client health check failed reach max limit 20", + dsAlarmInfo); + } + } else { + lostHealthTimes_ = 0; + } + }); + } return ErrorInfo(); } @@ -115,6 +179,7 @@ ErrorInfo DSCacheStateStore::Write(const std::string &key, std::shared_ptr(setParam.existence); p.writeMode = static_cast(setParam.writeMode); p.ttlSecond = setParam.ttlSecond; + p.cacheType = static_cast(setParam.cacheType); Status status = dsStateClient->Set(key, stringView, p); if (!status.IsOk()) { return GenerateSetErrorInfo(status); @@ -130,6 +195,7 @@ ErrorInfo DSCacheStateStore::Write(std::shared_ptr value, SetParam setPa p.existence = static_cast(setParam.existence); p.writeMode = static_cast(setParam.writeMode); p.ttlSecond = setParam.ttlSecond; + p.cacheType = static_cast(setParam.cacheType); returnKey = dsStateClient->Set(stringView, p); return ErrorInfo(); } @@ -152,6 +218,7 @@ ErrorInfo DSCacheStateStore::MSetTx(const std::vector &keys, p.existence = static_cast(mSetParam.existence); p.writeMode = static_cast(mSetParam.writeMode); p.ttlSecond = mSetParam.ttlSecond; + p.cacheType = static_cast(mSetParam.cacheType); Status status = dsStateClient->MSetTx(keys, valViews, p); if (!status.IsOk()) { return GenerateSetErrorInfo(status); @@ -176,7 +243,8 @@ MultipleReadResult DSCacheStateStore::Read(const std::vector &keys, GetParams params; ErrorInfo err = GetValueWithTimeout(keys, result, timeoutMS, params); if (err.Code() != ErrorCode::ERR_OK) { - YRLOG_ERROR("GetValueWithTimeout error: Code:{}, MCode:{}, Msg:{}.", err.Code(), err.MCode(), err.Msg()); + YRLOG_ERROR("GetValueWithTimeout error: Code:{}, MCode:{}, Msg:{}.", fmt::underlying(err.Code()), + fmt::underlying(err.MCode()), err.Msg()); return std::make_pair(result, err); } if (!allowPartial) { @@ -196,12 +264,26 @@ MultipleReadResult DSCacheStateStore::GetWithParam(const std::vector &keys, std::vector &outSizes) +{ + STATE_STORE_INIT_ONCE(); + ErrorInfo errInfo; + Status status = dsStateClient->QuerySize(keys, outSizes); + if (status.IsError()) { + YRLOG_ERROR("failed to query the value sizes from state store, errMsg:{}", status.ToString()); + ErrorCode errCode = ConvertDatasystemErrorToCore(status.GetCode(), ErrorCode::ERR_DATASYSTEM_FAILED); + errInfo.SetErrCodeAndMsg(errCode, YR::Libruntime::ModuleCode::DATASYSTEM, status.ToString(), status.GetCode()); + } + return errInfo; +} + ErrorInfo DSCacheStateStore::Del(const std::string &key) { STATE_STORE_INIT_ONCE(); @@ -230,6 +312,21 @@ MultipleDelResult DSCacheStateStore::Del(const std::vector &keys) return std::make_pair(failedKeys, errInfo); } +MultipleExistResult DSCacheStateStore::Exist(const std::vector &keys) +{ + std::vector exists; + STATE_STORE_INIT_ONCE_RETURN_PAIR(exists); + exists.resize(keys.size()); + ErrorInfo errInfo; + Status status = dsStateClient->Exist(keys, exists); + if (status.IsError()) { + YRLOG_ERROR("failed to query keys from state store, errMsg:{}", status.ToString()); + ErrorCode errCode = ConvertDatasystemErrorToCore(status.GetCode(), ErrorCode::ERR_DATASYSTEM_FAILED); + errInfo.SetErrCodeAndMsg(errCode, YR::Libruntime::ModuleCode::DATASYSTEM, status.ToString(), status.GetCode()); + } + return std::make_pair(exists, errInfo); +} + void DSCacheStateStore::Shutdown() { if (dsStateClient == nullptr) { @@ -237,9 +334,35 @@ void DSCacheStateStore::Shutdown() } Status status = dsStateClient->ShutDown(); if (!status.IsOk()) { - YRLOG_WARN("DSCacheStateStore Shutdown fail. Status code: {}, Msg: {}", status.GetCode(), status.ToString()); + YRLOG_WARN("DSCacheStateStore Shutdown fail. Status code: {}, Msg: {}", fmt::underlying(status.GetCode()), + status.ToString()); } isInit = false; } + +ErrorInfo DSCacheStateStore::UpdateToken(datasystem::SensitiveValue token) +{ + if (!isInit) { + return ErrorInfo(); + } + ErrorInfo err; + return err; +} + +ErrorInfo DSCacheStateStore::UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) +{ + ErrorInfo err; + return err; +} + +ErrorInfo DSCacheStateStore::HealthCheck() +{ + STATE_STORE_INIT_ONCE(); + auto status = dsStateClient->HealthCheck(); + if (!status.OK()) { + return ErrorInfo(ErrorCode::ERR_CONNECTION_FAILED, "datasystem client is not healthy"); + } + return ErrorInfo(); +} } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/statestore/datasystem_state_store.h b/src/libruntime/statestore/datasystem_state_store.h index 6816980..0db8284 100644 --- a/src/libruntime/statestore/datasystem_state_store.h +++ b/src/libruntime/statestore/datasystem_state_store.h @@ -23,14 +23,20 @@ #include "datasystem/kv_client.h" #include "src/dto/buffer.h" +#include "src/libruntime/metricsadaptor/metrics_adaptor.h" #include "src/libruntime/statestore/state_store.h" #include "src/libruntime/utils/constants.h" #include "src/libruntime/utils/datasystem_utils.h" #include "src/utility/logger/logger.h" +#include "src/utility/timer_worker.h" namespace YR { namespace Libruntime { +const int DS_HEALTHCHECK_INTERVAL = 3000; // 1s +const int DS_HEALTHCHECK_TIMES = -1; // unlimited retry +const int DS_HEALTHCHECK_FAILED_LIMIT = 10; + class DataSystemReadOnlyBuffer : public ReadOnlySharedBuffer { public: DataSystemReadOnlyBuffer(std::shared_ptr buf) @@ -92,7 +98,8 @@ public: ErrorInfo Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, - const std::string &dsPublicKey, std::int32_t connectTimeout) override; + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk, std::int32_t connectTimeout) override; ErrorInfo Init(datasystem::ConnectOptions &inputConnOpt) override; void InitOnce(void); @@ -108,6 +115,8 @@ public: MultipleDelResult Del(const std::vector &keys) override; + MultipleExistResult Exist(const std::vector &keys) override; + SingleReadResult Read(const std::string &key, int timeoutMS) override; MultipleReadResult Read(const std::vector &keys, int timeoutMS, bool allowPartial) override; @@ -115,12 +124,24 @@ public: MultipleReadResult GetWithParam(const std::vector &keys, const GetParams ¶ms, int timeoutMs) override; + ErrorInfo QuerySize(const std::vector &keys, std::vector &outSizes) override; + void Shutdown() override; + ErrorInfo UpdateToken(datasystem::SensitiveValue token) override; + + ErrorInfo UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) override; + ErrorInfo GenerateKey(std::string &returnKey) override; ErrorInfo Write(std::shared_ptr value, SetParam setParam, std::string &returnKey) override; + ErrorInfo StartHealthCheck() override; + + ErrorInfo DoHealthCheck(); + + ErrorInfo HealthCheck() override; + private: ErrorInfo DoInitOnce(void); @@ -233,6 +254,9 @@ private: ErrorInfo initErr; datasystem::SensitiveValue tokenUpdated; datasystem::ConnectOptions connectOpts; + std::shared_ptr timerWorker_; + std::shared_ptr timer_; + int lostHealthTimes_ = 0; }; #define STATE_STORE_INIT_ONCE() \ diff --git a/src/libruntime/statestore/state_store.h b/src/libruntime/statestore/state_store.h index 9d7bb04..a5d36be 100644 --- a/src/libruntime/statestore/state_store.h +++ b/src/libruntime/statestore/state_store.h @@ -29,6 +29,7 @@ namespace YR { namespace Libruntime { typedef std::pair, ErrorInfo> MultipleDelResult; +typedef std::pair, ErrorInfo> MultipleExistResult; typedef std::pair, ErrorInfo> SingleReadResult; typedef std::pair>, ErrorInfo> MultipleReadResult; @@ -102,7 +103,9 @@ public: */ virtual ErrorInfo Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, - const std::string &dsPublicKey, std::int32_t connectTimeout) = 0; + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, + const std::string &ak, const datasystem::SensitiveValue &sk, + std::int32_t connectTimeout) = 0; virtual ErrorInfo Init(const DsConnectOptions &options) = 0; @@ -152,6 +155,14 @@ public: */ virtual MultipleReadResult Read(const std::vector &keys, int timeoutMS, bool allowPartial) = 0; + /** + * @brief Query the value sizes of all the given keys by datasystem StateClient. + * @param[in] keys The vector of the keys. + * @param[in] outSizes A vector of value sizes that query from datasystem + * @return return ErrorInfo + */ + virtual ErrorInfo QuerySize(const std::vector &keys, std::vector &outSizes) = 0; + /** * @brief Delete a key by datasystem StateClient. * @param[in] key The key to delete. @@ -166,14 +177,32 @@ public: */ virtual MultipleDelResult Del(const std::vector &keys) = 0; + /** + * @brief Query all the given keys by datasystem StateClient. + * @param[in] keys The vector of the keys. + * @return MultipleExistResult The exists result and ErrorInfo. + */ + virtual MultipleExistResult Exist(const std::vector &keys) = 0; + /** * @brief Shutdown notify datasystem to release resource. */ virtual void Shutdown() = 0; + /** + * @brief update token + */ + virtual ErrorInfo UpdateToken(datasystem::SensitiveValue token) = 0; + virtual ErrorInfo GenerateKey(std::string &returnKey) = 0; + virtual ErrorInfo StartHealthCheck() = 0; + + virtual ErrorInfo UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) = 0; + virtual ErrorInfo Write(std::shared_ptr value, SetParam setParam, std::string &returnKey) = 0; + + virtual ErrorInfo HealthCheck() = 0; }; } // namespace Libruntime diff --git a/src/libruntime/streamstore/datasystem_stream_store.cpp b/src/libruntime/streamstore/datasystem_stream_store.cpp new file mode 100644 index 0000000..01e729a --- /dev/null +++ b/src/libruntime/streamstore/datasystem_stream_store.cpp @@ -0,0 +1,233 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem_stream_store.h" + +#include +#include "src/libruntime/err_type.h" +#include "src/libruntime/statestore/datasystem_state_store.h" +#include "src/libruntime/utils/datasystem_utils.h" +#include "src/libruntime/utils/exception.h" +#include "src/libruntime/utils/utils.h" + +namespace YR { +namespace Libruntime { +using namespace datasystem; +using datasystem::Status; +using YR::Libruntime::ErrorCode; +using YR::Libruntime::ModuleCode; +ErrorInfo DatasystemStreamStore::Init(const std::string &ip, int port) +{ + return this->Init(ip, port, false, false, "", datasystem::SensitiveValue{}, "", datasystem::SensitiveValue{}, "", + datasystem::SensitiveValue{}); +} + +ErrorInfo DatasystemStreamStore::Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, + const std::string &runtimePublicKey, + const datasystem::SensitiveValue &runtimePrivateKey, + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, + const std::string &ak, const datasystem::SensitiveValue &sk) +{ + this->ip = ip; + this->port = port; + this->enableDsAuth = enableDsAuth; + this->encryptEnable = encryptEnable; + this->runtimePublicKey = runtimePublicKey; + this->runtimePrivateKey = runtimePrivateKey; + this->dsPublicKey = dsPublicKey; + this->ak = ak; + this->sk = sk; + this->token = token; + return ErrorInfo(); +} + +ErrorInfo DatasystemStreamStore::Init(datasystem::ConnectOptions &inputConnOpt, std::shared_ptr stateStore) +{ + this->dsStateStore = stateStore; + return Init(inputConnOpt); +} + +ErrorInfo DatasystemStreamStore::Init(datasystem::ConnectOptions &inputConnOpt) +{ + this->connectOpts.host = inputConnOpt.host; + this->connectOpts.port = inputConnOpt.port; + this->connectOpts.clientPublicKey = inputConnOpt.clientPublicKey; + this->connectOpts.clientPrivateKey = inputConnOpt.clientPrivateKey; + this->connectOpts.serverPublicKey = inputConnOpt.serverPublicKey; + this->connectOpts.accessKey = inputConnOpt.accessKey; + this->connectOpts.secretKey = inputConnOpt.secretKey; + this->connectOpts.connectTimeoutMs = inputConnOpt.connectTimeoutMs; + this->connectOpts.tenantId = inputConnOpt.tenantId; + return ErrorInfo(); +} + +ErrorInfo DatasystemStreamStore::DoInitOnce() +{ + streamClient = std::make_shared(this->connectOpts); + Status status = streamClient->Init(); + std::string msg = "failed to init stream client, errMsg:" + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + isInit = true; + if (Config::Instance().ENABLE_DS_HEALTH_CHECK()) { + YRLOG_INFO("ds client heath check is enabled"); + if (dsStateStore == nullptr) { + YRLOG_ERROR("ds state client for stream client healthy check is null, run without healthy check"); + return ErrorInfo(); + } + return dsStateStore->StartHealthCheck(); + } + return ErrorInfo(); +} + +void DatasystemStreamStore::InitOnce(void) +{ + std::call_once(this->initFlag, [this]() { this->initErr = this->DoInitOnce(); }); +} + +std::pair DatasystemStreamStore::CheckAndBuildProducerConf( + const ProducerConf &producerConf) +{ + datasystem::ProducerConf dsConf; + ErrorInfo errInfo; + dsConf.delayFlushTime = producerConf.delayFlushTime; + dsConf.pageSize = producerConf.pageSize; + dsConf.maxStreamSize = producerConf.maxStreamSize; + dsConf.autoCleanup = producerConf.autoCleanup; + dsConf.encryptStream = producerConf.encryptStream; + dsConf.retainForNumConsumers = producerConf.retainForNumConsumers; + dsConf.reserveSize = producerConf.reserveSize; + auto it = producerConf.extendConfig.find(std::string(STREAM_MODE)); + if (it != producerConf.extendConfig.end()) { + auto mode = it->second; + auto modeIt = streamModeMap.find(mode); + if (modeIt != streamModeMap.end()) { + dsConf.streamMode = modeIt->second; + } else { + errInfo = ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, YR::Libruntime::ModuleCode::RUNTIME, + "unsupported stream mode: " + mode + ", only support MPMC、MPSC and SPSC."); + } + } + return {dsConf, errInfo}; +} + +ErrorInfo DatasystemStreamStore::CreateStreamProducer(const std::string &streamName, + std::shared_ptr &producer, + ProducerConf producerConf) +{ + if (producer == nullptr) { + return ErrorInfo(ERR_PARAM_INVALID, RUNTIME, + "check the second param of YR::CreateProducer interface, nullptr is ont supported."); + } + auto [dsConf, err] = CheckAndBuildProducerConf(producerConf); + if (!err.OK()) { + return err; + } + err = EnsureInit(); + if (!err.OK()) { + return err; + } + Status status = streamClient->CreateProducer(streamName, producer->GetProducer(), dsConf); + auto msg = "failed to CreateProducer, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +ErrorInfo DatasystemStreamStore::CreateStreamConsumer(const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck) +{ + if (consumer == nullptr) { + return ErrorInfo(ERR_PARAM_INVALID, RUNTIME, + "check the third param of YR::Subscribe interface, nullptr is ont supported."); + } + auto err = EnsureInit(); + if (!err.OK()) { + return err; + } + datasystem::SubscriptionConfig dsConfig(config.subscriptionName, typeMap[config.subscriptionType]); + Status status = streamClient->Subscribe(streamName, dsConfig, consumer->GetConsumer(), autoAck); + auto msg = "failed to Subscribe, streamName: " + streamName + " , errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +ErrorInfo DatasystemStreamStore::DeleteStream(const std::string &streamName) +{ + if (auto err = EnsureInit(); !err.OK()) { + return err; + } + Status status = streamClient->DeleteStream(streamName); + auto msg = "failed to DeleteStream, streamName: " + streamName + " , errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +ErrorInfo DatasystemStreamStore::QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) +{ + if (auto err = EnsureInit(); !err.OK()) { + return err; + } + Status status = streamClient->QueryGlobalProducersNum(streamName, gProducerNum); + auto msg = "failed to QueryGlobalProducersNum, streamName: " + streamName + " , errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +ErrorInfo DatasystemStreamStore::QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) +{ + if (auto err = EnsureInit(); !err.OK()) { + return err; + } + Status status = streamClient->QueryGlobalConsumersNum(streamName, gConsumerNum); + auto msg = "failed to QueryGlobalConsumersNum, streamName: " + streamName + " , errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +void DatasystemStreamStore::Shutdown() +{ + if (streamClient == nullptr) { + return; + } + Status status = streamClient->ShutDown(); + if (!status.IsOk()) { + YRLOG_WARN("DatasystemStreamStore Shutdown fail. Status code: {}, Msg: {}", fmt::underlying(status.GetCode()), + status.ToString()); + } +} + +ErrorInfo DatasystemStreamStore::UpdateToken(datasystem::SensitiveValue token) +{ + ErrorInfo err; + return err; +} + +ErrorInfo DatasystemStreamStore::UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) +{ + ErrorInfo err; + return err; +} +ErrorInfo DatasystemStreamStore::EnsureInit() +{ + std::unique_lock lock(initMutex); + if (isInit && initErr.OK()) { + return initErr; + } + initErr = DoInitOnce(); + isInit = true; + return initErr; +} +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/streamstore/datasystem_stream_store.h b/src/libruntime/streamstore/datasystem_stream_store.h new file mode 100644 index 0000000..16c66dc --- /dev/null +++ b/src/libruntime/streamstore/datasystem_stream_store.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "datasystem/stream_client.h" +#include "src/dto/stream_conf.h" +#include "src/libruntime/statestore/state_store.h" +#include "src/libruntime/utils/constants.h" +#include "src/utility/logger/logger.h" +#include "stream_store.h" + +namespace YR { +namespace Libruntime { +class DatasystemStreamStore : public StreamStore { +public: + ErrorInfo Init(const std::string &ip, int port) override; + + ErrorInfo Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, + const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk) override; + + ErrorInfo Init(datasystem::ConnectOptions &inputConnOpt) override; + + ErrorInfo Init(datasystem::ConnectOptions &inputConnOpt, std::shared_ptr dsStateStore) override; + + ErrorInfo CreateStreamProducer(const std::string &streamName, std::shared_ptr &producer, + ProducerConf producerConf = {}) override; + + ErrorInfo CreateStreamConsumer(const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck = false) override; + + ErrorInfo DeleteStream(const std::string &streamName) override; + + ErrorInfo QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) override; + + ErrorInfo QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) override; + + void Shutdown() override; + + ErrorInfo UpdateToken(datasystem::SensitiveValue token) override; + + ErrorInfo UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) override; + +private: + void InitOnce(void); + ErrorInfo DoInitOnce(void); + + std::pair CheckAndBuildProducerConf(const ProducerConf &producerConf); + + bool isInit = false; + std::mutex initMutex; + std::once_flag initFlag; + ErrorInfo initErr; + bool isReady = false; + std::shared_ptr streamClient; + std::string ip; + int port; + bool enableDsAuth = false; + bool encryptEnable = false; + std::string runtimePublicKey; + datasystem::SensitiveValue runtimePrivateKey; + std::string dsPublicKey; + std::string ak; + datasystem::SensitiveValue sk; + datasystem::SensitiveValue token; + datasystem::ConnectOptions connectOpts; + std::unordered_map typeMap = { + {libruntime::SubscriptionType::STREAM, datasystem::SubscriptionType::STREAM}, + {libruntime::SubscriptionType::KEY_PARTITIONS, datasystem::SubscriptionType::KEY_PARTITIONS}, + {libruntime::SubscriptionType::ROUND_ROBIN, datasystem::SubscriptionType::ROUND_ROBIN}, + {libruntime::SubscriptionType::UNKNOWN, datasystem::SubscriptionType::UNKNOWN}}; + std::unordered_map streamModeMap = { + {std::string(MPMC), datasystem::StreamMode::MPMC}, + {std::string(MPSC), datasystem::StreamMode::MPSC}, + {std::string(SPSC), datasystem::StreamMode::SPSC}}; + std::shared_ptr dsStateStore; + + ErrorInfo EnsureInit(void); +}; + +#define STREAM_STORE_INIT_ONCE() \ + do { \ + InitOnce(); \ + if (!initErr.OK()) { \ + return initErr; \ + } \ + } while (0) + +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/streamstore/stream_producer_consumer.cpp b/src/libruntime/streamstore/stream_producer_consumer.cpp new file mode 100644 index 0000000..b9f7ef4 --- /dev/null +++ b/src/libruntime/streamstore/stream_producer_consumer.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "stream_producer_consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/stream_config.h" +#include "src/dto/config.h" +#include "src/libruntime/utils/datasystem_utils.h" + +namespace YR { +namespace Libruntime { +ErrorInfo StreamProducer::Send(const Element &element) +{ + datasystem::Status status = dsProducer->Send(datasystem::Element(element.ptr, element.size, element.id)); + auto msg = "failed to send element, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +ErrorInfo StreamProducer::Send(const Element &element, int64_t timeoutMs) +{ + datasystem::Status status = dsProducer->Send(datasystem::Element(element.ptr, element.size, element.id), timeoutMs); + auto msg = "failed to send element, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +ErrorInfo StreamProducer::Flush() +{ + datasystem::Status status = dsProducer->Flush(); + auto msg = "producer failed to Flush, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +ErrorInfo StreamProducer::Close() +{ + datasystem::Status status = dsProducer->Close(); + auto msg = "failed to Close producer, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +std::shared_ptr &StreamProducer::GetProducer() +{ + return dsProducer; +} + +ErrorInfo StreamConsumer::Receive(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) +{ + std::vector receiver; + datasystem::Status status = dsConsumer->Receive(expectNum, timeoutMs, receiver); + uint64_t totalSize = 0; + for (auto element : receiver) { + YRLOG_DEBUG("receive stream with expectNum with element id: {}, size, {}", element.id, element.size); + outElements.push_back(Element(element.ptr, element.size, element.id)); + if (totalSize > std::numeric_limits::max() - element.size) { + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the total size exceed the uint64_t max value"); + } + totalSize += element.size; + } + auto msg = "failed to Receive element with expectNum, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + if (Config::Instance().STREAM_RECEIVE_LIMIT() != 0 && totalSize > Config::Instance().STREAM_RECEIVE_LIMIT()) { + YRLOG_ERROR("receive size: {} exceeded the limit: {}", totalSize, Config::Instance().STREAM_RECEIVE_LIMIT()); + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the total size exceed the limit"); + } + return ErrorInfo(); +} + +ErrorInfo StreamConsumer::Receive(uint32_t timeoutMs, std::vector &outElements) +{ + std::vector receiver; + datasystem::Status status = dsConsumer->Receive(timeoutMs, receiver); + uint64_t totalSize = 0; + for (auto element : receiver) { + YRLOG_DEBUG("receive stream with element id: {}, size, {}", element.id, element.size); + outElements.push_back(Element(element.ptr, element.size, element.id)); + if (totalSize > std::numeric_limits::max() - element.size) { + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the total size exceed the uint64_t max value"); + } + totalSize += element.size; + } + auto msg = "failed to Receive element, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + if (Config::Instance().STREAM_RECEIVE_LIMIT() != 0 && totalSize > Config::Instance().STREAM_RECEIVE_LIMIT()) { + YRLOG_ERROR("receive size: {} exceeded the limit: {}", totalSize, Config::Instance().STREAM_RECEIVE_LIMIT()); + return ErrorInfo(YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, YR::Libruntime::ModuleCode::RUNTIME, + "the total size exceed the limit"); + } + return ErrorInfo(); +} + +ErrorInfo StreamConsumer::Ack(uint64_t elementId) +{ + datasystem::Status status = dsConsumer->Ack(elementId); + auto msg = "failed to Ack, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +ErrorInfo StreamConsumer::Close() +{ + datasystem::Status status = dsConsumer->Close(); + auto msg = "failed to Close consumer, errMsg: " + status.ToString(); + RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); + return ErrorInfo(); +} + +std::shared_ptr &StreamConsumer::GetConsumer() +{ + return dsConsumer; +} + +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/streamstore/stream_producer_consumer.h b/src/libruntime/streamstore/stream_producer_consumer.h new file mode 100644 index 0000000..8a314c8 --- /dev/null +++ b/src/libruntime/streamstore/stream_producer_consumer.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "src/dto/stream_conf.h" +#include "src/libruntime/err_type.h" + +namespace YR { +namespace Libruntime { + +class StreamProducer { +public: + virtual ErrorInfo Send(const Element &element); + + virtual ErrorInfo Send(const Element &element, int64_t timeoutMs); + + virtual ErrorInfo Flush(); + + virtual ErrorInfo Close(); + + std::shared_ptr &GetProducer(); + +private: + std::shared_ptr dsProducer; +}; + +class StreamConsumer { +public: + virtual ErrorInfo Receive(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements); + + virtual ErrorInfo Receive(uint32_t timeoutMs, std::vector &outElements); + + virtual ErrorInfo Ack(uint64_t elementId); + + virtual ErrorInfo Close(); + + std::shared_ptr &GetConsumer(); + +private: + std::shared_ptr dsConsumer; +}; +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/streamstore/stream_store.h b/src/libruntime/streamstore/stream_store.h new file mode 100644 index 0000000..d0293da --- /dev/null +++ b/src/libruntime/streamstore/stream_store.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "datasystem/utils/connection.h" +#include "datasystem/utils/sensitive_value.h" +#include "src/libruntime/err_type.h" +#include "src/libruntime/statestore/state_store.h" +#include "stream_producer_consumer.h" + +namespace YR { +namespace Libruntime { +class StreamStore { +public: + virtual ~StreamStore() = default; + virtual ErrorInfo Init(const std::string &ip, int port) = 0; + virtual ErrorInfo Init(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, + const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, + const std::string &ak, const datasystem::SensitiveValue &sk) = 0; + virtual ErrorInfo Init(datasystem::ConnectOptions &inputConnOpt) = 0; + virtual ErrorInfo Init(datasystem::ConnectOptions &inputConnOpt, std::shared_ptr dsStateStore) = 0; + virtual ErrorInfo CreateStreamProducer(const std::string &streamName, std::shared_ptr &producer, + ProducerConf producerConf = {}) = 0; + virtual ErrorInfo CreateStreamConsumer(const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck = false) = 0; + virtual ErrorInfo DeleteStream(const std::string &streamName) = 0; + virtual ErrorInfo QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) = 0; + virtual ErrorInfo QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) = 0; + virtual void Shutdown() = 0; + virtual ErrorInfo UpdateToken(datasystem::SensitiveValue token) = 0; + virtual ErrorInfo UpdateAkSk(std::string ak, datasystem::SensitiveValue sk) = 0; +}; +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/traceadaptor/exporter/log_file_exporter.cpp b/src/libruntime/traceadaptor/exporter/log_file_exporter.cpp new file mode 100644 index 0000000..ecc1ce5 --- /dev/null +++ b/src/libruntime/traceadaptor/exporter/log_file_exporter.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "log_file_exporter.h" +#include +#include +#include "src/utility/logger/logger.h" + +namespace YR { +namespace Libruntime { +constexpr int TRACE_ID_LEN = 32; +constexpr int SPAN_ID_LEN = 16; + +std::string TraceIdToString(const opentelemetry::trace::TraceId& traceId) +{ + char traceIdHex[TRACE_ID_LEN]; + traceId.ToLowerBase16(traceIdHex); + return std::string(traceIdHex, TRACE_ID_LEN); +} + +std::string SpanIdToString(const opentelemetry::trace::SpanId& spanId) +{ + char spanIdHex[SPAN_ID_LEN]; + spanId.ToLowerBase16(spanIdHex); + return std::string(spanIdHex, SPAN_ID_LEN); +} + +common_sdk::ExportResult LogFileExporter::Export( + const nostd::span> &spans) noexcept +{ + if (isShutDown) { + YRLOG_ERROR("[YRLOG File Exporter] Exporting {} log(s) failed, exporter is shutdown", spans.size()); + return common_sdk::ExportResult::kFailure; + } + for (auto &recordable : spans) { + auto span = std::unique_ptr(static_cast(recordable.release())); + if (span == nullptr) { + continue; + } + std::ostringstream oss; + oss << "span_name: " << span->GetName() << ", " + << "trace_id: " << TraceIdToString(span->GetTraceId()) << ", " + << "span_id: " << SpanIdToString(span->GetSpanId()) << ", " + << "start_time: " << span->GetStartTime().time_since_epoch().count() << "ns" << ", " + << "duration: " << span->GetDuration().count() << "ns" << ", "; + auto attributes = span->GetAttributes(); + oss << "attributes: {"; + for (const auto& [key, value] : attributes) { + oss << " " << key << " = "; + opentelemetry::exporter::ostream_common::print_value(value, oss); + } + oss << "}"; + YRLOG_INFO("trace info: {}", oss.str()); + } + return common_sdk::ExportResult::kSuccess; +} + +std::unique_ptr LogFileExporter::MakeRecordable() noexcept +{ + return std::make_unique(); +} + +bool LogFileExporter::Shutdown(std::chrono::microseconds timeout) noexcept +{ + isShutDown = true; + return true; +} + +bool LogFileExporter::ForceFlush(std::chrono::microseconds timeout) noexcept +{ + return true; +} + +} +} + diff --git a/src/libruntime/traceadaptor/exporter/log_file_exporter.h b/src/libruntime/traceadaptor/exporter/log_file_exporter.h new file mode 100644 index 0000000..71bce7f --- /dev/null +++ b/src/libruntime/traceadaptor/exporter/log_file_exporter.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nostd = opentelemetry::nostd; +namespace trace_sdk = opentelemetry::sdk::trace; +namespace common_sdk = opentelemetry::sdk::common; + +namespace YR { +namespace Libruntime { +std::string TraceIdToString(const opentelemetry::trace::TraceId& traceId); +std::string SpanIdToString(const opentelemetry::trace::SpanId& spanId); + +class LogFileExporter final : public trace_sdk::SpanExporter { +public: + common_sdk::ExportResult Export( + const nostd::span> &spans) noexcept override; + + std::unique_ptr MakeRecordable() noexcept override; + + bool Shutdown(std::chrono::microseconds timeout = std::chrono::microseconds::max()) noexcept override; + + bool ForceFlush(std::chrono::microseconds timeout) noexcept override; + +private: + std::atomic isShutDown{false}; +}; +} +} \ No newline at end of file diff --git a/src/libruntime/traceadaptor/exporter/log_file_exporter_factory.cpp b/src/libruntime/traceadaptor/exporter/log_file_exporter_factory.cpp new file mode 100644 index 0000000..f52114c --- /dev/null +++ b/src/libruntime/traceadaptor/exporter/log_file_exporter_factory.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "log_file_exporter_factory.h" +#include "log_file_exporter.h" + +namespace YR { +namespace Libruntime { + +std::unique_ptr LogFileExporterFactory::Create() +{ + return std::make_unique(); +} + +} // namespace Libruntime +} // namespace YR diff --git a/src/libruntime/traceadaptor/exporter/log_file_exporter_factory.h b/src/libruntime/traceadaptor/exporter/log_file_exporter_factory.h new file mode 100644 index 0000000..bfe341d --- /dev/null +++ b/src/libruntime/traceadaptor/exporter/log_file_exporter_factory.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace YR { +namespace Libruntime { +class LogFileExporterFactory { +public: + static std::unique_ptr Create(); +}; +} // namespace Libruntime +} \ No newline at end of file diff --git a/src/libruntime/traceadaptor/trace_adapter.cpp b/src/libruntime/traceadaptor/trace_adapter.cpp new file mode 100644 index 0000000..cc5be8e --- /dev/null +++ b/src/libruntime/traceadaptor/trace_adapter.cpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "trace_adapter.h" +#include + +#include "src/dto/config.h" +#include "src/libruntime/traceadaptor/exporter/log_file_exporter_factory.h" +#include "src/utility/logger/logger.h" + +namespace YR { +namespace Libruntime { + +TraceAdapter::~TraceAdapter() noexcept +{ +} + +void TraceAdapter::InitTrace(const std::string &serviceName, const bool &enableTrace, const std::string &traceConfig) +{ + enableTrace_ = enableTrace; + YRLOG_DEBUG("init trace, enableTrace is {}, traceConfig is {}", enableTrace, traceConfig); + if (!enableTrace_) { + return; + } + std::vector> processors; + try { + auto confJson = nlohmann::json::parse(traceConfig); + for (auto &element : confJson.items()) { + if (element.key() == OTLP_GRPC_EXPORTER) { + if (!element.value().contains("enable") || !element.value().at("enable").get()) { + YRLOG_INFO("Trace exporter {} is not enabled", OTLP_GRPC_EXPORTER); + continue; + } + if (!element.value().contains("endpoint") + || element.value().at("endpoint").get().empty()) { + YRLOG_INFO("Trace exporter {} endpoint is not valid", OTLP_GRPC_EXPORTER); + continue; + } + OtelGrpcExporterConfig config; + config.endpoint = element.value().at("endpoint").get(); + opentelemetry::sdk::trace::BatchSpanProcessorOptions batchSpanProcessorOptions; + YRLOG_INFO("OtelGrpcExporter is enable, endpoint is {}", config.endpoint); + processors.push_back( + std::unique_ptr(trace_sdk::BatchSpanProcessorFactory::Create( + std::move(InitOtlpGrpcExporter(config)), batchSpanProcessorOptions))); + } else if (element.key() == LOG_FILE_EXPORTER) { + if (!element.value().contains("enable") + || !element.value().at("enable").get()) { + YRLOG_INFO("Trace exporter {} is not enabled", LOG_FILE_EXPORTER); + continue; + } + opentelemetry::sdk::trace::BatchSpanProcessorOptions batchSpanProcessorOptions; + YRLOG_INFO("logFileExporter is enable"); + processors.push_back( + std::unique_ptr(trace_sdk::BatchSpanProcessorFactory::Create( + std::move(InitLogFileExporter()), batchSpanProcessorOptions))); + } + } + } catch (nlohmann::detail::parse_error &e) { + YRLOG_ERROR("Failed to arse trace config json, error: {}", e.what()); + enableTrace_ = false; + return; + } catch (std::exception &e) { + YRLOG_ERROR("Failed to parse trace config json, error: {}", e.what()); + enableTrace_ = false; + return; + } + if (processors.empty()) { + YRLOG_WARN("There is no supported exporter in config"); + enableTrace_ = false; + return; + } + opentelemetry::sdk::resource::ResourceAttributes attributes = { + { opentelemetry::sdk::resource::SemanticConventions::kTelemetrySdkLanguage, "" }, + { opentelemetry::sdk::resource::SemanticConventions::kTelemetrySdkName, "" }, + { opentelemetry::sdk::resource::SemanticConventions::kTelemetrySdkVersion, "" }, + { opentelemetry::sdk::resource::SemanticConventions::kServiceName, serviceName }, + }; + auto provider = std::shared_ptr(std::make_shared( + std::move(processors), opentelemetry::sdk::resource::Resource::Create(attributes))); + trace_api::Provider::SetTracerProvider(provider); +} + +void TraceAdapter::ShutDown() +{ + if (!enableTrace_) { + return; + } + YRLOG_INFO("enter traceAdapter shutDown"); + enableTrace_ = false; + auto provider = trace_api::Provider::GetTracerProvider(); + auto traceProvider = static_cast(provider.get()); + if (traceProvider != nullptr && !traceProvider->ForceFlush()) { + YRLOG_WARN("traceProvider shutDown failed"); + } + std::shared_ptr none; + trace_api::Provider::SetTracerProvider(none); +} + +void TraceAdapter::SetAttr(const std::string &attr, const std::string &value) +{ + attribute_.insert_or_assign(attr, value); +} + +OtelSpan TraceAdapter::StartSpan(const std::string &name, const opentelemetry::common::KeyValueIterable &attributes, + const opentelemetry::trace::SpanContextKeyValueIterable &links, + const opentelemetry::trace::StartSpanOptions &startSpanOptions) +{ + if (enableTrace_) { + auto tracer = GetTracer(); + if (tracer != nullptr) { + return tracer->StartSpan(name, attributes, links, startSpanOptions); + } + } + std::shared_ptr noopTracer = std::make_shared(); + return opentelemetry::nostd::shared_ptr(new trace_api::NoopSpan(noopTracer)); +} + +OtelSpan TraceAdapter::StartSpan(const std::string &name, + const opentelemetry::trace::StartSpanOptions &startSpanOptions) +{ + return StartSpan(name, opentelemetry::common::NoopKeyValueIterable(), opentelemetry::trace::NullSpanContext(), + startSpanOptions); +} + +OtelSpan TraceAdapter::StartSpan( + const std::string &name, + std::vector> attrs, + const opentelemetry::trace::StartSpanOptions &startSpanOptions) +{ + // preset system attr + for (const auto &it : attribute_) { + attrs.emplace_back(it); + } + return StartSpan(name, opentelemetry::common::KeyValueIterableView(attrs), opentelemetry::trace::NullSpanContext(), + startSpanOptions); +} + +opentelemetry::nostd::shared_ptr TraceAdapter::GetTracer(const std::string &name, + const std::string &version) +{ + auto provider = opentelemetry::trace::Provider::GetTracerProvider(); + return provider->GetTracer(name, version); +} + +std::unique_ptr TraceAdapter::InitLogFileExporter() +{ + return LogFileExporterFactory::Create(); +} + +std::unique_ptr TraceAdapter::InitOtlpGrpcExporter( + const OtelGrpcExporterConfig &conf) +{ + if (conf.endpoint.empty()) { + return nullptr; + } + opentelemetry::exporter::otlp::OtlpGrpcExporterOptions options; + options.endpoint = conf.endpoint; + return opentelemetry::exporter::otlp::OtlpGrpcExporterFactory::Create(options); +} +} +} diff --git a/src/libruntime/traceadaptor/trace_adapter.h b/src/libruntime/traceadaptor/trace_adapter.h new file mode 100644 index 0000000..babf0cb --- /dev/null +++ b/src/libruntime/traceadaptor/trace_adapter.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_TRACE_TRACE_ADAPTER_H +#define COMMON_TRACE_TRACE_ADAPTER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "src/utility/singleton.h" +#include "trace_struct.h" + +namespace trace_api = opentelemetry::trace; +namespace trace_sdk = opentelemetry::sdk::trace; + +namespace YR { +namespace Libruntime { + +using OtelSpan = opentelemetry::nostd::shared_ptr; + +class TraceAdapter : public utility::Singleton { +public: + TraceAdapter() = default; + ~TraceAdapter() noexcept override; + + void InitTrace(const std::string &serviceName, const bool &enableTrace, const std::string &traceConfig); + + void SetAttr(const std::string &attr, const std::string &value); + + OtelSpan StartSpan(const std::string &name, const opentelemetry::common::KeyValueIterable &attributes, + const opentelemetry::trace::SpanContextKeyValueIterable &links, + const opentelemetry::trace::StartSpanOptions &startSpanOptions); + + OtelSpan StartSpan(const std::string &name, const opentelemetry::trace::StartSpanOptions &startSpanOptions = {}); + + OtelSpan StartSpan(const std::string &name, + std::vector> attrs, + const opentelemetry::trace::StartSpanOptions &startSpanOptions = {}); + + void ShutDown(); + +private: + bool enableTrace_{ false }; + + std::map attribute_; + + opentelemetry::nostd::shared_ptr GetTracer(const std::string &name = "yuanrong", + const std::string &version = ""); + std::unique_ptr InitOtlpGrpcExporter(const OtelGrpcExporterConfig &conf); + std::unique_ptr InitLogFileExporter(); +}; +} +} +#endif // COMMON_TRACE_TRACE_ADAPTER_H diff --git a/src/libruntime/traceadaptor/trace_struct.h b/src/libruntime/traceadaptor/trace_struct.h new file mode 100644 index 0000000..9dc08a3 --- /dev/null +++ b/src/libruntime/traceadaptor/trace_struct.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_TRACE_TRACE_STRUCT_H +#define COMMON_TRACE_TRACE_STRUCT_H +#include +#include +#include +#include +#include + +namespace YR { +namespace Libruntime { + +const std::string OTLP_GRPC_EXPORTER = "otlpGrpcExporter"; +const std::string LOG_FILE_EXPORTER = "logFileExporter"; + +enum class TraceExporterType : int { + LOG_FILE, OTEL_GRPC, OTEL_HTTP +}; + +struct OtelGrpcExporterConfig { + std::string endpoint; +}; +} +} +#endif // COMMON_TRACE_TRACE_STRUCT_H diff --git a/src/libruntime/utils/grpc_utils.cpp b/src/libruntime/utils/grpc_utils.cpp new file mode 100644 index 0000000..02ac213 --- /dev/null +++ b/src/libruntime/utils/grpc_utils.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "grpc_utils.h" +#include "src/libruntime/utils/hash_utils.h" +#include "src/libruntime/utils/utils.h" +#include "src/utility/logger/logger.h" + +namespace YR { +namespace Libruntime { +const double TIMESTAMP_EXPIRE_DURATION_SECONDS = 60; +std::pair SerializeBodyToString(const ::runtime_rpc::StreamingMessage &message) +{ + const google::protobuf::Descriptor *descriptor = message.GetDescriptor(); + const google::protobuf::Reflection *reflection = message.GetReflection(); + const google::protobuf::OneofDescriptor *oneofDescriptor = descriptor->FindOneofByName("body"); + const google::protobuf::FieldDescriptor *fieldDescriptor = + reflection->GetOneofFieldDescriptor(message, oneofDescriptor); + if (fieldDescriptor) { + const google::protobuf::Message &bodyField = reflection->GetMessage(message, fieldDescriptor); + std::stringstream ss; + SHA256AndHex(bodyField.DebugString(), ss); + return std::make_pair(ss.str(), true); + } else { + YRLOG_ERROR("failed to get body string of message: {}", message.messageid()); + return std::make_pair("", false); + } +} + +bool SignStreamingMessage(const std::string &accessKey, const SensitiveValue &secretKey, + ::runtime_rpc::StreamingMessage &message) +{ + auto timestamp = GetCurrentUTCTime(); + auto [body, serializeSuccess] = SerializeBodyToString(message); + if (body.empty() && !serializeSuccess) { + YRLOG_ERROR("body is empty and serialize failed, message is {}", message.DebugString()); + return false; + } + std::string signKey = accessKey + ":" + timestamp + ":" + body; + auto signature = GetHMACSha256(secretKey, signKey); + + (*message.mutable_metadata())["access_key"] = accessKey; + (*message.mutable_metadata())["signature"] = signature; + (*message.mutable_metadata())["timestamp"] = timestamp; + return true; +} + +std::string SignStreamingMessage(const std::string &accessKey, const SensitiveValue &secretKey) +{ + auto timestamp = GetCurrentUTCTime(); + std::string signKey = accessKey + ":" + timestamp; + auto signature = GetHMACSha256(secretKey, signKey); + return signature; +} + +bool VerifyStreamingMessage(const std::string &accessKey, const SensitiveValue &secretKey, + const ::runtime_rpc::StreamingMessage &message) +{ + auto timestamp = message.metadata().find("timestamp"); + if (timestamp == message.metadata().end() || timestamp->second.empty()) { + YRLOG_ERROR("failed to verify message: {}, failed to find timestamp in meta-data", message.DebugString()); + return false; + } + + auto currentTimestamp = GetCurrentUTCTime(); + if (IsLaterThan(currentTimestamp, timestamp->second, TIMESTAMP_EXPIRE_DURATION_SECONDS)) { + YRLOG_ERROR("failed to verify message: {}, failed to verify timestamp, difference is more than 1 min {} vs {}", + message.messageid(), currentTimestamp, timestamp->second); + return false; + } + + auto signature = message.metadata().find("signature"); + if (signature == message.metadata().end() || signature->second.empty()) { + YRLOG_ERROR("failed to verify message: {}, failed to find signature in meta-data", message.DebugString()); + return false; + } + + std::string signKey = accessKey + ":" + timestamp->second + ":" + SerializeBodyToString(message).first; + if (GetHMACSha256(secretKey, signKey) != signature->second) { + YRLOG_ERROR("failed to verify message: {},", message.DebugString()); + return false; + } + return true; +} + +std::string SignTimestamp(const std::string &accessKey, const SensitiveValue &secretKey, const std::string ×tamp) +{ + return GetHMACSha256(secretKey, accessKey + ":" + timestamp); +} + +bool VerifyTimestamp(const std::string &accessKey, const SensitiveValue &secretKey, const std::string ×tamp, + const std::string &signature) +{ + auto currentTimestamp = GetCurrentUTCTime(); + if (IsLaterThan(currentTimestamp, timestamp, TIMESTAMP_EXPIRE_DURATION_SECONDS)) { + YRLOG_ERROR("failed to verify timestamp, difference is more than 1 min {} vs {}", currentTimestamp, timestamp); + return false; + } + + std::string signKey = accessKey + ":" + timestamp; + if (GetHMACSha256(secretKey, signKey) != signature) { + YRLOG_ERROR("failed to verify timestamp, signature isn't the same"); + return false; + } + return true; +} + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/utils/grpc_utils.h b/src/libruntime/utils/grpc_utils.h new file mode 100644 index 0000000..845c010 --- /dev/null +++ b/src/libruntime/utils/grpc_utils.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "datasystem/utils/sensitive_value.h" +#include "src/libruntime/fsclient/protobuf/runtime_rpc.grpc.pb.h" + +namespace YR { +namespace Libruntime { +using SensitiveValue = datasystem::SensitiveValue; +std::pair SerializeBodyToString(const ::runtime_rpc::StreamingMessage &message); + +bool SignStreamingMessage(const std::string &accessKey, const SensitiveValue &secretKey, + ::runtime_rpc::StreamingMessage &message); + +std::string SignStreamingMessage(const std::string &accessKey, const SensitiveValue &secretKey); + +bool VerifyStreamingMessage(const std::string &accessKey, const SensitiveValue &secretKey, + const ::runtime_rpc::StreamingMessage &message); + +std::string SignTimestamp(const std::string &accessKey, const SensitiveValue &secretKey, const std::string ×tamp); + +bool VerifyTimestamp(const std::string &accessKey, const SensitiveValue &secretKey, const std::string ×tamp, + const std::string &signature); + +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/utils/hash_utils.cpp b/src/libruntime/utils/hash_utils.cpp new file mode 100644 index 0000000..3972f9d --- /dev/null +++ b/src/libruntime/utils/hash_utils.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hash_utils.h" + +#include +#include +#include +#include + +namespace YR { +namespace Libruntime { +const static unsigned int CHAR_TO_HEX = 2; +const static std::string HEX_STRING_SET = "0123456789abcdef"; +const int32_t FIRST_FOUR_BIT_MOVE = 4; +std::string GetHMACSha256(const SensitiveValue &key, const std::string &data) +{ + HMAC_CTX *ctx = HMAC_CTX_new(); + HMAC_Init_ex(ctx, key.GetData(), static_cast(key.GetSize()), EVP_sha256(), nullptr); + HMAC_Update(ctx, reinterpret_cast(&data[0]), data.length()); + unsigned int mdLength = EVP_MAX_MD_SIZE; + unsigned char md[EVP_MAX_MD_SIZE]; + HMAC_Final(ctx, md, &mdLength); + HMAC_CTX_free(ctx); + std::stringstream ss; + ss << std::hex << std::setfill('0'); + for (unsigned int i = 0; i < mdLength; i++) { + ss << std::setw(CHAR_TO_HEX) << static_cast(md[i]); + } + return ss.str(); +} + +void SHA256AndHex(const std::string &input, std::stringstream &output) +{ + unsigned char sha256Chars[SHA256_DIGEST_LENGTH]; + SHA256(reinterpret_cast(input.c_str()), input.size(), sha256Chars); + for (const auto &c : sha256Chars) { + output << HEX_STRING_SET[c >> FIRST_FOUR_BIT_MOVE] << HEX_STRING_SET[c & 0xf]; + } + output << "\n"; +} +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/utils/hash_utils.h b/src/libruntime/utils/hash_utils.h new file mode 100644 index 0000000..dc89db4 --- /dev/null +++ b/src/libruntime/utils/hash_utils.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include "datasystem/utils/sensitive_value.h" + +namespace YR { +namespace Libruntime { +using SensitiveValue = datasystem::SensitiveValue; +std::string GetHMACSha256(const SensitiveValue &key, const std::string &data); +void SHA256AndHex(const std::string &input, std::stringstream &output); +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/utils/http_utils.cpp b/src/libruntime/utils/http_utils.cpp new file mode 100644 index 0000000..350b165 --- /dev/null +++ b/src/libruntime/utils/http_utils.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "http_utils.h" +namespace YR { +namespace Libruntime { +std::string HashToHex(const std::string &message) +{ + std::stringstream ss; + SHA256AndHex(message, ss); + return ss.str(); +} + +std::string GenerateRequestDigest(const std::unordered_map &headers, const std::string &body, + const std::string &url) +{ + std::stringstream ss; + ss << url << "\n"; + if (auto iter = headers.find(TIMESTAMP_KEY); iter != headers.end()) { + ss << iter->first << ": " << iter->second << "\n"; + } + if (auto iter = headers.find(ACCESS_KEY); iter != headers.end()) { + ss << iter->first << ": " << iter->second << "\n"; + } + ss << body; + return ss.str(); +} + +void SignHttpRequest(const std::string &accessKey, const SensitiveValue &secretKey, + std::unordered_map &headers, const std::string &body, + const std::string &url) +{ + auto timestamp = GetCurrentUTCTime(); + headers[TIMESTAMP_KEY] = timestamp; + headers[ACCESS_KEY] = accessKey; + + auto digest = GenerateRequestDigest(headers, body, url); + auto digestHashHex = HashToHex(digest); + auto signature = GetHMACSha256(secretKey, digestHashHex); + std::stringstream rss; + rss << "HMAC-SHA256 timestamp=" << headers.at(TIMESTAMP_KEY) << ",access_key=" << headers.at(ACCESS_KEY) + << ",signature=" << signature; + headers[AUTHORIZATION_KEY] = rss.str(); +} + +bool VerifyHttpRequest(const std::string &accessKey, const SensitiveValue &secretKey, + std::unordered_map &headers, const std::string &body, + const std::string &url) +{ + auto tenantAccessKey = headers.find(ACCESS_KEY); + if (tenantAccessKey == headers.end() || tenantAccessKey->second.empty()) { + YRLOG_ERROR("failed to verify http request: failed to find ACCESS_KEY in headers"); + return false; + } + auto timestamp = headers.find(TIMESTAMP_KEY); + if (timestamp == headers.end() || timestamp->second.empty()) { + YRLOG_ERROR("failed to verify http request: failed to find TIMESTAMP in headers"); + return false; + } + + auto currentTimestamp = GetCurrentUTCTime(); + if (IsLaterThan(currentTimestamp, timestamp->second, TIMESTAMP_EXPIRE_DURATION_SECONDS)) { + YRLOG_ERROR("failed to verify http request: failed to verify timestamp, difference is more than 1 min {} vs {}", + currentTimestamp, timestamp->second); + return false; + } + + auto authorizationValue = headers.find(AUTHORIZATION_KEY); + if (authorizationValue == headers.end() || authorizationValue->second.empty()) { + YRLOG_ERROR("failed to verify http request: failed to find Authorization in headers"); + return false; + } + std::string key = "signature="; + size_t pos = authorizationValue->second.find(key); + if (pos == std::string::npos) { + return false; + } + std::string signature = authorizationValue->second.substr(pos + key.length()); + auto digest = GenerateRequestDigest(headers, body, url); + auto digestHashHex = HashToHex(digest); + if (GetHMACSha256(secretKey, digestHashHex) != signature) { + YRLOG_ERROR("failed to verify http request"); + return false; + } + return true; +} +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/utils/http_utils.h b/src/libruntime/utils/http_utils.h new file mode 100644 index 0000000..15727d3 --- /dev/null +++ b/src/libruntime/utils/http_utils.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include "datasystem/utils/sensitive_value.h" +#include "src/libruntime/utils/hash_utils.h" +#include "src/libruntime/utils/utils.h" +namespace YR { +namespace Libruntime { +const double TIMESTAMP_EXPIRE_DURATION_SECONDS = 60; +const std::string TRACE_ID_KEY_NEW = "X-Trace-Id"; +const std::string AUTHORIZATION_KEY = "Authorization"; +const std::string INSTANCE_CPU_KEY = "X-Instance-Cpu"; +const std::string INSTANCE_MEMORY_KEY = "X-Instance-Memory"; +const std::string TIMESTAMP_KEY = "X-Timestamp"; +const std::string ACCESS_KEY = "X-Access-Key"; +using SensitiveValue = datasystem::SensitiveValue; + +std::string HashToHex(const std::string &message); + +std::string GenerateRequestDigest(const std::unordered_map &headers, const std::string &body, + const std::string &url); + +void SignHttpRequest(const std::string &accessKey, const SensitiveValue &secretKey, + std::unordered_map &headers, const std::string &body, + const std::string &url); + +bool VerifyHttpRequest(const std::string &accessKey, const SensitiveValue &secretKey, + std::unordered_map &headers, const std::string &body, + const std::string &url); +} // namespace Libruntime +} // namespace YR \ No newline at end of file diff --git a/src/libruntime/utils/security.cpp b/src/libruntime/utils/security.cpp index 0bfa257..6231bd5 100644 --- a/src/libruntime/utils/security.cpp +++ b/src/libruntime/utils/security.cpp @@ -44,7 +44,7 @@ ErrorInfo Security::Init() } this->streamReaderThread_ = std::make_unique([this]() { - boost::asio::io_service::work work(this->streamReaderIoContext_); + auto work = boost::asio::make_work_guard(this->streamReaderIoContext_); this->streamReaderIoContext_.run(); }); @@ -73,13 +73,13 @@ ErrorInfo Security::Init() ErrorInfo Security::InitWithDriver(std::shared_ptr librtConfig) { - YRLOG_DEBUG("When init security as driver, enableMTLS is {}, enableAuth is {}", - librtConfig->enableMTLS, librtConfig->enableAuth); + YRLOG_DEBUG("When init security as driver, enableMTLS is {}, enableAuth is {}, ak is empty: {}, sk is empty: {}", + librtConfig->enableMTLS, librtConfig->enableAuth, librtConfig->ak_.empty(), librtConfig->sk_.Empty()); if (librtConfig->enableMTLS) { this->fsConf_.authEnable = librtConfig->enableMTLS; STACK_OF(X509) *ca = GetCAFromFile(librtConfig->verifyFilePath); X509 *cert = GetCertFromFile(librtConfig->certificateFilePath); - EVP_PKEY *pkey = GetPrivateKeyFromFile(librtConfig->privateKeyPath, nullptr); + EVP_PKEY *pkey = GetPrivateKeyFromFile(librtConfig->privateKeyPath, librtConfig->privateKeyPaaswd); this->fsConf_.rootCertData = GetCa(ca); this->fsConf_.certChainData = GetCert(cert); this->fsConf_.privateKeyData = GetPrivateKey(pkey); @@ -88,9 +88,15 @@ ErrorInfo Security::InitWithDriver(std::shared_ptr librtConfig } if (librtConfig->encryptEnable) { this->dsConf_.encryptEnable = librtConfig->encryptEnable; - this->dsConf_.clientPublicKey = GetValueFromFile(librtConfig->runtimePublicKeyPath); - this->dsConf_.clientPrivateKey = GetValueFromFile(librtConfig->runtimePrivateKeyPath); - this->dsConf_.serverPublicKey = GetValueFromFile(librtConfig->dsPublicKeyPath); + this->dsConf_.clientPublicKey = librtConfig->runtimePublicKey; + this->dsConf_.clientPrivateKey = librtConfig->runtimePrivateKey; + this->dsConf_.serverPublicKey = librtConfig->dsPublicKey; + } + if (!librtConfig->ak_.empty() && !librtConfig->sk_.Empty()) { + this->ak_ = librtConfig->ak_; + this->sk_ = SensitiveValue(librtConfig->sk_); + this->isCredential_ = true; + this->dsConf_.authEnable = true; } return ErrorInfo(); } @@ -134,6 +140,16 @@ void Security::StreamReaderWaitHandler(const boost::system::error_code &error) } if (!this->ReadOnce()) { YRLOG_DEBUG_COUNT(LOG_FREQUENT, "Reader read once failed"); + } else { + // in wait handler, token is expected to be updated, so call all token handlers + for (auto &h : this->updateTokenHandlers) { + h(this->token_); + } + YRLOG_INFO("token is updated."); + for (auto &h : this->updateAkSkHandlers) { + h(this->ak_, this->sk_); + } + YRLOG_INFO("ak,sk is updated."); } confStreamDesc_.async_wait(boost::asio::posix::stream_descriptor::wait_read, boost::bind(&Security::StreamReaderWaitHandler, this, boost::asio::placeholders::error)); @@ -205,24 +221,36 @@ bool Security::ReadOnce() this->dsConf_.serverPublicKey = tlsConf.dsserverpublickey(); this->fsConf_.authEnable = tlsConf.serverauthenable(); - if (this->fsConf_.authEnable) { - auto caCertFile = Config::Instance().YR_SSL_ROOT_FILE(); - auto certFile = Config::Instance().YR_SSL_CERT_FILE(); - auto keyFile = Config::Instance().YR_SSL_KEY_FILE(); - STACK_OF(X509) *ca = GetCAFromFile(caCertFile); - X509 *cert = GetCertFromFile(certFile); - EVP_PKEY *privateKey = GetPrivateKeyFromFile(keyFile, nullptr); - this->fsConf_.rootCertData = GetCa(ca); - this->fsConf_.certChainData = GetCert(cert); - this->fsConf_.privateKeyData = GetPrivateKey(privateKey); - ClearPemCerts(privateKey, cert, ca); + this->fsConf_.rootCertData = tlsConf.rootcertdata(); + + if (tlsConf.has_tenantcredentials() && !tlsConf.tenantcredentials().accesskey().empty()) { + this->ak_ = tlsConf.tenantcredentials().accesskey(); + } else { + this->ak_ = tlsConf.accesskey(); + } + + if (tlsConf.has_tenantcredentials() && !tlsConf.tenantcredentials().secretkey().empty()) { + this->sk_ = SensitiveValue(tlsConf.tenantcredentials().secretkey()); + } else { + this->sk_ = SensitiveValue(tlsConf.securitykey()); + } + + if (tlsConf.has_tenantcredentials()) { + this->dk_ = tlsConf.tenantcredentials().datakey(); } + if (tlsConf.has_tenantcredentials()) { + this->isCredential_ = tlsConf.tenantcredentials().iscredential(); + } + + this->token_ = SensitiveValue(tlsConf.token()); + this->fsConnMode_ = tlsConf.enableservermode(); this->serverNameoverride_ = tlsConf.servernameoverride(); - YRLOG_INFO("Read tls config finished, fs auth: {}, ds auth: {}", this->fsConf_.authEnable, - this->dsConf_.authEnable); + YRLOG_INFO("Read tls config finished, fs auth: {}, ds auth: {}, is credential {}, ak {}, sk {}, token {}", + this->fsConf_.authEnable, this->dsConf_.authEnable, isCredential_, !ak_.empty(), !sk_.Empty(), + !token_.Empty()); return true; } @@ -265,10 +293,57 @@ bool Security::GetFunctionSystemConfig(std::string &rootCACert, std::string &cer return this->fsConf_.authEnable; } +void Security::GetToken(SensitiveValue &token) +{ + token = this->token_; +} + +void Security::GetAKSK(std::string &ak, SensitiveValue &sk) +{ + ak = this->ak_; + sk = this->sk_; +} + +void Security::WhenTokenUpdated(std::function updateTokenHandler) +{ + if (ak_.empty() && sk_.Empty()) { + this->updateTokenHandlers.push_back(updateTokenHandler); + } +} + +void Security::WhenAkSkUpdated(std::function updateAkSkHandler) +{ + if (!ak_.empty() && !sk_.Empty()) { + this->updateAkSkHandlers.push_back(updateAkSkHandler); + } +} + void Security::ClearPrivateKey() { this->fsConf_.privateKeyData.Clear(); } +int Security::GetUpdateHandersSize() +{ + return this->updateTokenHandlers.size(); +} + +bool Security::IsFsAuthEnable() +{ + return this->isCredential_ && !this->ak_.empty() && !this->sk_.Empty(); +} + +Credential Security::GetCredential() +{ + return Credential{ak : this->ak_, sk : std::string(this->sk_.GetData(), this->sk_.GetSize()), dk : this->dk_}; +} + +void Security::SetAKSKAndCredential(const std::string &ak, const SensitiveValue &sk) +{ + this->ak_ = ak; + this->sk_ = sk; + this->isCredential_ = true; +} + } // namespace Libruntime } // namespace YR \ No newline at end of file diff --git a/src/libruntime/utils/security.h b/src/libruntime/utils/security.h index 143d6ea..51468c7 100644 --- a/src/libruntime/utils/security.h +++ b/src/libruntime/utils/security.h @@ -67,6 +67,35 @@ public: */ bool GetFunctionSystemConfig(std::string &rootCACert, std::string &certChain, std::string &privateKey); + /** + * @brief Get the Token + * + * @param token runtime's session token for build connections with function system and data system + */ + void GetToken(SensitiveValue &token); + + /** + * @brief Get the AK and SK + * + * @param ak system function runtime's access key for build connections with data system + * @param sk system function runtime's security key for build connections with data system + */ + virtual void GetAKSK(std::string &ak, SensitiveValue &sk); + + /** + * @brief register handler when token updated + * + * @param updateTokenHandler registered token updated event handler, input parameter is the token updated + */ + void WhenTokenUpdated(std::function updateTokenHandler); + + /** + * @brief register handler when token updated + * + * @param updateAkSkHandler registered aksk updated event handler, input parameter is the aksk updated + */ + void WhenAkSkUpdated(std::function updateAkSkHandler); + /** * @brief clear private key */ @@ -74,6 +103,10 @@ public: int GetUpdateHandersSize(); + virtual bool IsFsAuthEnable(); + + Credential GetCredential(); + Security(int confFileNo = STDIN_FILENO, size_t stdinPipeTimeoutMs = DEFAULT_STDIN_PIPE_TIMEOUT_MS); virtual ~Security(); @@ -81,6 +114,9 @@ public: virtual ErrorInfo Init(); virtual ErrorInfo InitWithDriver(std::shared_ptr librtConfig); + + virtual void SetAKSKAndCredential(const std::string &ak, const SensitiveValue &sk); + std::string GetValueFromFile(const std::string &path); private: @@ -115,8 +151,16 @@ private: SensitiveData privateKeyData; }; FunctionSystemSecurityConfig fsConf_; + SensitiveValue token_ = ""; + std::string ak_ = ""; + SensitiveValue sk_ = ""; + std::string dk_ = ""; + bool isCredential_ = false; // true means runtime auth with ds and fs use ak、sk bool fsConnMode_ = false; // false means runtime is server, function system is client std::string serverNameoverride_ = ""; + + std::list> updateTokenHandlers; + std::list> updateAkSkHandlers; size_t stdinPipeTimeoutMs_; }; } // namespace Libruntime diff --git a/src/libruntime/utils/utils.cpp b/src/libruntime/utils/utils.cpp index f107125..fe6d8e5 100644 --- a/src/libruntime/utils/utils.cpp +++ b/src/libruntime/utils/utils.cpp @@ -68,6 +68,13 @@ long long GetCurrentTimestampMs() return std::chrono::duration_cast(duration).count(); } +long long GetCurrentTimestampNs() +{ + auto now = std::chrono::system_clock::now(); + auto duration = now.time_since_epoch(); + return std::chrono::duration_cast(duration).count(); +} + std::string GetCurrentUTCTime() { std::ostringstream oss; @@ -176,6 +183,19 @@ std::shared_ptr GetChannelCreds(std::shared_ptr GetServerCreds(std::shared_ptr security) { bool enable = false; @@ -241,4 +261,26 @@ bool WillSizeOverFlow(size_t a, size_t b) { return b > (std::numeric_limits::max() - a); } + +int unhexlify(std::string input, char *ascii) +{ + auto first = input.c_str(); + auto last = first + input.size(); + while (first != last) { + int top = to_int(*first++); + int bot = to_int(*first++); + if (top == -1 or bot == -1) + return -1; // error + *ascii++ = (top << 4) + bot; + } + return 0; +} + +std::string GetEnvValue(const std::string &key) +{ + if (const char *env = std::getenv(key.c_str())) { + return std::string(env); + } + return std::string(""); +} } // namespace YR diff --git a/src/libruntime/utils/utils.h b/src/libruntime/utils/utils.h index 3810a55..11b6ef1 100644 --- a/src/libruntime/utils/utils.h +++ b/src/libruntime/utils/utils.h @@ -45,6 +45,7 @@ IpAddrInfo ParseIpAddr(const std::string &addr); void ParseIpAddr(const std::string &addr, std::string &ip, int32_t &port); std::string GetIpAddr(const std::string &ip, int port); long long GetCurrentTimestampMs(); +long long GetCurrentTimestampNs(); void SetCallResultWithStackTraceInfo(std::vector &infos, CallResult &callResult); std::vector GetStackTraceInfos(const NotifyRequest &req); void GetServerName(std::shared_ptr security, std::string &serverName); @@ -53,9 +54,13 @@ std::shared_ptr GetServerCreds(std::shared_ptr lock(mu); exceptionIds[id] = err; exceptionNum++; @@ -129,7 +129,7 @@ bool WaitingObjectManager::SetReady(const std::string &id) void WaitingObjectManager::SetError(const std::string &id, const ErrorInfo &err) { - YRLOG_DEBUG("set id {}, error {}", id, err.Msg()); + YRLOG_DEBUG("set id {}, msg: {}", id, err.Msg()); std::lock_guard lock(unreadyObjectMapMutex); if (unreadyObjectMap.count(id) > 0) { for (std::shared_ptr &waitingEntity : unreadyObjectMap[id]) { diff --git a/src/proto/libruntime.proto b/src/proto/libruntime.proto index 485c9f4..9217119 100644 --- a/src/proto/libruntime.proto +++ b/src/proto/libruntime.proto @@ -41,7 +41,9 @@ enum LanguageType { enum ApiType { Function = 0; - Posix = 1; + Faas = 1; + Posix = 2; + Serve = 3; } enum Signal { diff --git a/src/scene/downgrade.cpp b/src/scene/downgrade.cpp new file mode 100644 index 0000000..e11aadb --- /dev/null +++ b/src/scene/downgrade.cpp @@ -0,0 +1,226 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "downgrade.h" +namespace YR { +namespace scene { +using namespace YR::utility; +using namespace YR::Libruntime; +DowngradeController::DowngradeController(const std::string &functionId, std::shared_ptr clientsMgr, + std::shared_ptr security) + : functionId_(functionId), clientsMgr_(clientsMgr), security_(security) +{ + YRLOG_DEBUG("[DEBUG] functionId_: {}", functionId); // delete + isFrontendFunction_ = IsFrontendFunction(functionId); + existIngress_ = + !YR::GetEnvValue(FRONTEND_ADDRESS_ENV).empty() && !YR::GetEnvValue(INVOCATION_URL_PREFIX_ENV).empty(); +} + +DowngradeController::~DowngradeController() +{ + Stop(); +} + +YR::Libruntime::ErrorInfo DowngradeController::Init() +{ + if (init_) { + YR::Libruntime::ErrorInfo(); + } + auto watchFileName = YR::GetEnvValue(DOWNGRADE_FILE_ENV); + if (isFrontendFunction_ || watchFileName.empty()) { + YRLOG_DEBUG("is frontend function: {}, or DOWNGRADE_FILE env empty: {}", isFrontendFunction_, watchFileName); + return YR::Libruntime::ErrorInfo(); + } + ParseDowngrade(watchFileName); + watcher_ = std::make_shared( + watchFileName, std::bind(&DowngradeController::ParseDowngrade, this, std::placeholders::_1)); + watcher_->Start(); + init_ = true; + return YR::Libruntime::ErrorInfo(); +} + +YR::Libruntime::ErrorInfo DowngradeController::InitApiClient() +{ + if (apiClient_ != nullptr) { + return YR::Libruntime::ErrorInfo(); + } + apiClient_ = std::make_shared(functionId_, clientsMgr_, security_); + return apiClient_->Init(); +} + +// content: {"downgrade": true} +void DowngradeController::ParseDowngrade(const std::string &fileName) +{ + std::ifstream file(fileName); + if (!file.is_open()) { + YRLOG_DEBUG("{} not exit", fileName); + return; + } + nlohmann::json j; + try { + file >> j; + isDowngradeEnabled_ = j.value("downgrade", false); + YRLOG_DEBUG("update downgrade enable {}", isDowngradeEnabled_); + } catch (const nlohmann::json::parse_error &e) { + YRLOG_WARN("json parse error: {}", e.what()); + isDowngradeEnabled_ = false; + } +} + +bool DowngradeController::ShouldDowngrade(const std::shared_ptr spec) const +{ + if (!existIngress_ || isFrontendFunction_) { + return false; + } + if (spec->downgradeFlag_ || isDowngradeEnabled_) { + return true; + } + return false; +} + +void DowngradeController::Downgrade(const std::shared_ptr spec, const InvocationCallback &callback) +{ + if (spec == nullptr) { + return; + } + std::call_once(flag_, [&, this]() { err_ = InitApiClient(); }); + if (!err_.OK()) { + YRLOG_DEBUG("init api client err {}", err_.CodeAndMsg()); + callback(spec->requestId, err_.Code(), err_.Msg()); + return; + } + YRLOG_DEBUG("downgrade request id: {}", spec->requestId); + apiClient_->InvocationAsync(spec, callback); +} + +bool DowngradeController::ShouldFaultDowngrade() const +{ + return !isFrontendFunction_ && existIngress_; +} + +bool DowngradeController::IsFrontendFunction(const std::string &function) +{ + if (function.size() < FAAS_FRONTEND_FUNCTION_NAME_PREFIX.size()) { + return false; + } + return function.substr(0, FAAS_FRONTEND_FUNCTION_NAME_PREFIX.size()) == FAAS_FRONTEND_FUNCTION_NAME_PREFIX; +} + +void DowngradeController::Stop() +{ + if (!init_) { + return; + } + if (watcher_) { + watcher_->Stop(); + } + init_ = false; +} + +ApiClient::ApiClient(const std::string &functionId, std::shared_ptr clientsMgr, + std::shared_ptr security) + : functionId_(functionId), clientsMgr_(clientsMgr), security_(security) +{ +} + +YR::Libruntime::ErrorInfo ApiClient::Init() +{ + if (init_) { + return YR::Libruntime::ErrorInfo(); + } + businessId_ = YR::GetEnvValue(YR_BUSINESS_ID_ENV); + auto address = YR::GetEnvValue(FRONTEND_ADDRESS_ENV); + if (address.empty()) { + YRLOG_ERROR("YR_FRONTEND_ADDRESS env unset"); + return Libruntime::ErrorInfo(Libruntime::ErrorCode::ERR_PARAM_INVALID, "YR_FRONTEND_ADDRESS env unset"); + } + if (address.find(HTTP_PROTOCOL_PREFIX) == 0) { + address = address.substr(HTTP_PROTOCOL_PREFIX.size()); + enableTLS_ = false; + } else if (address.find(HTTPS_PROTOCOL_PREFIX) == 0) { + address = address.substr(HTTPS_PROTOCOL_PREFIX.size()); + enableTLS_ = true; + } else { + enableTLS_ = false; + } + YR::ParseIpAddr(address, ip_, port_); + std::string host; + if (ip_.empty()) { + port_ = enableTLS_ ? HTTPS_DEFAULT_PORT : HTTP_DEFAULT_PORT; + host = address; + ip_ = host; + YRLOG_DEBUG("ip host {}, port: {}", host, port_); + } + urlPrefix_ = YR::GetEnvValue(INVOCATION_URL_PREFIX_ENV); + if (urlPrefix_.empty()) { + YRLOG_ERROR("YR_INVOCATION_URL_PREFIX env unset"); + return Libruntime::ErrorInfo(Libruntime::ErrorCode::ERR_PARAM_INVALID, "YR_INVOCATION_URL_PREFIX env unset"); + } + auto config = std::make_shared(); + const uint32_t threadNum = 10; + config->httpIocThreadsNum = threadNum; + config->maxConnSize = config->httpIocThreadsNum; + config->enableTLS = enableTLS_; + config->verifyFilePath = YR::GetEnvValue(CERTIFICATE_FILE_ENV); + config->serverName = host; + config->httpIdleTime = YR::Libruntime::Config::Instance().YR_HTTP_IDLE_TIME(); + auto [httpClient, err] = clientsMgr_->GetOrNewHttpClient(ip_, port_, config); + if (!err.OK()) { + YRLOG_ERROR("get or new http client failed, code is {}, msg is {}", fmt::underlying(err.Code()), err.Msg()); + return err; + } + FSIntfHandlers handlers; + gwClient_ = std::make_shared(functionId_, handlers, security_); + gwClient_->Init(httpClient); + init_ = true; + return YR::Libruntime::ErrorInfo(); +} + +void ApiClient::InvocationAsync(const std::shared_ptr spec, const InvocationCallback &callback) +{ + // 8d86c63b22e24d9ab650878b75408ea6/test-jiuwen-session-004-bj/$latest + auto functionId = spec->functionMeta.functionId; + std::vector result; + YR::utility::Split(functionId, result, '/'); + size_t tenantIndex = 0; + size_t functionNameIndex = 1; + size_t functionVersionIndex = 2; + if (result.size() <= functionVersionIndex) { + callback(spec->requestId, ErrorCode::ERR_PARAM_INVALID, "function id invalid " + functionId); + return; + } + // 8d86c63b22e24d9ab650878b75408ea6 + auto tenantId = result[tenantIndex]; + // test-jiuwen-session-004-bj + auto functionName = result[functionNameIndex]; + std::vector results; + YR::utility::Split(functionName, results, '@'); + if (!results.empty()) { + functionName = results.back(); + } + // $latest + auto functionVersion = result[functionVersionIndex]; + if (functionVersion == "latest") { + functionVersion = "$latest"; + } + // /serverless/v2/functions/wisefunction:cn:iot:8d86c63b22e24d9ab65:function:0@faas@cpp:latest/invocations + std::string url = urlPrefix_ + URN_SEPARATOR + businessId_ + URN_SEPARATOR + tenantId + + ":function:" + functionName + URN_SEPARATOR + functionVersion + "/invocations"; + gwClient_->InvocationAsync(url, spec, callback); +} + +} // namespace scene +} // namespace YR \ No newline at end of file diff --git a/src/scene/downgrade.h b/src/scene/downgrade.h new file mode 100644 index 0000000..e233046 --- /dev/null +++ b/src/scene/downgrade.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "src/libruntime/clientsmanager/clients_manager.h" +#include "src/libruntime/gwclient/gw_client.h" +#include "src/libruntime/gwclient/http/async_http_client.h" +#include "src/libruntime/gwclient/http/client_manager.h" +#include "src/libruntime/invoke_spec.h" +#include "src/utility/file_watcher.h" +namespace YR { +namespace scene { +const std::string FAAS_FRONTEND_FUNCTION_NAME_PREFIX = "0/0-system-faasfrontend/"; +const std::string DOWNGRADE_FILE_ENV = "YR_DOWNGRADE_FILE"; +// 1、http://IP:PORT; 2、https://IP:PORT; 3、IP:PORT; 4、https://www.xxx.com; 5、http://www.xxx.com +const std::string FRONTEND_ADDRESS_ENV = "YR_FRONTEND_ADDRESS"; +// "/serverless/v2/functions/wisefunction:cn" +const std::string INVOCATION_URL_PREFIX_ENV = "YR_INVOCATION_URL_PREFIX"; +const std::string URN_SEPARATOR = ":"; +const std::string HTTP_PROTOCOL_PREFIX = "http://"; +const std::string HTTPS_PROTOCOL_PREFIX = "https://"; +const std::string CERTIFICATE_FILE_ENV = "YR_CERTIFICATE_FILE"; +const std::string YR_BUSINESS_ID_ENV = "YR_BUSINESS_ID"; +const int32_t HTTPS_DEFAULT_PORT = 443; +const int32_t HTTP_DEFAULT_PORT = 80; +class ApiClient { +public: + ApiClient(const std::string &functionId, std::shared_ptr clientsMgr, + std::shared_ptr security); + + ~ApiClient() = default; + + YR::Libruntime::ErrorInfo Init(); + + void InvocationAsync(const std::shared_ptr spec, + const YR::Libruntime::InvocationCallback &callback); + +private: + bool init_ = false; + std::shared_ptr gwClient_; + std::string functionId_; + std::shared_ptr clientsMgr_; + std::string ip_; + int32_t port_; + bool enableTLS_{false}; + std::string urlPrefix_; + std::shared_ptr security_; + std::string businessId_; +}; + +class DowngradeController { +public: + explicit DowngradeController(const std::string &functionId, + std::shared_ptr clientsMgr, + std::shared_ptr security); + + ~DowngradeController(); + + YR::Libruntime::ErrorInfo Init(); + + YR::Libruntime::ErrorInfo InitApiClient(); + + // content: {"downgrade": true} + void ParseDowngrade(const std::string &fileName); + + bool ShouldDowngrade(const std::shared_ptr spec) const; + + void Downgrade(const std::shared_ptr spec, + const YR::Libruntime::InvocationCallback &callback); + + bool ShouldFaultDowngrade() const; + + static bool IsFrontendFunction(const std::string &function); + + void Stop(); + +private: + bool init_{false}; + std::shared_ptr watcher_; + bool isFrontendFunction_{false}; + bool existIngress_{false}; + bool isDowngradeEnabled_{false}; + std::string functionId_; + std::shared_ptr clientsMgr_; + std::string ip_; + std::string port_; + std::shared_ptr apiClient_; + std::once_flag flag_; + std::shared_ptr security_; + YR::Libruntime::ErrorInfo err_; +}; + +} // namespace scene +} // namespace YR diff --git a/src/utility/file_watcher.cpp b/src/utility/file_watcher.cpp new file mode 100644 index 0000000..976c399 --- /dev/null +++ b/src/utility/file_watcher.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "file_watcher.h" +namespace YR { +namespace utility { + +FileWatcher::FileWatcher(const std::string &filename, Callback callback) + : filename_(filename), callback_(std::move(callback)), running_(false), fd_(-1), wd_(-1) +{ +} + +FileWatcher::~FileWatcher() +{ + Stop(); +} + +void FileWatcher::Start() +{ + if (running_) { + return; + } + + running_ = true; + watcherThread_ = std::thread([this]() { + while (true) { + if (!running_) { + break; + } + Watch(); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + }); +} + +void FileWatcher::Stop() +{ + if (!running_) { + return; + } + running_ = false; + if (wd_ >= 0) { + inotify_rm_watch(fd_, wd_); + wd_ = -1; + } + if (fd_ >= 0) { + close(fd_); + fd_ = -1; + } + if (watcherThread_.joinable()) { + watcherThread_.join(); + } +} + +void FileWatcher::Watch() +{ + fd_ = inotify_init1(IN_NONBLOCK); + if (fd_ < 0) { + YRLOG_DEBUG("inotify_init1 failed {}", strerror(errno)); + return; + } + + wd_ = inotify_add_watch(fd_, filename_.c_str(), WATCH_MASK); + if (wd_ < 0) { + YRLOG_DEBUG("inotify_add_watch failed {}", strerror(errno)); + close(fd_); + fd_ = -1; + return; + } + callback_(filename_); + char buffer[EVENT_BUF_LEN]; + while (running_) { + int length = read(fd_, buffer, EVENT_BUF_LEN); + if (length < 0) { + if (errno == EINTR) { // Signal interrupted + continue; + } + if (errno == EAGAIN) { // No data available + std::this_thread::sleep_for(std::chrono::seconds(1)); + continue; + } + YRLOG_WARN("read error: {}", strerror(errno)); + break; + } + + int i = 0; + while (i < length) { + struct inotify_event *event = static_cast(static_cast(&buffer[i])); + if (event->mask & IN_CLOSE_WRITE) { + callback_(filename_); + } + i += sizeof(struct inotify_event) + event->len; + } + } +} +} // namespace utility +} // namespace YR diff --git a/src/utility/file_watcher.h b/src/utility/file_watcher.h new file mode 100644 index 0000000..9a459a3 --- /dev/null +++ b/src/utility/file_watcher.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "src/utility/logger/logger.h" +namespace YR { +namespace utility { +class FileWatcher { +public: + using Callback = std::function; + + FileWatcher(const std::string &filename, Callback callback); + + ~FileWatcher(); + + void Start(); + + void Stop(); + +private: + static constexpr size_t EVENT_BUF_LEN = 1024 * (sizeof(struct inotify_event) + NAME_MAX + 1); + static constexpr uint32_t WATCH_MASK = IN_CLOSE_WRITE; // monitor write completion + + void Watch(); + + std::string filename_; + Callback callback_; + std::atomic running_{false}; + std::thread watcherThread_; + int fd_{-1}; + int wd_{-1}; +}; +} // namespace utility +} // namespace YR diff --git a/src/utility/logger/common.h b/src/utility/logger/common.h index acf3532..1a36799 100644 --- a/src/utility/logger/common.h +++ b/src/utility/logger/common.h @@ -26,9 +26,10 @@ const int DEFAULT_RETENTION_DAYS = 30; const int64_t SIZE_MEGA_BYTES = 1024 * 1024; // 1 MB const int DEFAULT_LOG_BUF_SECONDS = 10; const std::string LOGGER_NAME = "Logger"; +const std::string REPORT_LOGGER_NAME = "interface"; const std::string DEFAULT_JOB_ID = "job-ffffffff"; const std::string LOG_SUFFIX = ".log"; -const unsigned int DEFAULT_MAX_ASYNC_QUEUE_SIZE = 51200; // 1024*50(every log length) +const unsigned int DEFAULT_MAX_ASYNC_QUEUE_SIZE = 1024; // 1024(every log length) const unsigned int DEFAULT_ASYNC_THREAD_COUNT = 1; struct LogParam { @@ -45,6 +46,8 @@ struct LogParam { unsigned int asyncThreadCount = DEFAULT_ASYNC_THREAD_COUNT; bool alsoLog2Stderr = false; bool isLogMerge = false; + bool withLogPrefix = false; + std::string loggerId; }; } // namespace utility } // namespace YR diff --git a/src/utility/logger/log_handler.cpp b/src/utility/logger/log_handler.cpp index f58201d..748b639 100644 --- a/src/utility/logger/log_handler.cpp +++ b/src/utility/logger/log_handler.cpp @@ -98,8 +98,8 @@ void DoLogFileCompress(const LogParam &logParam) // e.g: xxx-function_agent.1.log -> xxx-function_agent.{TIME}.log -> xxx-function_agent.{TIME}.log.gz std::string basename, ext, idx; - std::tie(basename, ext) = spdlog::details::file_helper::split_by_extension(file); - std::tie(basename, idx) = spdlog::details::file_helper::split_by_extension(basename); + std::tie(basename, ext) = yr_spdlog::details::file_helper::split_by_extension(file); + std::tie(basename, idx) = yr_spdlog::details::file_helper::split_by_extension(basename); std::string targetFile = basename + "." + std::to_string(timestamp) + ext; if (!RenameFile(file, targetFile)) { YRLOG_WARN("failed to rename {} to {}", file, targetFile); diff --git a/src/utility/logger/spd_logger.cpp b/src/utility/logger/spd_logger.cpp index f62eabd..1cb3ef2 100644 --- a/src/utility/logger/spd_logger.cpp +++ b/src/utility/logger/spd_logger.cpp @@ -19,8 +19,8 @@ #include #include #include -#include #include +#include #include "spdlog/async.h" #include "spdlog/sinks/dup_filter_sink.h" @@ -29,16 +29,17 @@ namespace YR { namespace utility { static const uint32_t DUP_FILTER_TIME = 60; const std::string DEFAULT_LOG_NAME = "driver"; +const std::string DEFAULT_LOG_NAME_PREFIX = "libruntime"; static const int LOG_NOT_MERGE_TYPE = 0; static const int LOG_MERGE_TYPE = 1; -spdlog::level::level_enum GetLogLevel(const std::string &level) +yr_spdlog::level::level_enum GetLogLevel(const std::string &level) { - static std::map logLevelMap = { - {"TRACE", spdlog::level::trace}, {"DEBUG", spdlog::level::debug}, {"INFO", spdlog::level::info}, - {"WARN", spdlog::level::warn}, {"ERR", spdlog::level::err}, {"FATAL", spdlog::level::critical}}; + static std::map logLevelMap = { + {"TRACE", yr_spdlog::level::trace}, {"DEBUG", yr_spdlog::level::debug}, {"INFO", yr_spdlog::level::info}, + {"WARN", yr_spdlog::level::warn}, {"ERR", yr_spdlog::level::err}, {"FATAL", yr_spdlog::level::critical}}; auto iter = logLevelMap.find(level); - return iter == logLevelMap.end() ? spdlog::level::info : iter->second; + return iter == logLevelMap.end() ? yr_spdlog::level::info : iter->second; } std::string FormatTimePoint() @@ -71,10 +72,10 @@ std::string SpdLogger::GetModelName(void) const return this->modelName; } -std::pair, std::string> SpdLogger::GetLogger() +std::pair, std::string> SpdLogger::GetLogger() { std::string loggerName = this->getLoggerNameFunc ? getLoggerNameFunc() : LOGGER_NAME; - auto logger = spdlog::get(loggerName); + auto logger = yr_spdlog::get(loggerName); std::string logPrefix = ""; if (logMergeType_.load() == LOG_MERGE_TYPE) { GetLogPrefix(loggerName, logPrefix); @@ -103,7 +104,15 @@ std::string SpdLogger::GetLogFile(const LogParam &logParam) } else if (logParam.logFileWithTime) { logFile += Join({logParam.nodeName, logParam.modelName, FormatTimePoint()}, "-") + LOG_SUFFIX; } else { - logFile += Join({logParam.nodeName, logParam.modelName}, "-") + LOG_SUFFIX; + if (logParam.modelName == DEFAULT_LOG_NAME) { + logFile += Join({logParam.nodeName, logParam.modelName}, "-") + LOG_SUFFIX; + } else { + if (logParam.withLogPrefix) { + logFile += Join({logParam.loggerId, DEFAULT_LOG_NAME_PREFIX}, "_") + LOG_SUFFIX; + } else { + logFile += Join({logParam.nodeName, logParam.modelName}, "-") + LOG_SUFFIX; + } + } } return logFile; @@ -136,7 +145,7 @@ void SpdLogger::CreateLogger(const LogParam &logParam, const std::string &nodeNa InitAsyncThread(logParam); RegisterLogger(logParam, LOGGER_NAME, nodeName, modelName, logFile); } - } catch (const spdlog::spdlog_ex &ex) { + } catch (const yr_spdlog::spdlog_ex &ex) { std::cout << "failed to init logger:" << ex.what() << std::endl << std::flush; } } @@ -145,27 +154,28 @@ void SpdLogger::RegisterLogger(const LogParam &logParam, const std::string &logg const std::string &modelName, const std::string &logFile) { absl::WriterMutexLock lock(&spdLoggerMu_); - auto logger = spdlog::get(loggerName); + auto logger = yr_spdlog::get(loggerName); if (logger) { if (!logParam.isLogMerge) { - spdlog::drop(loggerName); + yr_spdlog::drop(loggerName); } else { return; } } if (sinks.empty()) { - auto rotatingSink = std::make_shared( + auto rotatingSink = std::make_shared( logFile, logParam.maxSize * SIZE_MEGA_BYTES, logParam.maxFiles); - auto dupFilter = std::make_shared(std::chrono::seconds(DUP_FILTER_TIME)); + auto dupFilter = std::make_shared(std::chrono::seconds(DUP_FILTER_TIME)); sinks = {rotatingSink, dupFilter}; if (logParam.alsoLog2Stderr) { - auto consoleSink = std::make_shared(); + auto consoleSink = std::make_shared(); (void)sinks.emplace_back(consoleSink); } } - logger = std::make_shared(loggerName, sinks.begin(), sinks.end(), spdlog::thread_pool(), - spdlog::async_overflow_policy::block); + logger = std::make_shared(loggerName, sinks.begin(), sinks.end(), + yr_spdlog::thread_pool(), + yr_spdlog::async_overflow_policy::block); logLevel = GetLogLevel(logParam.logLevel); logger->set_level(logLevel); @@ -175,21 +185,21 @@ void SpdLogger::RegisterLogger(const LogParam &logParam, const std::string &logg if (logMergeType_.load() == LOG_MERGE_TYPE) { pattern = "%L%m%d %H:%M:%S.%f %t %s:%#] %P,%!]%v"; } - logger->set_pattern(pattern, spdlog::pattern_time_type::utc); // log with international UTC time + logger->set_pattern(pattern, yr_spdlog::pattern_time_type::utc); // log with international UTC time - spdlog::register_logger(logger); + yr_spdlog::register_logger(logger); } void SpdLogger::Flush() { std::string loggerName = this->getLoggerNameFunc ? getLoggerNameFunc() : LOGGER_NAME; - auto logger = spdlog::get(loggerName); + auto logger = yr_spdlog::get(loggerName); if (logger) { logger->flush(); } } -spdlog::level::level_enum SpdLogger::level() +yr_spdlog::level::level_enum SpdLogger::level() { return logLevel; } @@ -217,7 +227,7 @@ void SpdLogger::GetLogPrefix(const std::string &key, std::string &value) void SpdLogger::Clear() { Flush(); - spdlog::drop_all(); + yr_spdlog::drop_all(); sinks.clear(); } @@ -226,12 +236,10 @@ void SpdLogger::InitAsyncThread(const LogParam &logParam) static std::once_flag onceflag; std::call_once(onceflag, [logParam]() { try { - if (!spdlog::thread_pool()) { - spdlog::init_thread_pool(static_cast(logParam.maxAsyncQueueSize), - static_cast(logParam.asyncThreadCount)); - } - spdlog::flush_every(std::chrono::seconds(logParam.logBufSecs)); - } catch (const spdlog::spdlog_ex &ex) { + yr_spdlog::init_thread_pool(static_cast(logParam.maxAsyncQueueSize), + static_cast(logParam.asyncThreadCount)); + yr_spdlog::flush_every(std::chrono::seconds(logParam.logBufSecs)); + } catch (const yr_spdlog::spdlog_ex &ex) { std::cout << "failed to init logger thread pool:" << ex.what() << std::endl << std::flush; } }); diff --git a/src/utility/logger/spd_logger.h b/src/utility/logger/spd_logger.h index 1bc61fa..2132cb1 100644 --- a/src/utility/logger/spd_logger.h +++ b/src/utility/logger/spd_logger.h @@ -15,7 +15,6 @@ */ #pragma once - #include #include #include @@ -33,7 +32,7 @@ namespace YR { namespace utility { extern const std::string DEFAULT_LOG_NAME; -spdlog::level::level_enum GetLogLevel(const std::string &level); +yr_spdlog::level::level_enum GetLogLevel(const std::string &level); using GetLoggerNameFunc = std::function; class SpdLogger : public Singleton { @@ -41,10 +40,10 @@ public: SpdLogger() = default; virtual ~SpdLogger(); - std::pair, std::string> GetLogger(); + std::pair, std::string> GetLogger(); void CreateLogger(const LogParam &logParam, const std::string &nodeName, const std::string &modelName); void Flush(); - spdlog::level::level_enum level(); + yr_spdlog::level::level_enum level(); std::string GetLogDir(void) const; std::string GetNodeName(void) const; std::string GetModelName(void) const; @@ -63,8 +62,8 @@ private: std::string logDir; std::string nodeName; std::string modelName; - spdlog::level::level_enum logLevel; - std::vector sinks; + yr_spdlog::level::level_enum logLevel; + std::vector sinks; std::atomic logMergeType_{-1}; // -1: default value;0: not merge log;1: merge log std::unordered_map logPrefixMap_ ABSL_GUARDED_BY(mu_); absl::Mutex mu_; @@ -92,6 +91,6 @@ private: auto lvl = YR::utility::GetLogLevel(level); \ YRLOG_ASYNC(lvl, logPrefix, logger, format, ##__VA_ARGS__); \ if (strcmp(level, "FATAL") == 0) { \ - (void)raise(SIGINT); \ + (void)raise(SIGINT); \ } \ } while (0) diff --git a/src/utility/memory.cpp b/src/utility/memory.cpp index 90ce0bf..07d589d 100644 --- a/src/utility/memory.cpp +++ b/src/utility/memory.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "memory.h" #include diff --git a/src/utility/memory.h b/src/utility/memory.h index dc58b09..4519bff 100644 --- a/src/utility/memory.h +++ b/src/utility/memory.h @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/src/utility/timer_worker.cpp b/src/utility/timer_worker.cpp index 3eb6245..305bdbc 100644 --- a/src/utility/timer_worker.cpp +++ b/src/utility/timer_worker.cpp @@ -20,13 +20,14 @@ #include #include "logger/logger.h" +const size_t timerIdLength = 16; namespace YR { namespace utility { using namespace boost::placeholders; static std::shared_ptr timerWorker = nullptr; -Timer::Timer(boost::asio::io_service &io, int timeoutMS, std::weak_ptr tw) : weakTW_(tw) +Timer::Timer(boost::asio::io_context &io, int timeoutMS, std::weak_ptr tw) : weakTW_(tw) { id_ = std::to_string(counter_.fetch_add(1)); timer_ = std::make_shared(io, boost::posix_time::milliseconds(timeoutMS)); @@ -39,9 +40,16 @@ std::shared_ptr &Timer::GetTimer() void Timer::cancel() { - timer_->cancel(); if (auto tw = weakTW_.lock(); tw) { - tw->EarseTimer(shared_from_this()); + timer_->cancel(); + std::shared_ptr self; + try { + self = shared_from_this(); + } catch (const std::exception &e) { + YRLOG_ERROR("Timer has been destructed, {}", e.what()); + return; + } + tw->EarseTimer(self); } } @@ -58,7 +66,9 @@ std::string Timer::ID() TimerWorker::TimerWorker() : isRunning(true) { th = std::thread([&] { - work.reset(new boost::asio::io_service::work(io)); + work = std::make_unique>( + boost::asio::make_work_guard(io)); + ; io.run(); }); pthread_setname_np(th.native_handle(), "TimerWorker"); @@ -201,6 +211,9 @@ void TimerWorker::Stop() { absl::WriterMutexLock lock(&this->mu); isRunning = false; + if (work) { + work->reset(); + } if (!io.stopped()) { io.stop(); } diff --git a/src/utility/timer_worker.h b/src/utility/timer_worker.h index 7983b20..9a110af 100644 --- a/src/utility/timer_worker.h +++ b/src/utility/timer_worker.h @@ -18,6 +18,7 @@ #include #include "absl/synchronization/mutex.h" #include + using BoostTimer = boost::asio::deadline_timer; namespace YR { @@ -27,7 +28,7 @@ class TimerWorker; class Timer : public std::enable_shared_from_this { public: - Timer(boost::asio::io_service &io, int timeoutMS, std::weak_ptr tw); + Timer(boost::asio::io_context &io, int timeoutMS, std::weak_ptr tw); ~Timer() = default; std::shared_ptr &GetTimer(); void cancel(); @@ -61,8 +62,8 @@ private: absl::Mutex timerMu; bool isRunning ABSL_GUARDED_BY(mu); std::thread th; - boost::asio::io_service io ABSL_GUARDED_BY(mu); - std::unique_ptr work; + boost::asio::io_context io ABSL_GUARDED_BY(mu); + std::unique_ptr> work; std::unordered_map> timerStore_ ABSL_GUARDED_BY(timerMu); }; diff --git a/test/BUILD.bazel b/test/BUILD.bazel index e9094f1..2a5e505 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -10,6 +10,7 @@ cc_test( "utility/thread_pool_test.cpp", "utility/string_utility_test.cpp", "utility/memory_test.cpp", + "utility/file_watcher_test.cpp", ], deps = [ "@gtest//:gtest_main", @@ -34,40 +35,113 @@ cc_test( linkstatic = True, ) + cc_test( - name = "libruntime_test", + name = "traceadaptor_test", + srcs = [ + "libruntime/trace_adapter_test.cpp", + ], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//:runtime_lib", + ], + dynamic_deps = ["//:grpc_dynamic"], + linkstatic = True, +) + +cc_test( + name = "gwclient_test", + srcs = [ + "libruntime/gw_client_test.cpp", + "libruntime/clients_manager_test.cpp", + "libruntime/httpserver/common_server.h", + "libruntime/httpserver/common_server.cpp", + "libruntime/httpserver/async_http_server.h", + "libruntime/httpserver/async_https_server.h", + ], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//test/libruntime/mock:mock_datasystem", + "//:runtime_lib", + ], + copts = [ + "-fno-access-control", + ], + dynamic_deps = ["//:grpc_dynamic"], + linkstatic = True, +) + +cc_test( + name = "fsintf_test", srcs = [ - "libruntime/invoke_adaptor_test.cpp", - "libruntime/libruntime_manager_test.cpp", - "libruntime/libruntime_test.cpp", - "libruntime/fm_client_test.cpp", "libruntime/fs_intf_impl_test.cpp", + "libruntime/fs_intf_manager_test.cpp", + "libruntime/fs_intf_grpc_rw_test.cpp", + "libruntime/connect_test.cpp", + "libruntime/grpc_utils_test.cpp", + ], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//test/libruntime/mock:mock_datasystem", + "//:runtime_lib", + ], + copts = [ + "-fno-access-control", + ], + dynamic_deps = ["//:grpc_dynamic"], + linkstatic = True, +) + +cc_test( + name = "memorystore_test", + srcs = [ "libruntime/waiting_object_manager_test.cpp", "libruntime/memory_store_test.cpp", "libruntime/kv_state_store_test.cpp", + "libruntime/stream_store_test.cpp", "libruntime/hetero_future_test.cpp", "libruntime/buffer_test.cpp", + "libruntime/async_decre_ref_test.cpp", + "libruntime/object_store_test.cpp", + "libruntime/hetero_store_test.cpp", + ], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//test/libruntime/mock:mock_datasystem", + "//:runtime_lib", + ], + copts = [ + "-fno-access-control", + ], + dynamic_deps = ["//:grpc_dynamic"], + linkstatic = True, +) + +cc_test( + name = "libruntime_test", + srcs = [ + "libruntime/invoke_adaptor_test.cpp", + "libruntime/libruntime_manager_test.cpp", + "libruntime/libruntime_test.cpp", + "libruntime/fm_client_test.cpp", "libruntime/utils_test.cpp", + "libruntime/hash_util_test.cpp", "libruntime/certs_utils_test.cpp", + "libruntime/scheduler_instance_info_test.cpp", "libruntime/security_test.cpp", "libruntime/instance_manager_test.cpp", "libruntime/invoke_spec_test.cpp", "libruntime/execution_manager_test.cpp", - "libruntime/clients_manager_test.cpp", - "libruntime/httpserver/common_server.h", - "libruntime/httpserver/common_server.cpp", - "libruntime/httpserver/async_http_server.h", - "libruntime/httpserver/async_https_server.h", - "libruntime/async_decre_ref_test.cpp", "libruntime/auto_init_test.cpp", "libruntime/fiber_test.cpp", - "libruntime/fs_intf_manager_test.cpp", - "libruntime/fs_intf_grpc_rw_test.cpp", - "libruntime/connect_test.cpp", "libruntime/libruntime_config_test.cpp", "libruntime/resource_group_test.cpp", - "libruntime/object_store_test.cpp", - "libruntime/hetero_store_test.cpp", + "libruntime/request_queue_test.cpp", + "libruntime/http_utils_test.cpp", ], deps = [ "@gtest//:gtest_main", @@ -123,6 +197,7 @@ cc_test( "libruntime/httpserver/async_http_server.h", "libruntime/httpserver/async_https_server.h", "libruntime/http_client_test.cpp", + "libruntime/https_client_test.cpp", ], deps = [ "@gtest//:gtest_main", @@ -130,6 +205,13 @@ cc_test( "//:runtime_lib", ], dynamic_deps = ["//:grpc_dynamic"], + data = [ + "data/cert/ca.crt", + "data/cert/client.crt", + "data/cert/client.key", + "data/cert/server.crt", + "data/cert/server.key" + ], linkstatic = True, ) @@ -137,6 +219,7 @@ cc_test( name = "ins_manager_test", size = "small", srcs = [ + "libruntime/faas_instance_manager_test.cpp", "libruntime/normal_instance_manager_test.cpp", ], dynamic_deps = ["//:grpc_dynamic"], @@ -149,6 +232,37 @@ cc_test( linkstatic = True, ) +cc_test( + name = "alias_routing_test", + size = "small", + srcs = [ + "libruntime/alias_routing_test.cpp", + "libruntime/limiter_consistant_hash_test.cpp", + ], + dynamic_deps = ["//:grpc_dynamic"], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//:runtime_lib", + ], + linkstatic = True, +) + +cc_test( + name = "load_balancer_test", + size = "small", + srcs = [ + "libruntime/load_balancer_test.cpp", + ], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//:runtime_lib", + ], + dynamic_deps = ["//:grpc_dynamic"], + linkstatic = True, +) + cc_test( name = "object_id_pool_test", size = "small", @@ -188,6 +302,7 @@ cc_test( "libruntime/rt_direct_call_test.cpp" ], dynamic_deps = ["//:grpc_dynamic"], + copts = ["-fno-access-control"], deps = [ "@gtest//:gtest_main", "//src/utility:yr_utils", @@ -227,7 +342,10 @@ cc_test( "api/config_manager_test.cpp", "api/exception_test.cpp", "api/invoke_options_test.cpp", + "api/stream_pub_sub_test.cpp", "api/code_manager_test.cpp", + "api/runtime_env_test.cpp", + "api/runtime_env_parse_test.cpp", "common/common.h", "common/mock_libruntime.h", ], @@ -243,6 +361,24 @@ cc_test( linkstatic = True, ) +cc_test( + name = "faas_test", + size = "small", + srcs = [ + "faas/faas_executor_test.cpp", + "faas/function_test.cpp", + "common/mock_libruntime.h", + ], + dynamic_deps = ["//:grpc_dynamic"], + includes = ["."], + deps = [ + "@gtest//:gtest_main", + "//api/cpp:functionsdk_lib", + ], + linkopts = ["-lstdc++fs", "-ldl"], + linkstatic = True, +) + cc_test( name = "parallel_for_test", size = "small", @@ -258,4 +394,68 @@ cc_test( ], linkopts = ["-lstdc++fs", "-ldl"], linkstatic = True, +) + +cc_test( + name = "clibruntime_test", + size = "small", + srcs = [ + "clibruntime/clibruntime_test.cpp", + "common/mock_libruntime.h", + ], + dynamic_deps = ["//:grpc_dynamic"], + includes = ["."], + deps = [ + "@gtest//:gtest_main", + "//api/go/libruntime/cpplibruntime:cpplibruntime_lib", + ], + linkstatic = True, + tags = ["cgo"], +) + +cc_test( + name = "generator_test", + size = "small", + srcs = [ + "libruntime/generator_test.cpp", + ], + dynamic_deps = ["//:grpc_dynamic"], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//test/libruntime/mock:mock_datasystem", + "//:runtime_lib", + ], + linkstatic = True, +) + +cc_test( + name = "driverlog_test", + size = "small", + srcs = [ + "libruntime/driverlog_test.cpp", + ], + dynamic_deps = ["//:grpc_dynamic"], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//test/libruntime/mock:mock_datasystem", + "//:runtime_lib", + ], + linkstatic = True, +) + +cc_test( + name = "scene_test", + srcs = [ + "scene/downgrade_test.cpp", + ], + deps = [ + "@gtest//:gtest_main", + "//src/utility:yr_utils", + "//test/libruntime/mock:mock_datasystem", + "//:runtime_lib", + ], + dynamic_deps = ["//:grpc_dynamic"], + linkstatic = True, ) \ No newline at end of file diff --git a/test/api/api_test.cpp b/test/api/api_test.cpp index 9660805..7f55099 100644 --- a/test/api/api_test.cpp +++ b/test/api/api_test.cpp @@ -20,9 +20,12 @@ #define private public #include "api/cpp/src/config_manager.h" #include "api/cpp/src/internal.h" +#include "api/cpp/src/read_only_buffer.h" +#include "api/cpp/src/stream_pubsub.h" #include "src/dto/buffer.h" #include "yr/api/config.h" #include "yr/api/hetero_manager.h" +#include "yr/api/mutable_buffer.h" #include "yr/api/object_store.h" #include "yr/yr.h" @@ -33,8 +36,11 @@ public: MOCK_METHOD0(Init, void()); MOCK_METHOD0(GetServerVersion, std::string()); MOCK_METHOD2(Put, std::string(std::shared_ptr, const std::unordered_set &)); + MOCK_METHOD2(Put, std::string(std::shared_ptr, const std::unordered_set &)); MOCK_METHOD3(Put, std::string(std::shared_ptr, const std::unordered_set &, const YR::CreateParam &)); + MOCK_METHOD3(Put, std::string(std::shared_ptr, const std::unordered_set &, + const YR::CreateParam &)); MOCK_METHOD3(Put, void(const std::string &, std::shared_ptr, const std::unordered_set &)); MOCK_METHOD3(KVMSetTx, void(const std::vector &, @@ -53,7 +59,14 @@ public: const YR::GetParams ¶ms, int timeout)); MOCK_METHOD2(KVDel, void(const std::string &, const YR::DelParam &)); MOCK_METHOD2(KVDel, std::vector(const std::vector &, const YR::DelParam &)); - MOCK_METHOD1(IncreGlobalReference, void(const std::vector &)); + MOCK_METHOD1(KVExist, std::vector(const std::vector &)); + MOCK_METHOD2(CreateStreamProducer, std::shared_ptr(const std::string &, YR::ProducerConf)); + MOCK_METHOD3(CreateStreamConsumer, + std::shared_ptr(const std::string &, const YR::SubscriptionConfig &, bool)); + MOCK_METHOD1(DeleteStream, void(const std::string &)); + MOCK_METHOD2(QueryGlobalProducersNum, void(const std::string &, uint64_t &)); + MOCK_METHOD2(QueryGlobalConsumersNum, void(const std::string &, uint64_t &)); + MOCK_METHOD2(IncreGlobalReference, void(const std::vector &, bool)); MOCK_METHOD1(DecreGlobalReference, void(const std::vector &)); MOCK_METHOD3(InvokeByName, std::string(const YR::internal::FuncMeta &, std::vector &, const YR::InvokeOptions &)); @@ -67,6 +80,7 @@ public: MOCK_METHOD3(SaveGroupInstanceIds, void(const std::string &, const std::string &, const YR::InvokeOptions &)); MOCK_METHOD3(Cancel, void(const std::vector &, bool, bool)); MOCK_METHOD1(TerminateInstance, void(const std::string &)); + MOCK_METHOD2(TerminateInstanceAsync, std::shared_future(const std::string &instanceId, bool isSync)); MOCK_METHOD0(Exit, void()); MOCK_METHOD0(IsOnCloud, bool()); MOCK_METHOD2(GroupCreate, void(const std::string &, YR::GroupOptions &)); @@ -78,9 +92,12 @@ public: MOCK_METHOD1(SaveState, void(const int &)); MOCK_METHOD1(LoadState, void(const int &)); MOCK_METHOD3(WaitBeforeGet, int64_t(const std::vector &ids, int timeoutMs, bool allowPartial)); - MOCK_METHOD2(Delete, void(const std::vector &objectIds, std::vector &failedObjectIds)); - MOCK_METHOD2(LocalDelete, - void(const std::vector &objectIds, std::vector &failedObjectIds)); + MOCK_METHOD(void, DevDelete, + (const std::vector &objectIds, std::vector &failedObjectIds)); + + MOCK_METHOD(void, DevLocalDelete, + (const std::vector &objectIds, std::vector &failedObjectIds)); + MOCK_METHOD3(DevSubscribe, void(const std::vector &keys, const std::vector &blob2dList, std::vector> &futureVec)); @@ -96,6 +113,10 @@ public: MOCK_METHOD1(GetInstanceRoute, std::string(const std::string &objectId)); MOCK_METHOD2(SaveInstanceRoute, void(const std::string &objectId, const std::string &instanceRoute)); MOCK_METHOD1(TerminateInstanceSync, void(const std::string &instanceId)); + MOCK_METHOD0(Nodes, std::vector()); + MOCK_METHOD1(CreateMutableBuffer, std::shared_ptr(uint64_t size)); + MOCK_METHOD2(GetMutableBuffer, + std::vector>(const std::vector &ids, int timeout)); }; class ApiTest : public ::testing::Test { @@ -244,6 +265,13 @@ TEST_F(ApiTest, ExitTest) ASSERT_NO_THROW(YR::Exit()); } +TEST_F(ApiTest, ExistTest) +{ + std::vector exists = {false, false, true}; + EXPECT_CALL(*this->runtime, KVExist(_)).WillOnce(testing::Return(exists)); + ASSERT_EQ(YR::KV().Exist({"key1", "key2", "key3"}), exists); +} + TEST_F(ApiTest, IsOnCloudTest) { ASSERT_TRUE(YR::IsOnCloud()); @@ -257,6 +285,45 @@ TEST_F(ApiTest, IsLocalModeTest) ASSERT_FALSE(YR::IsLocalMode()); } +TEST_F(ApiTest, LocalStreamTest) +{ + YR::Config conf; + conf.mode = YR::Config::Mode::LOCAL_MODE; + int mockArgc = 5; + char *mockArgv[] = {"--logDir=/tmp/log", "--logLevel=DEBUG", "--grpcAddress=12.34.56.78:1234", "--runtimeId=driver", + "jobId=job123"}; + YR::ConfigManager::Singleton().Init(conf, mockArgc, mockArgv); + YR::ProducerConf pConf; + YR::SubscriptionConfig sConf; + ASSERT_THROW(YR::CreateProducer("streamName", pConf), YR::Exception); + ASSERT_THROW(YR::Subscribe("streamName", sConf, true), YR::Exception); + ASSERT_THROW(YR::DeleteStream("streamName"), YR::Exception); +} + +TEST_F(ApiTest, StreamTest) +{ + YR::Config conf; + conf.mode = YR::Config::Mode::CLUSTER_MODE; + int mockArgc = 5; + char *mockArgv[] = {"--logDir=/tmp/log", "--logLevel=DEBUG", "--grpcAddress=12.34.56.78:1234", "--runtimeId=driver", + "jobId=job123"}; + YR::ConfigManager::Singleton().Init(conf, mockArgc, mockArgv); + YR::ProducerConf pConf; + YR::SubscriptionConfig sConf; + auto streamProducer = std::make_shared(); + std::shared_ptr producer = std::make_shared(streamProducer); + EXPECT_CALL(*this->runtime, CreateStreamProducer(_, _)).WillOnce(testing::Return(producer)); + ASSERT_NO_THROW(YR::CreateProducer("streamName", pConf)); + + auto streamConsumer = std::make_shared(); + std::shared_ptr consumer = std::make_shared(streamConsumer); + EXPECT_CALL(*this->runtime, CreateStreamConsumer(_, _, _)).WillOnce(testing::Return(consumer)); + ASSERT_NO_THROW(YR::Subscribe("streamName", sConf, true)); + + EXPECT_CALL(*this->runtime, DeleteStream(_)).WillOnce(testing::Return()); + ASSERT_NO_THROW(YR::DeleteStream("streamName")); +} + TEST_F(ApiTest, SaveLoadStateThrowTest) { YR::Config conf; @@ -303,19 +370,19 @@ TEST_F(ApiTest, HeteroDeleteTest) std::vector objectIds; std::vector failedObjectIds; YR::internal::RuntimeManager::GetInstance().mode_ = YR::Config::Mode::LOCAL_MODE; - ASSERT_THROW(YR::HeteroManager().Delete(objectIds, failedObjectIds), YR::HeteroException); + ASSERT_THROW(YR::HeteroManager().DevDelete(objectIds, failedObjectIds), YR::HeteroException); YR::internal::RuntimeManager::GetInstance().mode_ = YR::Config::Mode::CLUSTER_MODE; - ASSERT_NO_THROW(YR::HeteroManager().Delete(objectIds, failedObjectIds)); + ASSERT_NO_THROW(YR::HeteroManager().DevDelete(objectIds, failedObjectIds)); } -TEST_F(ApiTest, HeteroLocalDeleteTest) +TEST_F(ApiTest, HeteroDevLocalDeleteTest) { std::vector objIds; std::vector failedObjIds; YR::internal::RuntimeManager::GetInstance().mode_ = YR::Config::Mode::LOCAL_MODE; - ASSERT_THROW(YR::HeteroManager().LocalDelete(objIds, failedObjIds), YR::HeteroException); + ASSERT_THROW(YR::HeteroManager().DevLocalDelete(objIds, failedObjIds), YR::HeteroException); YR::internal::RuntimeManager::GetInstance().mode_ = YR::Config::Mode::CLUSTER_MODE; - ASSERT_NO_THROW(YR::HeteroManager().LocalDelete(objIds, failedObjIds)); + ASSERT_NO_THROW(YR::HeteroManager().DevLocalDelete(objIds, failedObjIds)); } TEST_F(ApiTest, HeteroDevSubscribeTest) @@ -381,3 +448,85 @@ TEST_F(ApiTest, APIGetInstanceTest) auto handler = YR::GetInstance(name, ns, 60); ASSERT_EQ(handler.name, "ins-name"); } + +TEST_F(ApiTest, APIPutAndGetSuccessfully) +{ + std::string objId = "abc"; + EXPECT_CALL(*this->runtime, + Put(testing::An>(), testing::An &>())) + .WillOnce(testing::Return(objId)); + std::string str = "success"; + YR::Buffer buf(str.data(), str.size()); + auto obj = YR::Put(buf); + ASSERT_EQ(obj.objId, objId); + + auto bufPtr = std::make_shared(buf); + YR::internal::RetryInfo returnRetryInfo; + returnRetryInfo.needRetry = true; + std::vector> buffers; + buffers.push_back(bufPtr); + EXPECT_CALL(*this->runtime, Get(_, _, _)) + .WillOnce(testing::Return( + std::pair>>(returnRetryInfo, buffers))); + auto value = YR::Get(obj); + std::string result = std::string(static_cast(value->ImmutableData()), value->GetSize()); + ASSERT_EQ(result, str); +} + +TEST_F(ApiTest, TestInvokeByName) +{ + YR::SetInitialized(true); + std::string str = "success"; + YR::Buffer buf(str.data(), str.size()); + + std::string objId = "abc"; + std::vector invokeArgs; + EXPECT_CALL(*this->runtime, InvokeByName(_, _, _)) + .WillOnce(::testing::Invoke([this, objId, &invokeArgs](auto &&arg1, auto &&arg2, auto &&arg3) { + invokeArgs = std::move(arg2); + return objId; + })); + auto ret = YR::PyFunction("common", "echo") + .SetUrn("sn:cn:yrk:12345678901234561234567890123456:function:0-yr-stpython:$latest") + .Invoke(buf); + ASSERT_EQ(ret.ID(), objId); + // 校验mock函数入参 + std::string result = + std::string(static_cast(invokeArgs[1].yrBuf.ImmutableData()), invokeArgs[1].yrBuf.GetSize()); + ASSERT_EQ(result, str); +} + +TEST_F(ApiTest, TestNodes) +{ + // 定义资源单元数据 + std::unordered_map resourceMap = {{"cpu", 500.0f}, {"memory", 500.0f}}; + std::unordered_map> labelsMap = {{"NODE_ID", {"node1", "node2"}}}; + std::vector nodesVector = {{.id = "node1", .alive = true, .resources = resourceMap, .labels = labelsMap}}; + EXPECT_CALL(*this->runtime, Nodes()).WillOnce(testing::Return(nodesVector)); + + std::vector nodes = YR::Nodes(); + EXPECT_EQ(nodes.size(), 1); + EXPECT_EQ(nodes[0].id, "node1"); + EXPECT_TRUE(nodes[0].alive); + EXPECT_EQ(nodes[0].resources["cpu"], 500.0f); + EXPECT_EQ(nodes[0].resources["memory"], 500.0f); + EXPECT_EQ(nodes[0].labels["NODE_ID"][0], "node1"); +} + +TEST_F(ApiTest, TestCreateMutableBuffer) +{ + ASSERT_NO_THROW(YR::CreateBuffer(60)); +} + +TEST_F(ApiTest, TestGetMutableBuffer) +{ + std::vector> objs; + objs.push_back(YR::ObjectRef("id")); + ASSERT_NO_THROW(YR::Get(objs, 60)); +} + +TEST_F(ApiTest, TestSerializeMutableBuffer) +{ + YR::ObjectRef obj("id"); + ASSERT_NO_THROW(YR::Serialize(obj)); +} \ No newline at end of file diff --git a/test/api/cluster_mode_runtime_test.cpp b/test/api/cluster_mode_runtime_test.cpp index aea0cc5..697b1c4 100644 --- a/test/api/cluster_mode_runtime_test.cpp +++ b/test/api/cluster_mode_runtime_test.cpp @@ -107,6 +107,7 @@ TEST_F(ClusterModeRuntimeTest, InitClusterModeRuntimeTest) TEST_F(ClusterModeRuntimeTest, When_In_Cluster_With_Empty_DatasystemAddr_Should_Throw_Exception) { Config conf; + conf.inCluster = true; conf.mode = Config::Mode::CLUSTER_MODE; conf.serverAddr = "127.0.0.1:1234"; conf.functionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:0-test-test:$latest"; @@ -183,7 +184,9 @@ TEST_F(ClusterModeRuntimeTest, BuildOptionsTest) auto af6 = YR::InstanceRequiredAffinity(YR::LabelNotInOperator("key", {"value"})); auto af7 = YR::ResourceRequiredAntiAffinity(YR::LabelInOperator("key", {"value"})); auto af8 = YR::InstanceRequiredAntiAffinity(YR::LabelNotInOperator("key", {"value"})); - invokeOptions.AddAffinity({af2, af3, af4, af5, af6, af7, af8}); + auto af9 = YR::Affinity(INSTANCE, REQUIRED_ANTI, {YR::LabelNotInOperator("key", {"value"})}, "POD"); + af9.SetAffinityScope(YR::AFFINITYSCOPE_NODE); + invokeOptions.AddAffinity({af2, af3, af4, af5, af6, af7, af8, af9}); YR::InstanceRange instanceRange; YR::RangeOptions rangeOpts; @@ -193,15 +196,23 @@ TEST_F(ClusterModeRuntimeTest, BuildOptionsTest) rangeOpts.timeout = 60; instanceRange.rangeOpts = rangeOpts; invokeOptions.instanceRange = instanceRange; + invokeOptions.preemptedAllowed = true; + invokeOptions.instancePriority = 200; + invokeOptions.scheduleTimeoutMs = 50000; YR::Libruntime::InvokeOptions libInvokeOptions = BuildOptions(std::move(invokeOptions)); auto firstAffinity = libInvokeOptions.scheduleAffinities.front(); EXPECT_FALSE(firstAffinity->GetPreferredAntiOtherLabels()); + auto lastAffinity = libInvokeOptions.scheduleAffinities.back(); + ASSERT_EQ(lastAffinity->GetAffinityScope(), YR::AFFINITYSCOPE_NODE); ASSERT_EQ(libInvokeOptions.instanceRange.min, instanceRange.min); ASSERT_EQ(libInvokeOptions.instanceRange.max, instanceRange.max); ASSERT_EQ(libInvokeOptions.instanceRange.step, 2); ASSERT_EQ(libInvokeOptions.instanceRange.sameLifecycle, instanceRange.sameLifecycle); ASSERT_EQ(libInvokeOptions.instanceRange.rangeOpts.timeout, instanceRange.rangeOpts.timeout); + ASSERT_EQ(libInvokeOptions.preemptedAllowed, true); + ASSERT_EQ(libInvokeOptions.instancePriority, 200); + ASSERT_EQ(libInvokeOptions.scheduleTimeoutMs, 50000); YR::InvokeOptions invokeOptions1; invokeOptions1.requiredPriority = false; @@ -385,9 +396,6 @@ TEST_F(ClusterModeRuntimeTest, TestTerminateInstanceSyncFailed) TEST_F(ClusterModeRuntimeTest, TestPutWithoutObjIDFailed) { - EXPECT_CALL(*lr.get(), CreateDataObject(_, _, _, _, _)) - .WillOnce( - Return(std::make_pair(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "111"), ""))); EXPECT_THROW(rt->Put(std::make_shared(), {}), YR::Exception); } @@ -478,13 +486,13 @@ TEST_F(ClusterModeRuntimeTest, TestGenerateGroupNameSuccessfully) TEST_F(ClusterModeRuntimeTest, TestIncreGlobalReferenceSuccessfully) { - EXPECT_CALL(*lr.get(), IncreaseReference(_)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + EXPECT_CALL(*lr.get(), IncreaseReference(_, Matcher(_))).WillOnce(Return(YR::Libruntime::ErrorInfo())); EXPECT_NO_THROW(rt->IncreGlobalReference({"111"})); } TEST_F(ClusterModeRuntimeTest, TestIncreGlobalReferenceFailed) { - EXPECT_CALL(*lr.get(), IncreaseReference(_)) + EXPECT_CALL(*lr.get(), IncreaseReference(_, Matcher(_))) .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_CONNECTION_FAILED, "aaa"))); EXPECT_THROW(rt->IncreGlobalReference({"111"}), YR::Exception); } @@ -603,6 +611,22 @@ TEST_F(ClusterModeRuntimeTest, TestKVDelSuccessfully) EXPECT_NO_THROW(rt->KVDel("111")); } +TEST_F(ClusterModeRuntimeTest, TestKVExistSuccessfully) +{ + std::vector exists = {false, false, true}; + EXPECT_CALL(*lr.get(), KVExist(_)).WillOnce(Return(std::make_pair(exists, YR::Libruntime::ErrorInfo()))); + ASSERT_EQ(rt->KVExist({"key1", "key2", "key3"}), exists); +} + +TEST_F(ClusterModeRuntimeTest, TestKVExistFailed) +{ + std::vector exists = {false, false, false}; + EXPECT_CALL(*lr.get(), KVExist(_)) + .WillOnce(Return( + std::make_pair(exists, YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "aaa")))); + EXPECT_THROW(rt->KVExist({"key1", "key2", "key3"}), YR::Exception); +} + TEST_F(ClusterModeRuntimeTest, TestKVDelFailed) { EXPECT_CALL(*lr.get(), KVDel(Matcher("111"))) @@ -635,6 +659,81 @@ TEST_F(ClusterModeRuntimeTest, TestKVDelMultiKeysFailed) EXPECT_THROW(rt->KVDel(input), YR::Exception); } +TEST_F(ClusterModeRuntimeTest, TestCreateStreamProducerSuccessfully) +{ + EXPECT_CALL(*lr.get(), CreateStreamProducer(_, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(rt->CreateStreamProducer("streamName", YR::ProducerConf())); +} + +TEST_F(ClusterModeRuntimeTest, TestCreateStreamProducerFailed) +{ + EXPECT_CALL(*lr.get(), CreateStreamProducer(_, _, _)) + .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "aaa"))); + EXPECT_THROW(rt->CreateStreamProducer("streamName", YR::ProducerConf()), YR::Exception); + EXPECT_CALL(*lr.get(), SetTraceId(_)) + .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "aaa"))); + EXPECT_THROW(rt->CreateStreamProducer("streamName", YR::ProducerConf()), YR::Exception); +} + +TEST_F(ClusterModeRuntimeTest, TestCreateStreamConsumerSuccessfully) +{ + EXPECT_CALL(*lr.get(), CreateStreamConsumer(_, _, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(rt->CreateStreamConsumer("streamName", YR::SubscriptionConfig(), true)); +} + +TEST_F(ClusterModeRuntimeTest, TestCreateStreamConsumerFailed) +{ + EXPECT_CALL(*lr.get(), CreateStreamConsumer(_, _, _, _)) + .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "aaa"))); + EXPECT_THROW(rt->CreateStreamConsumer("streamName", YR::SubscriptionConfig(), true), YR::Exception); + EXPECT_CALL(*lr.get(), SetTraceId(_)) + .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "aaa"))); + EXPECT_THROW(rt->CreateStreamConsumer("streamName", YR::SubscriptionConfig(), true), YR::Exception); +} + +TEST_F(ClusterModeRuntimeTest, TestDeleteStreamSuccessfully) +{ + EXPECT_CALL(*lr.get(), DeleteStream(_)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(rt->DeleteStream("streamName")); +} + +TEST_F(ClusterModeRuntimeTest, TestDeleteStreamFailed) +{ + EXPECT_CALL(*lr.get(), DeleteStream(_)) + .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "aaa"))); + EXPECT_THROW(rt->DeleteStream("streamName"), YR::Exception); +} + +TEST_F(ClusterModeRuntimeTest, TestQueryGlobalProducersNumSuccessfully) +{ + uint64_t gProducerNum = 1; + EXPECT_CALL(*lr.get(), QueryGlobalProducersNum(_, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(rt->QueryGlobalProducersNum("streamName", gProducerNum)); +} + +TEST_F(ClusterModeRuntimeTest, TestQueryGlobalProducersNumFailed) +{ + uint64_t gProducerNum = 1; + EXPECT_CALL(*lr.get(), QueryGlobalProducersNum(_, _)) + .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_CONNECTION_FAILED, "aaa"))); + EXPECT_THROW(rt->QueryGlobalProducersNum("streamName", gProducerNum), YR::Exception); +} + +TEST_F(ClusterModeRuntimeTest, TestQueryGlobalConsumersNumSuccessfully) +{ + uint64_t gConsumerNum = 1; + EXPECT_CALL(*lr.get(), QueryGlobalConsumersNum(_, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(rt->QueryGlobalConsumersNum("streamName", gConsumerNum)); +} + +TEST_F(ClusterModeRuntimeTest, TestQueryGlobalConsumersNumFailed) +{ + uint64_t gConsumerNum = 1; + EXPECT_CALL(*lr.get(), QueryGlobalConsumersNum(_, _)) + .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_CONNECTION_FAILED, "aaa"))); + EXPECT_THROW(rt->QueryGlobalConsumersNum("streamName", gConsumerNum), YR::Exception); +} + TEST_F(ClusterModeRuntimeTest, TestGetRealInstanceIdSuccessfully) { EXPECT_CALL(*lr.get(), GetRealInstanceId(_, _)).WillOnce(Return("realInstanceID")); @@ -744,20 +843,20 @@ TEST_F(ClusterModeRuntimeTest, TestDelete) { std::vector objectIds; std::vector failedObjectIds; - ASSERT_NO_THROW(rt->Delete(objectIds, failedObjectIds)); - EXPECT_CALL(*lr.get(), Delete(_, _)) + ASSERT_NO_THROW(rt->DevDelete(objectIds, failedObjectIds)); + EXPECT_CALL(*lr.get(), DevDelete(_, _)) .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "111"))); - ASSERT_THROW(rt->Delete(objectIds, failedObjectIds), YR::HeteroException); + ASSERT_THROW(rt->DevDelete(objectIds, failedObjectIds), YR::HeteroException); } -TEST_F(ClusterModeRuntimeTest, TestLocalDelete) +TEST_F(ClusterModeRuntimeTest, TestDevLocalDelete) { std::vector objectIds; std::vector failedObjectIds; - ASSERT_NO_THROW(rt->LocalDelete(objectIds, failedObjectIds)); - EXPECT_CALL(*lr.get(), LocalDelete(_, _)) + ASSERT_NO_THROW(rt->DevLocalDelete(objectIds, failedObjectIds)); + EXPECT_CALL(*lr.get(), DevLocalDelete(_, _)) .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "111"))); - ASSERT_THROW(rt->LocalDelete(objectIds, failedObjectIds), YR::HeteroException); + ASSERT_THROW(rt->DevLocalDelete(objectIds, failedObjectIds), YR::HeteroException); } TEST_F(ClusterModeRuntimeTest, TestDevSubscribe) @@ -834,6 +933,74 @@ TEST_F(ClusterModeRuntimeTest, TestSaveInstanceRouteSuccessfully) EXPECT_NO_THROW(rt->SaveInstanceRoute("objID", "insRoute")); } +TEST_F(ClusterModeRuntimeTest, TestInvokeByName) +{ + YR::SetInitialized(true); + internal::FuncMeta funcMeta; + funcMeta.funcName = "common"; + funcMeta.moduleName = "echo"; + funcMeta.language = YR::internal::FunctionLanguage::FUNC_LANG_PYTHON; + funcMeta.funcUrn = "sn:cn:yrk:12345678901234561234567890123456:function:0-yr-stpython:$latest"; + InvokeOptions opts; + std::string str = "success"; + YR::Buffer buf(str.data(), str.size()); + std::vector invokeArgs; + YR::internal::InvokeArg invokeArg1{}; + invokeArg1.buf = std::move(YR::internal::Serialize(PY_PLACEHOLDER)); + invokeArg1.isRef = false; + invokeArgs.emplace_back(std::move(invokeArg1)); + YR::internal::InvokeArg invokeArg2{}; + invokeArg2.yrBuf = buf; + invokeArg2.isRef = false; + invokeArgs.emplace_back(std::move(invokeArg2)); + + std::vector libArgs; + std::vector returnObjs{{"abc"}}; + EXPECT_CALL(*lr.get(), InvokeByFunctionName(_, _, _, _)) + .WillOnce(DoAll(SetArgReferee<3>(returnObjs), + ::testing::Invoke([this, &libArgs](auto &&arg1, auto &&arg2, auto &&arg3, auto &&arg4) { + libArgs = std::move(arg2); + return YR::Libruntime::ErrorInfo(); + }))); + auto ret = rt->InvokeByName(funcMeta, invokeArgs, opts); + ASSERT_EQ(ret, "abc"); + // 校验mock函数入参 + auto meta = libArgs[1].dataObj->meta; + auto data = libArgs[1].dataObj->data; + auto metaBuf = std::make_shared(const_cast(meta->ImmutableData()), meta->GetSize()); + auto metaResult = YR::internal::Deserialize(metaBuf); + ASSERT_EQ(metaResult, 3); + std::string dataResult = std::string(static_cast(data->ImmutableData()), data->GetSize()); + ASSERT_EQ(dataResult, str); +} + +TEST_F(ClusterModeRuntimeTest, APIPutNullptrFailed) +{ + ASSERT_THROW(rt->Put(std::make_shared(nullptr, 0), {}, {}), YR::Exception); +} + +TEST_F(ClusterModeRuntimeTest, TestCreateBuffer) +{ + ASSERT_NO_THROW(rt->CreateMutableBuffer(60)); + EXPECT_CALL(*lr.get(), CreateBuffer(_, _)) + .WillOnce(Return(std::make_pair(YR::Libruntime::ErrorInfo(YR::Libruntime::ERR_CLIENT_ALREADY_CLOSED, + Libruntime::ModuleCode::RUNTIME, "111"), + "nullptr"))); + EXPECT_THROW(rt->CreateMutableBuffer(60), YR::Exception); +} + +TEST_F(ClusterModeRuntimeTest, TestGetMutableBuffer) +{ + std::vector> buffers{}; + std::vector ids; + ASSERT_NO_THROW(rt->GetMutableBuffer(ids, 60)); + EXPECT_CALL(*lr.get(), GetBuffers(_, _, _)) + .WillOnce(Return(std::make_pair(YR::Libruntime::ErrorInfo(YR::Libruntime::ERR_CLIENT_ALREADY_CLOSED, + Libruntime::ModuleCode::RUNTIME, "111"), + buffers))); + ASSERT_THROW(rt->GetMutableBuffer(ids, 60), YR::Exception); +} + class A { public: int a; @@ -1088,7 +1255,7 @@ TEST_F(ClusterModeTest, TestHybridLocalPassCluster) Exception); } -TEST_F(ClusterModeTest, DISABLED_TestHybridLocalPassMix) +TEST_F(ClusterModeTest, TestHybridLocalPassMix) { YR::InvokeOptions opt; opt.alwaysLocalMode = true; @@ -1178,5 +1345,57 @@ TEST_F(ClusterModeTest, TestHybridClusterWaitGetMix) }, Exception); } + +TEST_F(ClusterModeTest, TestGetNodes) +{ + // 定义资源单元数据 + std::unordered_map capacityMap = {{"cpu", 500.0f}, {"memory", 500.0f}}; + std::unordered_map> nodeLabelsMap = {{"NODE_ID", {"node1", "node2"}}}; + std::vector resourceUnitVector = {{ + .id = "node1", + .capacity = capacityMap, + .allocatable = capacityMap, + .nodeLabels = nodeLabelsMap, + .status = 0, + }}; + EXPECT_CALL(*lr.get(), GetResources()) + .WillOnce(Return(std::make_pair(YR::Libruntime::ErrorInfo(), resourceUnitVector))); + + std::vector nodes = YR::Nodes(); + EXPECT_EQ(nodes.size(), 1); + EXPECT_EQ(nodes[0].id, "node1"); + EXPECT_TRUE(nodes[0].alive); + EXPECT_EQ(nodes[0].resources["cpu"], 500.0f); + EXPECT_EQ(nodes[0].resources["memory"], 500.0f); + EXPECT_EQ(nodes[0].labels["NODE_ID"][0], "node1"); +} + +TEST_F(ClusterModeRuntimeTest, TestTerminateInstanceAsyncSuccessfully) +{ + EXPECT_CALL(*lr.get(), KillAsync(_, _, _)) + .WillOnce(Invoke( + [](const std::string &instanceId, int sigNo, std::function cb) { + cb(YR::Libruntime::ErrorInfo()); + })); + auto f = rt->TerminateInstanceAsync("111", false); + auto status = f.wait_for(std::chrono::milliseconds(100)); + EXPECT_EQ(status, std::future_status::ready); + EXPECT_NO_THROW(f.get()); +} + +TEST_F(ClusterModeRuntimeTest, TestTerminateInstanceAsyncFailed) +{ + EXPECT_CALL(*lr.get(), KillAsync(_, _, _)) + .WillOnce(Invoke( + [](const std::string &instanceId, int sigNo, std::function cb) { + cb(YR::Libruntime::ErrorInfo(YR::Libruntime::ERR_INSTANCE_NOT_FOUND, Libruntime::ModuleCode::RUNTIME, + "111")); + })); + auto f = rt->TerminateInstanceAsync("111", false); + auto status = f.wait_for(std::chrono::milliseconds(100)); + EXPECT_EQ(status, std::future_status::ready); + EXPECT_THROW(f.get(), YR::Exception); +} + } // namespace test } // namespace YR diff --git a/test/api/config_manager_test.cpp b/test/api/config_manager_test.cpp index 2012d22..a817199 100644 --- a/test/api/config_manager_test.cpp +++ b/test/api/config_manager_test.cpp @@ -184,6 +184,21 @@ TEST_F(ConfigManagerTest, ConfigManagerInitTest7) EXPECT_EQ(ConfigManager::Singleton().threadPoolSize, wantSize) << "Test failed"; } +TEST_F(ConfigManagerTest, GetValidLocalThreadPoolSizeTest) +{ + Config conf = GetMockConf(); + conf.localThreadPoolSize = 1000; + int mockArgc = 1; + char *mockArgv[] = {"--logDir=/tmp/log"}; + int wantSize = static_cast(std::thread::hardware_concurrency()); + try { + ConfigManager::Singleton().Init(conf, mockArgc, mockArgv); + } catch (const std::invalid_argument &e) { + EXPECT_EQ(1, 0) << "Test failed"; + } + EXPECT_EQ(ConfigManager::Singleton().localThreadPoolSize, wantSize) << "Test failed"; +} + TEST_F(ConfigManagerTest, ConfigManagerInitTest8) { Config conf; @@ -234,6 +249,46 @@ TEST_F(ConfigManagerTest, GetValidMaxLogSizeMbTest) ASSERT_EQ(conf.maxLogFileNum, ConfigManager::Singleton().maxLogFileNum); } +TEST_F(ConfigManagerTest, ConfigManagerInitEnableMTLSTest) +{ + Config conf = GetMockConf(); + conf.enableMTLS = true; + conf.privateKeyPath = "ddd/module.key"; + conf.certificateFilePath = "ddd/module.crt"; + conf.verifyFilePath = "ddd/ca.crt"; + std::strcpy(conf.privateKeyPaaswd, "paaswd"); + conf.encryptPrivateKeyPasswd = "abcd"; + int mockArgc = 1; + char *mockArgv[] = {"--logDir=/tmp/log"}; + ConfigManager::Singleton().Init(conf, mockArgc, mockArgv); + ASSERT_EQ(conf.privateKeyPath, ConfigManager::Singleton().privateKeyPath); + ASSERT_EQ(conf.certificateFilePath, ConfigManager::Singleton().certificateFilePath); + ASSERT_EQ(conf.verifyFilePath, ConfigManager::Singleton().verifyFilePath); + ASSERT_EQ(std::string(conf.privateKeyPaaswd), std::string(ConfigManager::Singleton().privateKeyPaaswd)); + ASSERT_EQ(conf.encryptPrivateKeyPasswd, ConfigManager::Singleton().encryptPrivateKeyPasswd); +} + +TEST_F(ConfigManagerTest, ConfigManagerInitEnableDsEncryptTest) +{ + Config conf = GetMockConf(); + conf.enableDsEncrypt = true; + conf.dsPublicKeyContext = "aaa"; + conf.runtimePublicKeyContext = "bbb"; + conf.runtimePrivateKeyContext = "ccc"; + conf.encryptDsPublicKeyContext = "ddd"; + conf.encryptRuntimePublicKeyContext = "eee"; + conf.encryptRuntimePrivateKeyContext = "fff"; + int mockArgc = 1; + char *mockArgv[] = {"--logDir=/tmp/log"}; + ConfigManager::Singleton().Init(conf, mockArgc, mockArgv); + ASSERT_EQ(conf.dsPublicKeyContext, ConfigManager::Singleton().dsPublicKeyContext); + ASSERT_EQ(conf.runtimePublicKeyContext, ConfigManager::Singleton().runtimePublicKeyContext); + ASSERT_EQ(conf.runtimePrivateKeyContext, ConfigManager::Singleton().runtimePrivateKeyContext); + ASSERT_EQ(conf.encryptDsPublicKeyContext, ConfigManager::Singleton().encryptDsPublicKeyContext); + ASSERT_EQ(conf.encryptRuntimePublicKeyContext, ConfigManager::Singleton().encryptRuntimePublicKeyContext); + ASSERT_EQ(conf.encryptRuntimePrivateKeyContext, ConfigManager::Singleton().encryptRuntimePrivateKeyContext); +} + TEST_F(ConfigManagerTest, GetValidLogCompressTest) { Config conf = GetMockConf(); diff --git a/test/api/function_manager_test.cpp b/test/api/function_manager_test.cpp index 05ea57d..272a136 100644 --- a/test/api/function_manager_test.cpp +++ b/test/api/function_manager_test.cpp @@ -17,8 +17,8 @@ #include #include -#include "yr/api/function_manager.h" #include "yr/yr.h" +#include "yr/api/function_manager.h" namespace YR { namespace test { @@ -27,7 +27,8 @@ class FunctionManagerTest : public testing::Test { public: FunctionManagerTest(){}; ~FunctionManagerTest(){}; - void SetUp() override {} + void SetUp() override + {} }; class Counter { @@ -38,18 +39,7 @@ public: count = init; } - int A(int x) - { - return x; - } - - int B(int x) - { - return x; - } - - void Shutdown(uint64_t gracePeriodSecond) - { + void Shutdown(uint64_t gracePeriodSecond) { return; } @@ -58,20 +48,14 @@ public: YR_STATE(key, count); }; -int C(int x) -{ - return x; -} -TEST_F(FunctionManagerTest, RegisterShutdownFunctionsTest) -{ +TEST_F(FunctionManagerTest, RegisterShutdownFunctionsTest) { YR_SHUTDOWN(&Counter::Shutdown); auto func = internal::FunctionManager::Singleton().GetShutdownFunction("Counter"); EXPECT_TRUE(func.has_value()); } -TEST_F(FunctionManagerTest, CheckpointRecoverTest) -{ +TEST_F(FunctionManagerTest, CheckpointRecoverTest) { auto clsPtr = new Counter(); clsPtr->key = "1234"; auto instancePtr = YR::internal::Serialize((uint64_t)clsPtr); diff --git a/test/api/local_mode_test.cpp b/test/api/local_mode_test.cpp index d04c9eb..6432ad0 100644 --- a/test/api/local_mode_test.cpp +++ b/test/api/local_mode_test.cpp @@ -18,6 +18,7 @@ #include "yr/api/err_type.h" #include "yr/parallel/parallel_for.h" #include "yr/yr.h" +#include "api/cpp/src/config_manager.h" namespace YR { namespace test { using testing::HasSubstr; @@ -350,6 +351,23 @@ TEST_F(LocalTest, MSetTxTest) // case 5 std::vector keys2, vals2; EXPECT_THROW(YR::KV().MSetTx(keys2, vals2, YR::ExistenceOpt::NX), Exception); + + // case 6 + std::vector keys3, vals3; + int exceedNum = 10; + for (int i = 0; i < exceedNum; i++) { + std::string key = "Key" + std::to_string(i); + std::string value = "Value" + std::to_string(i); + keys3.emplace_back(key); + vals3.emplace_back(value); + } + EXPECT_THROW(YR::KV().MSetTx(keys3, vals3, YR::ExistenceOpt::NX), Exception); + + // case 7 + std::vector keys4, vals4; + keys4.emplace_back(""); + vals4.emplace_back("emptyKey"); + EXPECT_THROW(YR::KV().MSetTx(keys4, vals4, YR::ExistenceOpt::NX), Exception); } TEST_F(LocalTest, Test_When_Actor_Currency_Call_ParallelFor_Should_Not_Be_Stuck) @@ -406,6 +424,33 @@ TEST_F(LocalTest, cpp_local_kv_read_error_keys_allowpatital_true) YR::KV().Del(keys); } +TEST_F(LocalTest, cpp_local_kv_exist) +{ + std::string key; + std::string value; + std::vector keys; + for (int i = 0; i < 8; ++i) { + if (i % 2 == 0) { + key = "cpp_local_kv_exist" + std::to_string(i); + value = "value" + std::to_string(i); + YR::KV().Set(key, value); + keys.push_back(key); + } else { + keys.push_back("noValueKey" + std::to_string(i)); + } + + } + auto exists = YR::KV().Exist(keys); + for (int i = 0; i < 8; ++i) { + std::cout << keys[i] << "-> kv exist is: " << exists[i] << std::endl; + if (i % 2 == 0) { + EXPECT_EQ(exists[i], true); + } else { + EXPECT_EQ(exists[i], false); + } + } +} + int func_throw() { throw std::runtime_error("runtime error"); @@ -530,5 +575,17 @@ TEST_F(LocalTest, StopLocalModeRuntime) runtime.Init(); ASSERT_NO_THROW(runtime.Stop()); } + +TEST_F(LocalTest, LocalModeThreadPoolSize) +{ + YR::Finalize(); + YR::Config conf; + conf.mode = YR::Config::Mode::LOCAL_MODE; + conf.logLevel = "DEBUG"; + conf.logDir = "/tmp/log"; + conf.localThreadPoolSize = 65; + YR::Init(conf); + ASSERT_EQ(ConfigManager::Singleton().localThreadPoolSize, conf.localThreadPoolSize); +} } // namespace test } // namespace YR \ No newline at end of file diff --git a/test/api/object_ref_test.cpp b/test/api/object_ref_test.cpp index 2f41e90..ec02114 100644 --- a/test/api/object_ref_test.cpp +++ b/test/api/object_ref_test.cpp @@ -44,6 +44,10 @@ public: void ExitAsync(const ExitRequest &req, ExitCallBack callback) override {}; void StateSaveAsync(const StateSaveRequest &req, StateSaveCallBack callback) override {}; void StateLoadAsync(const StateLoadRequest &req, StateLoadCallBack callback) override {}; + bool IsHealth() override + { + return true; + }; }; class ObjectRefTest : public testing::Test { @@ -74,6 +78,7 @@ public: TEST_F(ObjectRefTest, PutGetTest) { YR::Config conf; + conf.inCluster = true; conf.functionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:0-x-x:$latest"; conf.serverAddr = "10.1.1.1:12345"; conf.dataSystemAddr = "10.1.1.1:12346"; diff --git a/test/api/runtime_env_parse_test.cpp b/test/api/runtime_env_parse_test.cpp new file mode 100644 index 0000000..0abaf8f --- /dev/null +++ b/test/api/runtime_env_parse_test.cpp @@ -0,0 +1,227 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "api/cpp/src/runtime_env_parse.h" +#include +#include +#include +#include "src/dto/invoke_options.h" + +namespace fs = std::filesystem; +namespace YR { +namespace test { +class RuntimeEnvParseTest : public ::testing::Test { +protected: + void SetUp() override + { + // 创建临时测试用的YAML文件 + validYamlPath_ = fs::temp_directory_path() / "valid_env.yaml"; + std::ofstream yamlFile(validYamlPath_); + yamlFile << "name: test_env\ndependencies:\n - python=3.8\n - pip\n - pip:\n - numpy\n"; + yamlFile.close(); + + // 创建临时测试用的YAML文件 + validNoNameYamlPath_ = fs::temp_directory_path() / "no_name_valid_env.yaml"; + std::ofstream noNameYamlFile(validNoNameYamlPath_); + noNameYamlFile << "dependencies:\n - python=3.8\n - pip\n - pip:\n - numpy\n"; + noNameYamlFile.close(); + + // 设置必要的环境变量 + setenv("YR_CONDA_HOME", "/fake/conda/path", 1); + } + + void TearDown() override + { + // 清理临时文件 + if (fs::exists(validYamlPath_)) { + fs::remove(validYamlPath_); + } + + // 清理临时文件 + if (fs::exists(validNoNameYamlPath_)) { + fs::remove(validNoNameYamlPath_); + } + } + + fs::path validYamlPath_; + fs::path validNoNameYamlPath_; +}; + +TEST_F(RuntimeEnvParseTest, ShouldThrowWhenCondaHomeNotSet) +{ + unsetenv("YR_CONDA_HOME"); + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set("conda", {{"name", "test"}}); + + EXPECT_THROW(YR::ParseRuntimeEnv(options, env), YR::Exception); +} + +TEST_F(RuntimeEnvParseTest, ShouldProcessPipPackagesCorrectly) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set>("pip", {"numpy", "pandas"}); + + YR::ParseRuntimeEnv(options, env); + EXPECT_EQ(options.createOptions["POST_START_EXEC"], "pip3 install numpy pandas"); +} + +TEST_F(RuntimeEnvParseTest, ShouldRejectBothPipAndConda) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set>("pip", {"numpy"}); + env.Set("conda", "env_name"); + + EXPECT_THROW(YR::ParseRuntimeEnv(options, env), YR::Exception); +} + +TEST_F(RuntimeEnvParseTest, ShouldHandleWorkingDirCorrectly) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set("working_dir", "/tmp/test_dir"); + + YR::ParseRuntimeEnv(options, env); + EXPECT_EQ(options.workingDir, "/tmp/test_dir"); +} + +TEST_F(RuntimeEnvParseTest, ShouldMergeEnvVarsProperly) +{ + YR::Libruntime::InvokeOptions options; + options.envVars["EXISTING"] = "original"; + + YR::RuntimeEnv env; + env.Set>("env_vars", {{"NEW", "value"}, {"EXISTING", "new"}}); + + YR::ParseRuntimeEnv(options, env); + EXPECT_EQ(options.envVars["NEW"], "value"); + EXPECT_EQ(options.envVars["EXISTING"], "original"); // 应保留原有值 +} + +TEST_F(RuntimeEnvParseTest, ShouldProcessCondaYamlFile) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set("conda", validYamlPath_.string()); + + YR::ParseRuntimeEnv(options, env); + EXPECT_EQ(options.createOptions["CONDA_PREFIX"], "/fake/conda/path"); + EXPECT_EQ(options.createOptions["CONDA_COMMAND"], "conda env create -f env.yaml"); + EXPECT_EQ(options.createOptions["CONDA_CONFIG"], "{\"name\": \"test_env\", \"dependencies\": [\"python=3.8\", \"pip\", {\"pip\": [\"numpy\"]}]}"); +} + +TEST_F(RuntimeEnvParseTest, ShouldProcessCondaNoNameYamlFile) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set("conda", validNoNameYamlPath_.string()); + + YR::ParseRuntimeEnv(options, env); + EXPECT_EQ(options.createOptions["CONDA_PREFIX"], "/fake/conda/path"); + EXPECT_EQ(options.createOptions["CONDA_COMMAND"], "conda env create -f env.yaml"); + EXPECT_FALSE(options.createOptions["CONDA_DEFAULT_ENV"].empty()); +} + +TEST_F(RuntimeEnvParseTest, ShouldProcessCondaEnvName) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set("conda", "existing_env"); + + YR::ParseRuntimeEnv(options, env); + EXPECT_EQ(options.createOptions["CONDA_COMMAND"], "conda activate existing_env"); + EXPECT_EQ(options.createOptions["CONDA_DEFAULT_ENV"], "existing_env"); +} + +TEST_F(RuntimeEnvParseTest, ShouldProcessCondaJsonConfig) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + nlohmann::json config = {{"name", "json_env"}, {"dependencies", {"python=3.8", "numpy"}}}; + env.Set("conda", config); + + YR::ParseRuntimeEnv(options, env); + EXPECT_EQ(options.createOptions["CONDA_COMMAND"], "conda env create -f env.yaml"); + EXPECT_TRUE(options.createOptions["CONDA_DEFAULT_ENV"].find("json_env") != std::string::npos); +} + +TEST_F(RuntimeEnvParseTest, ShouldRejectInvalidYamlFile) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set("conda", "/nonexistent/path.yaml"); + + EXPECT_THROW(YR::ParseRuntimeEnv(options, env), YR::Exception); +} + +TEST_F(RuntimeEnvParseTest, ShouldRejectInvalidPipType) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + env.Set("pip", "this_should_be_array"); + + EXPECT_THROW(YR::ParseRuntimeEnv(options, env), YR::Exception); +} + +TEST_F(RuntimeEnvParseTest, ShouldGenerateRandomNameForEmptyCondaName) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + nlohmann::json config = {{"name", ""}, {"dependencies", {"python"}}}; + env.Set("conda", config); + + YR::ParseRuntimeEnv(options, env); + EXPECT_FALSE(options.createOptions["CONDA_DEFAULT_ENV"].empty()); +} + +TEST_F(RuntimeEnvParseTest, ShouldAddSharedDirInCreateOpt) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + nlohmann::json config = {{"name", "abc"}, {"TTL", 5}}; + env.Set("shared_dir", config); + + YR::ParseRuntimeEnv(options, env); + EXPECT_EQ(options.createOptions["DELEGATE_SHARED_DIRECTORY"], "abc"); + EXPECT_EQ(options.createOptions["DELEGATE_SHARED_DIRECTORY_TTL"], "5"); + + YR::Libruntime::InvokeOptions options2; + YR::RuntimeEnv env2; + nlohmann::json config2 = {{"name", "abc"}}; + env2.Set("shared_dir", config2); + + YR::ParseRuntimeEnv(options2, env2); + EXPECT_EQ(options2.createOptions["DELEGATE_SHARED_DIRECTORY"], "abc"); + EXPECT_EQ(options2.createOptions["DELEGATE_SHARED_DIRECTORY_TTL"], "0"); +} + +TEST_F(RuntimeEnvParseTest, ShouldThrowExceptionWhenSharedDirConfigInvaild) +{ + YR::Libruntime::InvokeOptions options; + YR::RuntimeEnv env; + nlohmann::json config = {{"name", ""}}; + env.Set("shared_dir", config); + EXPECT_THROW(YR::ParseRuntimeEnv(options, env), YR::Exception); + + YR::Libruntime::InvokeOptions options2; + YR::RuntimeEnv env2; + env2.Set("shared_dir", "str"); + EXPECT_THROW(YR::ParseRuntimeEnv(options2, env2), YR::Exception); +} +} +} \ No newline at end of file diff --git a/test/api/runtime_env_test.cpp b/test/api/runtime_env_test.cpp new file mode 100644 index 0000000..fd7bee9 --- /dev/null +++ b/test/api/runtime_env_test.cpp @@ -0,0 +1,169 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "yr/api/runtime_env.h" +#include +#include +#include "yr/api/constant.h" + +namespace YR { +namespace internal { + +class RuntimeEnvTest : public ::testing::Test { +protected: + void SetUp() override { + env.Set("test_int", 100); + env.Set("test_str", "测试字符串"); + env.SetJsonStr("test_json", R"({"key":"value","num":123})"); + } + + RuntimeEnv env; +}; + +// 测试Set/Get基本数据类型 +TEST_F(RuntimeEnvTest, ShouldHandlePrimitiveTypes) { + env.Set("int_val", 42); + EXPECT_EQ(env.Get("int_val"), 42); + + env.Set("double_val", 3.14); + EXPECT_DOUBLE_EQ(env.Get("double_val"), 3.14); + + env.Set("bool_val", true); + EXPECT_TRUE(env.Get("bool_val")); +} + +// 测试Set/Get字符串 +TEST_F(RuntimeEnvTest, ShouldHandleString) { + env.Set("str_val", "测试字符串"); + EXPECT_EQ(env.Get("str_val"), "测试字符串"); +} + +// 测试Set/Get容器类型 +TEST_F(RuntimeEnvTest, ShouldHandleContainers) { + // 测试vector + std::vector vec{1, 2, 3}; + env.Set>("vec_val", vec); + auto vec_ret = env.Get>("vec_val"); + EXPECT_EQ(vec_ret, vec); + + // 测试map + std::map map_val{{"a", 1}, {"b", 2}}; + env.Set>("map_val", map_val); + auto map_ret = env.Get>("map_val"); + EXPECT_EQ(map_ret, map_val); +} + +// 测试异常情况 - Get不存在的字段 +TEST_F(RuntimeEnvTest, ShouldThrowWhenFieldNotExist) { + EXPECT_THROW(env.Get("non_exist"), YR::Exception); +} + +// 测试异常情况 - 类型不匹配 +TEST_F(RuntimeEnvTest, ShouldThrowWhenTypeMismatch) { + env.Set("int_field", 100); + + // 尝试用错误类型获取 + EXPECT_THROW(env.Get("int_field"), YR::Exception); + + // 测试错误消息包含类型信息 + try { + env.Get("int_field"); + } catch (const YR::Exception& e) { + EXPECT_NE(std::string(e.what()).find("Failed to get the field"), std::string::npos); + } +} + +// 测试边界值 +TEST_F(RuntimeEnvTest, ShouldHandleEmptyValues) { + env.Set("empty_str", ""); + EXPECT_TRUE(env.Get("empty_str").empty()); + + std::vector empty_vec; + env.Set>("empty_vec", empty_vec); + EXPECT_TRUE(env.Get>("empty_vec").empty()); +} + +// 测试特殊字符字段名 +TEST_F(RuntimeEnvTest, ShouldHandleSpecialCharNames) { + std::string special_name = "field@name#123"; + env.Set(special_name, 100); + EXPECT_EQ(env.Get(special_name), 100); +} + +// 测试基础Set/Get功能 +TEST_F(RuntimeEnvTest, ShouldCorrectlySetAndGetValues) { + EXPECT_EQ(env.Get("test_int"), 100); + EXPECT_EQ(env.Get("test_str"), "测试字符串"); + + auto jsonStr = env.GetJsonStr("test_json"); + nlohmann::json j = nlohmann::json::parse(jsonStr); + EXPECT_EQ(j["key"], "value"); + EXPECT_EQ(j["num"], 123); +} + +// 测试异常场景 +TEST_F(RuntimeEnvTest, ShouldThrowWhenGettingNonexistentField) { + EXPECT_THROW(env.Get("non_exist"), YR::Exception); + EXPECT_THROW(env.GetJsonStr("non_exist"), YR::Exception); +} + +// 测试JSON字符串处理 +TEST_F(RuntimeEnvTest, ShouldHandleJsonStringsProperly) { + // 测试无效JSON + EXPECT_THROW(env.SetJsonStr("bad_json", "{invalid}"), YR::Exception); + + // 测试嵌套JSON + env.SetJsonStr("nested_json", R"({"nested":{"key":"value"}})"); + auto jsonStr = env.GetJsonStr("nested_json"); + nlohmann::json j = nlohmann::json::parse(jsonStr); + EXPECT_EQ(j["nested"]["key"], "value"); +} + +// 测试Contains方法 +TEST_F(RuntimeEnvTest, ShouldCorrectlyCheckFieldExistence) { + EXPECT_TRUE(env.Contains("test_int")); + EXPECT_FALSE(env.Contains("non_exist")); +} + +// 测试Remove方法 +TEST_F(RuntimeEnvTest, ShouldRemoveFieldsCorrectly) { + EXPECT_TRUE(env.Remove("test_int")); + EXPECT_FALSE(env.Contains("test_int")); + EXPECT_FALSE(env.Remove("non_exist")); // 删除不存在的字段应返回false +} + +// 测试Empty方法 +TEST_F(RuntimeEnvTest, ShouldCheckEmptyState) { + RuntimeEnv emptyEnv; + EXPECT_TRUE(emptyEnv.Empty()); + + emptyEnv.Set("temp", 1); + EXPECT_FALSE(emptyEnv.Empty()); + + emptyEnv.Remove("temp"); + EXPECT_TRUE(emptyEnv.Empty()); +} + +// 测试模板特化 +TEST_F(RuntimeEnvTest, ShouldSupportVariousDataTypes) { + env.Set("double_val", 3.14); + env.Set("bool_val", true); + + EXPECT_DOUBLE_EQ(env.Get("double_val"), 3.14); + EXPECT_TRUE(env.Get("bool_val")); +} +} +} \ No newline at end of file diff --git a/test/api/stream_pub_sub_test.cpp b/test/api/stream_pub_sub_test.cpp new file mode 100644 index 0000000..5c1a557 --- /dev/null +++ b/test/api/stream_pub_sub_test.cpp @@ -0,0 +1,242 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "api/cpp/include/yr/api/err_type.h" +#include "api/cpp/include/yr/api/exception.h" +#include "api/cpp/include/yr/api/stream.h" +#include "api/cpp/src/stream_pubsub.h" + +namespace YR { +namespace Libruntime { +class MockStreamProducer : public YR::Libruntime::StreamProducer { +public: + MOCK_METHOD1(Send, YR::Libruntime::ErrorInfo(const YR::Libruntime::Element &element)); + MOCK_METHOD2(Send, YR::Libruntime::ErrorInfo(const YR::Libruntime::Element &element, int64_t timeoutMs)); + MOCK_METHOD0(Flush, YR::Libruntime::ErrorInfo()); + MOCK_METHOD0(Close, YR::Libruntime::ErrorInfo()); +}; + +class MockStreamConsumer : public YR::Libruntime::StreamConsumer { +public: + MOCK_METHOD3(Receive, YR::Libruntime::ErrorInfo(uint32_t expectNum, uint32_t timeoutMs, + std::vector &outElements)); + MOCK_METHOD2(Receive, + YR::Libruntime::ErrorInfo(uint32_t timeoutMs, std::vector &outElements)); + MOCK_METHOD1(Ack, YR::Libruntime::ErrorInfo(uint64_t elementId)); + MOCK_METHOD0(Close, YR::Libruntime::ErrorInfo()); +}; +} // namespace Libruntime + +namespace test { +using namespace YR::utility; +using namespace testing; + +class StreamPubSubTest : public testing::Test { +public: + StreamPubSubTest(){}; + ~StreamPubSubTest(){}; + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + this->producer = std::make_shared(); + this->consumer = std::make_shared(); + this->streamProducer = std::make_shared(this->producer); + this->streamConsumer = std::make_shared(this->consumer); + } + std::shared_ptr producer; + std::shared_ptr consumer; + std::shared_ptr streamProducer; + std::shared_ptr streamConsumer; +}; + +TEST_F(StreamPubSubTest, SendFailedTest) +{ + Element ele(nullptr, 10, 10); + YR::Libruntime::ErrorInfo err(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, + YR::Libruntime::ModuleCode::DATASYSTEM, "Send failed."); + EXPECT_CALL(*this->producer, Send(_)).WillOnce(testing::Return(err)); + bool isThrow = false; + try { + streamProducer->Send(ele); + } catch (YR::Exception &err) { + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + EXPECT_THAT(err.Msg(), testing::HasSubstr("Send failed.")); + isThrow = true; + } + EXPECT_TRUE(isThrow); + + EXPECT_CALL(*this->producer, Send(_, _)).WillOnce(testing::Return(err)); + isThrow = false; + try { + streamProducer->Send(ele, 1000); + } catch (YR::Exception &err) { + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + EXPECT_THAT(err.Msg(), testing::HasSubstr("Send failed.")); + isThrow = true; + } + EXPECT_TRUE(isThrow); +} + +TEST_F(StreamPubSubTest, SendSuccessfullyTest) +{ + Element ele(nullptr, 10, 10); + EXPECT_CALL(*this->producer, Send(_)).WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(streamProducer->Send(ele)); + + EXPECT_CALL(*this->producer, Send(_, _)).WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(streamProducer->Send(ele, 1000)); +} + +TEST_F(StreamPubSubTest, FlushFailedTest) +{ + YR::Libruntime::ErrorInfo err(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, + YR::Libruntime::ModuleCode::DATASYSTEM, "Flush failed."); + EXPECT_CALL(*(this->producer), Flush()).WillOnce(testing::Return(err)); + bool isThrow = false; + try { + this->streamProducer->Flush(); + } catch (YR::Exception &err) { + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + EXPECT_THAT(err.Msg(), testing::HasSubstr("Flush failed.")); + isThrow = true; + } + EXPECT_TRUE(isThrow); +} + +TEST_F(StreamPubSubTest, FlushSuccessfullyTest) +{ + EXPECT_CALL(*(this->producer), Flush()).WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(this->streamProducer->Flush()); +} + +TEST_F(StreamPubSubTest, ProducerCloseFailedTest) +{ + YR::Libruntime::ErrorInfo err(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, + YR::Libruntime::ModuleCode::DATASYSTEM, "Close failed."); + EXPECT_CALL(*(this->producer), Close()).WillOnce(testing::Return(err)); + bool isThrow = false; + try { + this->streamProducer->Close(); + } catch (YR::Exception &err) { + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + EXPECT_THAT(err.Msg(), testing::HasSubstr("Close failed.")); + isThrow = true; + } + EXPECT_TRUE(isThrow); +} + +TEST_F(StreamPubSubTest, ProducerCloseSuccessfullyTest) +{ + EXPECT_CALL(*(this->producer), Close()).WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(this->streamProducer->Close()); +} + +TEST_F(StreamPubSubTest, ReceiveTest) +{ + YR::Libruntime::ErrorInfo err(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, + YR::Libruntime::ModuleCode::DATASYSTEM, "Receive failed."); + EXPECT_CALL(*(this->consumer), Receive(_, _)) + .WillOnce(testing::Return(err)) + .WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); + std::vector outElements; + bool isThrow = false; + try { + this->streamConsumer->Receive(1000, outElements); + } catch (YR::Exception &err) { + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + EXPECT_THAT(err.Msg(), testing::HasSubstr("Receive failed.")); + isThrow = true; + } + EXPECT_TRUE(isThrow); + EXPECT_NO_THROW(this->streamConsumer->Receive(1000, outElements)); + + EXPECT_CALL(*(this->consumer), Receive(_, _, _)) + .WillOnce(testing::Return(err)) + .WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); + isThrow = false; + try { + this->streamConsumer->Receive(1, 1000, outElements); + } catch (YR::Exception &err) { + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + EXPECT_THAT(err.Msg(), testing::HasSubstr("Receive failed.")); + isThrow = true; + } + EXPECT_TRUE(isThrow); + EXPECT_NO_THROW(this->streamConsumer->Receive(1, 1000, outElements)); +} + +TEST_F(StreamPubSubTest, AckFailedTest) +{ + YR::Libruntime::ErrorInfo err(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, + YR::Libruntime::ModuleCode::DATASYSTEM, "Ack failed."); + EXPECT_CALL(*(this->consumer), Ack(_)).WillOnce(testing::Return(err)); + bool isThrow = false; + try { + this->streamConsumer->Ack(111); + } catch (YR::Exception &err) { + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + EXPECT_THAT(err.Msg(), testing::HasSubstr("Ack failed.")); + isThrow = true; + } + EXPECT_TRUE(isThrow); +} + +TEST_F(StreamPubSubTest, AckSuccessfullyTest) +{ + EXPECT_CALL(*(this->consumer), Ack(_)).WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(this->streamConsumer->Ack(111)); +} + +TEST_F(StreamPubSubTest, ConsumerCloseFailedTest) +{ + YR::Libruntime::ErrorInfo err(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, + YR::Libruntime::ModuleCode::DATASYSTEM, "Close failed."); + EXPECT_CALL(*(this->consumer), Close()).WillOnce(testing::Return(err)); + bool isThrow = false; + try { + this->streamConsumer->Close(); + } catch (YR::Exception &err) { + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + EXPECT_THAT(err.Msg(), testing::HasSubstr("Close failed.")); + isThrow = true; + } + EXPECT_TRUE(isThrow); +} + +TEST_F(StreamPubSubTest, ConsumerCloseSuccessfullyTest) +{ + EXPECT_CALL(*(this->consumer), Close()).WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); + EXPECT_NO_THROW(this->streamConsumer->Close()); +} +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/clibruntime/clibruntime_test.cpp b/test/clibruntime/clibruntime_test.cpp new file mode 100644 index 0000000..ec3f0ca --- /dev/null +++ b/test/clibruntime/clibruntime_test.cpp @@ -0,0 +1,1034 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include "common/mock_libruntime.h" +#define private public +#include "datasystem/kv_client.h" +#include "api/go/libruntime/cpplibruntime/clibruntime.h" +#include "src/libruntime/libruntime_manager.h" + +using namespace YR::utility; + +class FakeStreamProducer : public YR::Libruntime::StreamProducer { +public: + MOCK_METHOD1(Send, YR::Libruntime::ErrorInfo(const YR::Libruntime::Element &element)); + MOCK_METHOD2(Send, YR::Libruntime::ErrorInfo(const YR::Libruntime::Element &element, int64_t timeoutMs)); + MOCK_METHOD0(Flush, YR::Libruntime::ErrorInfo()); + MOCK_METHOD0(Close, YR::Libruntime::ErrorInfo()); +}; + +class FakeStreamConsumer : public YR::Libruntime::StreamConsumer { +public: + MOCK_METHOD3(Receive, YR::Libruntime::ErrorInfo(uint32_t expectNum, uint32_t timeoutMs, + std::vector &outElements)); + MOCK_METHOD2(Receive, + YR::Libruntime::ErrorInfo(uint32_t timeoutMs, std::vector &outElements)); + MOCK_METHOD1(Ack, YR::Libruntime::ErrorInfo(uint64_t elementId)); + MOCK_METHOD0(Close, YR::Libruntime::ErrorInfo()); +}; + +void SafeFreeCErr(CErrorInfo cErr) +{ + if (cErr.message != nullptr) { + free(cErr.message); + } +} + +CErrorInfo *GoLoadFunctions(char **codePaths, int size_codePaths) +{ + return nullptr; +} + +CErrorInfo *GoFunctionExecution(CFunctionMeta *, CInvokeType, CArg *, int, CDataObject *, int) +{ + return nullptr; +} + +CErrorInfo *GoCheckpoint(char *checkpointId, CBuffer *buffer) +{ + return nullptr; +} + +CErrorInfo *GoRecover(CBuffer *buffer) +{ + return nullptr; +} + +CErrorInfo *GoShutdown(uint64_t gracePeriodSeconds) +{ + return nullptr; +} + +CErrorInfo *GoSignal(int sigNo, CBuffer *payload) +{ + return nullptr; +} + +CHealthCheckCode GoHealthCheck(void) +{ + return CHealthCheckCode::HEALTHY; +} + +char GoHasHealthCheck(void) +{ + return 0; +} + +void GoFunctionExecutionPoolSubmit(void *ptr) {} + +void GoRawCallback(char *cKey, CErrorInfo cErr, CBuffer cResultRaw) {} + +void GoGetAsyncCallback(char *cObjectID, CBuffer cBuf, CErrorInfo *cErr, void *userData) {} + +void GoWaitAsyncCallback(char *cObjectID, CErrorInfo *cErr, void *userData) {} + +void freeCStrings(char **cStrings, size_t cStringsLen) +{ + for (size_t i = 0; i < cStringsLen; i++) { + free(cStrings[i]); + } + free(cStrings); +} + +void freeCErrorIds(CErrorObject **errorIds, int size_errorIds) +{ + for (size_t i = 0; i < size_errorIds; i++) { + free(errorIds[i]->objectId); + SafeFreeCErr(*(errorIds[i]->errorInfo)); + } + free(errorIds); +} + +class CLibruntimeTest : public testing::Test { +public: + CLibruntimeTest(){}; + ~CLibruntimeTest(){}; + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + auto lc = std::make_shared(); + lc->jobId = "111"; + auto clientsMgr = std::make_shared(); + auto metricsAdaptor = std::make_shared(); + auto sec = std::make_shared(); + auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); + lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec, socketClient); + YR::Libruntime::LibruntimeManager::Instance().SetLibRuntime(lr); + } + + char *GetStr(std::string str) + { + char *cStr = (char *)malloc(str.size() + 1); + (void)memcpy_s(cStr, str.size(), str.data(), str.size()); + cStr[str.size()] = 0; + tmpStrs.push_back(cStr); + return cStr; + } + + void TearDown() override + { + lr.reset(); + YR::Libruntime::LibruntimeManager::Instance().Finalize(); + for (auto ptr : tmpStrs) { + if (ptr != nullptr) { + free(ptr); + ptr = nullptr; + } + } + tmpStrs.clear(); + } + + std::shared_ptr lr; + std::vector tmpStrs; +}; + +TEST_F(CLibruntimeTest, CCreateStateStoreTest) +{ + std::shared_ptr stateStore; + stateStore = std::make_shared(); + EXPECT_CALL(*lr.get(), CreateStateStore(_, _)) + .WillOnce(DoAll(SetArgReferee<1>(stateStore), Return(YR::Libruntime::ErrorInfo()))); + CConnectArguments arguments{"127.0.0.1", 0, 100, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, + 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, 0}; + CStateStorePtr stateStore2 = nullptr; + auto cErr = CCreateStateStore(&arguments, &stateStore2); + ASSERT_TRUE(cErr.code == 0); + ASSERT_TRUE(stateStore2 != nullptr); + CDestroyStateStore(stateStore2); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CSetTraceIdTest) +{ + std::string traceId = "traceId"; + EXPECT_CALL(*lr.get(), SetTraceId(_)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + auto cErr = CSetTraceId(traceId.c_str(), traceId.size()); + ASSERT_TRUE(cErr.code == 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CGenerateKeyTest) +{ + char *key = nullptr; + int cKeyLen = 0; + CStateStorePtr nullStateStorePtr = nullptr; + auto cErr = CGenerateKey(nullStateStorePtr, &key, &cKeyLen); + ASSERT_TRUE(cErr.code == 3003); + ASSERT_TRUE(std::string(cErr.message) == "the state store is empty"); + SafeFreeCErr(cErr); + + auto stateStore = YR::Libruntime::DSCacheStateStore(); + auto stateStorePtr = reinterpret_cast(&stateStore); + EXPECT_CALL(*lr.get(), GenerateKeyByStateStore(_, _)) + .WillOnce(DoAll(SetArgReferee<1>("genKey"), Return(YR::Libruntime::ErrorInfo()))); + cErr = CGenerateKey(stateStorePtr, &key, &cKeyLen); + ASSERT_TRUE(cErr.code == 0); + ASSERT_TRUE(strcmp(key, "genKey") == 0); + ASSERT_TRUE(cKeyLen == strlen("genKey")); + SafeFreeCErr(cErr); + free(key); + + YR::Libruntime::LibruntimeManager::Instance().SetLibRuntime(nullptr); + cErr = CGenerateKey(stateStorePtr, &key, &cKeyLen); + ASSERT_TRUE(cErr.code == 9000); + ASSERT_TRUE(std::string(cErr.message) == "libRuntime empty"); +} + +TEST_F(CLibruntimeTest, CSetByStateStoreTest) +{ + std::string rightKey = "rightKey"; + std::string value = "value"; + CBuffer buffer{static_cast(const_cast(value.c_str())), static_cast(value.size())}; + CSetParam param; + + CStateStorePtr nullStateStorePtr = nullptr; + auto cErr = CSetByStateStore(nullStateStorePtr, const_cast(rightKey.c_str()), buffer, param); + ASSERT_EQ(cErr.code, 3003); + ASSERT_TRUE(std::string(cErr.message) == "the state store is empty"); + SafeFreeCErr(cErr); + + auto stateStore = YR::Libruntime::DSCacheStateStore(); + auto stateStorePtr = reinterpret_cast(&stateStore); + EXPECT_CALL(*lr.get(), SetByStateStore(_, _, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + cErr = CSetByStateStore(stateStorePtr, const_cast(rightKey.c_str()), buffer, param); + ASSERT_TRUE(cErr.code == 0); + SafeFreeCErr(cErr); + + auto err = YR::Libruntime::ErrorInfo(); + err.SetErrorCode(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + err.SetDsStatusCode(datasystem::StatusCode::K_RUNTIME_ERROR); + EXPECT_CALL(*lr.get(), SetByStateStore(_, _, _, _)).WillOnce(Return(err)); + std::string wrongKey = "wrongKey"; + cErr = CSetByStateStore(stateStorePtr, const_cast(wrongKey.c_str()), buffer, param); + ASSERT_TRUE(cErr.code != 0); + ASSERT_TRUE(cErr.dsStatusCode == datasystem::StatusCode::K_RUNTIME_ERROR); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CSetValueByStateStoreTest) +{ + std::string value = "value"; + CSetParam param; + CBuffer buffer{static_cast(const_cast(value.c_str())), static_cast(value.size())}; + char *key = nullptr; + int keyLen = 0; + + CStateStorePtr nullStateStorePtr = nullptr; + auto cErr = CSetValueByStateStore(nullStateStorePtr, buffer, param, &key, &keyLen); + ASSERT_TRUE(cErr.code == 3003); + ASSERT_TRUE(std::string(cErr.message) == "the state store is empty"); + SafeFreeCErr(cErr); + + auto stateStore = YR::Libruntime::DSCacheStateStore(); + auto stateStorePtr = reinterpret_cast(&stateStore); + EXPECT_CALL(*lr.get(), SetValueByStateStore(_, _, _, _)) + .WillOnce(DoAll(SetArgReferee<3>("returnKey"), Return(YR::Libruntime::ErrorInfo()))); + cErr = CSetValueByStateStore(stateStorePtr, buffer, param, &key, &keyLen); + ASSERT_TRUE(cErr.code == 0); + ASSERT_TRUE(strcmp(key, "returnKey") == 0); + ASSERT_TRUE(keyLen == strlen("returnKey")); + free(key); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, GetByStateStoreTest) +{ + std::string rightKey = "rightKey"; + CBuffer buffer; + + CStateStorePtr nullStateStorePtr = nullptr; + auto cErr = CGetByStateStore(nullStateStorePtr, const_cast(rightKey.c_str()), &buffer, 0); + ASSERT_TRUE(cErr.code == 3003); + ASSERT_TRUE(std::string(cErr.message) == "the state store is empty"); + SafeFreeCErr(cErr); + + auto stateStore = YR::Libruntime::DSCacheStateStore(); + auto stateStorePtr = reinterpret_cast(&stateStore); + EXPECT_CALL(*lr.get(), GetByStateStore(_, _, _)) + .WillOnce(Return(std::make_pair, YR::Libruntime::ErrorInfo>( + std::make_shared(1), YR::Libruntime::ErrorInfo()))); + cErr = CGetByStateStore(stateStorePtr, const_cast(rightKey.c_str()), &buffer, 0); + ASSERT_TRUE(cErr.code == 0); + SafeFreeCErr(cErr); + free(buffer.buffer); + + std::string wrongKey = "wrongKey"; + CBuffer bufferTwo{nullptr, 0}; + auto err = YR::Libruntime::ErrorInfo(); + err.SetErrorCode(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + err.SetDsStatusCode(datasystem::StatusCode::K_OUT_OF_MEMORY); + EXPECT_CALL(*lr.get(), GetByStateStore(_, _, _)) + .WillOnce(Return(std::make_pair, YR::Libruntime::ErrorInfo>( + std::make_shared(0), std::move(err)))); + cErr = CGetByStateStore(stateStorePtr, const_cast(wrongKey.c_str()), &bufferTwo, 0); + ASSERT_TRUE(cErr.code != 0); + ASSERT_TRUE(cErr.dsStatusCode == datasystem::StatusCode::K_OUT_OF_MEMORY); + SafeFreeCErr(cErr); + free(bufferTwo.buffer); +} + +TEST_F(CLibruntimeTest, GetArrayByStateStoreTest) +{ + char *keys[2]; + keys[0] = GetStr("key1"); + keys[1] = GetStr("key2"); + CBuffer buffer[2]; + for (size_t i = 0; i < 2; i++) { + buffer[i].buffer = nullptr; + buffer[i].size_buffer = 0; + } + + CStateStorePtr nullStateStorePtr = nullptr; + auto cErr = CGetArrayByStateStore(nullStateStorePtr, keys, 2, buffer, 0); + ASSERT_TRUE(cErr.code == 3003); + ASSERT_TRUE(std::string(cErr.message) == "the state store is empty"); + SafeFreeCErr(cErr); + + auto stateStore = YR::Libruntime::DSCacheStateStore(); + auto stateStorePtr = reinterpret_cast(&stateStore); + auto err = YR::Libruntime::ErrorInfo(); + err.SetErrorCode(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + err.SetDsStatusCode(datasystem::StatusCode::K_OUT_OF_MEMORY); + EXPECT_CALL(*lr.get(), GetArrayByStateStore(_, _, _, _)) + .WillOnce( + Return(std::make_pair>, YR::Libruntime::ErrorInfo>( + {}, std::move(err)))); + cErr = CGetArrayByStateStore(stateStorePtr, keys, 2, buffer, 0); + ASSERT_TRUE(cErr.code != 0); + ASSERT_EQ(cErr.dsStatusCode, datasystem::StatusCode::K_OUT_OF_MEMORY); + SafeFreeCErr(cErr); + for (size_t i = 0; i < 2; i++) { + free(buffer[i].buffer); + } +} + +TEST_F(CLibruntimeTest, QuerySizeByStateStoreTest) +{ + char *fake[1]; + fake[0] = GetStr("query_1"); + uint64_t fakeSizes[1]; + + CStateStorePtr nullStateStorePtr = nullptr; + auto cErr = CQuerySizeByStateStore(nullStateStorePtr, fake, 1, fakeSizes); + ASSERT_TRUE(cErr.code == 3003); + ASSERT_TRUE(std::string(cErr.message) == "the state store is empty"); + SafeFreeCErr(cErr); + + auto stateStore = YR::Libruntime::DSCacheStateStore(); + auto stateStorePtr = reinterpret_cast(&stateStore); + EXPECT_CALL(*lr.get(), QuerySizeByStateStore(_, _, _)) + .WillOnce([=](std::shared_ptr, const std::vector &, + std::vector ¶m) { + param = {10}; + return YR::Libruntime::ErrorInfo(); + }); + cErr = CQuerySizeByStateStore(stateStorePtr, fake, 1, fakeSizes); + ASSERT_EQ(0, cErr.code); + ASSERT_EQ(fakeSizes[0], 10); +} + +TEST_F(CLibruntimeTest, DelByStateStoreTest) +{ + CStateStorePtr nullStateStorePtr = nullptr; + auto stateStore = YR::Libruntime::DSCacheStateStore(); + auto stateStorePtr = reinterpret_cast(&stateStore); + EXPECT_CALL(*lr.get(), DelByStateStore(_, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + std::string rightKey = "rightKey"; + + auto cErr = CDelByStateStore(nullStateStorePtr, const_cast(rightKey.c_str())); + ASSERT_TRUE(cErr.code == 3003); + ASSERT_TRUE(std::string(cErr.message) == "the state store is empty"); + SafeFreeCErr(cErr); + + cErr = CDelByStateStore(stateStorePtr, const_cast(rightKey.c_str())); + ASSERT_TRUE(cErr.code == 0); + SafeFreeCErr(cErr); + + auto err = YR::Libruntime::ErrorInfo(); + err.SetErrorCode(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + err.SetDsStatusCode(datasystem::StatusCode::K_RUNTIME_ERROR); + EXPECT_CALL(*lr.get(), DelByStateStore(_, _)).WillOnce(Return(err)); + std::string wrongKey = "wrongKey"; + cErr = CDelByStateStore(stateStorePtr, const_cast(wrongKey.c_str())); + ASSERT_TRUE(cErr.code != 0); + ASSERT_EQ(cErr.dsStatusCode, datasystem::StatusCode::K_RUNTIME_ERROR); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, DelArrayByStateStoreTest) +{ + char *keys[2]; + keys[0] = GetStr("wrongKey"); + keys[1] = GetStr("rightKey"); + char **failedKeys = nullptr; + int failedKeysLen = 0; + + CStateStorePtr nullStateStorePtr = nullptr; + auto cErr = CDelArrayByStateStore(nullStateStorePtr, keys, 2, &failedKeys, &failedKeysLen); + ASSERT_TRUE(cErr.code == 3003); + ASSERT_TRUE(std::string(cErr.message) == "the state store is empty"); + SafeFreeCErr(cErr); + + auto stateStore = YR::Libruntime::DSCacheStateStore(); + auto stateStorePtr = reinterpret_cast(&stateStore); + auto err = YR::Libruntime::ErrorInfo(); + err.SetErrorCode(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + err.SetDsStatusCode(datasystem::StatusCode::K_RUNTIME_ERROR); + EXPECT_CALL(*lr.get(), DelArrayByStateStore(_, _)) + .WillOnce( + Return(std::make_pair, YR::Libruntime::ErrorInfo>({"wrongKey"}, std::move(err)))); + cErr = CDelArrayByStateStore(stateStorePtr, keys, 2, &failedKeys, &failedKeysLen); + ASSERT_TRUE(cErr.code != 0); + ASSERT_TRUE(cErr.dsStatusCode == datasystem::StatusCode::K_RUNTIME_ERROR); + ASSERT_TRUE(failedKeysLen == 1); + ASSERT_TRUE(strcmp(failedKeys[0], "wrongKey") == 0); + SafeFreeCErr(cErr); + freeCStrings(failedKeys, failedKeysLen); +} + +TEST_F(CLibruntimeTest, CGetCredentialTest) +{ + YR::Libruntime::Credential credential = {.ak = "ak", .sk = "sk", .dk = "dk"}; + EXPECT_CALL(*lr.get(), GetCredential()).WillOnce(Return(credential)); + auto cCredential = CGetCredential(); + ASSERT_TRUE(std::string(cCredential.ak) == "ak"); + ASSERT_TRUE(std::string(cCredential.sk) == "sk"); + ASSERT_TRUE(std::string(cCredential.dk) == "dk"); + free(cCredential.ak); + free(cCredential.sk); + free(cCredential.dk); +} + +TEST_F(CLibruntimeTest, CInitTest) +{ + CLibruntimeConfig config{ + "127.0.0.1:11111", + "127.0.0.1:11112", + "127.0.0.1:11113", + "jobId", + "runtimeId", + "instanceId", + "functionName", + "DEBUG", + "./", + CApiType::ACTOR, + 1, + 0, + 0, + "privateKeyPath", + "certificateFilePath", + "verifyFilePath", + "privateKeyPaaswd", + GetStr("functionId"), + "systemAuthAccessKey", + "systemAuthSecretKey", + 2, + "EncryptPrivateKeyPasswd", + "PrimaryKeyStoreFile", + "StandbyKeyStoreFile", + 0, + "RuntimePublicKeyContext", + "RuntimePrivateKeyContext", + "DsPublicKeyContext", + "EncryptRuntimePublicKeyContext", + "EncryptRuntimePrivateKeyContext", + "EncryptDsPublicKeyContext", + }; + auto cErr = CInit(&config); + ASSERT_TRUE(cErr.code == 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CWaitTest) +{ + CWaitResult result{nullptr, 0, nullptr, 0, nullptr, 0}; + char *objectIds[3]; + objectIds[0] = GetStr("obj1"); + objectIds[1] = GetStr("obj2"); + objectIds[2] = GetStr("obj3"); + char *readyIds[1]; + readyIds[0] = GetStr("obj1"); + char *unreadyIds[1]; + unreadyIds[0] = GetStr("obj2"); + std::vector stackTraceInfos; + YR::Libruntime::StackTraceInfo stackTraceInfo("type", "err msg"); + stackTraceInfos.push_back(stackTraceInfo); + auto err = YR::Libruntime::ErrorInfo(); + err.SetErrorCode(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); + err.SetDsStatusCode(datasystem::StatusCode::K_RUNTIME_ERROR); + err.SetErrorMsg("failed to wait"); + err.SetStackTraceInfos(stackTraceInfos); + auto ret = std::make_shared(); + ret->readyIds.push_back(readyIds[0]); + ret->unreadyIds.push_back(unreadyIds[0]); + ret->exceptionIds["obj3"] = err; + EXPECT_CALL(*lr.get(), Wait(_, _, _)).WillOnce(Return(ret)); + ASSERT_NO_THROW(CWait(objectIds, 3, 3, 1, &result)); + ASSERT_EQ(result.size_readyIds, 1); + ASSERT_EQ(result.size_unreadyIds, 1); + ASSERT_EQ(result.size_errorIds, 1); + freeCStrings(result.unreadyIds, result.size_unreadyIds); + freeCStrings(result.readyIds, result.size_readyIds); + freeCErrorIds(result.errorIds, result.size_errorIds); +} + +TEST_F(CLibruntimeTest, CCreateStreamProducerTest) +{ + CProducerConfig config{0, 0, 0}; + config.traceId = GetStr("trace_id"); + std::string streamName = "stream_001"; + Producer_p producer = nullptr; + EXPECT_CALL(*lr.get(), CreateStreamProducer(_, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + auto cErr = CCreateStreamProducer(const_cast(streamName.c_str()), &config, &producer); + ASSERT_EQ(cErr.code, 0); + ASSERT_TRUE(producer != nullptr); + SafeFreeCErr(cErr); + cErr = CProducerClose(producer); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CCreateInstanceTest) +{ + EXPECT_CALL(*lr.get(), CreateInstance(_, _, _)) + .WillOnce( + Return(std::pair(YR::Libruntime::ErrorInfo(), "instance_id"))); + CFunctionMeta meta{GetStr("app_name"), + GetStr("module_name"), + GetStr("func_name"), + GetStr("class_name"), + 1, + GetStr("code_id"), + GetStr("signature"), + GetStr("pool_label"), + CApiType::ACTOR, + GetStr("function_id"), + '1', + GetStr("name"), + '1', + GetStr("namespace")}; + + char *fake[2]; + fake[0] = GetStr("obj_1"); + fake[1] = GetStr("obj_2"); + CInvokeArg arg{(void *)GetStr("buf"), 3, '0', GetStr("obj_id"), GetStr("tenant_id"), fake, 2}; + CCustomResource res{GetStr("name"), 1.0}; + CCustomExtension extension{GetStr("key"), GetStr("value")}; + CCreateOpt opt{GetStr("key"), GetStr("value")}; + CLabelOperator labelOperator{CLabelOpType::EXISTS, GetStr("label_key"), fake, 2}; + CAffinity affinity{CAffinityKind::INSTANCE, CAffinityType::PREFERRED, '1', '1', &labelOperator, 1}; + CInvokeOptions option{500, + 500, + &res, + 1, + &extension, + 1, + &opt, + 1, + fake, + 2, + &affinity, + 0, + 0, + 1, + fake, + 2, + GetStr("scheduler_id"), + fake, + 2, + GetStr("trace_id"), + 1}; + char *instanceId; + int invokeArgsSize = 1; + auto cErr = CCreateInstance(&meta, &arg, invokeArgsSize, &option, &instanceId); + ASSERT_EQ(0, cErr.code); + free(instanceId); +} + +TEST_F(CLibruntimeTest, CInvokeByInstanceIdTest) +{ + EXPECT_CALL(*lr.get(), InvokeByInstanceId(_, _, _, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + CFunctionMeta meta{GetStr("app_name"), + GetStr("module_name"), + GetStr("func_name"), + GetStr("class_name"), + 1, + GetStr("code_id"), + GetStr("signature"), + GetStr("pool_label"), + CApiType::ACTOR, + GetStr("function_id"), + '1', + GetStr("name"), + '1', + GetStr("namespace")}; + + char *fake[2]; + fake[0] = GetStr("obj_1"); + fake[1] = GetStr("obj_2"); + CInvokeArg arg{(void *)GetStr("buf"), 3, '0', GetStr("obj_id"), GetStr("tenant_id"), fake, 2}; + CCustomResource res{GetStr("name"), 1.0}; + CCustomExtension extension{GetStr("key"), GetStr("value")}; + CCreateOpt opt{GetStr("key"), GetStr("value")}; + CLabelOperator labelOperator{CLabelOpType::EXISTS, GetStr("label_key"), fake, 2}; + CAffinity affinity{CAffinityKind::INSTANCE, CAffinityType::PREFERRED, '1', '1', &labelOperator, 1}; + CInvokeOptions option{500, + 500, + &res, + 1, + &extension, + 1, + &opt, + 1, + fake, + 2, + &affinity, + 0, + 0, + 1, + fake, + 2, + GetStr("scheduler_id"), + fake, + 2, + GetStr("trace_id"), + 1}; + char *returnObjId; + auto cErr = CInvokeByInstanceId(&meta, GetStr("instance_id"), &arg, 1, &option, &returnObjId); + ASSERT_EQ(0, cErr.code); + free(returnObjId); +} + +TEST_F(CLibruntimeTest, CInvokeByFunctionNameTest) +{ + EXPECT_CALL(*lr.get(), InvokeByFunctionName(_, _, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + CFunctionMeta meta{GetStr("app_name"), + GetStr("module_name"), + GetStr("func_name"), + GetStr("class_name"), + 1, + GetStr("code_id"), + GetStr("signature"), + GetStr("pool_label"), + CApiType::ACTOR, + GetStr("function_id"), + '1', + GetStr("name"), + '1', + GetStr("namespace")}; + + char *fake[2]; + fake[0] = GetStr("obj_1"); + fake[1] = GetStr("obj_2"); + CInvokeArg arg{(void *)GetStr("buf"), 3, '0', GetStr("obj_id"), GetStr("tenant_id"), fake, 2}; + CCustomResource res{GetStr("name"), 1.0}; + CCustomExtension extension{GetStr("key"), GetStr("value")}; + CCreateOpt opt{GetStr("key"), GetStr("value")}; + CLabelOperator labelOperator{CLabelOpType::EXISTS, GetStr("label_key"), fake, 2}; + CAffinity affinity{CAffinityKind::INSTANCE, CAffinityType::PREFERRED, '1', '1', &labelOperator, 1}; + CInvokeOptions option{500, + 500, + &res, + 1, + &extension, + 1, + &opt, + 1, + fake, + 2, + &affinity, + 0, + 0, + 1, + fake, + 2, + GetStr("scheduler_id"), + fake, + 2, + GetStr("trace_id"), + 1}; + char *returnObjId; + auto cErr = CInvokeByFunctionName(&meta, &arg, 1, &option, &returnObjId); + ASSERT_EQ(0, cErr.code); + free(returnObjId); +} + +TEST_F(CLibruntimeTest, CIncreaseReferenceTest) +{ + EXPECT_CALL(*lr.get(), IncreaseReference(_, Matcher(_))) + .WillOnce(Return(std::pair>(YR::Libruntime::ErrorInfo(), + {"failed_id1"}))); + char *fake[2]; + fake[0] = GetStr("obj_1"); + fake[1] = GetStr("obj_2"); + char **failedIds; + int failedIdSize; + auto cErr = CIncreaseReferenceCommon(fake, 2, GetStr("remote_id"), &failedIds, &failedIdSize, 0); + ASSERT_EQ(0, cErr.code); + ASSERT_EQ(std::string(failedIds[0]), "failed_id1"); + free(failedIds[0]); + free(failedIds); +} + +TEST_F(CLibruntimeTest, CDecreaseReferenceTest) +{ + EXPECT_CALL(*lr.get(), DecreaseReference(_, _)) + .WillOnce(Return(std::pair>(YR::Libruntime::ErrorInfo(), + {"failed_id1"}))); + char *fake[2]; + fake[0] = GetStr("obj_1"); + fake[1] = GetStr("obj_2"); + char **failedIds; + int failedIdSize; + auto cErr = CDecreaseReferenceCommon(fake, 2, GetStr("remote_id"), &failedIds, &failedIdSize, 0); + ASSERT_EQ(0, cErr.code); + ASSERT_EQ(std::string(failedIds[0]), "failed_id1"); + free(failedIds[0]); + free(failedIds); +} + +TEST_F(CLibruntimeTest, CKVWriteTest) +{ + EXPECT_CALL(*lr.get(), KVWrite(_, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + CBuffer data{(void *)"123", 3, nullptr}; + CSetParam param; + auto cErr = CKVWrite(GetStr("key"), data, param); + ASSERT_EQ(0, cErr.code); +} + +TEST_F(CLibruntimeTest, CKVMSetTxTest) +{ + EXPECT_CALL(*lr.get(), KVMSetTx(_, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + char *fake[1]; + fake[0] = GetStr("obj_1"); + CBuffer data{(void *)"123", 3, nullptr}; + CMSetParam param; + auto cErr = CKVMSetTx(fake, 1, &data, param); + ASSERT_EQ(0, cErr.code); +} + +TEST_F(CLibruntimeTest, CKVReadTest) +{ + auto returnBuffer = std::make_shared(3); + returnBuffer->MemoryCopy("123", 3); + EXPECT_CALL(*lr.get(), KVRead(_, _)) + .WillOnce(Return(std::pair, YR::Libruntime::ErrorInfo>( + returnBuffer, YR::Libruntime::ErrorInfo()))); + CBuffer data; + auto cErr = CKVRead(GetStr("key"), 1000, &data); + ASSERT_EQ(0, cErr.code); + free(data.buffer); +} + +TEST_F(CLibruntimeTest, CKVMultiReadTest) +{ + auto returnBuffer = std::make_shared(3); + returnBuffer->MemoryCopy("123", 3); + EXPECT_CALL(*lr.get(), KVRead(_, _, _)) + .WillOnce(Return(std::pair>, YR::Libruntime::ErrorInfo>( + {returnBuffer}, YR::Libruntime::ErrorInfo()))); + CBuffer data; + char *fake[1]; + fake[0] = GetStr("obj_1"); + auto cErr = CKVMultiRead(fake, 1, 1000, 1, &data); + ASSERT_EQ(0, cErr.code); + free(data.buffer); +} + +TEST_F(CLibruntimeTest, CGetTest) +{ + auto returnDataObj = std::make_shared(0, 3); + returnDataObj->data->MemoryCopy("123", 3); + EXPECT_CALL(*lr.get(), Get(_, _, _)) + .WillOnce(Return(std::pair>>( + YR::Libruntime::ErrorInfo(), {returnDataObj}))); + CBuffer data; + auto cErr = CGet(GetStr("obj_id"), 1000, &data); + ASSERT_EQ(0, cErr.code); + free(data.buffer); +} + +TEST_F(CLibruntimeTest, CKillTest) +{ + EXPECT_CALL(*lr.get(), Kill(_, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + CBuffer data; + auto cErr = CKill(GetStr("instance_id"), 1, data); + ASSERT_EQ(0, cErr.code); +} + +TEST_F(CLibruntimeTest, CExitTest) +{ + EXPECT_CALL(*lr.get(), Exit(_, _)).WillOnce(Return()); + EXPECT_NO_THROW(CExit(0, "")); +} + +TEST_F(CLibruntimeTest, CPutTest) +{ + EXPECT_CALL(*lr.get(), Put(_, _, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + CBuffer data{(void *)"123", 3, nullptr}; + char *fake[2]; + fake[0] = GetStr("obj_1"); + fake[1] = GetStr("obj_2"); + CCreateParam param; + auto cErr = CPutCommon(GetStr("obj_id"), data, fake, 2, 0, param); + ASSERT_EQ(0, cErr.code); +} + +TEST_F(CLibruntimeTest, CGetMultiTest) +{ + auto returnDataObj = std::make_shared(0, 3); + returnDataObj->data->MemoryCopy("123", 3); + EXPECT_CALL(*lr.get(), Get(_, _, _)) + .WillOnce(Return(std::pair>>( + YR::Libruntime::ErrorInfo(), {returnDataObj}))); + CBuffer data; + char *fake[1]; + fake[0] = GetStr("obj_1"); + auto cErr = CGetMultiCommon(fake, 1, 1000, false, &data, 0); + ASSERT_EQ(0, cErr.code); + free(data.buffer); +} + +TEST_F(CLibruntimeTest, CKVDelTest) +{ + EXPECT_CALL(*lr.get(), KVDel(Matcher(_))).WillOnce(Return(YR::Libruntime::ErrorInfo())); + auto cErr = CKVDel(GetStr("obj_id")); + ASSERT_EQ(0, cErr.code); +} + +TEST_F(CLibruntimeTest, CKVMultiDelTest) +{ + EXPECT_CALL(*lr.get(), KVDel(Matcher &>(_))) + .WillOnce(Return( + std::pair, YR::Libruntime::ErrorInfo>({"key"}, YR::Libruntime::ErrorInfo()))); + char *fake[1]; + fake[0] = GetStr("obj_1"); + char **failedKeys; + int failedKeysSize; + auto cErr = CKMultiVDel(fake, 1, &failedKeys, &failedKeysSize); + ASSERT_EQ(0, cErr.code); + freeCStrings(failedKeys, failedKeysSize); +} + +TEST_F(CLibruntimeTest, CCreateStreamConsumerTest) +{ + CSubscriptionConfig config{GetStr("name"), CSubscriptionType::KEY_PARTITIONS}; + config.traceId = GetStr("trace_id"); + std::string streamName = "stream_001"; + Consumer_p consumer = nullptr; + EXPECT_CALL(*lr.get(), CreateStreamConsumer(_, _, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + auto cErr = CCreateStreamConsumer(const_cast(streamName.c_str()), &config, &consumer); + ASSERT_EQ(cErr.code, 0); + ASSERT_TRUE(consumer != nullptr); + SafeFreeCErr(cErr); + cErr = CConsumerClose(consumer); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CDeleteStreamTest) +{ + EXPECT_CALL(*lr.get(), DeleteStream(_)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + auto cErr = CDeleteStream(GetStr("stream_001")); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CQueryGlobalProducersNumTest) +{ + EXPECT_CALL(*lr.get(), QueryGlobalProducersNum(_, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + uint64_t num; + auto cErr = CQueryGlobalProducersNum(GetStr("stream_001"), &num); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CQueryGlobalConsumersNumTest) +{ + EXPECT_CALL(*lr.get(), QueryGlobalConsumersNum(_, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + uint64_t num; + auto cErr = CQueryGlobalConsumersNum(GetStr("stream_001"), &num); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CProducerTest) +{ + std::shared_ptr fkProducer = std::make_shared(); + std::unique_ptr> pProducer = + std::make_unique>(std::move(fkProducer)); + Producer_p producer = reinterpret_cast(pProducer.release()); + uint8_t array[10] = {0}; + uint8_t *ptr = array; + uint64_t size = 10; + uint64_t id = 111; + auto cErr = CProducerSend(producer, ptr, size, id); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); + + cErr = CProducerSendWithTimeout(producer, ptr, size, id, 1000); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); + + cErr = CProducerFlush(producer); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); + + cErr = CProducerClose(producer); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CConsumerTest) +{ + std::shared_ptr fkConsumer = std::make_shared(); + std::unique_ptr> pConsumer = + std::make_unique>(fkConsumer); + Consumer_p consumer = reinterpret_cast(pConsumer.release()); + + uint64_t elementId = 111; + auto cErr = CConsumerAck(consumer, elementId); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); + + cErr = CConsumerClose(consumer); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CAllocReturnObjectTest) +{ + EXPECT_CALL(*lr.get(), AllocReturnObject(Matcher(_), _, _, _, _)) + .WillOnce(Return(YR::Libruntime::ErrorInfo())); + CDataObject obj; + auto dataObj = std::make_shared(); + dataObj->data = std::make_shared(1); + obj.selfSharedPtr = dataObj.get(); + char *fake[1]; + fake[0] = GetStr("obj_1"); + uint64_t totalNativeBufferSize; + auto cErr = CAllocReturnObject(&obj, 1, fake, 1, &totalNativeBufferSize); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CSetReturnObjectTest) +{ + CDataObject obj; + auto dataObj = std::make_shared(); + obj.selfSharedPtr = dataObj.get(); + EXPECT_NO_THROW(CSetReturnObject(&obj, 10)); + auto cBuffer = obj.buffer; + auto cErr = CWriterLatch(&cBuffer); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); + char *cSrc = GetStr("abc"); + cErr = CMemoryCopy(&cBuffer, (void *)cSrc, 3); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); + cErr = CSeal(&cBuffer); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); + cErr = CWriterUnlatch(&cBuffer); + ASSERT_EQ(cErr.code, 0); + SafeFreeCErr(cErr); +} + +TEST_F(CLibruntimeTest, CAcquireInstanceTest) +{ + YR::Libruntime::InstanceAllocation insAlloc{"functionId", "funcSig", "instanceId", "leaseId", 0}; + EXPECT_CALL(*lr.get(), AcquireInstance(_, _, _)) + .WillOnce(Return(std::make_pair( + std::move(insAlloc), YR::Libruntime::ErrorInfo()))); + char *stateId = GetStr("aaa"); + + CFunctionMeta meta{GetStr("app_name"), + GetStr("module_name"), + GetStr("func_name"), + GetStr("class_name"), + 1, + GetStr("code_id"), + GetStr("signature"), + GetStr("pool_label"), + CApiType::ACTOR, + GetStr("function_id"), + '1', + GetStr("name"), + '1', + GetStr("namespace")}; + + char *fake[2]; + fake[0] = GetStr("obj_1"); + fake[1] = GetStr("obj_2"); + CCustomResource res{GetStr("name"), 1.0}; + CCustomExtension extension{GetStr("key"), GetStr("value")}; + CCreateOpt opt{GetStr("key"), GetStr("value")}; + CLabelOperator labelOperator{CLabelOpType::EXISTS, GetStr("label_key"), fake, 2}; + CAffinity affinity{CAffinityKind::INSTANCE, CAffinityType::PREFERRED, '1', '1', &labelOperator, 1}; + char *schdulerId = GetStr("scheduler_id"); + char *traceId = GetStr("trace_id"); + CInvokeOptions option{500, 500, &res, 1, &extension, 1, &opt, 1, fake, 2, &affinity, + 0, 0, 1, fake, 2, schdulerId, fake, 2, traceId, 1}; + std::cout << option.customExtensions[0].value << std::endl; + CInstanceAllocation cInsAlloc; + auto cErr = CAcquireInstance(stateId, &meta, &option, &cInsAlloc); + ASSERT_EQ(cErr.code, 0); + auto err = YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "aaa"); + YR::Libruntime::InstanceAllocation insAlloc2{"functionId", "funcSig", "instanceId", "leaseId", 0}; + EXPECT_CALL(*lr.get(), AcquireInstance(_, _, _)) + .WillOnce(Return(std::make_pair( + std::move(insAlloc2), std::move(err)))); + cErr = CAcquireInstance(stateId, &meta, &option, &cInsAlloc); + ASSERT_EQ(cErr.code, 1001); + ASSERT_EQ(std::string(cErr.message), "aaa"); + free(cErr.message); +} diff --git a/test/common/mock_libruntime.h b/test/common/mock_libruntime.h index 430fd9a..9d03215 100644 --- a/test/common/mock_libruntime.h +++ b/test/common/mock_libruntime.h @@ -66,7 +66,7 @@ public: std::pair>>(const std::vector &ids, int timeoutMs, bool allowPartial)); - MOCK_METHOD1(IncreaseReference, ErrorInfo(const std::vector &objIds)); + MOCK_METHOD2(IncreaseReference, ErrorInfo(const std::vector &objIds, bool toDataSystem)); MOCK_METHOD2(IncreaseReference, std::pair>(const std::vector &objIds, @@ -86,8 +86,10 @@ public: ErrorInfo(DataObject *returnObj, size_t metaSize, size_t dataSize, const std::vector &nestedObjIds, uint64_t &totalNativeBufferSize)); - MOCK_METHOD3(CreateBuffer, std::pair(size_t dataSize, std::shared_ptr &dataBuf, - const std::vector &nestedObjIds)); + MOCK_METHOD3(CreateBuffer, ErrorInfo(size_t dataSize, std::shared_ptr &dataBuf, + const std::vector &nestedObjIds)); + + MOCK_METHOD2(CreateBuffer, std::pair(size_t dataSize, std::shared_ptr &dataBuf)); MOCK_METHOD3(GetBuffers, std::pair>>(const std::vector &ids, @@ -117,10 +119,15 @@ public: MOCK_METHOD0(Exit, void(void)); + MOCK_METHOD2(Exit, void(const int code, const std::string &message)); + MOCK_METHOD2(Kill, ErrorInfo(const std::string &instanceId, int sigNo)); MOCK_METHOD3(Kill, ErrorInfo(const std::string &instanceId, int sigNo, std::shared_ptr data)); + MOCK_METHOD3(KillAsync, + void(const std::string &instanceId, int sigNo, std::function cb)); + MOCK_METHOD1(Finalize, void(bool isDriver)); MOCK_METHOD3(WaitAsync, void(const std::string &objectId, WaitAsyncCallback callback, void *userData)); @@ -154,6 +161,16 @@ public: MultipleReadResult(const std::vector &keys, const GetParams ¶ms, int timeoutMs)); MOCK_METHOD1(KVDel, ErrorInfo(const std::string &key)); MOCK_METHOD1(KVDel, MultipleDelResult(const std::vector &keys)); + + MOCK_METHOD1(KVExist, MultipleExistResult(const std::vector &keys)); + + MOCK_METHOD3(CreateStreamProducer, ErrorInfo(const std::string &streamName, ProducerConf producerConf, + std::shared_ptr &producer)); + MOCK_METHOD4(CreateStreamConsumer, ErrorInfo(const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck)); + MOCK_METHOD1(DeleteStream, ErrorInfo(const std::string &streamName)); + MOCK_METHOD2(QueryGlobalProducersNum, ErrorInfo(const std::string &streamName, uint64_t &gProducerNum)); + MOCK_METHOD2(QueryGlobalConsumersNum, ErrorInfo(const std::string &streamName, uint64_t &gConsumerNum)); MOCK_METHOD2(SaveState, ErrorInfo(const std::shared_ptr data, const int &timeout)); MOCK_METHOD2(LoadState, ErrorInfo(std::shared_ptr &data, const int &timeout)); MOCK_METHOD0(GetInvokingRequestId, std::string(void)); @@ -181,6 +198,9 @@ public: MOCK_METHOD4(GetArrayByStateStore, MultipleReadResult(std::shared_ptr stateStore, const std::vector &keys, int timeoutMs, bool allowPartial)); + MOCK_METHOD3(QuerySizeByStateStore, + ErrorInfo(std::shared_ptr stateStore, const std::vector &keys, + std::vector &outSizes)); MOCK_METHOD2(DelByStateStore, ErrorInfo(std::shared_ptr stateStore, const std::string &key)); MOCK_METHOD2(DelArrayByStateStore, MultipleDelResult(std::shared_ptr stateStore, const std::vector &keys)); @@ -201,12 +221,20 @@ public: MOCK_METHOD3(WaitBeforeGet, std::pair(const std::vector &ids, int timeoutMs, bool allowPartial)); MOCK_METHOD0(GetServerVersion, std::string()); + MOCK_METHOD2(PeekObjectRefStream, std::pair(const std::string &generatorId, bool blocking)); MOCK_METHOD0(GetFunctionGroupRunningInfo, FunctionGroupRunningInfo()); + MOCK_METHOD3(AcquireInstance, + std::pair(const std::string &stateId, const FunctionMeta &functionMeta, + InvokeOptions &opts)); + MOCK_METHOD4(ReleaseInstance, + ErrorInfo(const std::string &leaseId, const std::string &stateId, bool abnormal, InvokeOptions &opts)); + MOCK_METHOD0(GetCredential, Credential()); + // heteroclient - MOCK_METHOD2(Delete, + MOCK_METHOD2(DevDelete, ErrorInfo(const std::vector &objectIds, std::vector &failedObjectIds)); - MOCK_METHOD2(LocalDelete, + MOCK_METHOD2(DevLocalDelete, ErrorInfo(const std::vector &objectIds, std::vector &failedObjectIds)); MOCK_METHOD3(DevSubscribe, ErrorInfo(const std::vector &keys, const std::vector &blob2dList, @@ -220,8 +248,13 @@ public: std::vector &failedKeys, int32_t timeoutMs)); MOCK_METHOD3(GetInstance, std::pair( const std::string &name, const std::string &nameSpace, int timeoutSec)); + MOCK_METHOD3(UpdateSchdulerInfo, + void(const std::string &schedulerName, const std::string &schedulerId, const std::string &option)); MOCK_METHOD1(GetInstanceRoute, std::string(const std::string &objectId)); MOCK_METHOD2(SaveInstanceRoute, void(const std::string &objectId, const std::string &instanceRoute)); + MOCK_METHOD0(GetResources, std::pair>()); + MOCK_METHOD0(IsHealth, bool()); + MOCK_METHOD0(IsDsHealth, bool()); }; } // namespace Libruntime } // namespace YR diff --git a/test/data/cert/ca.crt b/test/data/cert/ca.crt new file mode 100644 index 0000000..77b0c15 --- /dev/null +++ b/test/data/cert/ca.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBzCCAe+gAwIBAgIUN6+1QymRCVlpPopmNFIbYMjca4IwDQYJKoZIhvcNAQEL +BQAwEjEQMA4GA1UEAwwHZXRjZC1jYTAgFw0yNDExMjAwNzM0MDhaGA8yMDUyMDQw +NzA3MzQwOFowEjEQMA4GA1UEAwwHZXRjZC1jYTCCASIwDQYJKoZIhvcNAQEBBQAD +ggEPADCCAQoCggEBAL/gGC4JbRQermjR7C+DZw1VTu0d/x4gK7/1aAleRv1TEP8u +dAQ7VJD4a8WkIfk2AxvP2dl1vAdxfLfjp23BQU8rWk72g3pox+4oTJRPTYshFIZt +SDW4vFHxkV7uABysQiq8SFh5Q+eRS6tWYuiifZO+nJtobQjsUZ0/Zl0CJAaiUAkm +C4pd4PD9DrakEZznGwCoTgxu5dULa5bXSC1WfR/Zy31SY+90kiTyXPlrH+elHro+ +zzx1Yzj6ejRDEyUp3vSxBAUDZChx8F/1ZcHlHoNypI3FLeH6ecFvrnCFz2uDzCEu +3Ipp6WEIPo0h4t2uH6SqtXpgSTu+LLavlPvEbHUCAwEAAaNTMFEwHQYDVR0OBBYE +FGPmOcSBBNeD/0bo5ZHpKnwDUBiWMB8GA1UdIwQYMBaAFGPmOcSBBNeD/0bo5ZHp +KnwDUBiWMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAITMBXo1 +BP1hCf2CM75/5fSCLEURDZysZciieFYL7mT8MQcZhusKUFVNMGY9TUgJt9QJRDWv +caq8YWwyUvjIqojlTxni4D94AY+rzBieW+tZWA5smVwfBukOo1RkcaDxn557XGu+ +Korwf6pnF+ojePBeQgBVUOex3NeKnLO3n+ZeHL5qh6YYx2x9mA8K0Kh0naODbqAn +y2Sh+X+KWcUfsKATvCvFaNodQTtSaF2xenpR3BuvXCH3aEOQsZpzR9VdEXWEo2+p +3yuxcKH9MlYz8of2LrJA8ZDHGfg04hiaebIFiMPnJlPfEC6rG2yLJq1WLdlbnM+8 +PzJosKDyjnCir8Y= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/test/data/cert/client.crt b/test/data/cert/client.crt new file mode 100644 index 0000000..ab21c5f --- /dev/null +++ b/test/data/cert/client.crt @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICsTCCAZkCFFSPk8DQXHkvJAy+yX/rpEYH6BagMA0GCSqGSIb3DQEBCwUAMBIx +EDAOBgNVBAMMB2V0Y2QtY2EwIBcNMjQxMTIwMDkxMzU4WhgPMjA1MjA0MDcwOTEz +NThaMBYxFDASBgNVBAMMC2V0Y2QtY2xpZW50MIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEA3dqEGmTD1w4XnpV1HLR0RbR/IXmoIuqhZ4njUnK/M3HF2+9s +IKFs2vZtuyz/KMIoOrfBqUPpnmy8RQwKwGJvZZmjaUcjMN9zmhen43gVQL5erFiR +VFGUXvahjubZ+/h5FAAIp6UtP/2e+KIwh+lWoOGARMkZcJCxGa+prhAUCLmmTmvO +PBOysAEAWuYv/Y8pKOgXQ3LMtzQ+WD7yOZ4LGyCveKcXkvUDR1XCgS6V7CeVqLZ5 +di0qQCcA1BDqIIkXeR11v9JAKNmrCdhgnTfDzQL3WEMdJGlBhiWdsjAw3FDswI3r +yFi/keYLJqev4hT/V6snXpjNdayrNlbKSHXpYwIDAQABMA0GCSqGSIb3DQEBCwUA +A4IBAQCHMvb59PY7DxERO+ztOLaMO95IGQTRJnySzxM/GKMe+hPgrmAQAdz/1xFH +2pYPpspaEHnjIfoRu/55M3V4EGsOQSyofOVhGlMShVnE7Iu4wgeql8lGBSflEFNo +aydA12q72W2crE5zSnxi7g0xvrDuJnoEakYXCzV3+06hPjBoG9BihBwPbr04iKwk +WFhF9wBoz/i/EOxYUbUhxN9aNtaQal5ZG0gQe9vRqp1pUZTNqEOU0TcurnAzC0eX +Fx7ppI997PCz9zG1b4rHqux/j+AKo5CGvX86pHmDqTbCC5CPaNeIhY4gp1xZSkMQ +9uCUFbibE02A2YmbCFVymKQUGTMw +-----END CERTIFICATE----- \ No newline at end of file diff --git a/test/data/cert/client.key b/test/data/cert/client.key new file mode 100644 index 0000000..c7e8d27 --- /dev/null +++ b/test/data/cert/client.key @@ -0,0 +1,30 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,03CBEE81FC3A10510A141198DFE9E611 + ++3XW0g8SAV1XFRofm1TSCG0RHk23PFQCrjslDCaihzwwauLjHh+PT1UUykgtbh1J +kBt71hy3TE+uebG9wYbZ0qAbGU3Yvn1RPPfg/123Ezg/574CH/lEJPUeEjh5Eg4i +y8biObh5PNhAM1u3/gmbDf2DfxMAhU+X81VSbup8/y8LaxFpJAPUqisWiH1Qgz3Z +EkzXYe33zkJB/GqNjPBoee/1xbwXlFcP/7uHCWttz58bI0IH3M49VVLguvJbfLn6 +N7xt2CU/c4TNUbM+oxL+CrdoLKwuR/O32OHCahLnUyVwTrG2J8vvKOtKE3at6ybx +eBdWcy1/XR7GE0l5M2XIl2IOE9q+LAj+RFSz6wi9j3eWwF24KfBJw41bBte+cCur +/oh/8K969GC6Fctkvk47Xz9FFOl66qRXp1sjK9LDP91eHeE1SVrztL8wWnejw+SP +5xuCz2hoxqkGLLS1mqBgiRZii2fzgtW3NP1h8tBBC5zz+Voqt607Y2+N/tJO0Fft +q+eK92zT3NnsqWZSpW4EW1sEeoWI3NeXEB6sw9ekNZbicTiqCuHKeiFn9+oPaoVO +tuZwGGYil92m7Vt3hVdL9IOJ5xASnwyUfHTFAhpou2/RCv46iT0Zo0LWEs9g2VfA +Co2otPVnxgVznHbHZLZtCYKbxvthJoLxd00NY7wrcF7OBvRjLQugxKfa6NEmRRfv +6ciaC3VmU376RJH8xpu0s1Mw/YYbISexWJzB/ypeHyWXO1UIPzm/N4mKi4CAZJlW +KCSGr7heHtwISd895MB9flBrf8Lxp6KXVtgUgO9V0sKBCuI7AK1W95D42yNCDMyK +tVS3xLKGMLfZcCh1a92R0gaLELG08obpvOsg4wO72+jgelSk+LpnJG3zPq7O3Ad1 +NT4SKdZTJoRcJvM/IFpT3LsrLx9hxVJlZxdjPXDuUZEzTegbZJJBWpWHGUvffe6k +sSibrQapDZKvKuUHmonQgKya55xdl34fEBi5IDrHfjTcgkS14Mb/nmO28c6wLHcu +ufoM/AGcPQSWlFhj7Q95ugagj/ZybfchcoQ21VJaOiNwf+wvzf1l4Py4gOtg2sYy +2NhenlbWc3OI+S1ou6Qdh0BpmqhwmjHPBO7ZGvXg/rwpMcpBAIJ+lhnyGLw0aDXV +BYC8a6Ggr5wkf+BKgnlQxBnvPM5S888kXbzFY9ydj6utYIR8PSwNiqxDbPTb5sdL +jJHQ6iPHTq240YFgBqhf3yfdSuGFAaavRV+wldrKOyqanxe+xbmBfO0w3Z/M3K2z +YxgPnmEY+ICRjBARTEffl54t7CB7Aw/XejuylOfrhiKRtynJSJoit22UMXEyCcVO +cGBBYU5RmXaA3QGrpeTEHSpXaxvIA90Ha/zPSMKkh8nrk68B+HiE6DaR+pbLogWT +YelkTRB2QM2vD4Lm3RfOul5ut99Lvd0nTqwvrRhTLO2pvYhDrJhKMHTUVIcFPA6I +muiE+YKJCg+msZbmYO5QY+a/qSfMquXCl4OgwEbZu4M3GzRSnJiHEGAqC5w0ty5n +MoA4uxOdHIDyuwRHbj4qTIanQTSz6BOCHvPFX0BgGVX82VLI2GgLOqVX93na1Ml4 +-----END RSA PRIVATE KEY----- \ No newline at end of file diff --git a/test/data/cert/server.crt b/test/data/cert/server.crt new file mode 100644 index 0000000..959739c --- /dev/null +++ b/test/data/cert/server.crt @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC9zCCAd+gAwIBAgIUVI+TwNBceS8kDL7Jf+ukRgfoFp8wDQYJKoZIhvcNAQEL +BQAwEjEQMA4GA1UEAwwHZXRjZC1jYTAgFw0yNDExMjAwOTEzNTBaGA8yMDUyMDQw +NzA5MTM1MFowFjEUMBIGA1UEAwwLZXRjZC1zZXJ2ZXIwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDE2XShX+jtn2Pl6WEewDg3mFlpYeExiQrvCZfvjXW9 +3GHf3itrNyoTHkC9LsCVr51CAHVYKA/2Rf+WDFKC1SJSkXoJseLl5Cm8JBmfbCR5 +0YiWDDaWRjo+ilfWr0S6zntbYO42wA3xmN4yPzn3N8ZFIow5P4M1Z6X2vHPYCa0O +z9y34T6pERstAUBjsDW2HjC8lki88i6ybVW8a16sZQp9jzXjji6qDtqPw9Ld3MWA +ajgk7qHPe0f0oueDg7lLTZyrtGqoyan5YDol1WcqFxwP9k6iB419qInNSo3SFPKR +iboM2EKKIyQeb+v2MTW65j80XHWZoVOe8Pjt8nYQJTxvAgMBAAGjPzA9MAsGA1Ud +DwQEAwIEMDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0RBAgw +BocEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAJS8OcfzjlJSNIfV6QmSRSGPw6eWz +M+xIEmwaCCwdeegd49X4aPjy91cq2JF0iDLVBZ2VWr2gZ7p7xQCXGnTVQ1ZVCMAA +naOblu3EeEHT6linain/llT4ZBDdKSHxTugugKNpQzgRUi9FbeZGLyEVSZfDHxdL +EMWYwsnn/F/9ao2gF6uZVsNbQM2VjUFtGPRr2HB0sjwPgZMcFMTTdQqFG8H4YlNm +mkCB0UnHdwooTZR6uFpk79xQcpA4rkoC7wAUGquuocrZryRDy2EfURS5oBDeo5pU +JjBCw48pvqPMi7JRSPrRVTyKRyikZ/s6390ZR0IBYWPzVFu1vSQPclfZjw== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/test/data/cert/server.key b/test/data/cert/server.key new file mode 100644 index 0000000..6ac27f1 --- /dev/null +++ b/test/data/cert/server.key @@ -0,0 +1,30 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,1F98F2BD02FB1BCCBCD983AEE0DDBD8E + +IGYmXghk8e4fFME07Z4UxDkh5876NjT2tOYpw0hCKLcV5iSFsDrR8zSdTbxobDPy +oXWMjWcbJHl9FAz46l3ZHs6Ta1qm/N/heiy4NmXxthiK71i+rAFUUA4WLk7ziXxP +iZ+g4PTKEdDTHxrMysa0YV27zZEeC3xyx/YOTAPXqTB4R1gJ8VSoeqgo+XJlI/Jn +65yA/17n2YIhZeT4lXq261wfcesSRb0RubL8vqLYLm4rmyzpOxwXi9fADVpdwra4 +jOcj+FSNiPMoSgfvJmegWBa1CZcnd8XbplpY/CeMkuCo1X7HIa95FZMSct81gCLQ +set6SU4DMdDG62VSUHCkx8RhtT+lnV1/J3z+QN2G/ryMTG+eeJbtrjqgFHUUGA1H +T1FzfA3guSCVI5oXS9WRRiGrYsfKffkZ0PbgbNqJAWR5O1ZfZHaS0Up2bZ2+Eaa8 +JjlY8bc0jVXWWzGJRA4AfckaMxNzpNTKvrLqykUdUeCojjp2ahIy/vbsX87XghQP +pTmF9nWIQHJRQBStdYKK/cT8iFBAxnYkC0R2usR7/QnSKq4E4180irln3f1Cwrye +T15c23R7tv+ETx2CojaTTgundqTPLQj1rlVbMa5pUXbADO3jstclhnLpmDwYNUnl +k/FH2vAreUpd3si9p/9rN5Poa3eNlU8ax4H5/B4rhEmeNcj//fgF8tP1LsuxxZLO +w92fn+8a204ZKIvdrnQqHcv9OpvRb++LDlCL1P4DTNZrReCz6ggMiqMgrK8As+yL +R4/hAaJ8mAzAoSqXLgLRj4vv+s8Pe883Kc5idgEfbauWDqJULVCt83o96FYnuDof +IkyvpGmHABpcvgnnBNp2y8OMj7KGksBT6cKXEzwhguQ/H70QYFVLsmfDKxwC9YGZ +L1suBn4LSuHJahMuEVr8h4l3P/0SJnqoTkod3m+M1Or4nYE7iJaJ40WeZsg8lPL3 +LDdprywfrcsXJLXF1fnMDrniNks860tj/VdSOOWmLa8V4xaEXLOiwnWtteDMUPp6 +ZB1mYWoz5iZvqFpyqtI5L4SFbh/wvOeTee9kPqQeSGUK7qpM+YzUMwVpGTnKyYMB +MUuSqrXaneWhUWHS7fpN3shfFjirnED53IJuZRhFauhmRnhZmgzDOuNQtaKnkNeU +izdQMF8VxqBIEeJEA9b+iYzzfN6IIJ8RaK5OwyfXnSE/gnVGL5CS6qmNjhxACN2a +/Oh+zXbnGKFu75nV1RV5w8DL8kaY/b/xCrRurGJsINxHiPQA7z1wz8hnsiOjj/bn +FfPIqnz2KsCI7WUVb0lnMrih8+5JvZmNbi3B9BydxD5z7y8YztvTjLm1yETakXAD +jLwMwnwymOxic7PtL+f9qfNp3dSNXcgqT+RuQwo2acDj2j/ln5BRT9K3tbnlYATY +SViWeBMVl2+UQ6+PtRhcwS/E1D+Zf5spECPVt7deyLrm6Bz4g6PpJJf59MtDz7za +3K4k8oOhqygy3ycHoulhtXVBq6pmeUkCJpJrk6zcL4s7Bbxlwag9UPU94eGmjQfH +lPHWykcDEmkZoHkh0V7jE1pULBjsflKMr6AwfGo3VLcnPfxRUZ/TaYC+fPAcH6Wh +-----END RSA PRIVATE KEY----- diff --git a/test/dto/config_test.cpp b/test/dto/config_test.cpp index e299129..fbefdc3 100644 --- a/test/dto/config_test.cpp +++ b/test/dto/config_test.cpp @@ -34,16 +34,17 @@ public: TEST_F(DTOConfigTest, TestConfig) { ASSERT_EQ(Config::Instance().REQUEST_ACK_ACC_MAX_SEC(), 1800); + ASSERT_EQ(Config::Instance().DS_CONNECT_TIMEOUT_SEC(), 60); setenv("MOCK_ENV1", "5", 0); size_t mockVal1 = Config::Instance().ParseFromEnv( "MOCK_ENV1", 1800, [](const size_t &val) -> bool { return val >= REQUEST_ACK_TIMEOUT_SEC; }); - ASSERT_EQ(mockVal1, 1800); + ASSERT_EQ(mockVal1, 5); - setenv("MOCK_ENV2", "10", 0); + setenv("MOCK_ENV2", "1", 0); size_t mockVal2 = Config::Instance().ParseFromEnv( "MOCK_ENV2", 1800, [](const size_t &val) -> bool { return val >= REQUEST_ACK_TIMEOUT_SEC; }); - ASSERT_EQ(mockVal2, 10); + ASSERT_EQ(mockVal2, 1800); } TEST_F(DTOConfigTest, TestGetMaxArgsInMsgBytes) @@ -77,6 +78,16 @@ TEST_F(DTOConfigTest, TestGetMaxArgsInMsgBytes) Config::c = Config(); ASSERT_EQ(Config::Instance().FASS_SCHEDULE_TIMEOUT(), 100); unsetenv("FASS_SCHEDULE_TIMEOUT"); + + ASSERT_EQ(Config::Instance().YR_MAX_LOG_SIZE_MB(), 500); + ASSERT_EQ(Config::Instance().YR_MAX_LOG_FILE_NUM(), 10); + setenv("YR_MAX_LOG_SIZE_MB", "100", 1); + setenv("YR_MAX_LOG_FILE_NUM", "5", 1); + Config::c = Config(); + ASSERT_EQ(Config::Instance().YR_MAX_LOG_SIZE_MB(), 100); + ASSERT_EQ(Config::Instance().YR_MAX_LOG_FILE_NUM(), 5); + unsetenv("YR_MAX_LOG_SIZE_MB"); + unsetenv("YR_MAX_LOG_FILE_NUM"); } TEST_F(DTOConfigTest, TestSetenv) diff --git a/test/faas/faas_executor_test.cpp b/test/faas/faas_executor_test.cpp new file mode 100644 index 0000000..681672b --- /dev/null +++ b/test/faas/faas_executor_test.cpp @@ -0,0 +1,435 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include +#include +#include "Constant.h" +#include "FunctionError.h" +#include "common/mock_libruntime.h" +#include "api/cpp/src/utils/utils.h" +#include "src/libruntime/libruntime_manager.h" +#include "src/utility/id_generator.h" +#define private public +#include "api/cpp/src/faas/context_impl.h" +#include "api/cpp/src/faas/faas_executor.h" +#include "api/cpp/src/faas/register_runtime_handler.h" + +namespace YR { +namespace test { +using namespace testing; +using namespace YR::internal; +using namespace Function; +using namespace YR::utility; +namespace fs = std::filesystem; +class FaasExecutorTest : public testing::Test { +public: + FaasExecutorTest(){}; + ~FaasExecutorTest(){}; + + void SetUp() override + { + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + auto lc = std::make_shared(); + lc->jobId = YR::utility::IDGenerator::GenApplicationId(); + auto clientsMgr = std::make_shared(); + auto metricsAdaptor = std::make_shared(); + auto sec = std::make_shared(); + auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); + lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec, socketClient); + YR::Libruntime::LibruntimeManager::Instance().SetLibRuntime(lr); + exec_ = std::make_shared(); + } + + void TearDown() override + { + YR::Libruntime::LibruntimeManager::Instance().Finalize(); + lr.reset(); + exec_.reset(); + } + +private: + std::shared_ptr lr; + std::shared_ptr exec_; +}; + +TEST_F(FaasExecutorTest, LoadFunctionsTest) +{ + auto err = exec_->LoadFunctions({}); + ASSERT_TRUE(err.OK()) << err.Msg(); +} + +TEST_F(FaasExecutorTest, SignalTest) +{ + auto payload = std::make_shared(1); + auto err = exec_->Signal(65, payload); + ASSERT_TRUE(err.OK()) << err.Msg(); +} + +std::string contextMetaStr = R"( +{"funcMetaData": + { + "layers":[],"name":"0@test@cpp","description":"this is my app", + "functionUrn":"sn:cn:yrk:12345678901234561234567890123456:function:0@test@cpp", + "tenantId":"12345678901234561234567890123456","tags":null,"functionUpdateTime":"", + "functionVersionUrn":"sn:cn:yrk:12345678901234561234567890123456:function:0@test@cpp:latest", + "revisionId":"20241017135254074","codeSize":0,"codeSha512":"","handler":"bin/start.sh", + "runtime":"posix-custom-runtime","timeout":600,"version":"latest","deadLetterConfig":"", + "businessId":"yrk","functionType":"","func_id":"","func_name":"cpp","domain_id":"", + "project_name":"","service":"test","poolLabel":"","dependencies":"", + "enable_cloud_debug":"","isStatefulFunction":false,"isBridgeFunction":false, + "isStreamEnable":false,"type":"","enable_auth_in_header":false,"dns_domain_cfg":null,"vpcTriggerImage":"", + "stateConfig":{"lifeCycle":""}}, + "s3MetaData":{"appId":"","bucketId":"","objectId":"","bucketUrl":"","code_type":"","code_url":"","code_filename":"", + "func_code":{"file":"","link":""}}, + "codeMetaData":{"sha512":"","storage_type":"s3","code_path":"", + "appId":"","bucketId":"test","objectId":"cpp.zip","bucketUrl":"https://127.0.0,1:30110","code_type":"", + "code_url":"","code_filename":"","func_code":{"file":"","link":""}}, + "envMetaData":{"environment":"fa3f8de0d7ac57ed34babf52:db67e7193a6d9ee35a15f92dc8de08f16a58d7810806921ad404bdccde", + "encrypted_user_data":"", + "envKey":"d79a80e56bd11a37c35ea5e7:d31ee89d52ab7f36094faa015a599149f1d27e295010bd328b319d3316c80bcfdeeefeffe1a3b622fb3c1ea8ffa86511dc7682086c711dc33bdb67d3c21dc93dbfdf487e90010b1905c1d168c1fe57c3", + "cryptoAlgorithm":"GCM"}, + "stsMetaData":{"enableSts":false}, + "resourceMetaData":{"cpu":1000,"memory":1024,"gpu_memory":0,"enable_dynamic_memory":false,"customResources":"", + "enable_tmp_expansion":false,"ephemeral_storage":0,"CustomResourcesSpec":""}, + "instanceMetaData":{"maxInstance":100,"minInstance":0,"concurrentNum":100,"diskLimit":0,"instanceType":"", + "schedulePolicy":"concurrency","scalePolicy":"","idleMode":false}, + "extendedMetaData":{"image_name":"","role":{"xrole":"","app_xrole":""}, + "func_vpc":null,"endpoint_tenant_vpc":null,"mount_config":null, + "strategy_config":{"concurrency":0},"extend_config":"", + "initializer":{"initializer_handler":"","initializer_timeout":0}, + "heartbeat":{"heartbeat_handler":""},"enterprise_project_id":"", + "log_tank_service":{"logGroupId":"","logStreamId":""}, + "tracing_config":{"tracing_ak":"","tracing_sk":"","project_name":""}, + "custom_container_config":{"control_path":"","image":"","command":null,"args":null, + "working_dir":"","uid":0,"gid":0},"async_config_loaded":false,"restore_hook":{}, + "network_controller":{"disable_public_network":false,"trigger_access_vpcs":null}, + "user_agency":{"accessKey":"","secretKey":"","token":"","securityAk":"","securitySk":"", + "securityToken":""},"custom_filebeat_config":{"sidecarConfigInfo":null,"cpu":0,"memory":0, + "version":"","imageAddress":""},"custom_health_check":{"timeoutSeconds":0,"periodSeconds":0, + "failureThreshold":0},"dynamic_config":{"enabled":false,"update_time":"","config_content":null}, + "runtime_graceful_shutdown":{"maxShutdownTimeout":0},"pre_stop":{"pre_stop_handler":"", + "pre_stop_timeout":0},"rasp_config":{"init-image":"","rasp-image":"","rasp-server-ip":"","rasp-server-port":""}}} +)"; + +std::string createParamStr = R"( +{"instanceLabel": "aaaaa"} +)"; + +std::string deleteDecryptStr = R"( + { + "environment": "{\"key\":\"value\"}", + "encrypted_user_data": "{\"aaa\":\"bbb\"}" + } +)"; + +TEST_F(FaasExecutorTest, ExecuteFunctionSuccessFullyTest) +{ + auto handlerPtr = std::make_unique(); + handlerPtr->RegisterHandler([](const std::string &event, Context &context) -> std::string { + Function::FunctionLogger logger = context.GetLogger(); + logger.setLevel("INFO"); + logger.Info("hello cpp %s ", "user info log"); + logger.Error("hello cpp %s ", "user error log"); + logger.Warn("hello cpp %s ", "user warn log"); + logger.Debug("hello cpp %s ", "user debug log"); + + EXPECT_EQ(context.GetAccessKey(), ""); + EXPECT_EQ(context.GetSecretKey(), ""); + EXPECT_EQ(context.GetSecurityAccessKey(), ""); + EXPECT_EQ(context.GetSecuritySecretKey(), ""); + EXPECT_EQ(context.GetToken(), ""); + EXPECT_EQ(context.GetAlias(), ""); + EXPECT_EQ(context.GetTraceId(), "traceid"); + EXPECT_EQ(context.GetInvokeId(), "initializer"); + EXPECT_EQ(context.GetState(), ""); + EXPECT_EQ(context.GetInstanceId(), ""); + EXPECT_EQ(context.GetInvokeProperty(), ""); + EXPECT_EQ(context.GetRequestID(), "traceid"); + EXPECT_EQ(context.GetUserData("aaa"), "bbb"); + EXPECT_EQ(context.GetFunctionName(), "cpp"); + EXPECT_EQ(context.GetInstanceLabel(), "aaaaa"); + EXPECT_EQ(context.GetRemainingTimeInMilliSeconds(), 0); + EXPECT_EQ(context.GetRunningTimeInSeconds(), 0); + EXPECT_EQ(context.GetVersion(), "latest"); + EXPECT_EQ(context.GetMemorySize(), 1024); + EXPECT_EQ(context.GetCPUNumber(), 1000); + EXPECT_EQ(context.GetProjectID(), "12345678901234561234567890123456"); + EXPECT_EQ(context.GetPackage(), "test"); + + return event; + }); + std::string traceId = "traceid"; + handlerPtr->RegisterInitializerFunction([](Context &context) {}); + SetRuntimeHandler(std::move(handlerPtr)); + YR::SetEnv("key", "1"); + YR::SetEnv("ENV_DELEGATE_DECRYPT", deleteDecryptStr); + + YR::Libruntime::ErrorInfo err; + YR::Libruntime::FunctionMeta function; + std::vector> rawArgs; + std::vector> returnObjects; + returnObjects.push_back(std::make_shared()); + + auto contextMetaObj = std::make_shared(0, contextMetaStr.size()); + contextMetaObj->data->MemoryCopy(contextMetaStr.data(), contextMetaStr.size()); + rawArgs.push_back(contextMetaObj); + auto createParamObj = std::make_shared(0, createParamStr.size()); + createParamObj->data->MemoryCopy(createParamStr.data(), createParamStr.size()); + rawArgs.push_back(createParamObj); + + err = exec_->ExecuteFunction(function, libruntime::InvokeType::CreateInstance, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + rawArgs.pop_back(); + + std::string eventStr = "{\"body\":\"event\", \"header\": {\"X-Trace-Id\":\"traceid\"}}"; + auto eventObj = std::make_shared(0, eventStr.size()); + eventObj->data->MemoryCopy(eventStr.data(), eventStr.size()); + rawArgs.push_back(eventObj); + + auto traceIdObj = std::make_shared(0, traceId.size()); + traceIdObj->data->MemoryCopy(traceId.data(), traceId.size()); + rawArgs.push_back(traceIdObj); + + auto returnObj2 = std::make_shared(0, 200); + EXPECT_CALL(*lr.get(), AllocReturnObject(Matcher &>(_), _, _, _, _)) + .WillOnce(DoAll(SetArgReferee<0>(returnObj2), Return(YR::Libruntime::ErrorInfo()))); + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + auto result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), + returnObjects[0]->data->GetSize()); + ASSERT_TRUE(result.find("event")) << result; + + ASSERT_EQ(YR::GetEnv("key"), "value"); +} + +TEST_F(FaasExecutorTest, ExecuteFunctionFailedTest) +{ + YR::Libruntime::ErrorInfo err; + YR::Libruntime::FunctionMeta function; + std::vector> rawArgs; + std::vector> returnObjects; + returnObjects.push_back(std::make_shared()); + auto contextMetaObj = std::make_shared(0, contextMetaStr.size()); + contextMetaObj->data->MemoryCopy(contextMetaStr.data(), contextMetaStr.size()); + rawArgs.push_back(contextMetaObj); + auto returnObj = std::make_shared(0, 200); + EXPECT_CALL(*lr.get(), AllocReturnObject(Matcher &>(_), _, _, _, _)) + .WillRepeatedly(DoAll(SetArgReferee<0>(returnObj), Return(YR::Libruntime::ErrorInfo()))); + + auto handlerPtr = std::make_unique(); + SetRuntimeHandler(std::move(handlerPtr)); + + auto eventObj = std::make_shared(0, 0); + rawArgs.push_back(eventObj); + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + auto result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), + returnObjects[0]->data->GetSize()); + ASSERT_TRUE(result.find("call req is empty")); + + std::string eventStr = "{}"; + eventObj = std::make_shared(0, eventStr.size()); + eventObj->data->MemoryCopy(eventStr.data(), eventStr.size()); + rawArgs[1] = eventObj; + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), + returnObjects[0]->data->GetSize()); + ASSERT_TRUE(result.find("can not find body")); + + eventStr = "{\"body\": 1}"; + eventObj = std::make_shared(0, eventStr.size()); + eventObj->data->MemoryCopy(eventStr.data(), eventStr.size()); + rawArgs[1] = eventObj; + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), + returnObjects[0]->data->GetSize()); + ASSERT_TRUE(result.find("event type is not string")); + + eventStr = "{\"body\": \"event\"}"; + eventObj = std::make_shared(0, eventStr.size()); + eventObj->data->MemoryCopy(eventStr.data(), eventStr.size()); + rawArgs[1] = eventObj; + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), + returnObjects[0]->data->GetSize()); + ASSERT_TRUE(result.find("can not call handlerequest before initialize")); + + contextMetaObj = std::make_shared(0, 0); + rawArgs[0] = contextMetaObj; + err = exec_->ExecuteFunction(function, libruntime::InvokeType::CreateInstance, rawArgs, returnObjects); + ASSERT_FALSE(err.OK()) << err.Msg(); + + contextMetaObj = std::make_shared(0, contextMetaStr.size()); + contextMetaObj->data->MemoryCopy(contextMetaStr.data(), contextMetaStr.size()); + rawArgs[0] = contextMetaObj; + handlerPtr = std::make_unique(); + handlerPtr->RegisterInitializerFunction( + [](Context &context) { throw FunctionError(ErrorCode::FUNCTION_EXCEPTION, "function error"); }); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteFunction(function, libruntime::InvokeType::CreateInstance, rawArgs, returnObjects); + ASSERT_FALSE(err.OK()) << err.Msg(); + ASSERT_TRUE(result.find("function error")); + + handlerPtr = std::make_unique(); + handlerPtr->RegisterInitializerFunction([](Context &context) { throw std::runtime_error("runtime error"); }); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteFunction(function, libruntime::InvokeType::CreateInstance, rawArgs, returnObjects); + ASSERT_FALSE(err.OK()) << err.Msg(); + + handlerPtr = std::make_unique(); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), + returnObjects[0]->data->GetSize()); + ASSERT_TRUE(result.find("undefined HandleRequest")); + + handlerPtr = std::make_unique(); + handlerPtr->RegisterHandler([](const std::string &event, Context &context) -> std::string { + throw FunctionError(ErrorCode::FUNCTION_EXCEPTION, "function error"); + }); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), + returnObjects[0]->data->GetSize()); + ASSERT_TRUE(result.find("function error")); + + handlerPtr = std::make_unique(); + handlerPtr->RegisterHandler( + [](const std::string &event, Context &context) -> std::string { throw std::runtime_error("runtime error"); }); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), + returnObjects[0]->data->GetSize()); + ASSERT_TRUE(result.find("runtime error")); + + handlerPtr = std::make_unique(); + handlerPtr->RegisterHandler([](const std::string &event, Context &context) -> std::string { + std::string a; + a.resize(6 * 1024 * 1024 + 1, '0'); + return a; + }); + SetRuntimeHandler(std::move(handlerPtr)); + std::string expectResult = + "{\"body\":\"function result size: 6291457, exceed limit(6291456)\",\"innerCode\":\"4004\"}"; + err = exec_->ExecuteFunction(function, libruntime::InvokeType::InvokeFunction, rawArgs, returnObjects); + ASSERT_TRUE(err.OK()) << err.Msg(); + result = std::string(static_cast(returnObjects[0]->data->ImmutableData()), expectResult.size()); + auto resultJson = nlohmann::json::parse(result); + std::string value = resultJson["body"]; + ASSERT_TRUE(value.find("exceed limit(")); +} + +TEST_F(FaasExecutorTest, ExecuteShutdownFunctionTest) +{ + YR::Libruntime::ErrorInfo err; + auto handlerPtr = std::make_unique(); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteShutdownFunction(1); + ASSERT_FALSE(err.OK()); + ASSERT_TRUE(err.Msg().find("can not call prestop before initialize")); + + exec_->contextEnv_ = std::make_shared(); + std::unordered_map params; + exec_->contextInvokeParams_ = std::make_shared(params); + err = exec_->ExecuteShutdownFunction(1); + ASSERT_TRUE(err.OK()); + + handlerPtr = std::make_unique(); + handlerPtr->RegisterPreStopFunction( + [](Context &context) { throw FunctionError(ErrorCode::FUNCTION_EXCEPTION, "function error"); }); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteShutdownFunction(1); + ASSERT_FALSE(err.OK()); + ASSERT_TRUE(err.Msg().find("function error")); + + handlerPtr = std::make_unique(); + handlerPtr->RegisterPreStopFunction([](Context &context) { throw std::runtime_error("runtime error"); }); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteShutdownFunction(1); + ASSERT_FALSE(err.OK()); + ASSERT_TRUE(err.Msg().find("runtime error")); + + handlerPtr = std::make_unique(); + handlerPtr->RegisterPreStopFunction([](Context &context) { return; }); + SetRuntimeHandler(std::move(handlerPtr)); + err = exec_->ExecuteShutdownFunction(1); + ASSERT_TRUE(err.OK()) << err.Msg(); +} + +TEST_F(FaasExecutorTest, CheckpointRecoverSuccessfullyTest) +{ + YR::Libruntime::ErrorInfo err; + + std::shared_ptr data; + err = exec_->Checkpoint("instanceid", data); + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_OK) << err.Msg(); + + err = exec_->Recover(data); + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_OK) << err.Msg(); +} + +TEST_F(FaasExecutorTest, ContextImplTest) +{ + std::unordered_map params; + auto contextInvokeParams = std::make_shared(params); + auto contextEnv = std::make_shared(); + ContextImpl ctx(contextInvokeParams, contextEnv); + ContextImpl ctx2 = ctx; + ctx.SetFuncStartTime(1000000); + ctx.SetStateId("stateId"); + ASSERT_EQ(ctx.GetInstanceId(), "stateId"); + ctx.SetState("state"); + ASSERT_EQ(ctx.GetState(), "state"); + ctx.SetInvokeProperty("prop"); + ASSERT_EQ(ctx.GetInvokeProperty(), "prop"); +} + +TEST_F(FaasExecutorTest, RuntimeTest) +{ + std::shared_ptr rt = std::make_shared(); + rt->InitRuntimeLogger(); + rt->RegisterHandler([](const std::string &request, Function::Context &context) -> std::string { return ""; }); + rt->RegisterInitializerFunction([](Function::Context &context) { return; }); + rt->RegisterPreStopFunction([](Function::Context &context) { return; }); + rt->InitState([](const std::string &, Function::Context &) { return; }); + ASSERT_NO_THROW(rt->BuildRegisterRuntimeHandler()); + rt->ReleaseRuntimeLogger(); +} + +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/faas/function_test.cpp b/test/faas/function_test.cpp new file mode 100644 index 0000000..59a4e78 --- /dev/null +++ b/test/faas/function_test.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "Function.h" +#include "FunctionError.h" +#include "common/mock_libruntime.h" +#include "api/cpp/src/faas/context_impl.h" +#include "src/libruntime/libruntime_manager.h" +#include "src/utility/id_generator.h" + +namespace YR { +namespace test { +using namespace YR::utility; +class FunctionTest : public testing::Test { +public: + FunctionTest(){}; + ~FunctionTest(){}; + + void SetUp() override + { + auto lc = std::make_shared(); + lc->jobId = YR::utility::IDGenerator::GenApplicationId(); + auto clientsMgr = std::make_shared(); + auto metricsAdaptor = std::make_shared(); + auto sec = std::make_shared(); + auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); + lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec, socketClient); + YR::Libruntime::LibruntimeManager::Instance().SetLibRuntime(lr); + } + + void TearDown() override + { + YR::Libruntime::LibruntimeManager::Instance().Finalize(); + lr.reset(); + } + + std::shared_ptr lr; +}; + +TEST_F(FunctionTest, InvokeTest) +{ + auto contextEnv = std::make_shared(); + contextEnv->SetFunctionName("func"); + contextEnv->SetFuncPackage("service"); + contextEnv->SetProjectID("123"); + std::unordered_map params; + auto contextInvokeParams = std::make_shared(params); + auto context = Function::ContextImpl(contextInvokeParams, contextEnv); + std::string result = "{\"innerCode\":\"0\", \"body\":\"result\"}"; + auto returnObj = std::make_shared(0, result.size()); + returnObj->data->MemoryCopy(static_cast(result.data()), result.size()); + + std::vector> funcs; + funcs.emplace_back(std::make_shared(context)); + funcs.emplace_back(std::make_shared(context, "func")); + funcs.emplace_back(std::make_shared(context, "func:latest")); + for (auto &func : funcs) { + std::vector> rets; + rets.push_back(returnObj); + EXPECT_CALL(*lr.get(), InvokeByFunctionName(_, _, _, _)).WillOnce(Return(YR::Libruntime::ErrorInfo())); + EXPECT_CALL(*lr.get(), Get(_, _, _)) + .WillOnce(Return( + std::make_pair>>( + YR::Libruntime::ErrorInfo(), std::move(rets)))); + auto ref = func->Invoke("aaa"); + auto ret = func->GetObjectRef(ref); + ASSERT_EQ(ret, "result"); + ASSERT_EQ(func->GetContext() != nullptr, true); + } +} + +TEST_F(FunctionTest, InvokeFailedTest) +{ + auto contextEnv = std::make_shared(); + contextEnv->SetFunctionName("func"); + contextEnv->SetFuncPackage("service"); + contextEnv->SetProjectID("123"); + std::unordered_map params; + auto contextInvokeParams = std::make_shared(params); + auto context = Function::ContextImpl(contextInvokeParams, contextEnv); + std::string result = "result"; + auto returnObj = std::make_shared(0, result.size()); + returnObj->data->MemoryCopy(static_cast(result.data()), result.size()); + + auto func = Function::Function(context); + std::make_shared(context, "func:latest"); + std::vector> rets; + rets.push_back(returnObj); + EXPECT_CALL(*lr.get(), InvokeByFunctionName(_, _, _, _)) + .WillOnce(Return(YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "aaa"))); + EXPECT_THROW(func.Invoke("aaa"), Function::FunctionError); + + auto func2 = Function::Function(context, ""); + EXPECT_THROW(func2.Invoke("aaa"), Function::FunctionError); + + auto func3 = Function::Function(context, "func:func:lateset"); + EXPECT_THROW(func3.Invoke("aaa"), Function::FunctionError); + + auto func4 = Function::Function(context, "func:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); + EXPECT_THROW(func4.Invoke("aaa"), Function::FunctionError); + + auto func5 = Function::Function(context, "func&^%$:latest"); + EXPECT_THROW(func5.Invoke("aaa"), Function::FunctionError); + + auto func6 = Function::Function(context, "func&^%$:latest"); + EXPECT_THROW(func6.Invoke("aaa"), Function::FunctionError); + + auto func7 = Function::Function(context, "func&^%$"); + EXPECT_THROW(func7.Invoke("aaa"), Function::FunctionError); + + auto func8 = Function::Function(context, "!func&^%$:latest", "instanceName"); + EXPECT_THROW(func8.Invoke("aaa"), Function::FunctionError); +} + +TEST_F(FunctionTest, MemberFunctionTest) +{ + auto contextEnv = std::make_shared(); + contextEnv->SetFunctionName("func"); + contextEnv->SetFuncPackage("service"); + contextEnv->SetProjectID("123"); + std::unordered_map params; + auto contextInvokeParams = std::make_shared(params); + auto context = Function::ContextImpl(contextInvokeParams, contextEnv); + auto func = Function::Function(context, "func:latest", "instanceName"); + EXPECT_THROW(func.GetInstance("func:latest", "instanceName"), Function::FunctionError); + EXPECT_THROW(func.GetLocalInstance("func:latest", "instanceName"), Function::FunctionError); + EXPECT_THROW(func.Terminate(), Function::FunctionError); + EXPECT_THROW(func.SaveState(), Function::FunctionError); + EXPECT_THROW(func.GetInstanceId(), Function::FunctionError); +} + +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/alias_routing_test.cpp b/test/libruntime/alias_routing_test.cpp new file mode 100644 index 0000000..dc3b28a --- /dev/null +++ b/test/libruntime/alias_routing_test.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "src/libruntime/invokeadaptor/alias_routing.h" + +using namespace YR::Libruntime; + +namespace YR { +namespace test { +class AliasRoutingTest : public ::testing::Test { + void SetUp() override {} + + void TearDown() override {} +}; + +static std::vector g_aes = { + { + .aliasUrn = "fake_alias_urn", + .functionUrn = "fake_function_urn", + .functionVersionUrn = "fake_function_version_urn", + .name = "fake_name", + .functionVersion = "fake_function_version", + .revisionId = "fake_revision_id", + .description = "fake_description", + .routingType = "", + .routingConfig = + { + { + .functionVersionUrn = "fake_function_version_urn_1", + .weight = 50.0, + }, + { + .functionVersionUrn = "fake_function_version_urn_2", + .weight = 50.0, + }, + }, + }, +}; + +TEST_F(AliasRoutingTest, CheckAliasTest) +{ + AliasRouting ar; + auto ok = ar.CheckAlias("tenantId"); + ASSERT_EQ(ok, false); + ok = ar.CheckAlias("tenantId/fullName/0"); + ASSERT_EQ(ok, false); + ok = ar.CheckAlias("tenantId/fullName/latest"); + ASSERT_EQ(ok, false); + ok = ar.CheckAlias("tenantId/fullName/alias"); + ASSERT_EQ(ok, true); + ok = ar.CheckAlias("tenantId/fullName/alia-_s0"); + ASSERT_EQ(ok, true); +} + +TEST_F(AliasRoutingTest, ParseAliasTest) +{ + AliasRouting ar; + + std::unordered_map m; + auto functionId = ar.ParseAlias("tenantId/function/version", m); + ASSERT_EQ(functionId, "tenantId/function/version"); + functionId = ar.ParseAlias("tenantId/function", m); + ASSERT_EQ(functionId, "tenantId/function"); + + std::vector g_aes_rule1 = { + { + .aliasUrn = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasv1", + .functionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld", + .functionVersionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + .name = "myaliasv1", + .functionVersion = "$latest", + .revisionId = "20210617023315921", + .description = "fake_description", + .routingType = "rule", + .routingRules = + { + .ruleLogic = "and", + .rules = + { + "userType:=:VIP", + "age:<=:20", + "devType:in:P40,P50,MATE40", + }, + .grayVersion = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:1", + }, + }, + }; + ar.UpdateAliasInfo(g_aes_rule1); + + + std::unordered_map params; + params["userType"] = "VIP"; + params["age"] = "10"; + params["devType"] = "MATE40"; + functionId = ar.ParseAlias("12345678901234561234567890123456/helloworld/myaliasv1", params); + ASSERT_EQ(functionId, "12345678901234561234567890123456/helloworld/1"); + + g_aes_rule1[0].routingRules.grayVersion = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:2"; + ar.UpdateAliasInfo(g_aes_rule1); + functionId = ar.ParseAlias("12345678901234561234567890123456/helloworld/myaliasv1", params); + ASSERT_EQ(functionId, "12345678901234561234567890123456/helloworld/2"); +} + +TEST_F(AliasRoutingTest, UpdateAliasInfoTest) +{ + AliasRouting ar; + ar.UpdateAliasInfo(g_aes); + std::unordered_map params; + auto urn1 = ar.GetFuncVersionUrnWithParams("fake_alias_urn", params); + auto urn2 = ar.GetFuncVersionUrnWithParams("fake_alias_urn", params); + ASSERT_EQ(urn1, "fake_function_version_urn_1"); + ASSERT_EQ(urn2, "fake_function_version_urn_2"); +} + +TEST_F(AliasRoutingTest, GetFuncVersionUrnWithParamsNotFoundTest) +{ + AliasRouting ar; + ar.UpdateAliasInfo(g_aes); + std::unordered_map params; + auto urn1 = ar.GetFuncVersionUrnWithParams("xxx", params); + ASSERT_EQ(urn1, "xxx"); +} + +std::vector g_aes_rule = { + { + .aliasUrn = "fake_alias_urn", + .functionUrn = "fake_function_urn", + .functionVersionUrn = "fake_function_version_urn", + .name = "fake_name", + .functionVersion = "fake_function_version", + .revisionId = "fake_revision_id", + .description = "fake_description", + .routingType = "rule", + .routingRules = + { + .ruleLogic = "and", + .rules = + { + "userType:=:VIP", + "age:<=:20", + "devType:in:P40,P50,MATE40", + }, + .grayVersion = "fake_gray_version", + }, + }, +}; + +TEST_F(AliasRoutingTest, GetFuncVersionUrnWithParamsMatchTest) +{ + AliasRouting ar; + ar.UpdateAliasInfo(g_aes_rule); + std::unordered_map params; + params["userType"] = "VIP"; + params["age"] = "10"; + params["devType"] = "MATE40"; + auto urn1 = ar.GetFuncVersionUrnWithParams("fake_alias_urn", params); + ASSERT_EQ(urn1, "fake_gray_version"); +} + +TEST_F(AliasRoutingTest, GetFuncVersionUrnWithParamsIntNoMatchTest) +{ + AliasRouting ar; + ar.UpdateAliasInfo(g_aes_rule); + std::unordered_map params; + params["userType"] = "VIP"; + params["age"] = "50"; + params["devType"] = "P40"; + auto urn1 = ar.GetFuncVersionUrnWithParams("fake_alias_urn", params); + ASSERT_EQ(urn1, "fake_function_version_urn"); +} + +TEST_F(AliasRoutingTest, GetFuncVersionUrnWithParamsInNoMatchTest) +{ + AliasRouting ar; + ar.UpdateAliasInfo(g_aes_rule); + std::unordered_map params; + params["userType"] = "VIP"; + params["age"] = "10"; + params["devType"] = "P40X"; + auto urn1 = ar.GetFuncVersionUrnWithParams("fake_alias_urn", params); + ASSERT_EQ(urn1, "fake_function_version_urn"); +} + +TEST_F(AliasRoutingTest, GetFuncVersionUrnWithParamsStringEqNoMatchTest) +{ + AliasRouting ar; + ar.UpdateAliasInfo(g_aes_rule); + std::unordered_map params; + params["userType"] = "VVIP"; + params["age"] = "10"; + params["devType"] = "P40"; + auto urn1 = ar.GetFuncVersionUrnWithParams("fake_alias_urn", params); + ASSERT_EQ(urn1, "fake_function_version_urn"); +} +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/auto_init_test.cpp b/test/libruntime/auto_init_test.cpp index 2a83511..2b8461d 100644 --- a/test/libruntime/auto_init_test.cpp +++ b/test/libruntime/auto_init_test.cpp @@ -73,6 +73,22 @@ static void MakeMasterInfoFile(const std::string &filepath, const std::string &c ofs.close(); } +static void RemoveMasterInfoFile(const std::string &filepath) +{ + std::remove(filepath.c_str()); +} + +TEST_F(AutoInitTest, AutoCreateYuanRongClusterFailed) +{ + RemoveMasterInfoFile(YR::Libruntime::kDefaultDeployPathCurrMasterInfo); + YR::Libruntime::ClusterAccessInfo info; + auto info2 = YR::Libruntime::AutoGetClusterAccessInfo(info); + + ASSERT_EQ(info2.serverAddr, ""); + ASSERT_EQ(info2.dsAddr, ""); + ASSERT_EQ(info2.inCluster, false); +} + TEST_F(AutoInitTest, AutoInitWithClusterAccessInfo) { MakeMasterInfoFile(YR::Libruntime::kDefaultDeployPathCurrMasterInfo, masterInfoString); diff --git a/test/libruntime/clients_manager_test.cpp b/test/libruntime/clients_manager_test.cpp index 8c9b0fb..50dff20 100644 --- a/test/libruntime/clients_manager_test.cpp +++ b/test/libruntime/clients_manager_test.cpp @@ -68,15 +68,16 @@ TEST_F(ClientsManagerTest, DISABLED_DsClientsTest) { auto clientsMgr = std::make_shared(); datasystem::SensitiveValue runtimePrivateKey; + datasystem::SensitiveValue token; auto librtCfg = std::make_shared(); librtCfg->dataSystemIpAddr = "127.0.0.1"; librtCfg->dataSystemPort = 22222; librtCfg->runtimePrivateKey = runtimePrivateKey; - auto res = clientsMgr->GetOrNewDsClient(librtCfg, 30); + auto res = clientsMgr->GetOrNewDsClient(librtCfg, "", datasystem::SensitiveValue{}, 30); EXPECT_EQ(res.second.Code(), ErrorCode::ERR_OK); EXPECT_EQ(clientsMgr->dsClientsReferCounter["127.0.0.1:22222"], 1); - res = clientsMgr->GetOrNewDsClient(librtCfg, 30); + res = clientsMgr->GetOrNewDsClient(librtCfg, "", datasystem::SensitiveValue{}, 30); EXPECT_EQ(res.second.Code(), ErrorCode::ERR_OK); EXPECT_EQ(clientsMgr->dsClientsReferCounter["127.0.0.1:22222"], 2); auto err = clientsMgr->ReleaseDsClient("127.0.0.1", 22222); @@ -110,7 +111,7 @@ TEST_F(ClientsManagerTest, HttpClientsTest) TEST_F(ClientsManagerTest, GetFsConnTest) { auto clientsMgr = std::make_shared(); - auto res = clientsMgr->GetFsConn("127.0.0.1", 8080); + auto res = clientsMgr->GetFsConn("127.0.0.1", 8080, ""); ASSERT_TRUE(res.first == nullptr); ASSERT_TRUE(res.second.OK()); } @@ -121,6 +122,7 @@ TEST_F(ClientsManagerTest, ReleaseDsClientTest) clientsMgr->dsClientsReferCounter["127.0.0.1:80"] = 1; DatasystemClients dsClients{.dsObjectStore = std::make_shared(), .dsStateStore = std::make_shared(), + .dsStreamStore = std::make_shared(), .dsHeteroStore = std::make_shared()}; clientsMgr->dsClients["127.0.0.1:80"] = dsClients; ASSERT_EQ(clientsMgr->ReleaseDsClient("127.0.0.1", 80).OK(), true); diff --git a/test/libruntime/driverlog_test.cpp b/test/libruntime/driverlog_test.cpp new file mode 100644 index 0000000..5239086 --- /dev/null +++ b/test/libruntime/driverlog_test.cpp @@ -0,0 +1,169 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include "mock/mock_datasystem.h" + +#define private public + +#include "src/libruntime/driverlog/driverlog_receiver.h" +#include "src/utility/logger/logger.h" + +namespace YR { +namespace test { +using namespace testing; +using namespace YR::Libruntime; +using namespace YR::utility; +using namespace std::chrono_literals; +class DriverLogTest : public testing::Test { +public: + std::stringstream buffer; + std::streambuf* coutbuf; + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-driver-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + coutbuf = std::cout.rdbuf(); + std::cout.rdbuf(buffer.rdbuf()); + } + + void TearDown() override + { + std::cout.rdbuf(coutbuf); + } +}; + +TEST_F(DriverLogTest, DriverLogTestReceiver) +{ + std::shared_ptr streamStore = std::make_shared(); + auto c = std::make_shared(); + std::string jobId = "job-8d638c95"; + EXPECT_CALL(*streamStore, CreateStreamConsumer("/log/runtime/std/job-8d638c95", _, _, true)) + .WillOnce([&c](const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck) { + consumer = c; + return ErrorInfo(); + }); + + std::string logInfo = "this is driver log"; + EXPECT_CALL(*c, Receive(_, _)) + .WillOnce([&logInfo](uint32_t timeoutMs, std::vector &outElements) { + outElements.resize(1); + outElements[0].ptr = reinterpret_cast(const_cast(logInfo.data())); + outElements[0].size = logInfo.size(); + return ErrorInfo(); + }) + .WillRepeatedly([](uint32_t timeoutMs, std::vector &outElements) { + std::this_thread::sleep_for(500ms); + return ErrorInfo(); + }); + + { + auto r = std::make_shared(); + r->Init(streamStore, jobId, false); + std::this_thread::sleep_for(50ms); + size_t found = buffer.str().find(logInfo); + ASSERT_FALSE(found == std::string::npos); + } +} + +TEST_F(DriverLogTest, DriverLogTestDedup) +{ + std::shared_ptr streamStore = std::make_shared(); + auto c = std::make_shared(); + std::string jobId = "job-8d638c94"; + EXPECT_CALL(*streamStore, CreateStreamConsumer("/log/runtime/std/job-8d638c94", _, _, true)) + .WillOnce([&c](const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck) { + consumer = c; + return ErrorInfo(); + }); + + + std::vector logInfos; + for (auto i = 0; i < 5; i++) { + auto msg = "(runtime-" + std::to_string(i) + ") " + "this is driver log"; + logInfos.push_back(msg); + } + + EXPECT_CALL(*c, Receive(_, _)) + .WillOnce([&logInfos](uint32_t timeoutMs, std::vector &outElements) { + outElements.resize(5); + for (auto i = 0; i < 5; i++) { + outElements[i].ptr = reinterpret_cast(const_cast(logInfos[i].data())); + outElements[i].size = logInfos[i].size(); + } + return ErrorInfo(); + }) + .WillRepeatedly([](uint32_t timeoutMs, std::vector &outElements) { + std::this_thread::sleep_for(500ms); + return ErrorInfo(); + }); + + { + auto r = std::make_shared(); + r->Init(streamStore, jobId, true); + std::this_thread::sleep_for(50ms); + r.reset(); + size_t found = buffer.str().find("this is driver log"); + ASSERT_FALSE(found == std::string::npos); + auto dedueInfo = "across cluster"; + YRLOG_ERROR("receive info is : {}", buffer.str()); + std::cout << buffer.str() << std::endl; + found = buffer.str().find(dedueInfo); + + ASSERT_FALSE(found == std::string::npos); + } +} + +TEST_F(DriverLogTest, TestParse) +{ + std::string largeContent(40000, 'x'); + auto r = std::make_shared(); + struct Case { + std::string input; + std::string rtId; + std::string logContent; + } tests[] = {{"(rtid1)content1", "rtid1", "content1"}, + {"(rt(i)d)content1", "rt(i)d", "content1"}, + {"(rtId)" + largeContent, "rtId", largeContent}}; + + for (auto &t : tests) { + auto [id, cont] = r->ParseLine(t.input); + ASSERT_EQ(id, t.rtId); + ASSERT_EQ(cont, t.logContent); + } +} +} +} \ No newline at end of file diff --git a/test/libruntime/execution_manager_test.cpp b/test/libruntime/execution_manager_test.cpp index 92aaeac..6c8114f 100644 --- a/test/libruntime/execution_manager_test.cpp +++ b/test/libruntime/execution_manager_test.cpp @@ -111,5 +111,26 @@ TEST_F(ExecutionManagerTest, HandleNormalRequestTest) ASSERT_TRUE(handled); } +TEST_F(ExecutionManagerTest, isMultipleConcurrencyTest) +{ + execMgr = std::make_shared(1, nullptr); + ASSERT_TRUE(!(execMgr->isMultipleConcurrency())); +} + +TEST_F(ExecutionManagerTest, DoInitWithUserHookTest) +{ + std::function func = [](std::function userFuc)->void{}; + execMgr = std::make_shared(1, func); + ASSERT_TRUE(execMgr->DoInit(1).OK()); + ASSERT_NO_THROW(execMgr->ErasePendingThread("reqID")); +} + +TEST_F(ExecutionManagerTest, ErasePendingThreadTest) +{ + execMgr = std::make_shared(2, nullptr); + execMgr->DoInit(2); + ASSERT_NO_THROW(execMgr->ErasePendingThread("reqID")); +} + } // namespace test } // namespace YR \ No newline at end of file diff --git a/test/libruntime/faas_instance_manager_test.cpp b/test/libruntime/faas_instance_manager_test.cpp new file mode 100644 index 0000000..fcd21a2 --- /dev/null +++ b/test/libruntime/faas_instance_manager_test.cpp @@ -0,0 +1,547 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mock/mock_fs_intf.h" +#include "src/libruntime/objectstore/datasystem_object_store.h" +#include "src/utility/id_generator.h" +#include "src/utility/logger/logger.h" + +#define protected public +#include "mock/mock_fs_intf_with_callback.h" +#include "src/libruntime/invokeadaptor/faas_instance_manager.h" + +namespace YR { +namespace Libruntime { +std::pair GetFaasInstanceRsp(const NotifyRequest ¬ifyReq); +std::pair GetFaasBatchInstanceRsp(const NotifyRequest ¬ifyReq); +std::vector BuildReacquireInstanceData(const RequestResource &resource); +} +} // namespace YR + +using namespace YR::Libruntime; +using namespace YR::utility; +using YR::utility::CloseGlobalTimer; +using YR::utility::InitGlobalTimer; +namespace YR { +namespace test { + +class FaasInstanceManagerTest : public testing::Test { +public: + FaasInstanceManagerTest(){}; + ~FaasInstanceManagerTest(){}; + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + InitGlobalTimer(); + std::function cb = + [](const RequestResource &resource, const ErrorInfo &err, bool isRemainIns) {}; + auto reqMgr = std::make_shared(); + auto librtCfg = std::make_shared(); + mockFsIntf = std::make_shared(); + auto fsClient = std::make_shared(mockFsIntf); + std::shared_ptr memoryStore = std::make_shared(); + auto dsObjectStore = std::make_shared(); + dsObjectStore->Init("127.0.0.1", 8080); + auto wom = std::make_shared(); + memoryStore->Init(dsObjectStore, wom); + insManager = std::make_shared(cb, fsClient, memoryStore, reqMgr, librtCfg); + insManager->UpdateSchdulerInfo("scheduler1", "scheduler1", "ADD"); + } + void TearDown() override + { + if (insManager->leaseTimer) { + insManager->leaseTimer->cancel(); + insManager->leaseTimer.reset(); + } + insManager.reset(); + mockFsIntf.reset(); + CloseGlobalTimer(); + } + +protected: + std::shared_ptr insManager; + std::shared_ptr mockFsIntf; +}; + +TEST_F(FaasInstanceManagerTest, BuildAcquireRequestTest) +{ + auto spec = std::make_shared(); + std::vector returnObjs{DataObject("returnID")}; + spec->returnIds = returnObjs; + spec->jobId = "jobId"; + spec->requestId = "requestId"; + spec->traceId = "traceId"; + spec->instanceId = "instanceId"; + spec->invokeLeaseId = "leaseId"; + spec->invokeInstanceId = "insId"; + spec->functionMeta = {"", + "", + "funcname", + "classname", + libruntime::LanguageType::Cpp, + "", + "", + "poollabel", + libruntime::ApiType::Function}; + InvokeOptions opts; + opts.schedulerInstanceIds.push_back("shcedulerInstanceId"); + std::unordered_map invokelabels; + invokelabels["xxx"] = "xxx"; + opts.invokeLabels = invokelabels; + spec->opts = opts; + std::dynamic_pointer_cast(mockFsIntf)->isAcquireResponse = false; + std::dynamic_pointer_cast(mockFsIntf)->isReqNormal = false; + std::dynamic_pointer_cast(mockFsIntf)->needCheckArgs = true; + auto [instanceAllocation, err] = insManager->AcquireInstance("", spec); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + +TEST_F(FaasInstanceManagerTest, RecordRequestTest) +{ + auto reqInsInfo = std::make_shared(); + reqInsInfo->instanceInfos["leaseId"] = std::make_shared(); + reqInsInfo->instanceInfos["leaseId"]->instanceId = "insId"; + reqInsInfo->instanceInfos["leaseId"]->leaseId = "leaseId"; + reqInsInfo->instanceInfos["leaseId"]->idleTime = 0; + reqInsInfo->instanceInfos["leaseId"]->unfinishReqNum = 0; + reqInsInfo->instanceInfos["leaseId"]->available = true; + reqInsInfo->instanceInfos["leaseId"]->traceId = "traceId"; + reqInsInfo->instanceInfos["leaseId"]->faasInfo = FaasAllocationInfo{"", "", 100}; + reqInsInfo->instanceInfos["leaseId"]->reporter = std::make_shared(); + reqInsInfo->instanceInfos["leaseId"]->stateId = ""; + reqInsInfo->instanceInfos["leaseId"]->claimTime = 10000LL; + auto spec = std::make_shared(); + spec->invokeLeaseId = "leaseId"; + spec->functionMeta = { + "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Function}; + auto resource = GetRequestResource(spec); + insManager->requestResourceInfoMap[resource] = reqInsInfo; + + insManager->RecordRequest(resource, spec, true); + ASSERT_EQ(insManager->requestResourceInfoMap[resource]->instanceInfos["leaseId"]->reporter->GetTotalDuration() > 0, + true); + + insManager->RecordRequest(resource, spec, false); + ASSERT_EQ(insManager->requestResourceInfoMap[resource]->instanceInfos["leaseId"]->reporter->IsAbnormal(), true); +} + +TEST_F(FaasInstanceManagerTest, ScaleDownTest) +{ + auto reqInsInfo = std::make_shared(); + reqInsInfo->instanceInfos["leaseId1"] = std::make_shared(); + reqInsInfo->instanceInfos["leaseId1"]->instanceId = "insId"; + reqInsInfo->instanceInfos["leaseId1"]->leaseId = "leaseId1"; + reqInsInfo->instanceInfos["leaseId1"]->idleTime = 0; + reqInsInfo->instanceInfos["leaseId1"]->unfinishReqNum = 0; + reqInsInfo->instanceInfos["leaseId1"]->available = true; + reqInsInfo->instanceInfos["leaseId1"]->traceId = "traceId"; + reqInsInfo->instanceInfos["leaseId1"]->faasInfo = FaasAllocationInfo{"", "", 100}; + reqInsInfo->instanceInfos["leaseId1"]->reporter = std::make_shared(); + reqInsInfo->instanceInfos["leaseId2"] = std::make_shared(); + reqInsInfo->instanceInfos["leaseId2"]->instanceId = "insId"; + reqInsInfo->instanceInfos["leaseId2"]->leaseId = "leaseId2"; + reqInsInfo->instanceInfos["leaseId2"]->idleTime = 0; + reqInsInfo->instanceInfos["leaseId2"]->unfinishReqNum = 0; + reqInsInfo->instanceInfos["leaseId2"]->available = true; + reqInsInfo->instanceInfos["leaseId2"]->traceId = "traceId"; + reqInsInfo->instanceInfos["leaseId2"]->faasInfo = FaasAllocationInfo{"", "", 100}; + reqInsInfo->instanceInfos["leaseId2"]->reporter = std::make_shared(); + auto spec = std::make_shared(); + spec->invokeLeaseId = "leaseId1"; + spec->invokeInstanceId = "insId"; + spec->functionMeta = { + "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Function}; + auto resource = GetRequestResource(spec); + insManager->requestResourceInfoMap[resource] = reqInsInfo; + + insManager->ScaleDown(spec, false); + ASSERT_EQ(insManager->requestResourceInfoMap[resource]->instanceInfos.find("leaseId1") == + insManager->requestResourceInfoMap[resource]->instanceInfos.end(), + true); + ASSERT_EQ(insManager->requestResourceInfoMap[resource]->instanceInfos.find("leaseId2") == + insManager->requestResourceInfoMap[resource]->instanceInfos.end(), + true); +} + +TEST_F(FaasInstanceManagerTest, ScaleUpSuccessTest) +{ + auto queue = std::make_shared(); + auto spec = std::make_shared(); + std::vector returnObjs{DataObject("returnID")}; + spec->returnIds = returnObjs; + spec->jobId = "jobId"; + spec->requestId = "requestId"; + spec->traceId = "traceId"; + spec->instanceId = "instanceId"; + spec->invokeLeaseId = "leaseId"; + spec->invokeInstanceId = "insId"; + spec->functionMeta = { + "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Function}; + InvokeOptions opts; + opts.schedulerInstanceIds.push_back("shcedulerInstanceId"); + spec->opts = opts; + insManager->ScaleUp(spec, queue->Size()); + auto resource = GetRequestResource(spec); + ASSERT_EQ(insManager->requestResourceInfoMap[resource]->creatingIns.size(), 0); + + queue->Push(spec); + std::dynamic_pointer_cast(mockFsIntf)->isAcquireResponse = true; + insManager->ScaleUp(spec, queue->Size()); + sleep(2); + ASSERT_EQ(insManager->requestResourceInfoMap.find(resource), insManager->requestResourceInfoMap.end()); + ASSERT_EQ(insManager->totalCreatedInstanceNum_, 0); +} + +TEST_F(FaasInstanceManagerTest, ScaleUpFailTest) +{ + auto queue = std::make_shared(); + auto spec = std::make_shared(); + std::vector returnObjs{DataObject("returnID")}; + spec->returnIds = returnObjs; + spec->jobId = "jobId"; + spec->requestId = "requestId"; + spec->traceId = "traceId"; + spec->instanceId = "instanceId"; + spec->invokeLeaseId = "leaseId"; + spec->invokeInstanceId = "insId"; + spec->functionMeta = {"", + "", + "funcname", + "classname", + libruntime::LanguageType::Cpp, + "", + "", + "poollabel", + libruntime::ApiType::Function}; + InvokeOptions opts; + opts.schedulerInstanceIds.push_back("shcedulerInstanceId"); + spec->opts = opts; + queue->Push(spec); + auto resource = GetRequestResource(spec); + std::dynamic_pointer_cast(mockFsIntf)->isReqNormal = false; + std::dynamic_pointer_cast(mockFsIntf)->isAcquireResponse = true; + insManager->ScaleUp(spec, queue->Size()); + sleep(2); + std::shared_ptr info; + { + absl::ReaderMutexLock lock(&insManager->insMtx); + ASSERT_EQ(insManager->requestResourceInfoMap.find(resource) == insManager->requestResourceInfoMap.end(), true); + } +} + +TEST_F(FaasInstanceManagerTest, StartBatchRenewTimer) +{ + auto spec = std::make_shared(); + auto resource = GetRequestResource(spec); + insManager->StartBatchRenewTimer(); + ASSERT_EQ(insManager->requestResourceInfoMap.find(resource) != insManager->requestResourceInfoMap.end(), false); + + auto reqInsInfo = std::make_shared(); + insManager->tLeaseInterval = 1000; + reqInsInfo->instanceInfos["leaseId"] = std::make_shared(); + reqInsInfo->instanceInfos["leaseId"]->instanceId = "insId"; + reqInsInfo->instanceInfos["leaseId"]->leaseId = "leaseId"; + reqInsInfo->instanceInfos["leaseId"]->idleTime = 0; + reqInsInfo->instanceInfos["leaseId"]->unfinishReqNum = 0; + reqInsInfo->instanceInfos["leaseId"]->available = true; + reqInsInfo->instanceInfos["leaseId"]->traceId = "traceId"; + reqInsInfo->instanceInfos["leaseId"]->faasInfo = FaasAllocationInfo{"", "", 100}; + reqInsInfo->instanceInfos["leaseId"]->reporter = std::make_shared(); + insManager->requestResourceInfoMap[resource] = reqInsInfo; + insManager->globalLeases["leaseId"] = resource; + std::dynamic_pointer_cast(mockFsIntf)->isAcquireResponse = true; + insManager->StartBatchRenewTimer(); + ASSERT_EQ(insManager->leaseTimer != nullptr, true); + if (insManager->leaseTimer) { + insManager->leaseTimer->cancel(); + } +} + +TEST_F(FaasInstanceManagerTest, BatchRenewHandlerReleased) +{ + auto reqInsInfo = std::make_shared(); + insManager->tLeaseInterval = 10; + reqInsInfo->instanceInfos["leaseId"] = std::make_shared(); + reqInsInfo->instanceInfos["leaseId"]->instanceId = "insId"; + reqInsInfo->instanceInfos["leaseId"]->leaseId = "leaseId"; + reqInsInfo->instanceInfos["leaseId"]->idleTime = 0; + reqInsInfo->instanceInfos["leaseId"]->unfinishReqNum = 0; + reqInsInfo->instanceInfos["leaseId"]->available = true; + reqInsInfo->instanceInfos["leaseId"]->traceId = "traceId"; + reqInsInfo->instanceInfos["leaseId"]->faasInfo = FaasAllocationInfo{"functionid", "sig111", 10000}; + reqInsInfo->instanceInfos["leaseId"]->reporter = std::make_shared(); + ASSERT_EQ(reqInsInfo->instanceInfos.size(), 1); + + auto spec = std::make_shared(); + auto resource = GetRequestResource(spec); + insManager->requestResourceInfoMap[resource] = reqInsInfo; + insManager->globalLeases["leaseId"] = resource; + + std::dynamic_pointer_cast(mockFsIntf)->isBatchRenew = true; + std::dynamic_pointer_cast(mockFsIntf)->isReqNormal = true; + insManager->StartBatchRenewTimer(); + sleep(1); + ASSERT_EQ(reqInsInfo->instanceInfos.size(), 0); + if (insManager->leaseTimer) { + insManager->leaseTimer->cancel(); + } + std::dynamic_pointer_cast(mockFsIntf)->isReqNormal = false; + std::dynamic_pointer_cast(mockFsIntf)->isBatchRenew = false; +} + +TEST_F(FaasInstanceManagerTest, BatchRenewHandlerRetained) +{ + auto reqInsInfo = std::make_shared(); + insManager->tLeaseInterval = 10; + int i; + auto spec = std::make_shared(); + auto resource = GetRequestResource(spec); + for (i = 0; i < 1001; i++) { + std::string leaseId = "leaseId" + std::to_string(i); + reqInsInfo->instanceInfos[leaseId] = std::make_shared(); + reqInsInfo->instanceInfos[leaseId]->instanceId = "insId"; + reqInsInfo->instanceInfos[leaseId]->leaseId = leaseId; + reqInsInfo->instanceInfos[leaseId]->idleTime = 0; + reqInsInfo->instanceInfos[leaseId]->unfinishReqNum = 0; + reqInsInfo->instanceInfos[leaseId]->available = true; + reqInsInfo->instanceInfos[leaseId]->traceId = leaseId; + reqInsInfo->instanceInfos[leaseId]->faasInfo = FaasAllocationInfo{"functionid", "sig111", 10000}; + reqInsInfo->instanceInfos[leaseId]->reporter = std::make_shared(); + insManager->globalLeases[leaseId] = resource; + } + ASSERT_EQ(reqInsInfo->instanceInfos.size(), 1001); + + insManager->requestResourceInfoMap[resource] = reqInsInfo; + + std::dynamic_pointer_cast(mockFsIntf)->isBatchRenew = true; + std::dynamic_pointer_cast(mockFsIntf)->isReqNormal = true; + insManager->StartBatchRenewTimer(); + sleep(1); + ASSERT_EQ(reqInsInfo->instanceInfos.size(), 1001); + if (insManager->leaseTimer) { + insManager->leaseTimer->cancel(); + } + for (i = 0; i < 1001; i++) { + std::string leaseId = "leaseId" + std::to_string(i); + reqInsInfo->instanceInfos.erase(leaseId); + insManager->globalLeases.erase(leaseId); + } + std::dynamic_pointer_cast(mockFsIntf)->isReqNormal = false; + std::dynamic_pointer_cast(mockFsIntf)->isBatchRenew = false; +} + +TEST_F(FaasInstanceManagerTest, BatchRenewHandlerFailed) +{ + auto reqInsInfo = std::make_shared(); + insManager->tLeaseInterval = 10; + reqInsInfo->instanceInfos["leaseId"] = std::make_shared(); + reqInsInfo->instanceInfos["leaseId"]->instanceId = "insId"; + reqInsInfo->instanceInfos["leaseId"]->leaseId = "leaseId"; + reqInsInfo->instanceInfos["leaseId"]->idleTime = 0; + reqInsInfo->instanceInfos["leaseId"]->unfinishReqNum = 0; + reqInsInfo->instanceInfos["leaseId"]->available = true; + reqInsInfo->instanceInfos["leaseId"]->traceId = "traceId"; + reqInsInfo->instanceInfos["leaseId"]->faasInfo = FaasAllocationInfo{"functionid", "sig111", 10000}; + reqInsInfo->instanceInfos["leaseId"]->reporter = std::make_shared(); + reqInsInfo->instanceInfos["leaseId2"] = std::make_shared(); + reqInsInfo->instanceInfos["leaseId2"]->instanceId = "insId"; + reqInsInfo->instanceInfos["leaseId2"]->leaseId = "leaseId2"; + reqInsInfo->instanceInfos["leaseId2"]->idleTime = 0; + reqInsInfo->instanceInfos["leaseId2"]->unfinishReqNum = 0; + reqInsInfo->instanceInfos["leaseId2"]->available = true; + reqInsInfo->instanceInfos["leaseId2"]->traceId = "traceId"; + reqInsInfo->instanceInfos["leaseId2"]->faasInfo = FaasAllocationInfo{"functionid", "sig111", 10000}; + reqInsInfo->instanceInfos["leaseId2"]->reporter = std::make_shared(); + ASSERT_EQ(reqInsInfo->instanceInfos.size(), 2); + + auto spec = std::make_shared(); + auto resource = GetRequestResource(spec); + insManager->requestResourceInfoMap[resource] = reqInsInfo; + insManager->globalLeases["leaseId"] = resource; + insManager->globalLeases["leaseId2"] = resource; + + std::dynamic_pointer_cast(mockFsIntf)->isReqNormal = false; + std::dynamic_pointer_cast(mockFsIntf)->isBatchRenew = true; + insManager->StartBatchRenewTimer(); + sleep(1); + ASSERT_EQ(reqInsInfo->instanceInfos.size(), 0); + if (insManager->leaseTimer) { + insManager->leaseTimer->cancel(); + } + std::dynamic_pointer_cast(mockFsIntf)->isBatchRenew = false; +} + +TEST_F(FaasInstanceManagerTest, BuildReacquireInstanceData) +{ + RequestResource r1; + r1.opts.instanceSession = std::make_shared(); + r1.opts.invokeLabels["label1"] = "1"; + r1.opts.podLabels["label1"] = "2"; + r1.opts.customResources["ccc"] = 1; + std::vector vec = BuildReacquireInstanceData(r1); + ASSERT_EQ(vec.size(), 384); +} + +TEST_F(FaasInstanceManagerTest, AcquireAndReleaseInstanceTest) +{ + auto spec = std::make_shared(); + std::vector returnObjs{DataObject("returnID")}; + spec->returnIds = returnObjs; + spec->jobId = "jobId"; + spec->requestId = "requestId"; + spec->traceId = "traceId"; + spec->instanceId = "instanceId"; + spec->invokeLeaseId = "leaseId"; + spec->invokeInstanceId = "insId"; + spec->functionMeta = {"", + "", + "funcname", + "classname", + libruntime::LanguageType::Cpp, + "", + "", + "poollabel", + libruntime::ApiType::Function}; + InvokeOptions opts; + opts.schedulerInstanceIds.push_back("shcedulerInstanceId"); + spec->opts = opts; + auto resource = GetRequestResource(spec); + // std::string stateId = ; + std::dynamic_pointer_cast(mockFsIntf)->isAcquireResponse = true; + auto [instanceAllocation1, err1] = insManager->AcquireInstance("", spec); + ASSERT_EQ(instanceAllocation1.leaseId, "leaseId"); + ASSERT_EQ(insManager->requestResourceInfoMap[resource]->instanceInfos.find("leaseId") != + insManager->requestResourceInfoMap[resource]->instanceInfos.end(), + true); + std::dynamic_pointer_cast(mockFsIntf)->isAcquireResponse = false; + auto err = insManager->ReleaseInstance("leaseId", "", false, spec); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + ASSERT_EQ(insManager->requestResourceInfoMap[resource]->instanceInfos.find("leaseId") == + insManager->requestResourceInfoMap[resource]->instanceInfos.end(), + true); + + std::dynamic_pointer_cast(mockFsIntf)->isAcquireResponse = true; + std::dynamic_pointer_cast(mockFsIntf)->isReqNormal = false; + auto [instanceAllocation3, err3] = insManager->AcquireInstance("", spec); + ASSERT_EQ(err3.Code(), ErrorCode::ERR_PARAM_INVALID); +} + +TEST_F(FaasInstanceManagerTest, GetFaasInstanceRspTest) +{ + NotifyRequest notifyReq; + auto object = notifyReq.add_smallobjects(); + object->set_id("123"); + + std::string json_str = + R"({"funcKey":"","funcSig":"","instanceID":"","threadID":"","leaseInterval":0,"errorCode":150428,"errorMessage":"","schedulerTime":30})"; + std::string value = "0000000000000000" + json_str; + object->set_value(value); + auto [resp, errorInfo] = GetFaasInstanceRsp(notifyReq); + ASSERT_TRUE(errorInfo.OK()); + ASSERT_EQ(resp.errorCode, 150428); +} + +TEST_F(FaasInstanceManagerTest, AddInsInfoBareTest) +{ + auto info = std::make_shared(); + auto faasInfo = std::make_shared(); + faasInfo->leaseId = "leaseId"; + + insManager->AddInsInfoBare(info, faasInfo); + ASSERT_EQ(info->instanceInfos.size(), 1); + ASSERT_EQ(info->avaliableInstanceInfos.size(), 1); + ASSERT_EQ(insManager->totalCreatedInstanceNum_, 1); + + auto faasInfo1 = std::make_shared(); + faasInfo1->leaseId = "leaseId"; + faasInfo1->instanceId = "instanceId"; + insManager->AddInsInfoBare(info, faasInfo1); + ASSERT_EQ(info->instanceInfos.size(), 1); + ASSERT_EQ(info->instanceInfos[faasInfo->leaseId]->instanceId, "instanceId"); + ASSERT_EQ(insManager->totalCreatedInstanceNum_, 1); +} + +TEST_F(FaasInstanceManagerTest, GetFaasBatchInstanceRsp) +{ + NotifyRequest notifyReq; + auto object = notifyReq.add_smallobjects(); + object->set_id("123"); + + std::string json_str = + R"({"instanceAllocSucceed":{"f1a00e58-f2a1-4000-8000-000000f8e9e3-thread26":{"funcKey":"12345678901234561234567890123456/0@functest@functest/latest","funcSig":"4243308021","instanceID":"f1a00e58-f2a1-4000-8000-000000f8e9e3","threadID":"f1a00e58-f2a1-4000-8000-000000f8e9e3-thread26","instanceIP":"10.42.1.119","instancePort":"22771","nodeIP":"","nodePort":"","leaseInterval":0,"cpu":600,"memory":512}},"instanceAllocFailed":{},"leaseInterval":1000,"schedulerTime":0.000118108})"; + std::string value = "0000000000000000" + json_str; + object->set_value(value); + auto [resp, errorInfo] = GetFaasBatchInstanceRsp(notifyReq); + ASSERT_TRUE(errorInfo.OK()); + ASSERT_EQ(resp.instanceAllocSucceed.size(), 1); +} + + +TEST_F(FaasInstanceManagerTest, UpdateSpecSchedulerIdsTest) +{ + auto spec = std::make_shared(); + insManager->UpdateSpecSchedulerIds(spec, "schedulerId"); + ASSERT_TRUE(spec->schedulerInfos->schedulerInstanceList[0]->InstanceID == "schedulerId"); + ASSERT_TRUE(!spec->schedulerInfos->schedulerInstanceList[0]->isAvailable); + auto updateTime = spec->schedulerInfos->schedulerInstanceList[0]->updateTime; + insManager->UpdateSpecSchedulerIds(spec, "schedulerId"); + ASSERT_TRUE(spec->schedulerInfos->schedulerInstanceList[0]->updateTime > updateTime); +} + +TEST_F(FaasInstanceManagerTest, AcquireCallbackTest) +{ + auto acquireSpec = std::make_shared(); + auto invokeSpec = std::make_shared(); + + insManager->UpdateSchedulerInfo("schedulerkey", + {SchedulerInstance{.InstanceName = "instanceName1", .InstanceID = "instanceId1"}, + SchedulerInstance{.InstanceName = "instanceName2", .InstanceID = "instanceId2"}}); + + ErrorInfo err(YR::Libruntime::ErrorCode::ERR_INSTANCE_EXITED, "err msg"); + ErrorInfo outputErr; + insManager->scheduleInsCb = [&outputErr](const RequestResource &resource, const ErrorInfo &err, bool isRemainIs) -> void { + outputErr = err; + }; + insManager->AcquireCallback(acquireSpec, err, InstanceResponse{}, invokeSpec); + std::cout << "size is :" << invokeSpec->schedulerInfos->schedulerInstanceList.size() << std::endl; + ASSERT_NE(outputErr.Code(), YR::Libruntime::ErrorCode::ERR_ALL_SCHEDULER_UNAVALIABLE); + + acquireSpec->invokeInstanceId = invokeSpec->schedulerInfos->schedulerInstanceList[0]->InstanceID; + insManager->AcquireCallback(acquireSpec, err, InstanceResponse{}, invokeSpec); + ASSERT_NE(outputErr.Code(), YR::Libruntime::ErrorCode::ERR_ALL_SCHEDULER_UNAVALIABLE); + + acquireSpec->invokeInstanceId = invokeSpec->schedulerInfos->schedulerInstanceList[1]->InstanceID; + insManager->AcquireCallback(acquireSpec, err, InstanceResponse{}, invokeSpec); + ASSERT_EQ(outputErr.Code(), YR::Libruntime::ErrorCode::ERR_ALL_SCHEDULER_UNAVALIABLE); +} +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/fm_client_test.cpp b/test/libruntime/fm_client_test.cpp index 95598c8..af0252d 100644 --- a/test/libruntime/fm_client_test.cpp +++ b/test/libruntime/fm_client_test.cpp @@ -14,12 +14,14 @@ * limitations under the License. */ +#include #include #include #include #define private public #include "src/libruntime/err_type.h" #include "src/libruntime/fmclient/fm_client.h" +#include "src/libruntime/gwclient/gw_client.h" using namespace testing; using namespace YR::Libruntime; @@ -31,6 +33,7 @@ public: ~MockHttpClient() = default; ErrorInfo Init(const ConnectionParam ¶m) override { + ResetConnActive(); return ErrorInfo(); } @@ -43,27 +46,16 @@ public: auto rsp = BuildQueryResponse(); std::string rspBody; rsp.SerializeToString(&rspBody); - receiver(rspBody, boost::beast::error_code(), 200); + if (isSuccess_) { + receiver(rspBody, boost::beast::error_code(), 200); + } else { + receiver(rspBody, boost::beast::error_code(), 400); + } } else { std::cout << "unknown target:" << target; } } - void RegisterHeartbeat(const std::string &jobID, int timeout) override {} - bool Available() const override - { - return true; - }; - - bool IsActive() const override - { - return true; - }; - - bool IsConnActive() const override - { - return true; - }; ErrorInfo ReInit() override { return ErrorInfo(); @@ -95,9 +87,18 @@ public: r2.mutable_vectors()->mutable_values()->insert({"ids", c2}); ru.mutable_capacity()->mutable_resources()->insert({"NPU", r2}); + resources::Value::Counter cnter1; + cnter1.mutable_items()->insert( { "value1_1", 1 } ); + cnter1.mutable_items()->insert( { "value1_2", 1 } ); + ru.mutable_nodelabels()->insert({ "key1", cnter1 }); + resources::Value::Counter cnter2; + cnter2.mutable_items()->insert( { "value2", 3 } ); + ru.mutable_nodelabels()->insert({ "key2", cnter2 }); + resource->mutable_fragment()->insert({"resource", ru}); return rsp; } + bool isSuccess_ = true; }; class FmClientTest : public testing::Test { @@ -121,33 +122,89 @@ public: .alsoLog2Stderr = true, }; InitLog(g_logParam); - - auto httpClient = std::make_shared(); - httpClient->Init(ConnectionParam{"127.0.0.1", "8888"}); - std::shared_ptr config = std::make_shared(); - std::vector functionMasters; - functionMasters.push_back("192.168.0.1"); - functionMasters.push_back("127.0.0.1"); - config->functionMasters = functionMasters; - fmClient_ = std::make_shared(config); - fmClient_->httpClients_["127.0.0.1"] = httpClient; - } - void TearDown() override - { - fmClient_.reset(); } - std::shared_ptr fmClient_; + void TearDown() override {} }; -TEST_F(FmClientTest, DISABLED_TestGetResourcesSuccessfully) +TEST_F(FmClientTest, TestGetResourcesSuccessfully) { - auto client = fmClient_->GetNextHttpClient(); - EXPECT_TRUE(client != nullptr); - auto [err, res] = fmClient_->GetResources(); - EXPECT_TRUE(err.OK()); - auto unit = res.at(0); + std::shared_ptr fmClient_ = std::make_shared(); + auto httpClient = std::make_shared(); + httpClient->Init(ConnectionParam{"127.0.0.1", "8888"}); + httpClient->SetAvailable(); + fmClient_->UpdateActiveMaster("127.0.0.1:8888"); + fmClient_->activeMasterHttpClient_ = httpClient; + auto res = fmClient_->GetResources(); + EXPECT_TRUE(res.first.OK()); + auto unit = res.second.at(0); EXPECT_TRUE(unit.capacity["CPU"] == 400.00); EXPECT_TRUE(unit.capacity["NPU"] == 2.00); - client = fmClient_->GetNextHttpClient(); - EXPECT_TRUE(client != nullptr); + EXPECT_TRUE(unit.nodeLabels.size() == 2); + EXPECT_TRUE(unit.nodeLabels.count("key1")); + EXPECT_TRUE(unit.nodeLabels["key1"].size() == 2); + EXPECT_TRUE(std::find(unit.nodeLabels["key1"].begin(), unit.nodeLabels["key1"].end(), "value1_1") != + unit.nodeLabels["key1"].end()); + EXPECT_TRUE(std::find(unit.nodeLabels["key1"].begin(), unit.nodeLabels["key1"].end(), "value1_2") != + unit.nodeLabels["key1"].end()); + EXPECT_TRUE(unit.nodeLabels.count("key2")); + EXPECT_TRUE(std::find(unit.nodeLabels["key2"].begin(), unit.nodeLabels["key2"].end(), "value2") != + unit.nodeLabels["key2"].end()); + + httpClient->isSuccess_ = false; + fmClient_->cb_ = [](){}; + fmClient_->maxWaitTimeSec = 1; + res = fmClient_->GetResources(); + EXPECT_TRUE(!res.first.OK()); } + +TEST_F(FmClientTest, TestGetResourcesWithRetryFailed) +{ + std::shared_ptr fmClient_ = std::make_shared(); + fmClient_->UpdateActiveMaster("127.0.0.1"); + auto res = fmClient_->GetResourcesWithRetry(); + EXPECT_TRUE(!res.first.OK()); +} + +TEST_F(FmClientTest, TestUpdateActiveMasterWithStopSuccessfully) +{ + std::shared_ptr fmClient_ = std::make_shared(); + fmClient_->Stop(); + EXPECT_NO_THROW(fmClient_->UpdateActiveMaster("127.0.0.1")); +} + +TEST_F(FmClientTest, TestCheckResponseCode) +{ + // Test case 1: errorCode is set + boost::beast::error_code errorCode(1, boost::system::system_category()); + uint statusCode = 200; + std::string result = "Success"; + std::string requestId = "12345"; + ErrorInfo errorInfo = YR::Libruntime::CheckResponseCode(errorCode, statusCode, result, requestId); + EXPECT_TRUE(errorInfo.Code() == YR::Libruntime::ErrorCode::ERR_INNER_COMMUNICATION); + EXPECT_TRUE(errorInfo.Msg() == "network error between runtime and function master, error_code: Operation not permitted, requestId: 12345"); + + // Test case 2: statusCode is not successful + errorCode = boost::beast::error_code(); + statusCode = 400; + result = "Bad Request"; + requestId = "67890"; + errorInfo = YR::Libruntime::CheckResponseCode(errorCode, statusCode, result, requestId); + EXPECT_TRUE(errorInfo.Code() == YR::Libruntime::ErrorCode::ERR_PARAM_INVALID); + EXPECT_TRUE(errorInfo.Msg() == "response is error, status_code: 400, result: Bad Request, requestId: 67890"); + + // Test case 3: no error + errorCode = boost::beast::error_code(); + statusCode = 200; + result = "Success"; + requestId = "12345"; + errorInfo = YR::Libruntime::CheckResponseCode(errorCode, statusCode, result, requestId); + EXPECT_TRUE(errorInfo.OK()); +} + +TEST_F(FmClientTest, TestRemoveActiveMaster) +{ + std::shared_ptr fmClient_ = std::make_shared(); + fmClient_->activeMasterAddr_ = "activeMasterAddr_"; + fmClient_->RemoveActiveMaster(); + EXPECT_TRUE(fmClient_->activeMasterAddr_.empty()); +} \ No newline at end of file diff --git a/test/libruntime/fs_client_test.cpp b/test/libruntime/fs_client_test.cpp index 83997e5..db992f6 100644 --- a/test/libruntime/fs_client_test.cpp +++ b/test/libruntime/fs_client_test.cpp @@ -58,7 +58,6 @@ public: .modelName = "test", .maxSize = 100, .maxFiles = 1, - .retentionDays = DEFAULT_RETENTION_DAYS, .logFileWithTime = false, .logBufSecs = 30, .maxAsyncQueueSize = 1048510, @@ -311,7 +310,7 @@ TEST_F(FSClientGrpcTest, GrpcClientTest_KillAsync) DoStartGrpcClient(); NotificationUtility notify; - auto cb = [¬ify](const KillResponse &req, ErrorInfo err) -> void { notify.Notify(); }; + auto cb = [¬ify](const KillResponse &req, const ErrorInfo &err) -> void { notify.Notify(); }; KillRequest req; fsClient_->KillAsync(req, cb, -1); @@ -328,12 +327,34 @@ TEST_F(FSClientGrpcTest, GrpcClientTest_KillAsync) } } +TEST_F(FSClientGrpcTest, GrpcClientTest_CreateRGroupAsync) +{ + DoStartGrpcClient(); + + NotificationUtility notify; + auto cb = [¬ify](const CreateResourceGroupResponse &resp) -> void { notify.Notify(); }; + auto reqId = YR::utility::IDGenerator::GenRequestId(); + CreateResourceGroupRequest req; + req.set_requestid(reqId); + fsClient_->CreateRGroupAsync(req, cb, 1000); + StreamingMessage msg; + auto ret = grpcServer->Read(msg); + EXPECT_TRUE(ret); + + CreateResourceGroupResponse resp; + grpcServer->Send(*GenStreamMsg(msg.messageid(), resp)); + { + auto err = notify.WaitForNotification(); + EXPECT_TRUE(err.OK()); + } +} + TEST_F(FSClientGrpcTest, GrpcClientTest_KillAsyncTimeout) { DoStartGrpcClient(); NotificationUtility notify; - auto cb = [¬ify](const KillResponse &req, ErrorInfo e) -> void { + auto cb = [¬ify](const KillResponse &req, const ErrorInfo &fakeErr) -> void { auto err = ErrorInfo(static_cast(req.code()), req.message()); notify.Notify(err); }; @@ -681,7 +702,7 @@ TEST_F(FSClientGrpcTest, ReconnectTest) grpcServer = std::make_shared(Config::Instance().HOST_IP()); grpcServer->StartWithPort(port); NotificationUtility notify; - auto cb = [¬ify](const KillResponse &req, ErrorInfo err) -> void { notify.Notify(); }; + auto cb = [¬ify](const KillResponse &req, const ErrorInfo &err) -> void { notify.Notify(); }; KillRequest req; fsClient_->KillAsync(req, cb, -1); diff --git a/test/libruntime/fs_intf_grpc_rw_test.cpp b/test/libruntime/fs_intf_grpc_rw_test.cpp index 2f15ec8..5469f2e 100644 --- a/test/libruntime/fs_intf_grpc_rw_test.cpp +++ b/test/libruntime/fs_intf_grpc_rw_test.cpp @@ -18,6 +18,7 @@ #include #include #include "mock/mock_security.h" +#include "src/libruntime/utils/grpc_utils.h" #include "src/libruntime/clientsmanager/clients_manager.h" #include "src/libruntime/fsclient/grpc/fs_intf_grpc_client_reader_writer.h" #include "src/utility/logger/logger.h" @@ -144,6 +145,35 @@ protected: } }; +TEST_F(FSIntfGrpcRWTest, TestSignRequest) +{ + std::string ak = "ak"; + SensitiveValue sk = std::string("sk"); + this->security_->SetAKSKAndCredential(ak, sk); + StartService(); + auto clientRw = StartClient("client"); + auto recvPromise = std::make_shared>(); + std::string msgID = "invokereq"; + RegisterMessagePromise(msgID, recvPromise); + auto msg = StreamingMessage(); + msg.set_messageid(msgID); + auto rsp = msg.mutable_invokereq(); + rsp->set_requestid("request"); + rsp->set_function("function"); + rsp->set_traceid("traceid"); + rsp->set_instanceid("server"); + (*rsp->mutable_invokeoptions()->mutable_customtag())["custom"] = "value"; + auto arg = rsp->add_args(); + arg->set_value("args_value"); + YR::Libruntime::SignStreamingMessage(ak, sk, msg); + std::promise writecb; + auto ptr = std::make_shared(std::move(msg)); + clientRw->Write(ptr, [&](bool, ErrorInfo err) { writecb.set_value(err); }); + auto err = writecb.get_future().get(); + ASSERT_EQ(err.OK(), true); + StopService(); +} + TEST_F(FSIntfGrpcRWTest, SendInvokeMsg) { StartService(); diff --git a/test/libruntime/fs_intf_impl_test.cpp b/test/libruntime/fs_intf_impl_test.cpp index b131896..b54a6cb 100644 --- a/test/libruntime/fs_intf_impl_test.cpp +++ b/test/libruntime/fs_intf_impl_test.cpp @@ -83,6 +83,7 @@ public: void TearDown() override { UnsetEnv("REQUEST_ACK_ACC_MAX_SEC"); + fsIntfImpl_->Stop(); fsIntfImpl_.reset(); } std::shared_ptr fsIntfImpl_; @@ -106,7 +107,7 @@ TEST_F(FSIntfImplTest, Test_when_retry_timeout_should_execute_callback) auto retry = [&retryTimes]() { std::cout << "retryTimes: " << retryTimes++ << std::endl; }; wr->SetupRetry(retry, std::bind(&FSIntfImpl::NeedRepeat, fsIntfImpl_.get(), requestId)); EXPECT_EQ(future.get().Code(), ERR_REQUEST_BETWEEN_RUNTIME_BUS); - EXPECT_EQ(retryTimes, 1); + EXPECT_EQ(retryTimes, 5); } TEST_F(FSIntfImplTest, Test_resend) @@ -261,6 +262,49 @@ TEST_F(FSIntfImplTest, Test_when_receive_repeated_call_request_should_return_cal ASSERT_TRUE(fOne.get()); } +TEST_F(FSIntfImplTest, ResendRequestWithRetryTest) +{ + auto promise = std::make_shared>(); + auto anotherPromise = std::make_shared>(); + auto future = promise->get_future(); + auto anotherFuture = anotherPromise->get_future(); + std::atomic num{0}; + std::atomic count{2}; + auto wiredReq = std::make_shared(); + std::function retry = [&num, &anotherPromise]() -> void { + num++; + if (num == 4) { + anotherPromise->set_value(0); + } + YRLOG_DEBUG("current num is {}", num.load()); + }; + wiredReq->retryHdlr = std::move(retry); + std::function needRetryHdlr = [&count, &promise, &anotherPromise]() -> bool { + if (count > 0) { + count--; + return true; + } + promise->set_value(0); + return false; + }; + wiredReq->needRetryHdlr = std::move(needRetryHdlr); + auto timer = std::make_shared(); + wiredReq->timerWorkerWeak = timer->weak_from_this(); + wiredReq->ResendRequestWithRetry(); + wiredReq->ResendRequestWithRetry(); + wiredReq->retryIntervalSec = 1; + + ASSERT_EQ(num, 2); + future.get(); + anotherFuture.get(); + ASSERT_EQ(num, 4); + ASSERT_EQ(count, 0); + if (wiredReq->timer_) { + wiredReq->timer_->cancel(); + wiredReq->timer_.reset(); + } +} + TEST_F(FSIntfImplTest, DirectlyCallWithRetry) { Config::Instance().RUNTIME_DIRECT_CONNECTION_ENABLE() = true; @@ -270,6 +314,7 @@ TEST_F(FSIntfImplTest, DirectlyCallWithRetry) fsIntfImpl_ = std::make_shared(Config::Instance().HOST_IP(), 0, handlers, true, nullptr, std::make_shared(), false); fsIntfImpl_->fsInrfMgr = mockFsIntfMgr; + fsIntfImpl_->noitfyExecutor.Init(1); auto mockFsIntfRW = std::make_shared(); EXPECT_CALL(*mockFsIntfMgr, Get).WillRepeatedly(Return(mockFsIntfRW)); // mock first directly return communicate err @@ -322,11 +367,6 @@ TEST_F(FSIntfImplTest, DirectlyCallResultWithRetry) auto wr = ptr->GetWiredRequest(reqId, false); EXPECT_NE(wr, nullptr); } - })) - .WillOnce(Invoke([&](const std::shared_ptr &msg, - std::function callback, std::function preWrite) { - // retry to failure - callback(true, ErrorInfo(ERR_INSTANCE_EXITED, "posix stream is closed")); })); auto promise = std::make_shared>(); auto ackHandler = [promise](const CallResultAck &req) -> void { promise->set_value(req); }; @@ -337,10 +377,97 @@ TEST_F(FSIntfImplTest, DirectlyCallResultWithRetry) auto messageSpec = std::make_shared(); messageSpec->Mutable() = std::move(req); fsIntfImpl_->CallResultAsync(messageSpec, ackHandler); - auto rsp = promise->get_future().get(); - EXPECT_EQ(rsp.code(), common::ErrorCode(ERR_INSTANCE_EXITED)); auto wr = fsIntfImpl_->GetWiredRequest(reqId, false); - EXPECT_EQ(wr, nullptr); + EXPECT_NE(wr, nullptr); +} + +TEST_F(FSIntfImplTest, TestIsHealth) +{ + EXPECT_NE(fsIntfImpl_->IsHealth(), true); + fsIntfImpl_->fsInrfMgr = std::make_shared(fsIntfImpl_->clientsMgr); + EXPECT_NE(fsIntfImpl_->IsHealth(), true); + auto mockFsIntf = std::make_shared(); + fsIntfImpl_->fsInrfMgr->systemIntf = mockFsIntf; + EXPECT_CALL(*mockFsIntf, IsHealth).WillRepeatedly(Return(true)); + EXPECT_TRUE(fsIntfImpl_->IsHealth()); +} + +TEST_F(FSIntfImplTest, TestNotifyDisconnected) +{ + EXPECT_NO_THROW(fsIntfImpl_->NotifyDisconnected("no_function-proxy")); + auto wr = std::make_shared(); + wr->dstInstanceID = "instanceId"; + wr->reqId_ = "reqId"; + wr->notifyCallback = [](const NotifyRequest &req, const ErrorInfo &err) {std::cout << "hello, world" << std::endl;}; + fsIntfImpl_->wiredRequests.emplace(wr->reqId_, wr); + auto mockFsIntf = std::make_shared(); + fsIntfImpl_->fsInrfMgr->Emplace("instanceId", mockFsIntf); + EXPECT_CALL(*mockFsIntf, Available).WillRepeatedly(Return(false)); + fsIntfImpl_->NotifyDisconnected("function-proxy"); + EXPECT_TRUE(fsIntfImpl_->wiredRequests.find("reqId") == fsIntfImpl_->wiredRequests.end()); +} + +TEST_F(FSIntfImplTest, TestNeedRepeat) +{ + auto wr = std::make_shared(); + fsIntfImpl_->wiredRequests["reqId"] = wr; + wr->reqId_ = "reqId"; + wr->retryCount = 0; + wr->remainTimeoutSec = 1; + wr->retryIntervalSec = 10; + wr->callback = [](const StreamingMessage &createResp, ErrorInfo status, std::function){}; + EXPECT_FALSE(fsIntfImpl_->NeedRepeat("reqId")); + wr->ackReceived = true; + EXPECT_FALSE(fsIntfImpl_->NeedRepeat("reqId")); +} + +TEST_F(FSIntfImplTest, TestUpdateRetryInterval) +{ + std::string reqId = "non_existing_request_id"; + auto result = fsIntfImpl_->UpdateRetryInterval(reqId); + EXPECT_EQ(result.first, nullptr); + EXPECT_TRUE(result.second); + + reqId = "existing_request_id"; + auto wr1 = std::make_shared(); + wr1->retryCount = 0; + wr1->remainTimeoutSec = 0; + wr1->retryIntervalSec = 1; + fsIntfImpl_->wiredRequests[reqId] = wr1; + result = fsIntfImpl_->UpdateRetryInterval(reqId); + EXPECT_EQ(result.first, wr1); + EXPECT_TRUE(result.second); + EXPECT_EQ(fsIntfImpl_->wiredRequests.find(reqId), fsIntfImpl_->wiredRequests.end()); + + reqId = "existing_request_id_1"; + auto wr2 = std::make_shared(); + wr2->retryCount = 0; + wr2->remainTimeoutSec = 10; + wr2->retryIntervalSec = 1; + wr2->exponentialBackoff = true; + fsIntfImpl_->wiredRequests[reqId] = wr2; + + result = fsIntfImpl_->UpdateRetryInterval(reqId); + EXPECT_EQ(result.first, wr2); + EXPECT_FALSE(result.second); + EXPECT_EQ(result.first->retryCount, 1); + EXPECT_EQ(result.first->remainTimeoutSec, 9); + EXPECT_EQ(result.first->retryIntervalSec, 2); + + reqId = "existing_request_id_2"; + auto wr3 = std::make_shared(); + wr3->retryCount = 0; + wr3->remainTimeoutSec = 1; + wr3->retryIntervalSec = 10; + wr3->exponentialBackoff = false; + fsIntfImpl_->wiredRequests[reqId] = wr3; + + result = fsIntfImpl_->UpdateRetryInterval(reqId); + EXPECT_EQ(result.first, wr3); + EXPECT_TRUE(result.second); + EXPECT_EQ(result.first->retryCount, 1); + EXPECT_EQ(result.first->remainTimeoutSec, -9); + EXPECT_EQ(result.first->retryIntervalSec, 10); } } // namespace test } // namespace YR \ No newline at end of file diff --git a/test/libruntime/fs_intf_manager_test.cpp b/test/libruntime/fs_intf_manager_test.cpp index 0430fd5..9adedf7 100644 --- a/test/libruntime/fs_intf_manager_test.cpp +++ b/test/libruntime/fs_intf_manager_test.cpp @@ -28,7 +28,8 @@ protected: std::shared_ptr clientsMgr; std::shared_ptr fsIntfManager; - void SetUp() override { + void SetUp() override + { security = std::make_shared(); clientsMgr = std::make_shared(); fsIntfManager = std::make_shared(clientsMgr); diff --git a/test/libruntime/function_group_test.cpp b/test/libruntime/function_group_test.cpp index f7c7187..555ec53 100644 --- a/test/libruntime/function_group_test.cpp +++ b/test/libruntime/function_group_test.cpp @@ -77,8 +77,8 @@ public: dsObjectStore->Init("127.0.0.1", 8080); waitManager = std::make_shared(); this->memoryStore->Init(dsObjectStore, waitManager); - fsIntf_ = std::make_shared(); - this->fsClient = std::make_shared(fsIntf_); + gwClient_ = std::make_shared(); + this->fsClient = std::make_shared(gwClient_); this->fnGroup = std::make_shared("groupName", "tenantId", opt, this->fsClient, this->waitManager, this->memoryStore, this->invokeOrderMgr, nullptr, nullptr); spec = std::make_shared(); @@ -100,7 +100,7 @@ public: std::shared_ptr waitManager; std::shared_ptr invokeOrderMgr; std::shared_ptr spec; - std::shared_ptr fsIntf_; + std::shared_ptr gwClient_; }; TEST_F(FunctionGroupTest, CreateRespHandlerTest) diff --git a/test/libruntime/generator_test.cpp b/test/libruntime/generator_test.cpp new file mode 100644 index 0000000..993ca6d --- /dev/null +++ b/test/libruntime/generator_test.cpp @@ -0,0 +1,620 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include "mock/mock_datasystem.h" + +#define private public +#include "src/libruntime/generator/generator_id_map.h" +#include "src/libruntime/generator/stream_generator_notifier.h" +#include "src/libruntime/generator/stream_generator_receiver.h" +#include "src/utility/logger/logger.h" + +namespace YR { +namespace test { +using namespace testing; +using namespace YR::Libruntime; +using namespace YR::utility; +using namespace std::chrono_literals; +class GeneratorTest : public testing::Test { +public: + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + consumer = std::make_shared(); + streamStore = std::make_shared(); + auto dsStore = std::make_shared(); + auto waitManager = std::make_shared(); + memStore = std::make_shared(); + memStore->Init(dsStore, waitManager); + } + + void TearDown() override + { + map.reset(); + consumer.reset(); + streamStore.reset(); + memStore.reset(); + if (notifier) { + notifier.reset(); + } + if (receiver) { + receiver.reset(); + } + } + + std::shared_ptr receiver = nullptr; + std::shared_ptr notifier = nullptr; + std::shared_ptr streamStore = nullptr; + std::shared_ptr memStore = nullptr; + std::shared_ptr consumer = nullptr; + std::shared_ptr config = std::make_shared(); + std::shared_ptr map = std::make_shared(); +}; + +TEST_F(GeneratorTest, DetectHeartbeatTest) +{ + receiver = std::make_shared(config, streamStore, memStore); + auto fakeTimeStamp = GetCurrentTimestampMs() - 40000; + receiver->map_->AddRecord("objId", std::to_string(fakeTimeStamp)); + ASSERT_NO_THROW(receiver->DetectHeartbeat()); +} + +TEST_F(GeneratorTest, HandleGeneratorResultTest) +{ + receiver = std::make_shared(config, streamStore, memStore); + libruntime::NotifyGeneratorResult result; + result.set_genid("genid"); + result.set_errorcode(0); + result.set_data("result"); + receiver->numGeneratorResults_["genid"] = 1; + receiver->generatorResultsCounter_["genid"] = 1; + ASSERT_NO_THROW(receiver->HandleGeneratorResult(result)); +} + +TEST_F(GeneratorTest, AddRecordTest) +{ + receiver = std::make_shared(config, streamStore, memStore); + receiver->AddRecord("genId"); + std::vector keys; + receiver->map_->GetRecordKeys(keys); + ASSERT_EQ(keys.size() != 0, true); +} + +TEST_F(GeneratorTest, HandleGeneratorHeartbeatTest) +{ + receiver = std::make_shared(config, streamStore, memStore); + receiver->HandleGeneratorHeartbeat("genId"); + std::vector keys; + receiver->map_->GetRecordKeys(keys); + ASSERT_EQ(keys.size() != 0, true); +} + +TEST_F(GeneratorTest, NotifyHeartbeatTest) +{ + notifier = std::make_shared(streamStore, map); + notifier->map_ = nullptr; + ASSERT_EQ(notifier->NotifyHeartbeat().Code(), ErrorCode::ERR_INNER_SYSTEM_ERROR); + notifier->map_ = map; + notifier->map_->AddRecord("genId", std::to_string(GetCurrentTimestampMs())); + ASSERT_EQ(notifier->NotifyHeartbeat().OK(), true); +} + +TEST_F(GeneratorTest, NotifyHeartbeatByStreamTest) +{ + notifier = std::make_shared(streamStore, map); + notifier->map_ = nullptr; + ASSERT_EQ(notifier->NotifyHeartbeatByStream("genId").Code(), ErrorCode::ERR_INNER_SYSTEM_ERROR); + + notifier->map_ = map; + ASSERT_EQ(notifier->NotifyHeartbeatByStream("genId").Code(), ErrorCode::ERR_INNER_SYSTEM_ERROR); + + notifier->map_->AddRecord("genId", std::to_string(GetCurrentTimestampMs())); + ASSERT_EQ(notifier->NotifyHeartbeatByStream("genId").OK(), true); +} + +TEST_F(GeneratorTest, NotifyFinishedByStreamTest) +{ + notifier = std::make_shared(streamStore, map); + notifier->map_->AddRecord("genId", std::to_string(GetCurrentTimestampMs())); + notifier->dsStreamStore_ = nullptr; + ASSERT_EQ(notifier->NotifyFinishedByStream("genId", 1).Code(), ErrorCode::ERR_INNER_SYSTEM_ERROR); + std::shared_ptr dataObj = std::make_shared("objId"); + ASSERT_EQ(notifier->NotifyResultByStream("genId", 1, dataObj, ErrorInfo()).Code(), + ErrorCode::ERR_INNER_SYSTEM_ERROR); +} + +TEST_F(GeneratorTest, IncreaseProducerReferenceTest) +{ + notifier = std::make_shared(streamStore, map); + notifier->IncreaseProducerReference("topic"); + ASSERT_EQ(notifier->producerReferences_["topic"], 1); + notifier->IncreaseProducerReference("topic"); + ASSERT_EQ(notifier->producerReferences_["topic"], 2); +} + +TEST_F(GeneratorTest, InitializeTest) +{ + notifier = std::make_shared(streamStore, map); + notifier->map_ = nullptr; + notifier->stopped = true; + ASSERT_NO_THROW(notifier->Initialize()); + notifier->timer_->cancel(); + notifier->timer_.reset(); +} + +TEST_F(GeneratorTest, PopBatchTest) +{ + notifier = std::make_shared(streamStore, map); + std::vector> datas; + notifier->genQueue_.push_back(std::make_shared()); + notifier->PopBatch(datas); + ASSERT_EQ(datas.size(), 1); +} + +TEST_F(GeneratorTest, GeneratorIdMapTest) +{ + auto map = std::make_shared(); + std::string genId("fake_genid"); + std::string rtId("fake_rtid"); + { + GeneratorIdRecorder r(genId, rtId, map); + std::string v; + map->GetRecord(genId, v); + EXPECT_EQ(v, rtId); + } + std::string v1; + map->GetRecord(genId, v1); + EXPECT_TRUE(v1 == rtId); + map->RemoveRecord(genId); + std::string v2; + map->GetRecord(genId, v2); + EXPECT_TRUE(v2.empty()); +} + +TEST_F(GeneratorTest, GeneratorIdNotifyTest) +{ + std::shared_ptr streamStore = std::make_shared(); + auto map = std::make_shared(); + { + auto n = std::make_shared(streamStore, map); + std::string genId("fake_genid"); + std::string rtId("fake_rtid"); + std::string objId("fake_objid"); + int idx = 0; + for (int i = 0; i < 2; i++) { + GeneratorIdRecorder r(genId, rtId, map); + auto resultObj = std::make_shared(); + resultObj->id = objId; + resultObj->buffer = std::make_shared(nullptr, 0); + auto resultErr = ErrorInfo(); + auto err = n->NotifyResult(genId, idx + i, resultObj, resultErr); + EXPECT_TRUE(err.OK()); + } + auto err = n->NotifyFinished(genId, 3); + EXPECT_TRUE(err.OK()); + } +} + +TEST_F(GeneratorTest, GeneratorNotifierTest_NotifyResult) +{ + std::shared_ptr streamStore = std::make_shared(); + auto map = std::make_shared(); + + std::string genId("fake_genid"); + std::string rtId("fake_rtid"); + std::string objId("fake_objid"); + int idx = 0; + + auto p = std::make_shared(); + + EXPECT_CALL(*streamStore, CreateStreamProducer(rtId, _, _)) + .WillOnce( + [&p](const std::string &streamName, std::shared_ptr &producer, ProducerConf producerConf) { + producer = p; + return ErrorInfo(); + }); + + EXPECT_CALL(*p, Send(_)) + .WillOnce([&genId, &objId, &idx](const Element &element) { + libruntime::NotifyGeneratorResult res; + auto success = res.ParseFromArray(element.ptr, element.size); + EXPECT_TRUE(success); + EXPECT_EQ(res.genid(), genId); + EXPECT_EQ(res.objectid(), objId); + EXPECT_EQ(res.index(), idx); + EXPECT_EQ(res.errorcode(), 0); + EXPECT_EQ(res.finished(), false); + return ErrorInfo(); + }) + .WillOnce([&genId, &objId, &idx](const Element &element) { + libruntime::NotifyGeneratorResult res; + auto success = res.ParseFromArray(element.ptr, element.size); + EXPECT_TRUE(success); + EXPECT_EQ(res.genid(), genId); + EXPECT_EQ(res.objectid(), objId); + EXPECT_EQ(res.index(), (idx + 1)); + EXPECT_EQ(res.errorcode(), 0); + EXPECT_EQ(res.finished(), false); + return ErrorInfo(); + }); + + EXPECT_CALL(*p, Flush()).Times(2).WillRepeatedly([]() { return ErrorInfo(); }); + + { + auto n = std::make_shared(streamStore, map); + for (int i = 0; i < 2; i++) { + GeneratorIdRecorder r(genId, rtId, map); + auto resultObj = std::make_shared(); + resultObj->id = objId; + resultObj->buffer = std::make_shared(nullptr, 0); + auto resultErr = ErrorInfo(); + n->NotifyResultByStream(genId, idx + i, resultObj, resultErr); + } + } +} + +TEST_F(GeneratorTest, GeneratorNotifierTest_NotifyError) +{ + std::shared_ptr streamStore = std::make_shared(); + auto map = std::make_shared(); + + std::string genId("fake_genid"); + std::string rtId("fake_rtid"); + std::string objId("fake_objid"); + int idx = 0; + ErrorCode errCode = ErrorCode::ERR_USER_FUNCTION_EXCEPTION; + std::string errMsg("failed to execute user func"); + + auto p = std::make_shared(); + + EXPECT_CALL(*streamStore, CreateStreamProducer(rtId, _, _)) + .WillOnce( + [&p](const std::string &streamName, std::shared_ptr &producer, ProducerConf producerConf) { + producer = p; + return ErrorInfo(); + }); + + EXPECT_CALL(*p, Send(_)).WillOnce([&genId, &objId, &idx, &errCode, &errMsg](const Element &element) { + libruntime::NotifyGeneratorResult res; + auto success = res.ParseFromArray(element.ptr, element.size); + EXPECT_TRUE(success); + EXPECT_EQ(res.genid(), genId); + EXPECT_EQ(res.objectid(), objId); + EXPECT_EQ(res.index(), idx); + EXPECT_EQ(res.errorcode(), int64_t(errCode)); + EXPECT_EQ(res.errormessage(), errMsg); + EXPECT_EQ(res.finished(), false); + return ErrorInfo(); + }); + + EXPECT_CALL(*p, Flush()).WillOnce([]() { return ErrorInfo(); }); + + { + auto n = std::make_shared(streamStore, map); + GeneratorIdRecorder r(genId, rtId, map); + auto resultObj = std::make_shared(); + resultObj->id = objId; + resultObj->buffer = std::make_shared(nullptr, 0); + auto resultErr = ErrorInfo(errCode, errMsg); + n->NotifyResultByStream(genId, idx, resultObj, resultErr); + EXPECT_TRUE(n->producerReferences_.find(rtId) == n->producerReferences_.end()); + } +} + +TEST_F(GeneratorTest, GeneratorNotifierTest_NotifyFinished) +{ + std::shared_ptr streamStore = std::make_shared(); + auto map = std::make_shared(); + + std::string genId("fake_genid"); + std::string rtId("fake_rtid"); + int numResults = 10; + + auto p = std::make_shared(); + + EXPECT_CALL(*streamStore, CreateStreamProducer(rtId, _, _)) + .WillOnce( + [&p](const std::string &streamName, std::shared_ptr &producer, ProducerConf producerConf) { + producer = p; + return ErrorInfo(); + }); + + EXPECT_CALL(*p, Send(_)).WillOnce([&genId, &numResults](const Element &element) { + libruntime::NotifyGeneratorResult res; + auto success = res.ParseFromArray(element.ptr, element.size); + EXPECT_TRUE(success); + EXPECT_EQ(res.genid(), genId); + EXPECT_EQ(res.finished(), true); + EXPECT_EQ(res.numresults(), numResults); + return ErrorInfo(); + }); + + EXPECT_CALL(*p, Flush()).WillOnce([]() { return ErrorInfo(); }); + + { + auto n = std::make_shared(streamStore, map); + GeneratorIdRecorder r(genId, rtId, map); + n->NotifyFinishedByStream(genId, numResults); + EXPECT_TRUE(n->producerReferences_.find(rtId) == n->producerReferences_.end()); + } +} + +TEST_F(GeneratorTest, GeneratorReceiverTest_MarkEndOfStream) +{ + auto librtCfg = std::make_shared(); + librtCfg->runtimeId = "driver"; + librtCfg->jobId = "56781234"; + std::shared_ptr streamStore = std::make_shared(); + auto memoryStore = std::make_shared(); + auto dsObjectStore = std::make_shared(); + auto wom = std::make_shared(); + memoryStore->Init(dsObjectStore, wom); + wom->SetMemoryStore(memoryStore); + + std::string genId("fake_genid"); + std::string rtId("fake_rtid"); + std::string objId("fake_objid"); + int idx = 0; + std::string errMsg("failed to execute user func"); + + auto c = std::make_shared(); + + EXPECT_CALL(*streamStore, CreateStreamConsumer("driver_56781234", _, _, true)) + .WillOnce([&c](const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck) { + consumer = c; + return ErrorInfo(); + }); + + void *buffer1; + + EXPECT_CALL(*c, Receive(_, _, _)) + .WillOnce([&genId, &idx, &objId, &buffer1](uint32_t expectNum, uint32_t timeoutMs, + std::vector &outElements) { + libruntime::NotifyGeneratorResult res; + res.set_genid(genId); + res.set_index(idx); + res.set_objectid(objId); + auto rs = res.ByteSizeLong(); + buffer1 = malloc(rs); + EXPECT_TRUE(buffer1 != nullptr); + res.SerializeToArray(buffer1, rs); + outElements.resize(1); + outElements[0].ptr = static_cast(buffer1); + outElements[0].size = rs; + return ErrorInfo(); + }) + .WillRepeatedly([](uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) { + std::this_thread::sleep_for(500ms); + return ErrorInfo(); + }); + + { + auto r = std::make_shared(librtCfg, streamStore, memoryStore); + r->Initialize(); + std::this_thread::sleep_for(50ms); + EXPECT_EQ(r->generatorResultsCounter_[genId], 1); + auto err = ErrorInfo(ErrorCode::ERR_INNER_SYSTEM_ERROR, "some invoke error"); + r->MarkEndOfStream(genId, err); + std::this_thread::sleep_for(100ms); + EXPECT_EQ(r->generatorResultsCounter_.find(genId), r->generatorResultsCounter_.end()); + r->Stop(); + } + free(buffer1); +} + +TEST_F(GeneratorTest, GeneratorReceiverTest_ReceiveError) +{ + auto librtCfg = std::make_shared(); + librtCfg->runtimeId = "driver"; + librtCfg->jobId = "56781234"; + std::shared_ptr streamStore = std::make_shared(); + auto memoryStore = std::make_shared(); + auto dsObjectStore = std::make_shared(); + auto wom = std::make_shared(); + memoryStore->Init(dsObjectStore, wom); + wom->SetMemoryStore(memoryStore); + + std::string genId("fake_genid"); + std::string rtId("fake_rtid"); + std::string objId("fake_objid"); + int idx = 0; + ErrorCode errCode = ErrorCode::ERR_USER_FUNCTION_EXCEPTION; + std::string errMsg("failed to execute user func"); + + auto c = std::make_shared(); + + EXPECT_CALL(*streamStore, CreateStreamConsumer("driver_56781234", _, _, true)) + .WillOnce([&c](const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck) { + consumer = c; + return ErrorInfo(); + }); + + void *buffer1; + void *buffer2; + + EXPECT_CALL(*c, Receive(_, _, _)) + .WillOnce([&genId, &idx, &objId, &buffer1](uint32_t expectNum, uint32_t timeoutMs, + std::vector &outElements) { + libruntime::NotifyGeneratorResult res; + res.set_genid(genId); + res.set_index(idx); + res.set_objectid(objId); + auto rs = res.ByteSizeLong(); + buffer1 = malloc(rs); + EXPECT_TRUE(buffer1 != nullptr); + res.SerializeToArray(buffer1, rs); + outElements.resize(1); + outElements[0].ptr = static_cast(buffer1); + outElements[0].size = rs; + return ErrorInfo(); + }) + .WillOnce([&genId, &idx, &objId, &errCode, &errMsg, &buffer2](uint32_t expectNum, uint32_t timeoutMs, + std::vector &outElements) { + std::this_thread::sleep_for(100ms); + libruntime::NotifyGeneratorResult res; + res.set_genid(genId); + res.set_index(idx + 1); + res.set_objectid(objId); + res.set_errorcode(int64_t(errCode)); + res.set_errormessage(errMsg); + auto rs = res.ByteSizeLong(); + buffer2 = malloc(rs); + EXPECT_TRUE(buffer2 != nullptr); + res.SerializeToArray(buffer2, rs); + outElements.resize(1); + outElements[0].ptr = static_cast(buffer2); + outElements[0].size = rs; + return ErrorInfo(); + }) + .WillRepeatedly([](uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) { + std::this_thread::sleep_for(500ms); + return ErrorInfo(); + }); + + { + auto r = std::make_shared(librtCfg, streamStore, memoryStore); + r->Initialize(); + std::this_thread::sleep_for(50ms); + EXPECT_EQ(r->generatorResultsCounter_[genId], 1); + std::this_thread::sleep_for(200ms); + EXPECT_EQ(r->generatorResultsCounter_.find(genId), r->generatorResultsCounter_.end()); + r->Stop(); + } + + free(buffer1); + free(buffer2); +} + +TEST_F(GeneratorTest, GeneratorReceiverTest_ReceiveFinished) +{ + auto librtCfg = std::make_shared(); + librtCfg->runtimeId = "driver"; + librtCfg->jobId = "56781234"; + std::shared_ptr streamStore = std::make_shared(); + auto memoryStore = std::make_shared(); + auto dsObjectStore = std::make_shared(); + auto wom = std::make_shared(); + memoryStore->Init(dsObjectStore, wom); + wom->SetMemoryStore(memoryStore); + std::string genId("fake_genid"); + std::string rtId("fake_rtid"); + std::string objId("fake_objid"); + int idx = 0; + + auto c = std::make_shared(); + + EXPECT_CALL(*streamStore, CreateStreamConsumer("driver_56781234", _, _, true)) + .WillOnce([&c](const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck) { + consumer = c; + return ErrorInfo(); + }); + + void *buffer1; + void *buffer2; + + EXPECT_CALL(*c, Receive(_, _, _)) + .WillOnce([&genId, &idx, &objId, &buffer1](uint32_t expectNum, uint32_t timeoutMs, + std::vector &outElements) { + libruntime::NotifyGeneratorResult res; + res.set_genid(genId); + res.set_index(idx); + res.set_objectid(objId); + auto rs = res.ByteSizeLong(); + buffer1 = malloc(rs); + EXPECT_TRUE(buffer1 != nullptr); + res.SerializeToArray(buffer1, rs); + outElements.resize(1); + outElements[0].ptr = static_cast(buffer1); + outElements[0].size = rs; + return ErrorInfo(); + }) + .WillOnce([&genId, &buffer2](uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) { + std::this_thread::sleep_for(100ms); + libruntime::NotifyGeneratorResult res; + res.set_genid(genId); + res.set_finished(true); + res.set_numresults(1); + auto rs = res.ByteSizeLong(); + buffer2 = malloc(rs); + EXPECT_TRUE(buffer2 != nullptr); + res.SerializeToArray(buffer2, rs); + outElements.resize(1); + outElements[0].ptr = static_cast(buffer2); + outElements[0].size = rs; + return ErrorInfo(); + }) + .WillRepeatedly([](uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) { + std::this_thread::sleep_for(500ms); + return ErrorInfo(); + }); + + { + auto r = std::make_shared(librtCfg, streamStore, memoryStore); + r->Initialize(); + std::this_thread::sleep_for(50ms); + EXPECT_EQ(r->generatorResultsCounter_[genId], 1); + for (int i = 0; i < 100; i++) { + std::this_thread::sleep_for(50ms); + if (r->generatorResultsCounter_.find(genId) == r->generatorResultsCounter_.end()) { + break; + } + } + EXPECT_EQ(r->generatorResultsCounter_.find(genId), r->generatorResultsCounter_.end()); + r->Stop(); + } + + free(buffer1); + free(buffer2); +} + +TEST_F(GeneratorTest, UpdateRecordAndGetRecordTest) +{ + auto map = std::make_shared(); + std::string key = "key"; + std::string value = "value"; + map->UpdateRecord(key, value); + ASSERT_EQ(map->records_[key], value); + + std::vector vec; + map->GetRecordKeys(vec); + ASSERT_EQ(vec[0], key); +} + +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/grpc_utils_test.cpp b/test/libruntime/grpc_utils_test.cpp new file mode 100644 index 0000000..39ede5e --- /dev/null +++ b/test/libruntime/grpc_utils_test.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/libruntime/utils/grpc_utils.h" +#include +#include +#include "datasystem/utils/sensitive_value.h" +#include "src/libruntime/utils/utils.h" + +namespace YR { +namespace test { +using SensitiveValue = datasystem::SensitiveValue; +class GrpcUtilTest : public ::testing::Test { +protected: + void SetUp() override {} + + void TearDown() override {} +}; + +TEST_F(GrpcUtilTest, SignAndVerifyStreamingMessageTest) +{ + std::string accessKey = "access_key"; + SensitiveValue secretKey = std::string("secret_key"); + runtime_rpc::StreamingMessage message; + message.set_messageid("message_id"); + + // fail to sign empty StreamingMessage + ASSERT_FALSE(YR::Libruntime::SignStreamingMessage(accessKey, secretKey, message)); + ASSERT_FALSE(message.metadata().contains("access_key")); + ASSERT_FALSE(message.metadata().contains("signature")); + ASSERT_FALSE(message.metadata().contains("timestamp")); + + message.mutable_callreq()->set_requestid("123"); + ASSERT_TRUE(YR::Libruntime::SignStreamingMessage(accessKey, secretKey, message)); + ASSERT_TRUE(message.metadata().contains("access_key")); + ASSERT_EQ(message.metadata().at("access_key"), accessKey); + ASSERT_TRUE(message.metadata().contains("signature")); + ASSERT_TRUE(message.metadata().contains("timestamp")); + + ASSERT_TRUE(YR::Libruntime::VerifyStreamingMessage(accessKey, secretKey, message)); + + ASSERT_FALSE(YR::Libruntime::VerifyStreamingMessage("fake_access_key", secretKey, message)); + + SensitiveValue fakeSecretKey = std::string("fake_secret_key"); + ASSERT_FALSE(YR::Libruntime::VerifyStreamingMessage(accessKey, fakeSecretKey, message)); + + message.mutable_callreq()->set_requestid("1234"); + ASSERT_FALSE(YR::Libruntime::VerifyStreamingMessage(accessKey, secretKey, message)); + + message.clear_callreq(); + message.mutable_callresultack()->set_message("123"); + ASSERT_FALSE(YR::Libruntime::VerifyStreamingMessage(accessKey, secretKey, message)); + + message.Clear(); + ASSERT_FALSE(YR::Libruntime::VerifyStreamingMessage(accessKey, secretKey, message)); +} + +TEST_F(GrpcUtilTest, SignAndVerifyTimestampTest) +{ + std::string accessKey = "access_key"; + SensitiveValue secretKey = std::string("secret_key"); + std::cout << std::string(secretKey.GetData(), secretKey.GetSize()) << "<>=======" << std::endl; + auto timestamp = YR::GetCurrentUTCTime(); + auto signature = YR::Libruntime::SignTimestamp(accessKey, secretKey, timestamp); + ASSERT_FALSE(signature.empty()); + + ASSERT_TRUE(YR::Libruntime::VerifyTimestamp(accessKey, secretKey, timestamp, signature)); + + ASSERT_FALSE(YR::Libruntime::VerifyTimestamp("fake_access_key", secretKey, timestamp, signature)); + + SensitiveValue fakeSecretKey = std::string("fake_secret_key"); + ASSERT_FALSE(YR::Libruntime::VerifyTimestamp(accessKey, fakeSecretKey, timestamp, signature)); + + auto fakeTimestamp = YR::GetCurrentUTCTime(); + ASSERT_FALSE(YR::Libruntime::VerifyTimestamp(accessKey, fakeSecretKey, fakeTimestamp, signature)); + + ASSERT_FALSE(YR::Libruntime::VerifyTimestamp(accessKey, fakeSecretKey, timestamp, "fake_signature")); +} +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/gw_client_test.cpp b/test/libruntime/gw_client_test.cpp new file mode 100644 index 0000000..b48cbf0 --- /dev/null +++ b/test/libruntime/gw_client_test.cpp @@ -0,0 +1,1019 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/libruntime/utils/utils.h" +#define private public +#define protected public +#include "src/dto/data_object.h" +#include "src/libruntime/gwclient/gw_client.h" +using namespace testing; +using namespace YR::Libruntime; +using namespace YR::utility; + +namespace YR { +namespace test { +std::vector> BuildBuffers(); +std::vector> BuildEmptyBuffer(); +bool ValidBuffers(std::shared_ptr>> buffers); + +enum ResponseType { + HTTP_ERROR_COMMUNICATION = 0, + HTTP_BAD_REQUEST = 1, + HTTP_OK_AND_SUCCESS = 2, + HTTP_OK_BUT_FAILED = 3, + HTTP_TIMEOUT = 4, +}; +const size_t getSize = 1; +const uint HTTP_BAD_REQUEST_CODE = 400; +const uint HTTP_OK_CODE = 200; +static std::unordered_map responseTypeMap = { + {ResponseType::HTTP_ERROR_COMMUNICATION, "TEST_HTTP_ERROR_COMMUNICATION"}, + {ResponseType::HTTP_BAD_REQUEST, "TEST_HTTP_BAD_REQUEST"}, + {ResponseType::HTTP_OK_AND_SUCCESS, "TEST_HTTP_OK_AND_SUCCESS"}, + {ResponseType::HTTP_OK_BUT_FAILED, "TEST_HTTP_OK_BUT_FAILED"}, + {ResponseType::HTTP_TIMEOUT, "TEST_HTTP_TIMEOUT"}}; + +const int32_t testTimeoutMs = 50; +const int32_t connectTimeoutS = 1; +const std::string errorCommunicationMsg = "error_code: "; +const std::string badRequestMsg = "failed response status_code: "; +const std::string timeoutErr = "http request timeout s: " + std::to_string(connectTimeoutS); +const std::string failedMsg = "system error"; +const std::string g_requestId = "cae7c30c8d63f5ed00"; +const std::string g_instanceId = "cae7c30c8d63f5ed"; +void CommonCallback(const std::string &rspType, const HttpCallbackFunction &receiver, const std::string &body) +{ + if (responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION] == rspType) { + receiver(body, boost::asio::error::make_error_code(boost::asio::error::fault), 0); + } else if (responseTypeMap[ResponseType::HTTP_BAD_REQUEST] == rspType) { + receiver(body, boost::beast::error_code(), HTTP_BAD_REQUEST_CODE); + } else if (responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS] == rspType) { + receiver(body, boost::beast::error_code(), HTTP_OK_CODE); + } else if (responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED] == rspType) { + receiver(body, boost::beast::error_code(), HTTP_OK_CODE); + } else { + std::cout << "unknown rspType:" << rspType; + } +} +class MockHttpClient : public HttpClient { +public: + MockHttpClient() = default; + ~MockHttpClient() = default; + ErrorInfo Init(const ConnectionParam ¶m) override + { + return ErrorInfo(); + } + + void SubmitInvokeRequest(const http::verb &method, const std::string &target, + const std::unordered_map &headers, const std::string &body, + const std::shared_ptr requestId, + const HttpCallbackFunction &receiver) override + { + const auto rspType = headers.at(REMOTE_CLIENT_ID_KEY); + auto code = common::ERR_NONE; + std::string msg = ""; + if (responseTypeMap[ResponseType::HTTP_TIMEOUT] == rspType) { + return; + } + if (!(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS] == rspType)) { + code = common::ERR_INNER_SYSTEM_ERROR; + msg = failedMsg; + } + if (POSIX_LEASE == target || POSIX_LEASE_KEEPALIVE == target) { + auto rsp = BuildLeaseResponse(code, msg); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_CREATE == target) { + std::vector> returnObjects = {}; + std::vector infos = {}; + auto rsp = BuildNotifyRequest(*requestId, code, msg, returnObjects, infos); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_INVOKE == target) { + std::vector> returnObjects = {}; + std::vector infos = {}; + auto rsp = BuildNotifyRequest(*requestId, code, msg, returnObjects, infos); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (target.find("invocation") != std::string::npos) { + std::string rspBody = "{\"code\":150444,\"message\":\"instance label not found\"}"; + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_KILL == target) { + auto rsp = BuildKillResponse(code, msg); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_OBJ_GET == target) { + auto rsp = BuildObjGetResponse(code, msg, BuildBuffers()); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_OBJ_PUT == target) { + auto rsp = BuildObjPutResponse(code, msg); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_OBJ_DECREASE == target) { + std::vector failedObjectIds = {}; + auto rsp = BuildDecreaseRefResponse(code, msg, failedObjectIds); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_OBJ_INCREASE == target) { + std::vector failedObjectIds = {"failed_obj-1"}; + auto rsp = BuildIncreaseRefResponse(code, msg, failedObjectIds); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_KV_GET == target) { + auto rsp = BuildKvGetResponse(code, msg, BuildBuffers()); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_KV_SET == target) { + auto rsp = BuildKvSetResponse(code, msg); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_KV_MSET_TX == target) { + auto rsp = BuildKvMSetTxResponse(code, msg); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else if (POSIX_KV_DEL == target) { + std::vector failedKeys = {}; + auto rsp = BuildKvDelResponse(code, msg, failedKeys); + std::string rspBody; + rsp.SerializeToString(&rspBody); + CommonCallback(rspType, receiver, rspBody); + } else { + std::cout << "unknown target:" << target; + } + } + + ErrorInfo ReInit() override + { + return ErrorInfo(); + } + + LeaseResponse BuildLeaseResponse(const common::ErrorCode code, const std::string &msg) + { + LeaseResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + return rsp; + } + + CreateResponse BuildCreateResponse(const common::ErrorCode code, const std::string &msg, + const std::string &instanceID) + { + CreateResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + rsp.set_instanceid(instanceID); + return rsp; + } + + NotifyRequest BuildNotifyRequest(const std::string &requestID, const common::ErrorCode code, const std::string &msg, + const std::vector> &returnObjects, + const std::vector &infos) + { + NotifyRequest req; + req.set_requestid(requestID); + req.set_code(code); + req.set_message(msg); + for (size_t i = 0; i < returnObjects.size(); i++) { + if (returnObjects[i]->data != nullptr && returnObjects[i]->data->IsNative() && + returnObjects[i]->putDone == false) { + auto smallObject = req.add_smallobjects(); + smallObject->set_id(returnObjects[i]->id); + smallObject->set_value(returnObjects[i]->data->ImmutableData(), returnObjects[i]->data->GetSize()); + } + } + for (size_t i = 0; i < infos.size(); i++) { + auto setInfo = req.add_stacktraceinfos(); + auto eles = infos[i].StackTraceElements(); + for (size_t j = 0; j < eles.size(); j++) { + auto stackTraceEle = setInfo->add_stacktraceelements(); + stackTraceEle->set_classname(eles[j].className); + stackTraceEle->set_methodname(eles[j].methodName); + stackTraceEle->set_filename(eles[j].fileName); + stackTraceEle->set_linenumber(eles[j].lineNumber); + for (auto &it : eles[j].extensions) { + stackTraceEle->mutable_extensions()->insert({it.first, it.second}); + } + } + setInfo->set_type(infos[i].Type()); + setInfo->set_message(infos[i].Message()); + } + return req; + } + + KillResponse BuildKillResponse(const common::ErrorCode code, const std::string &msg) + { + KillResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + return rsp; + } + + GetResponse BuildObjGetResponse(const common::ErrorCode code, const std::string &msg, + const std::vector> &values) + { + GetResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + for (const auto &value : values) { + rsp.add_buffers(value->ImmutableData(), value->GetSize()); + } + return rsp; + } + + PutResponse BuildObjPutResponse(const common::ErrorCode code, const std::string &msg) + { + PutResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + return rsp; + } + + KvSetResponse BuildKvSetResponse(const common::ErrorCode code, const std::string &msg) + { + KvSetResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + return rsp; + } + + KvMSetTxResponse BuildKvMSetTxResponse(const common::ErrorCode code, const std::string &msg) + { + KvMSetTxResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + return rsp; + } + + KvGetResponse BuildKvGetResponse(const common::ErrorCode code, const std::string &msg, + const std::vector> &values) + { + KvGetResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + for (const auto &value : values) { + rsp.add_values(value->ImmutableData(), value->GetSize()); + } + return rsp; + } + + KvDelResponse BuildKvDelResponse(const common::ErrorCode code, const std::string &msg, + const std::vector &failedKeys) + { + KvDelResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + for (const auto &key : failedKeys) { + rsp.add_failedkeys(key); + } + return rsp; + } + + IncreaseRefResponse BuildIncreaseRefResponse(const common::ErrorCode code, const std::string &msg, + const std::vector &failedObjectIds) + { + IncreaseRefResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + for (const auto &objId : failedObjectIds) { + rsp.add_failedobjectids(objId); + } + return rsp; + } + + DecreaseRefResponse BuildDecreaseRefResponse(const common::ErrorCode code, const std::string &msg, + const std::vector &failedObjectIds) + { + DecreaseRefResponse rsp; + rsp.set_code(code); + rsp.set_message(msg); + for (const auto &objId : failedObjectIds) { + rsp.add_failedobjectids(objId); + } + return rsp; + } +}; + +std::vector> BuildBuffers() +{ + std::vector> buffers; + const char *str = "value"; + size_t len = std::strlen(str); + auto buf = std::make_shared(len); + buf->MemoryCopy(str, len); + buffers.emplace_back(std::move(buf)); + return buffers; +} + +std::vector> BuildEmptyBuffer() +{ + size_t number = 1; + std::vector> buffers(number); + buffers[0] = std::make_shared(0); + return buffers; +} + +bool ValidBuffers(std::shared_ptr>> buffers) +{ + EXPECT_TRUE(buffers->size() == getSize); + const size_t firstIndex = 0; + auto buffersOne = BuildBuffers(); + for (const auto &sbuf : *buffers) { + if (buffersOne[firstIndex]->GetSize() != sbuf->GetSize()) { + return false; + } + if (!std::equal( + static_cast(buffersOne[firstIndex]->ImmutableData()), + static_cast(buffersOne[firstIndex]->ImmutableData()) + buffersOne[firstIndex]->GetSize(), + static_cast(sbuf->ImmutableData()))) { + return false; + } + } + return true; +} + +class GwClientTest : public testing::Test { +public: + GwClientTest(){ + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + }; + ~GwClientTest(){}; + void SetUp() override + { + auto librtCfg = std::make_shared(); + librtCfg->functionSystemIpAddr = ""; + librtCfg->functionSystemPort = 0; + librtCfg->functionIds[libruntime::LanguageType::Cpp] = + "12345678901234561234567890123456/0-function-function/$latest"; + FSIntfHandlers handlers; + auto httpClient = std::make_unique(); + httpClient->Init(ConnectionParam{librtCfg->functionSystemIpAddr, std::to_string(librtCfg->functionSystemPort)}); + auto security = std::make_shared(); + std::string accessKey = "access_key"; + SensitiveValue secretKey = std::string("secret_key"); + security->SetAKSKAndCredential(accessKey, secretKey); + gwClient_ = + std::make_shared(librtCfg->functionIds[libruntime::LanguageType::Cpp], handlers, security); + gwClient_->Init(std::move(httpClient), connectTimeoutS); + } + void TearDown() override + { + gwClient_.reset(); + } + std::shared_ptr gwClient_; +}; + +void ErrorMsgCheck(const ErrorInfo &err, ErrorCode code, const std::string &msg) +{ + EXPECT_EQ(err.Code(), code); + if (msg.empty()) { + return; + } + std::string::size_type code_idx = err.Msg().find(msg); + EXPECT_TRUE(code_idx != std::string::npos); +} + +TEST_F(GwClientTest, TestLease) +{ + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->Lease(), ErrorCode::ERR_INNER_COMMUNICATION, errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->Lease(), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->Lease(), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->Lease(), ErrorCode::ERR_OK, ""); + + ErrorMsgCheck(gwClient_->KeepLease(), ErrorCode::ERR_OK, ""); + + ErrorMsgCheck(gwClient_->Release(), ErrorCode::ERR_OK, ""); + + gwClient_->Stop(); +} + +TEST_F(GwClientTest, TestCreate) +{ + CreateRequest req; + req.set_requestid(g_requestId); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + gwClient_->CreateAsync( + req, [](const CreateResponse &rsp) { EXPECT_TRUE(rsp.code() == common::ERR_INNER_COMMUNICATION); }, + [](const NotifyRequest &req) { EXPECT_TRUE(0 == 1); }); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + gwClient_->CreateAsync( + req, [](const CreateResponse &rsp) { EXPECT_TRUE(rsp.code() == common::ERR_PARAM_INVALID); }, + [](const NotifyRequest &req) { EXPECT_TRUE(0 == 1); }); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + gwClient_->CreateAsync( + req, [](const CreateResponse &rsp) { EXPECT_TRUE(rsp.code() == common::ERR_INNER_SYSTEM_ERROR); }, + [](const NotifyRequest &req) { EXPECT_TRUE(0 == 1); }); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + gwClient_->CreateAsync( + req, [](const CreateResponse &rsp) { EXPECT_TRUE(rsp.code() == common::ERR_NONE); }, + [](const NotifyRequest &req) { EXPECT_TRUE(req.code() == common::ERR_NONE); }); +} + +TEST_F(GwClientTest, TestInvoke) +{ + InvokeRequest req; + req.set_requestid(g_requestId); + auto messageSpec = std::make_shared(std::move(req)); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + gwClient_->InvokeAsync(messageSpec, [](const NotifyRequest &req, const ErrorInfo &err) { + EXPECT_TRUE(req.code() == common::ERR_INNER_COMMUNICATION); + }); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + gwClient_->InvokeAsync(messageSpec, [](const NotifyRequest &req, const ErrorInfo &err) { + EXPECT_TRUE(req.code() == common::ERR_PARAM_INVALID); + }); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + gwClient_->InvokeAsync(messageSpec, [](const NotifyRequest &req, const ErrorInfo &err) { + EXPECT_TRUE(req.code() == common::ERR_INNER_SYSTEM_ERROR); + }); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + gwClient_->InvokeAsync(messageSpec, [](const NotifyRequest &req, const ErrorInfo &err) { + EXPECT_TRUE(req.code() == common::ERR_NONE); + }); +} + +TEST_F(GwClientTest, TestInvocation) +{ + const std::string url = + "/serverless/v2/functions/" + "wisefunction:cn:iot:8d86c63b22e24d9ab650878b75408ea6:function:test-jiuwen-session-004-bj:$latest/invocations"; + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + auto spec = std::make_shared(); + spec->traceId = "traceId"; + spec->requestId = "requestId"; + std::string arg = "{\"key\":\"invoke\"}"; + InvokeArg libArg; + libArg.dataObj = std::make_shared(0, arg.size()); + libArg.dataObj->data->MemoryCopy(arg.data(), arg.size()); + spec->invokeArgs.emplace_back(std::move(libArg)); + gwClient_->InvocationAsync(url, spec, + [](const std::string &requestId, Libruntime::ErrorCode code, const std::string &result) { + EXPECT_TRUE(code == Libruntime::ErrorCode::ERR_INNER_COMMUNICATION); + }); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + gwClient_->InvocationAsync(url, spec, + [](const std::string &requestId, Libruntime::ErrorCode code, const std::string &result) { + EXPECT_TRUE(code == Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR); + }); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + gwClient_->InvocationAsync(url, spec, + [](const std::string &requestId, Libruntime::ErrorCode code, const std::string &result) { + EXPECT_TRUE(code == Libruntime::ErrorCode::ERR_OK); + }); +} + +TEST_F(GwClientTest, TestKill) +{ + KillRequest req; + req.set_instanceid(g_instanceId); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + gwClient_->KillAsync(req, [](const KillResponse &rsp, const ErrorInfo &err) { + EXPECT_TRUE(rsp.code() == common::ERR_INNER_COMMUNICATION); + }); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + gwClient_->KillAsync(req, [](const KillResponse &rsp, const ErrorInfo &err) { + EXPECT_TRUE(rsp.code() == common::ERR_PARAM_INVALID); + }); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + gwClient_->KillAsync(req, [](const KillResponse &rsp, const ErrorInfo &err) { + EXPECT_TRUE(rsp.code() == common::ERR_INNER_SYSTEM_ERROR); + }); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + gwClient_->KillAsync( + req, [](const KillResponse &rsp, const ErrorInfo &err) { EXPECT_TRUE(rsp.code() == common::ERR_NONE); }); +} + +TEST_F(GwClientTest, TestPosixObjGet) +{ + std::vector objIds = {"obj"}; + auto result = std::make_shared>>(); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->PosixObjGet(objIds, result, testTimeoutMs), ErrorCode::ERR_INNER_COMMUNICATION, + errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->PosixObjGet(objIds, result, testTimeoutMs), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + result->resize(getSize); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->PosixObjGet(objIds, result, testTimeoutMs), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + result->resize(getSize); + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->PosixObjGet(objIds, result, testTimeoutMs), ErrorCode::ERR_OK, ""); + EXPECT_TRUE(ValidBuffers(result)); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_TIMEOUT]).OK()); + ErrorMsgCheck(gwClient_->PosixObjGet(objIds, result, testTimeoutMs), ErrorCode::ERR_INNER_COMMUNICATION, + timeoutErr); +} + +TEST_F(GwClientTest, TestObjGet) +{ + std::string obj = "obj"; + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->Get(obj, testTimeoutMs).first, ErrorCode::ERR_INNER_COMMUNICATION, errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->Get(obj, testTimeoutMs).first, ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->Get(obj, testTimeoutMs).first, ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->Get(obj, testTimeoutMs).first, ErrorCode::ERR_OK, ""); + ErrorMsgCheck(gwClient_->GetBuffers({obj}, testTimeoutMs).first, ErrorCode::ERR_OK, ""); + ErrorMsgCheck(gwClient_->GetBuffersWithoutRetry({obj}, testTimeoutMs).first.errorInfo, ErrorCode::ERR_OK, ""); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_TIMEOUT]).OK()); + ErrorMsgCheck(gwClient_->Get(obj, testTimeoutMs).first, ErrorCode::ERR_INNER_COMMUNICATION, timeoutErr); +} + +TEST_F(GwClientTest, TestPosixObjPut) +{ + PutRequest req; + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->PosixObjPut(req), ErrorCode::ERR_INNER_COMMUNICATION, errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->PosixObjPut(req), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->PosixObjPut(req), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->PosixObjPut(req), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestObjPut) +{ + auto result = std::make_shared(1); + std::string objID = "obj_id"; + std::unordered_set nestedID = {"nested_id"}; + CreateParam param; + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->Put(result, objID, nestedID, param), ErrorCode::ERR_INNER_COMMUNICATION, + errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->Put(result, objID, nestedID, param), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->Put(result, objID, nestedID, param), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->Put(result, objID, nestedID, param), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestPosixKvGet) +{ + std::vector keys = {"key"}; + auto result = std::make_shared>>(); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->PosixKvGet(keys, result, testTimeoutMs), ErrorCode::ERR_INNER_COMMUNICATION, + errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->PosixKvGet(keys, result, testTimeoutMs), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + result->resize(getSize); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->PosixKvGet(keys, result, testTimeoutMs), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + result->resize(getSize); + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->PosixKvGet(keys, result, testTimeoutMs), ErrorCode::ERR_OK, ""); + EXPECT_TRUE(ValidBuffers(result)); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_TIMEOUT]).OK()); + ErrorMsgCheck(gwClient_->PosixKvGet(keys, result, testTimeoutMs), ErrorCode::ERR_INNER_COMMUNICATION, timeoutErr); +} + +TEST_F(GwClientTest, TestKVGetRead) +{ + std::vector keys = {"key"}; + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->Read(keys, testTimeoutMs, true).second, ErrorCode::ERR_INNER_COMMUNICATION, + errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->Read(keys, testTimeoutMs, true).second, ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->Read(keys, testTimeoutMs, true).second, ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->Read(keys, testTimeoutMs, false).second, ErrorCode::ERR_OK, ""); + ErrorMsgCheck(gwClient_->Read("key", testTimeoutMs).second, ErrorCode::ERR_OK, ""); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_TIMEOUT]).OK()); + ErrorMsgCheck(gwClient_->Read(keys, testTimeoutMs, true).second, ErrorCode::ERR_INNER_COMMUNICATION, timeoutErr); +} + +TEST_F(GwClientTest, TestKvSet) +{ + std::string key = ""; + auto value = std::make_shared(0); + SetParam setParam; + auto result = std::make_shared>>(); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->PosixKvSet(key, value, setParam), ErrorCode::ERR_INNER_COMMUNICATION, + errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->PosixKvSet(key, value, setParam), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->PosixKvSet(key, value, setParam), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->PosixKvSet(key, value, setParam), ErrorCode::ERR_OK, ""); + ErrorMsgCheck(gwClient_->Write(key, value, setParam), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestMKvSetTx) +{ + std::vector keys = {"", ""}; + auto value1 = std::make_shared(0); + auto value2 = std::make_shared(1); + std::vector> vals = {value1, value2}; + MSetParam mSetParam; + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->PosixKvMSetTx(keys, vals, mSetParam), ErrorCode::ERR_INNER_COMMUNICATION, + errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->PosixKvMSetTx(keys, vals, mSetParam), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->PosixKvMSetTx(keys, vals, mSetParam), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->PosixKvMSetTx(keys, vals, mSetParam), ErrorCode::ERR_OK, ""); + ErrorMsgCheck(gwClient_->MSetTx(keys, vals, mSetParam), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestPosixKvDel) +{ + std::vector keys = {}; + auto failedKeys = std::make_shared>(); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->PosixKvDel(keys, failedKeys), ErrorCode::ERR_INNER_COMMUNICATION, errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->PosixKvDel(keys, failedKeys), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->PosixKvDel(keys, failedKeys), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->PosixKvDel(keys, failedKeys), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestDel) +{ + std::string key = "key"; + auto failedKeys = std::make_shared>(); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->Del(key), ErrorCode::ERR_INNER_COMMUNICATION, errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->Del(key), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->Del(key), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->Del(key), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestPosixIncreaseRef) +{ + std::vector objIds = {}; + auto failedObjIds = std::make_shared>(); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->PosixGInCreaseRef(objIds, failedObjIds), ErrorCode::ERR_INNER_COMMUNICATION, + errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->PosixGInCreaseRef(objIds, failedObjIds), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->PosixGInCreaseRef(objIds, failedObjIds), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->PosixGInCreaseRef(objIds, failedObjIds), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestIncreaseRef) +{ + std::vector objIds = {}; + auto failedObjIds = std::make_shared>(); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->IncreGlobalReference(objIds), ErrorCode::ERR_INNER_COMMUNICATION, errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->IncreGlobalReference(objIds), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->IncreGlobalReference(objIds), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->IncreGlobalReference(objIds), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestPosixDecreaseRef) +{ + std::vector objIds = {}; + auto failedObjIds = std::make_shared>(); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->PosixGDecreaseRef(objIds, failedObjIds), ErrorCode::ERR_INNER_COMMUNICATION, + errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(gwClient_->PosixGDecreaseRef(objIds, failedObjIds), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(gwClient_->PosixGDecreaseRef(objIds, failedObjIds), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(gwClient_->PosixGDecreaseRef(objIds, failedObjIds), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestDecreaseRef) +{ + std::vector objIds = {"objId"}; + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + gwClient_->SetTenantId("tenantId"); + ErrorMsgCheck(gwClient_->DecreGlobalReference(objIds), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestUpdateToken) +{ + datasystem::SensitiveValue token("token"); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->UpdateToken(token), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestUpdateAkSk) +{ + std::string ak = "ak"; + datasystem::SensitiveValue sk("sk"); + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->UpdateAkSk(ak, sk), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestGenerateKey) +{ + std::string key; + std::string prefix = "591672113dc36b6a0000"; + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->GenerateKey(key, prefix, false), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestGetPrefix) +{ + std::string key = "591672113dc36b6a0000"; + std::string prefix; + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(gwClient_->GetPrefix(key, prefix), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestParseObjGetResponse) +{ + size_t number = 1; + auto result = std::make_shared>>(); + result->resize(number); + auto buffers = BuildEmptyBuffer(); + auto client = MockHttpClient(); + auto code = common::ERR_NONE; + std::string msg = ""; + auto rsp = client.BuildObjGetResponse(code, msg, buffers); + std::string rspBody; + rsp.SerializeToString(&rspBody); + gwClient_->ParseObjGetResponse(rspBody, result); + EXPECT_TRUE((*result)[0] == nullptr); +} + +TEST_F(GwClientTest, TestParseKvGetResponse) +{ + size_t number = 1; + auto result = std::make_shared>>(); + result->resize(number); + auto buffers = BuildEmptyBuffer(); + auto client = MockHttpClient(); + auto code = common::ERR_NONE; + std::string msg = ""; + auto rsp = client.BuildKvGetResponse(code, msg, buffers); + std::string rspBody; + rsp.SerializeToString(&rspBody); + gwClient_->ParseKvGetResponse(rspBody, result); + EXPECT_TRUE((*result)[0] == nullptr); +} + +TEST_F(GwClientTest, When_HttpClient_Connecting_Do_CreateBuffer_Should_Return_OK_Test) +{ + std::string objectId = "111"; + std::shared_ptr dataBuf; + CreateParam param; + param.cacheType = CacheType::DISK; + param.consistencyType = ConsistencyType::PRAM; + param.writeMode = WriteMode::NONE_L2_CACHE_EVICT; + EXPECT_TRUE(gwClient_->CreateBuffer(objectId, 10, dataBuf, param).OK()); + std::unordered_set nestedIds{"222"}; + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_ERROR_COMMUNICATION]).OK()); + ErrorMsgCheck(dataBuf->Seal(nestedIds), ErrorCode::ERR_INNER_COMMUNICATION, errorCommunicationMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_BAD_REQUEST]).OK()); + ErrorMsgCheck(dataBuf->Seal(nestedIds), ErrorCode::ERR_PARAM_INVALID, badRequestMsg); + + EXPECT_FALSE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_BUT_FAILED]).OK()); + ErrorMsgCheck(dataBuf->Seal(nestedIds), ErrorCode::ERR_INNER_SYSTEM_ERROR, failedMsg); + + EXPECT_TRUE(gwClient_->Start(responseTypeMap[ResponseType::HTTP_OK_AND_SUCCESS]).OK()); + ErrorMsgCheck(dataBuf->Seal(nestedIds), ErrorCode::ERR_OK, ""); +} + +TEST_F(GwClientTest, TestBuildHeader) +{ + auto header = gwClient_->BuildHeaders("1", "2", "3"); + auto c = header.at("tenantId"); + EXPECT_TRUE(c == "3"); +} + +TEST_F(GwClientTest, TestCreateStreamConsumerWillReturnError) +{ + auto c = std::make_shared(); + auto err = gwClient_->CreateStreamConsumer("stream", SubscriptionConfig{}, c); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); +} + +TEST_F(GwClientTest, TestQueryGlobalReferenceWillThrowLibRuntimeException) +{ + try { + gwClient_->QueryGlobalReference(std::vector()); + ASSERT_TRUE(false); + } catch (YR::Libruntime::Exception &e) { + ASSERT_EQ(e.Code(), ERR_INNER_SYSTEM_ERROR); + ASSERT_EQ(e.Msg(), "QueryGlobalReference method is not supported when inCluster is false"); + } +} + +TEST_F(GwClientTest, TestUnsupportedReq1) +{ + ErrorInfo err; + datasystem::ConnectOptions options; + DsConnectOptions connOptions; + CreateResourceGroupRequest req; + std::vector objIds; + std::vector outSizes; + std::shared_ptr producer; + std::string returnKey; + uint64_t gNum; + try { + gwClient_->CallResultAsync(nullptr, nullptr); + ASSERT_TRUE(false); + } catch (YR::Libruntime::Exception &e) { + ASSERT_EQ(e.Code(), ERR_INNER_SYSTEM_ERROR); + } + try { + gwClient_->CreateRGroupAsync(req, nullptr); + ASSERT_TRUE(false); + } catch (YR::Libruntime::Exception &e) { + ASSERT_EQ(e.Code(), ERR_INNER_SYSTEM_ERROR); + } + + err = gwClient_->Init("127.0.0.1", 11111, false, false, "runtimePublicKey", + datasystem::SensitiveValue("runtimePrivateKey"), "dsPublicKey", + datasystem::SensitiveValue("token"), "ak", datasystem::SensitiveValue("sk")); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); + + err = gwClient_->Init(options); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); + + try { + gwClient_->QuerySize(objIds, outSizes); + ASSERT_TRUE(false); + } catch (YR::Libruntime::Exception &e) { + ASSERT_EQ(e.Code(), ERR_INNER_SYSTEM_ERROR); + } + + gwClient_->Shutdown(); + err = gwClient_->Init(connOptions); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); + + err = gwClient_->GenerateKey(returnKey); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); + + err = gwClient_->CreateStreamProducer("streamName", producer); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); + + err = gwClient_->DeleteStream("streamName"); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); + + err = gwClient_->QueryGlobalProducersNum("streamName", gNum); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); + + err = gwClient_->QueryGlobalConsumersNum("streamName", gNum); + ASSERT_EQ(err.Code(), ERR_INNER_SYSTEM_ERROR); +} + +TEST_F(GwClientTest, TestUnsupportedReq3) +{ + ErrorInfo err; + ExitRequest exitReq; + StateSaveRequest saveStateReq; + StateLoadRequest loadStateReq; + try { + gwClient_->ExitAsync(exitReq, [](const ExitResponse &) {}); + ASSERT_TRUE(false); + } catch (Exception &e) { + std::string errMsg = e.what(); + ASSERT_EQ(errMsg, + "ErrCode: 3003, ModuleCode: 20, ErrMsg: ExitAsync method not implemented when inCluster is false"); + } + try { + gwClient_->StateSaveAsync(saveStateReq, [](const StateSaveResponse) {}); + ASSERT_TRUE(false); + } catch (Exception &e) { + std::string errMsg = e.what(); + ASSERT_EQ( + errMsg, + "ErrCode: 3003, ModuleCode: 20, ErrMsg: StateSaveAsync method not implemented when inCluster is false"); + } + try { + gwClient_->StateLoadAsync(loadStateReq, [](const StateLoadResponse &) {}); + ASSERT_TRUE(false); + } catch (Exception &e) { + std::string errMsg = e.what(); + ASSERT_EQ(errMsg, "ErrCode: 3003, ModuleCode: 20, ErrMsg: StateLoadAsync is not supported with gateway client"); + } + + err = gwClient_->ReleaseGRefs("remoteId"); + ASSERT_EQ(err.Code(), ERR_PARAM_INVALID); +} + +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/hash_util_test.cpp b/test/libruntime/hash_util_test.cpp new file mode 100644 index 0000000..f2da722 --- /dev/null +++ b/test/libruntime/hash_util_test.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/libruntime/utils/hash_utils.h" + +namespace YR { +namespace test { +using namespace YR::Libruntime; +class HashUtilTest : public ::testing::Test { +protected: + void SetUp() override {} + + void TearDown() override {} +}; + +TEST_F(HashUtilTest, CorrectHMACSHA256Test2) +{ + SensitiveValue key = std::string("secret"); + std::string worldSha256 = GetHMACSha256(key, "Hello, World!"); + EXPECT_EQ(worldSha256, "fcfaffa7fef86515c7beb6b62d779fa4ccf092f2e61c164376054271252821ff"); +} + +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/hetero_future_test.cpp b/test/libruntime/hetero_future_test.cpp index b0ac215..ead0fab 100644 --- a/test/libruntime/hetero_future_test.cpp +++ b/test/libruntime/hetero_future_test.cpp @@ -30,6 +30,7 @@ using namespace YR::utility; namespace YR { namespace Libruntime { YR::Libruntime::AsyncResult ConverDsStatusToAsyncRes(datasystem::Status dsStatus); +YR::Libruntime::AsyncResult ConverDsAsyncResultToLib(datasystem::AsyncResult dsResult); } // namespace Libruntime } // namespace YR @@ -65,6 +66,28 @@ public: std::shared_ptr heteroFuture_; }; +TEST_F(HeteroFutureTest, TestSharedFutureGet) +{ + std::promise promise; + auto future = promise.get_future().share(); + heteroFuture_ = std::make_shared(std::make_shared>(future)); + ASSERT_EQ(heteroFuture_->IsDsFuture(), false); + promise.set_value(datasystem::AsyncResult()); + YR::Libruntime::AsyncResult result = heteroFuture_->Get(); + ASSERT_EQ(result.error.OK(), true); +} + +TEST_F(HeteroFutureTest, TestConverDsAsyncResultToLib) +{ + datasystem::AsyncResult dsResult; + auto result_1 = ConverDsAsyncResultToLib(dsResult); + ASSERT_EQ(result_1.error.OK(), true); + datasystem::Status status(datasystem::StatusCode::K_DUPLICATED, "err"); + dsResult.status = status; + auto result_2 = ConverDsAsyncResultToLib(dsResult); + ASSERT_EQ(result_2.error.Code(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID); +} + TEST_F(HeteroFutureTest, TestConverDsStatusToAsyncRes) { datasystem::Status dsStatus1; diff --git a/test/libruntime/hetero_store_test.cpp b/test/libruntime/hetero_store_test.cpp index d50d9ed..267cf41 100644 --- a/test/libruntime/hetero_store_test.cpp +++ b/test/libruntime/hetero_store_test.cpp @@ -80,19 +80,19 @@ TEST_F(HeteroStoreTest, ShutdownTest) EXPECT_NO_THROW(heteroStore_->Shutdown()); } -TEST_F(HeteroStoreTest, DeleteTest) +TEST_F(HeteroStoreTest, DevDeleteTest) { std::vector objIds = {"obj1", "obj2"}; std::vector failedObjectIds; - auto err = heteroStore_->Delete(objIds, failedObjectIds); + auto err = heteroStore_->DevDelete(objIds, failedObjectIds); ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); } -TEST_F(HeteroStoreTest, LocalDeleteTest) +TEST_F(HeteroStoreTest, DevLocalDeleteTest) { std::vector objIds = {"obj1", "obj2"}; std::vector failedObjectIds; - auto err = heteroStore_->LocalDelete(objIds, failedObjectIds); + auto err = heteroStore_->DevLocalDelete(objIds, failedObjectIds); ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); } diff --git a/test/libruntime/http_utils_test.cpp b/test/libruntime/http_utils_test.cpp new file mode 100644 index 0000000..0f0d022 --- /dev/null +++ b/test/libruntime/http_utils_test.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/libruntime/utils/http_utils.h" +#include +#include + +namespace YR { +namespace test { +using namespace YR::utility; +using namespace YR::Libruntime; +class HttpUtilTest : public ::testing::Test { +protected: + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + } + + void TearDown() override {} +}; + +TEST_F(HttpUtilTest, SignHttpRequest) +{ + std::string url = + "/serverless/v2/functions/" + "wisefunction:cn:iot:8d86c63b22e24d9ab650878b75408ea6:function:0@faas@python:latest/invocations"; + std::string accessKey = "access_key"; + SensitiveValue secretKey = std::string("secret_key"); + std::unordered_map headers; + headers[TRACE_ID_KEY_NEW] = "traceId"; + headers[INSTANCE_CPU_KEY] = "500"; + headers[INSTANCE_MEMORY_KEY] = "300"; + std::string body = "123"; + SignHttpRequest(accessKey, secretKey, headers, body, url); + YRLOG_DEBUG(headers[AUTHORIZATION_KEY]); + ASSERT_FALSE(headers[AUTHORIZATION_KEY].empty()); + bool ok = VerifyHttpRequest(accessKey, secretKey, headers, body, url); + ASSERT_TRUE(ok); + auto digest = GenerateRequestDigest(headers, body, url); + YRLOG_DEBUG(digest); + ASSERT_TRUE(!digest.empty()); +} +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/https_client_test.cpp b/test/libruntime/https_client_test.cpp new file mode 100644 index 0000000..d190a28 --- /dev/null +++ b/test/libruntime/https_client_test.cpp @@ -0,0 +1,374 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#define private public +#include "httpserver/async_https_server.h" +#include "src/libruntime/gwclient/http/client_manager.h" +#include "src/libruntime/gwclient/http/http_client.h" +namespace YR { +namespace test { +using namespace YR::Libruntime; + +class HttpsClientTest : public ::testing::Test { +public: + HttpsClientTest() {} + ~HttpsClientTest() {} + void SetUp() override + { + httpsServer_ = std::make_shared(); + YR::utility::LogParam logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-https", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + YR::utility::InitLog(logParam); + } + void TearDown() override + { + if (httpsServer_) { + httpsServer_.reset(); + } + } + +private: + std::shared_ptr httpsServer_; + std::string ip_ = "127.0.0.1"; + unsigned short port_ = 12346; + int threadNum = 8; +}; + +std::shared_ptr ConstructLibruntimeConfig() +{ + std::shared_ptr librtCfg = std::make_shared(); + librtCfg->enableMTLS = true; + librtCfg->verifyFilePath = "./test/data/cert/ca.crt"; + librtCfg->certificateFilePath = "./test/data/cert/client.crt"; + std::strcpy(librtCfg->privateKeyPaaswd, "test"); + librtCfg->privateKeyPath = "./test/data/cert/client.key"; + // The serverName is not verified. + librtCfg->serverName = "test"; + return librtCfg; +} + +std::shared_ptr ConstructSslContext() +{ + try { + auto ctx = std::make_shared(ssl::context::tlsv12); + ctx->set_options(boost::asio::ssl::context::default_workarounds | boost::asio::ssl::context::no_sslv2); + ctx->load_verify_file("./test/data/cert/ca.crt"); + ctx->use_certificate_chain_file("./test/data/cert/server.crt"); + ctx->set_password_callback( + [](std::size_t max_length, ssl::context_base::password_purpose purpose) { return "test"; }); + ctx->use_private_key_file("./test/data/cert/server.key", ssl::context::pem); + return ctx; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return nullptr; + } +} + +TEST_F(HttpsClientTest, InitFailed) +{ + auto librtCfg = ConstructLibruntimeConfig(); + auto httpClient = std::make_unique(librtCfg); + auto err = httpClient->Init({"127.0.0.1.0", "12346"}); + ASSERT_EQ(err.OK(), false); +} + +TEST_F(HttpsClientTest, SubmitTask) +{ + auto ctx = ConstructSslContext(); + ASSERT_TRUE(ctx != nullptr); + if (httpsServer_->StartServer(ip_, port_, threadNum, ctx)) { + std::cout << "start https server success" << std::endl; + } else { + std::cout << "start https server failed" << std::endl; + } + auto librtCfg = ConstructLibruntimeConfig(); + librtCfg->httpIocThreadsNum = 5; + auto httpClient = std::make_unique(librtCfg); + auto err = httpClient->Init({"127.0.0.1", "12346"}); + ASSERT_EQ(err.OK(), true); + + std::unordered_map headers; + headers.emplace("type", "test"); + std::string urn = "/test"; + auto retPromise = std::make_shared>(); + auto future = retPromise->get_future(); + auto requestId = std::make_shared("requestID"); + httpClient->SubmitInvokeRequest( + GET, urn, headers, "", requestId, + [retPromise](const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + if (errorCode) { + std::cerr << "network error, error_code: " << errorCode.message() << std::endl; + } else { + retPromise->set_value(result); + } + }); + ASSERT_EQ("ok", future.get()); + httpsServer_->StopServer(); +} + +/*case + * @title: Server故障恢复后发送请求成功 + * @precondition: + * @step: 1. 启动HttpsServer + * @step: 2. 创建客户端连接 + * @step: 3. 停止HttpsServer + * @step: 4. 恢复HttpsServer + * @step: 5. 发送https请求 + * @expect: 1.用例不卡住 + */ +TEST_F(HttpsClientTest, after_https_server_recover_request_should_return) +{ + auto ctx = ConstructSslContext(); + ASSERT_TRUE(ctx != nullptr); + if (httpsServer_->StartServer(ip_, port_, threadNum, ctx)) { + std::cout << "start https server success" << std::endl; + } else { + std::cout << "start https server failed" << std::endl; + } + auto librtCfg = ConstructLibruntimeConfig(); + librtCfg->httpIocThreadsNum = 5; + auto httpClient = std::make_unique(librtCfg); + auto err = httpClient->Init({"127.0.0.1", "12346"}); + ASSERT_EQ(err.OK(), true); + httpsServer_->StopServer(); + if (httpsServer_->StartServer(ip_, port_, threadNum, ctx)) { + std::cout << "start https server success" << std::endl; + } else { + std::cout << "start https server failed" << std::endl; + } + std::unordered_map headers; + headers.emplace("type", "test"); + std::string urn = "/test"; + int num = 0; + auto promise = std::make_shared>(); + auto future = promise->get_future(); + + auto requestId = std::make_shared("requestID"); + std::mutex mtx; + int sendTimes = 10; + auto sendMsgHandler = [&]() { + httpClient->SubmitInvokeRequest( + GET, urn, headers, "", requestId, + [&](const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + if (errorCode) { + std::cerr << "network error, error_code: " << errorCode.message() << std::endl; + } else { + std::cout << "request success" << std::endl; + } + std::lock_guard lockGuard(mtx); + std::cout << "num: " << num << std::endl; + num++; + if (num == sendTimes) { + promise->set_value(num); + } + }); + }; + for (int i = 0; i < sendTimes; i++) { + sendMsgHandler(); + } + ASSERT_EQ(sendTimes, future.get()); + httpsServer_->StopServer(); +} + +TEST_F(HttpsClientTest, after_reinit_should_not_coredump) +{ + auto ctx = ConstructSslContext(); + ASSERT_TRUE(ctx != nullptr); + if (httpsServer_->StartServer(ip_, port_, threadNum, ctx)) { + std::cout << "start https server success" << std::endl; + } else { + std::cout << "start https server failed" << std::endl; + } + auto librtCfg = ConstructLibruntimeConfig(); + librtCfg->httpIocThreadsNum = 5; + // Change the value of idleTime to a smaller value to trigger reinit. + librtCfg->httpIdleTime = 1; + auto httpClient = std::make_unique(librtCfg); + auto err = httpClient->Init({"127.0.0.1", "12346"}); + std::unordered_map headers; + headers.emplace("type", "test"); + std::string urn = "/test"; + for (int i = 0; i < 3; i++) { + int num = 0; + auto promise = std::make_shared>(); + auto future = promise->get_future(); + auto requestId = std::make_shared("requestID"); + std::mutex mtx; + int sendTimes = 10; + auto sendMsgHandler = [&]() { + httpClient->SubmitInvokeRequest( + GET, urn, headers, "", requestId, + [&](const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + std::lock_guard lockGuard(mtx); + std::cout << "num: " << num << std::endl; + num++; + if (num == sendTimes) { + promise->set_value(num); + } + }); + }; + for (int j = 0; j < sendTimes; j++) { + sendMsgHandler(); + } + ASSERT_EQ(sendTimes, future.get()); + sleep(1); + } + httpsServer_->StopServer(); +} + +TEST_F(HttpsClientTest, trigger_init_more_client_should_not_coredump) +{ + auto ctx = ConstructSslContext(); + ASSERT_TRUE(ctx != nullptr); + if (httpsServer_->StartServer(ip_, port_, threadNum, ctx)) { + std::cout << "start https server success" << std::endl; + } else { + std::cout << "start https server failed" << std::endl; + } + auto librtCfg = ConstructLibruntimeConfig(); + librtCfg->httpIocThreadsNum = 20; + auto httpClient = std::make_unique(librtCfg); + auto err = httpClient->Init({"127.0.0.1", "12346"}); + std::unordered_map headers; + headers.emplace("type", "test"); + std::string urn = "/test"; + auto promise = std::make_shared>(); + auto future = promise->get_future(); + auto sharedFuture = future.share(); + auto requestId = std::make_shared("requestID"); + // default http init is 10, sendTimes = 15 trigger init more client + int sendTimes = 15; + std::atomic get{0}; + auto sendMsgHandler = [&]() { + httpClient->SubmitInvokeRequest( + GET, urn, headers, "", requestId, + [&](const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + sharedFuture.get(); + get++; + }); + }; + for (int j = 0; j < sendTimes; j++) { + sendMsgHandler(); + } + promise->set_value(1); + for (;;) { + if (get == sendTimes) { + ASSERT_EQ(get, sendTimes); + break; + } + std::this_thread::yield(); + } + EXPECT_NO_THROW(httpsServer_->StopServer()); +} + +// Manual test case requires using tc to construct network latency. +TEST_F(HttpsClientTest, DISABLED_Reinit_Retry_When_Network_latency) +{ + auto ctx = ConstructSslContext(); + ASSERT_TRUE(ctx != nullptr); + if (httpsServer_->StartServer(ip_, port_, threadNum, ctx)) { + std::cout << "start https server success" << std::endl; + } else { + std::cout << "start https server failed" << std::endl; + } + auto librtCfg = ConstructLibruntimeConfig(); + librtCfg->httpIocThreadsNum = 5; + // Change the value of idleTime to a smaller value to trigger reinit. + librtCfg->httpIdleTime = 1; + auto httpClient = std::make_unique(librtCfg); + auto err = httpClient->Init({"127.0.0.1", "12346", 1}); + std::unordered_map headers; + headers.emplace("type", "test"); + std::string urn = "/test"; + + std::promise waitPromise; + auto waitFuture = waitPromise.get_future(); + waitFuture.wait_for(std::chrono::seconds(10)); + // tc qdisc add dev eth0 root netem delay 3s + + auto requestId = std::make_shared("requestID"); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + httpClient->SubmitInvokeRequest( + GET, urn, headers, "", requestId, + [&](const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + if (errorCode) { + YRLOG_DEBUG("error {}", errorCode.to_string()); + promise->set_value(0); + } else { + YRLOG_DEBUG("response ok"); + promise->set_value(1); + } + }); + ASSERT_EQ(1, future.get()); + httpsServer_->StopServer(); +} + +TEST_F(HttpsClientTest, DISABLED_TLSVerify) +{ + auto ctx = ConstructSslContext(); + ASSERT_TRUE(ctx != nullptr); + if (httpsServer_->StartServer(ip_, port_, threadNum, ctx)) { + std::cout << "start https server success" << std::endl; + } else { + std::cout << "start https server failed" << std::endl; + } + auto librtCfg = std::make_shared(); + librtCfg->httpIocThreadsNum = 5; + librtCfg->enableTLS = true; + librtCfg->serverName = "test"; + librtCfg->verifyFilePath = "./test/data/cert/ca.crt"; + auto httpClient = std::make_unique(librtCfg); + auto err = httpClient->Init({"127.0.0.1", "12346"}); + ASSERT_EQ(err.OK(), true); + + std::unordered_map headers; + headers.emplace("type", "test"); + std::string urn = "/test"; + auto retPromise = std::make_shared>(); + auto future = retPromise->get_future(); + auto requestId = std::make_shared("requestID"); + httpClient->SubmitInvokeRequest( + GET, urn, headers, "", requestId, + [retPromise](const std::string &result, const boost::beast::error_code &errorCode, const uint statusCode) { + if (errorCode) { + std::cerr << "network error, error_code: " << errorCode.message() << std::endl; + } else { + retPromise->set_value(result); + } + }); + ASSERT_EQ("ok", future.get()); + httpsServer_->StopServer(); +} +} // namespace test +} // namespace YR diff --git a/test/libruntime/instance_manager_test.cpp b/test/libruntime/instance_manager_test.cpp index 5fb52f7..5636f0d 100644 --- a/test/libruntime/instance_manager_test.cpp +++ b/test/libruntime/instance_manager_test.cpp @@ -58,7 +58,7 @@ public: cb = {}; auto reqMgr = std::make_shared(); auto librtCfg = std::make_shared(); - auto mockFsIntf = std::make_unique(); + auto mockFsIntf = std::make_unique(); auto fsClient = std::make_shared(std::move(mockFsIntf)); std::shared_ptr memoryStore = std::make_shared(); auto dsObjectStore = std::make_shared(); @@ -87,7 +87,7 @@ TEST_F(NormalInstanceManagerTest, ScheduleInsTest) "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Function}; spec->opts = {}; auto resource = GetRequestResource(spec); - auto [insId, leaseId] = insManager->ScheduleIns(resource); + auto [insId, leaseId] = insManager->GetAvailableIns(resource); ASSERT_EQ(insId.empty(), true); std::unordered_map> instanceInfos; @@ -101,7 +101,7 @@ TEST_F(NormalInstanceManagerTest, ScheduleInsTest) requestResourceInfo->instanceInfos = instanceInfos; requestResourceInfo->avaliableInstanceInfos = instanceInfos; insManager->requestResourceInfoMap[resource] = requestResourceInfo; - auto [id, lId] = insManager->ScheduleIns(resource); + auto [id, lId] = insManager->GetAvailableIns(resource); ASSERT_EQ(id.empty(), false); insManager->Stop(); } diff --git a/test/libruntime/invoke_adaptor_test.cpp b/test/libruntime/invoke_adaptor_test.cpp index 269992f..aaab6a4 100644 --- a/test/libruntime/invoke_adaptor_test.cpp +++ b/test/libruntime/invoke_adaptor_test.cpp @@ -32,6 +32,8 @@ #define private public #include "src/libruntime/err_type.h" +#include "src/libruntime/generator/stream_generator_notifier.h" +#include "src/libruntime/generator/stream_generator_receiver.h" #include "src/libruntime/groupmanager/group.h" #include "src/libruntime/invokeadaptor/invoke_adaptor.h" #include "src/libruntime/objectstore/datasystem_object_store.h" @@ -50,11 +52,10 @@ using json = nlohmann::json; namespace YR { namespace Libruntime { -bool ParseRequest(const CallRequest &request, std::vector> &rawArgs, - std::shared_ptr memStore, bool isPosix); bool ParseMetaData(const CallRequest &request, bool isPosix, libruntime::MetaData &metaData); bool ParseFunctionGroupRunningInfo(const CallRequest &request, bool isPosix, common::FunctionGroupRunningInfo &runningInfo); +bool ParseFaasController(const CallRequest &request, std::vector> &rawArgs, bool isPosix); } // namespace Libruntime namespace test { @@ -95,16 +96,22 @@ public: libConfig->inCluster = false; this->memoryStore->Init(dsObjectStore, wom); auto dependencyResolver = std::make_shared(this->memoryStore); - this->fsIntf = std::make_shared(); + this->gwClient = std::make_shared(); + this->taskSubmitter = std::make_shared(); auto cb = []() { return; }; auto clientsMgr = std::make_shared(); auto metricsAdaptor = std::make_shared(); - auto fsClient = std::make_shared(this->fsIntf); + auto fsClient = std::make_shared(this->gwClient); + auto mapper = std::make_shared(); + auto dsStreamStore = std::shared_ptr(); + auto generatorReceiver = std::make_shared(libConfig, dsStreamStore, this->memoryStore); + auto generatorNotifier = std::make_shared(dsStreamStore, mapper); auto rGroupManager = std::make_shared(); + auto security = std::make_shared(); this->invokeAdaptor = std::make_shared(libConfig, dependencyResolver, fsClient, this->memoryStore, runtimeContext, cb, nullptr, std::make_shared(), clientsMgr, - metricsAdaptor); + metricsAdaptor, mapper, generatorReceiver, generatorNotifier); invokeAdaptor->SetRGroupManager(rGroupManager); invokeAdaptor->SetCallbackOfSetTenantId([]() {}); this->invokeAdaptor->Init(*runtimeContext, nullptr); @@ -113,15 +120,16 @@ public: void TearDown() override { CloseGlobalTimer(); - this->fsIntf.reset(); + this->gwClient.reset(); this->libConfig.reset(); this->memoryStore.reset(); this->invokeAdaptor.reset(); } - std::shared_ptr fsIntf; + std::shared_ptr gwClient; std::shared_ptr memoryStore; std::shared_ptr invokeAdaptor; std::shared_ptr libConfig; + std::shared_ptr taskSubmitter; }; TEST_F(InvokeAdaptorTest, ParseInvokeRequestTest) @@ -139,7 +147,7 @@ TEST_F(InvokeAdaptorTest, ParseInvokeRequestTest) pbArg2->set_value(objId.c_str(), objId.size()); std::vector> rawArgs; - bool ok = YR::Libruntime::ParseRequest(request, rawArgs, memoryStore, false); + bool ok = invokeAdaptor->ParseRequest(request, rawArgs, false); ASSERT_EQ(ok, true); libruntime::MetaData metaData; ok = YR::Libruntime::ParseMetaData(request, false, metaData); @@ -149,7 +157,7 @@ TEST_F(InvokeAdaptorTest, ParseInvokeRequestTest) auto pbArg3 = request.add_args(); pbArg3->set_type(common::Arg::VALUE); pbArg3->set_value(objId.c_str(), objId.size()); - ok = YR::Libruntime::ParseRequest(request, rawArgs, memoryStore, false); + ok = invokeAdaptor->ParseRequest(request, rawArgs, false); ASSERT_EQ(ok, true); } @@ -168,7 +176,7 @@ TEST_F(InvokeAdaptorTest, ParseCreateRequestTest) pbArg2->set_value(objId.c_str(), objId.size()); std::vector> rawArgs; - bool ok = YR::Libruntime::ParseRequest(request, rawArgs, memoryStore, false); + bool ok = invokeAdaptor->ParseRequest(request, rawArgs, false); ASSERT_EQ(ok, true); libruntime::MetaData metaData; ok = YR::Libruntime::ParseMetaData(request, false, metaData); @@ -181,7 +189,7 @@ TEST_F(InvokeAdaptorTest, ParseYrCreateRequestTest) CallRequest request; request.set_iscreate(true); std::vector> rawArgs; - bool ok = YR::Libruntime::ParseRequest(request, rawArgs, memoryStore, true); + bool ok = invokeAdaptor->ParseRequest(request, rawArgs, true); ASSERT_EQ(ok, true); ASSERT_EQ(rawArgs.empty(), true); libruntime::MetaData metaData; @@ -195,7 +203,7 @@ TEST_F(InvokeAdaptorTest, ParseYrInvokeRequestTest) CallRequest request; request.set_iscreate(false); std::vector> rawArgs; - bool ok = YR::Libruntime::ParseRequest(request, rawArgs, memoryStore, true); + bool ok = invokeAdaptor->ParseRequest(request, rawArgs, true); ASSERT_EQ(ok, true); ASSERT_EQ(rawArgs.empty(), true); libruntime::MetaData metaData; @@ -206,6 +214,30 @@ TEST_F(InvokeAdaptorTest, ParseYrInvokeRequestTest) TEST_F(InvokeAdaptorTest, PrepareCallExecutorTest) { + auto memStore = std::make_shared(); + auto dependencyResolver = std::make_shared(memStore); + auto librtCfg = std::make_shared(); + auto runtimeContext = std::make_shared(YR::utility::IDGenerator::GenApplicationId()); + auto execMgr = std::make_shared(); + FSIntfHandlers handlers; + + auto gwClient = std::make_shared(librtCfg->functionIds[libruntime::LanguageType::Cpp], handlers); + auto cb = []() { return; }; + auto clientsMgr = std::make_shared(); + auto metricsAdaptor = std::make_shared(); + auto fsClient = std::make_shared(gwClient); + auto mapper = std::make_shared(); + auto dsStreamStore = std::shared_ptr(); + auto generatorReceiver = std::make_shared(librtCfg, dsStreamStore, memStore); + auto generatorNotifier = std::make_shared(dsStreamStore, mapper); + auto rGroupManager = std::make_shared(); + auto invokeAdaptor = std::make_shared(librtCfg, dependencyResolver, fsClient, memStore, + runtimeContext, cb, nullptr, execMgr, clientsMgr, + metricsAdaptor, mapper, generatorReceiver, generatorNotifier); + invokeAdaptor->SetRGroupManager(rGroupManager); + invokeAdaptor->SetCallbackOfSetTenantId([]() {}); + invokeAdaptor->Init(*runtimeContext, nullptr); + struct MyTests { int concurrency; common::ErrorCode errCode; @@ -244,6 +276,7 @@ TEST_F(InvokeAdaptorTest, CallTest) req.set_senderid("instance_id"); req.set_iscreate(true); + libruntime::MetaData metaData; auto result = invokeAdaptor->Call(req, metaData, options, objectsInDs); ASSERT_EQ(result.code(), ::common::ERR_NONE); @@ -266,6 +299,31 @@ TEST_F(InvokeAdaptorTest, CallTest) }; auto result2 = invokeAdaptor->Call(req, metaData, options, objectsInDs); ASSERT_EQ(result2.code(), ::common::ErrorCode::ERR_INNER_SYSTEM_ERROR); + + req.add_args(); + auto arg2 = req.add_args(); + arg2->set_type(Arg_ArgType::Arg_ArgType_VALUE); + arg2->set_value("{\"schedulerFuncK{\"schedulerFuncKey\":\"0/0-system-faasscheduler/$latest\",\"schedulerInstanceList\":[{\"instanceName\":\"abfe9e68-9221-4b97-8e85-87b5b5faf69c\",\"instanceId\":\"abfe9e68-9221-4b97-8e85-87b5b5faf69c\"},{\"instanceName\":\"2db4a71b-157c-4ec2-95d7-c70fccc85dfa\",\"instanceId\":\"abfe9e68-9221-4b97-8e85-87b5b5faf69c\"}]}"); + + options.functionExecuteCallback = [](const FunctionMeta &function, const libruntime::InvokeType invokeType, + const std::vector> &rawArgs, + std::vector> &returnValues) -> ErrorInfo { + return ErrorInfo(ErrorCode::ERR_OK, ModuleCode::RUNTIME, "test"); + }; + + std::vector schedulerInstanceList; + this->taskSubmitter = std::make_shared(); + EXPECT_CALL(*(this->taskSubmitter), UpdateFaaSSchedulerInfo(_, _)) + .WillOnce([=](std::string schedulerFuncKey, const std::vector &sInstanceList) { + ASSERT_EQ(sInstanceList.size(), 2); + return; + }); + invokeAdaptor->taskSubmitter = this->taskSubmitter; + libruntime::FunctionMeta functionMeta; + functionMeta.set_apitype(libruntime::ApiType::Faas); + metaData.set_allocated_functionmeta(&functionMeta); + invokeAdaptor->Call(req, metaData, options, objectsInDs); + metaData.release_functionmeta(); } TEST_F(InvokeAdaptorTest, InitCallTest) @@ -306,7 +364,7 @@ TEST_F(InvokeAdaptorTest, CreateInstanceTest) invokeSpec->returnIds = returnObjs; invokeSpec->BuildInstanceCreateRequest(cfg); invokeAdaptor->CreateInstance(invokeSpec); - auto [rawRequestId, seq] = YR::utility::IDGenerator::DecodeRawRequestId(invokeSpec->requestCreate.requestid()); + auto [rawRequestId, seq] =YR::utility::IDGenerator::DecodeRawRequestId(invokeSpec->requestCreate.requestid()); EXPECT_EQ(rawRequestId, invokeSpec->requestId); EXPECT_EQ(seq, 1); } @@ -361,6 +419,59 @@ TEST_F(InvokeAdaptorTest, SubmitFunctionWithFunctionGroupTest) ASSERT_TRUE(invokeAdaptor->groupManager->IsGroupExist("groupName")); } +TEST_F(InvokeAdaptorTest, SubmitFunctionWithAliasTest) +{ + auto cfg = LibruntimeConfig(); + auto invokeSpec = std::make_shared(); + invokeSpec->requestId = "cae7c30c8d63f5ed00"; + std::vector returnObjs{DataObject("returnID")}; + invokeSpec->returnIds = returnObjs; + invokeSpec->jobId = YR::utility::IDGenerator::GenApplicationId(); + auto opts = InvokeOptions(); + opts.groupName = "groupName"; + FunctionGroupOptions opt; + opt.functionGroupSize = 8; + opt.bundleSize = 2; + opts.functionGroupOpts = opt; + + std::unordered_map params; + params["userType"] = "VIP"; + params["age"] = "10"; + params["devType"] = "MATE40"; + + opts.aliasParams = params; + invokeSpec->functionMeta.functionId = "12345678901234561234567890123456/helloworld/myaliasv1"; + invokeSpec->opts = opts; + invokeSpec->BuildInstanceCreateRequest(cfg); + + std::vector g_aes_rule = { + { + .aliasUrn = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasv1", + .functionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld", + .functionVersionUrn = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + .name = "myaliasv1", + .functionVersion = "$latest", + .revisionId = "20210617023315921", + .description = "fake_description", + .routingType = "rule", + .routingRules = + { + .ruleLogic = "and", + .rules = + { + "userType:=:VIP", + "age:<=:20", + "devType:in:P40,P50,MATE40", + }, + .grayVersion = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:1", + }, + }, + }; + invokeAdaptor->ar->UpdateAliasInfo(g_aes_rule); + invokeAdaptor->SubmitFunction(invokeSpec); + ASSERT_EQ(invokeSpec->functionMeta.functionId, "12345678901234561234567890123456/helloworld/1"); +} + TEST_F(InvokeAdaptorTest, CreateResponseHandlerTest) { CreateResponse resp; @@ -493,6 +604,12 @@ TEST_F(InvokeAdaptorTest, RangeCreateTest) ASSERT_EQ(invokeAdaptor->RangeCreate("groupName", range).Code(), ErrorCode::ERR_PARAM_INVALID); } +TEST_F(InvokeAdaptorTest, ReleaseInstanceTest) +{ + InvokeOptions opts; + ASSERT_EQ(invokeAdaptor->ReleaseInstance("leaseId", "", true, opts).OK(), true); +} + TEST_F(InvokeAdaptorTest, SubscribeAllTest) { libruntime::FunctionMeta meta; @@ -538,6 +655,7 @@ TEST_F(InvokeAdaptorTest, CreateNotifyHandlerTest) TEST_F(InvokeAdaptorTest, TestFinalize) { EXPECT_NO_THROW(invokeAdaptor->Finalize(false)); + EXPECT_NE(invokeAdaptor->functionMasterClient_->work_, nullptr); } std::string g_alias = R"( @@ -568,6 +686,31 @@ std::string g_alias = R"( }] )"; +TEST_F(InvokeAdaptorTest, ParseAliasInfoTest) +{ + SignalRequest req; + req.set_payload(g_alias); + std::vector aliasInfo; + invokeAdaptor->ParseAliasInfo(req, aliasInfo); + ASSERT_EQ(aliasInfo.size(), 1); + + auto &alias = aliasInfo[0]; + ASSERT_EQ(alias.aliasUrn, "fake_alias_urn"); + ASSERT_EQ(alias.functionUrn, "fake_function_urn"); + ASSERT_EQ(alias.functionVersionUrn, "fake_function_version_urn"); + ASSERT_EQ(alias.name, "fake_name"); + ASSERT_EQ(alias.functionVersion, "fake_function_version"); + ASSERT_EQ(alias.revisionId, "fake_revision_id"); + ASSERT_EQ(alias.description, "fake_description"); + ASSERT_EQ(alias.routingType, "rule"); + ASSERT_EQ(alias.routingRules.ruleLogic, "and"); + ASSERT_EQ(alias.routingRules.rules.size(), 3); + ASSERT_EQ(alias.routingRules.rules[0], "userType:=:VIP"); + ASSERT_EQ(alias.routingRules.rules[1], "age:<=:20"); + ASSERT_EQ(alias.routingRules.rules[2], "devType:in:P40,P50,MATE40"); + ASSERT_EQ(alias.routingRules.grayVersion, "fake_gray_version"); +} + TEST_F(InvokeAdaptorTest, ExecSignalCallbackNullptrTest) { libConfig->libruntimeOptions.signalCallback = nullptr; @@ -630,6 +773,13 @@ TEST_F(InvokeAdaptorTest, SignalHandlerTest) req.set_signal(libruntime::Signal::ErasePendingThread); ASSERT_NO_THROW(invokeAdaptor->SignalHandler(req)); + req.set_signal(libruntime::Signal::UpdateAlias); + std::vector aliasInfo = {YR::Libruntime::AliasElement()}; + json j = aliasInfo; + req.set_payload(j.dump()); + auto response = invokeAdaptor->SignalHandler(req); + ASSERT_EQ(response.code(), ::common::ErrorCode::ERR_NONE); + req.set_signal(libruntime::Signal::Update); NotificationPayload notifyscription; InstanceTermination *termination = notifyscription.mutable_instancetermination(); @@ -639,14 +789,11 @@ TEST_F(InvokeAdaptorTest, SignalHandlerTest) req.set_payload(serializedPayload); libruntime::FunctionMeta funcMeta; invokeAdaptor->metaMap["insId"] = funcMeta; - auto response = invokeAdaptor->SignalHandler(req); + response = invokeAdaptor->SignalHandler(req); ASSERT_EQ(invokeAdaptor->metaMap.size() == 0, true); - req.set_signal(libruntime::Signal::UpdateScheduler); - req.set_payload( - "{\"schedulerFuncKey\":\"0/0-system-faasscheduler/" - "$latest\",\"schedulerIDList\":[\"abfe9e68-9221-4b97-8e85-87b5b5faf69c\",\"2db4a71b-157c-4ec2-95d7-" - "c70fccc85dfa\"]}"); + req.set_signal(libruntime::Signal::UpdateSchedulerHash); + req.set_payload("{\"schedulerFuncKey\":\"0/0-system-faasscheduler/$latest\",\"schedulerIDList\":null,\"schedulerInstanceList\":[{\"instanceName\":\"abfe9e68-9221-4b97-8e85-87b5b5faf69c\",\"instanceId\":\"abfe9e68-9221-4b97-8e85-87b5b5faf69c\"},{\"instanceName\":\"2db4a71b-157c-4ec2-95d7-c70fccc85dfa\",\"instanceId\":\"abfe9e68-9221-4b97-8e85-87b5b5faf69c\"}]}"); response = invokeAdaptor->SignalHandler(req); ASSERT_EQ(response.code(), ::common::ErrorCode::ERR_NONE); @@ -665,6 +812,21 @@ TEST_F(InvokeAdaptorTest, SignalHandlerTest) req.set_signal(libruntime::Signal::UpdateSchedulerHash); response = invokeAdaptor->SignalHandler(req); ASSERT_EQ(response.code(), ::common::ErrorCode::ERR_NONE); + + core_service::NotificationPayload notifyPayload; + auto *functionmasterevent = notifyPayload.mutable_functionmasterevent(); + functionmasterevent->set_address("127.0.0.1:8080"); + std::string payload; + notifyPayload.SerializeToString(&payload); + req.set_signal(libruntime::Signal::Update); + req.set_payload(payload); + response = invokeAdaptor->SignalHandler(req); + ASSERT_EQ(response.code(), ::common::ErrorCode::ERR_NONE); + + invokeAdaptor->isRunning = false; + response = invokeAdaptor->SignalHandler(req); + ASSERT_EQ(response.code(), ::common::ErrorCode::ERR_NONE); + invokeAdaptor->isRunning = true; } TEST_F(InvokeAdaptorTest, CreateInstanceRawTest) @@ -698,15 +860,15 @@ TEST_F(InvokeAdaptorTest, CreateInstanceRawTest) ASSERT_EQ(info2.OK(), true); ASSERT_EQ(instanceId, "58f32000-0000-4000-8000-0ecfe00dd5e5"); - this->fsIntf->isReqNormal = false; - this->fsIntf->callbackPromise = std::promise(); - this->fsIntf->callbackFuture = this->fsIntf->callbackPromise.get_future(); + this->gwClient->isReqNormal = false; + this->gwClient->callbackPromise = std::promise(); + this->gwClient->callbackFuture = this->gwClient->callbackPromise.get_future(); invokeAdaptor->CreateInstanceRaw(reqRaw2, cb); - this->fsIntf->callbackFuture.get(); + this->gwClient->callbackFuture.get(); ASSERT_EQ(info.OK(), true); ASSERT_EQ(info2.OK(), false); ASSERT_EQ(instanceId, "58f32000-0000-4000-8000-0ecfe00dd5e5"); - this->fsIntf->isReqNormal = true; + this->gwClient->isReqNormal = true; } TEST_F(InvokeAdaptorTest, InvokeByInstanceIdRawTest) @@ -742,18 +904,17 @@ TEST_F(InvokeAdaptorTest, KillRawTest) RawCallback cb = [&info](const ErrorInfo &err, std::shared_ptr resultRaw) { info.SetErrorCode(err.Code()); }; - invokeAdaptor->KillRaw(reqRaw, cb); ASSERT_EQ(info.OK(), true); } -TEST_F(InvokeAdaptorTest, ExecShutdownCallbackWithZeroDurationTest) +TEST_F(InvokeAdaptorTest, ExecShutdownCallbackWithMinusDurationTest) { libConfig->libruntimeOptions.shutdownCallback = [](uint64_t gracePeriodSeconds) -> ErrorInfo { return ErrorInfo(ErrorCode::ERR_OK, ModuleCode::RUNTIME, std::to_string(gracePeriodSeconds)); }; - int gracePeriodSec = 0; + int gracePeriodSec = -10; auto err = invokeAdaptor->ExecShutdownCallback(gracePeriodSec); ASSERT_EQ(err.Msg(), "Execute user shutdown callback timeout"); @@ -779,6 +940,18 @@ TEST_F(InvokeAdaptorTest, ParseFunctionGroupRunningInfoTest) ASSERT_EQ(res4, true); } +TEST_F(InvokeAdaptorTest, ParseFaasControllerTest) +{ + CallRequest request; + auto pbArg = request.add_args(); + pbArg->set_type(common::Arg::OBJECT_REF); + std::string objId("mock-123"); + pbArg->set_value(objId.c_str(), objId.size()); + std::vector> rawArgs; + auto res = ParseFaasController(request, rawArgs, true); + ASSERT_EQ(res, true); +} + TEST_F(InvokeAdaptorTest, InitHandlerTest) { std::shared_ptr req = std::make_shared(); @@ -806,6 +979,41 @@ TEST_F(InvokeAdaptorTest, InitHandlerTest) ASSERT_EQ(res, 2); } +TEST_F(InvokeAdaptorTest, PausedInitHandlerTest) +{ + std::shared_ptr req = std::make_shared(); + req->Mutable().set_requestid("fff87cc506e547d9"); + req->Mutable().set_senderid("fff87cc506e547d9"); + req->Mutable().set_iscreate(true); + int res = 0; + libConfig->libruntimeOptions.loadFunctionCallback = [&res](const std::vector &codePaths) -> ErrorInfo { + res++; + return ErrorInfo(); + }; + libConfig->libruntimeOptions.functionExecuteCallback = + [&res](const FunctionMeta &function, const libruntime::InvokeType invokeType, + const std::vector> &rawArgs, + std::vector> &returnValues) -> ErrorInfo { + res++; + return ErrorInfo(); + }; + auto pbArg = req->Mutable().add_args(); + pbArg->set_type(Arg_ArgType::Arg_ArgType_VALUE); + InvokeSpec invokeSpec; + invokeSpec.invokeType = libruntime::InvokeType::InvokeFunction; + pbArg->set_value(invokeSpec.BuildInvokeMetaData(*invokeAdaptor->librtConfig)); + + auto createOpt = req->Mutable().mutable_createoptions(); + nlohmann::json debugJson; + debugJson["enable"] = "true"; + createOpt->insert({"debug_config", debugJson.dump()}); + + bool triggered = false; + invokeAdaptor->setDebugBreakpoint_ = [&triggered]() { triggered = true; }; + invokeAdaptor->InitHandler(req); + ASSERT_TRUE(triggered); +} + TEST_F(InvokeAdaptorTest, CallHandlerTest) { std::shared_ptr req = std::make_shared(); @@ -969,12 +1177,34 @@ TEST_F(InvokeAdaptorTest, GetInstanceIdsTest) ASSERT_EQ(vec2.size() == 1, true); } +TEST_F(InvokeAdaptorTest, AcquireInstanceTest) +{ + auto invokeAdaptor = std::make_shared(); + invokeAdaptor->taskSubmitter = std::make_shared(); + invokeAdaptor->requestManager = std::make_shared(); + FunctionMeta meta; + meta.name = "name"; + InvokeOptions opts; + opts.schedulerInstanceIds.push_back("id"); + opts.traceId = "traceId"; + auto [res, err] = invokeAdaptor->AcquireInstance("stateId", meta, opts); + ASSERT_EQ(err.OK(), true); +} + +TEST_F(InvokeAdaptorTest, UpdateSchdulerInfoTest) +{ + auto invokeAdaptor = std::make_shared(); + invokeAdaptor->taskSubmitter = std::make_shared(); + ASSERT_NO_THROW(invokeAdaptor->UpdateSchdulerInfo("schedulerName", "schedulerId", "ADD")); +} + TEST_F(InvokeAdaptorTest, AdaptorGetInsTest) { + this->gwClient->isGetInstance = true; std::string name = "name"; std::string ns = "ns"; auto [res, err] = invokeAdaptor->GetInstance(name, ns, 60); - ASSERT_EQ(res.className, ""); + ASSERT_EQ(res.className, "classname"); auto [res1, err1] = invokeAdaptor->GetInstance(name, ns, 60); ASSERT_EQ(err1.OK(), true); ASSERT_EQ(invokeAdaptor->metaMap.size() == 1, true); @@ -989,17 +1219,18 @@ TEST_F(InvokeAdaptorTest, AdaptorGetInsTest) auto [res3, err3] = invokeAdaptor->GetInstance(name, ns, 60); ASSERT_EQ(err3.OK(), false); ASSERT_EQ(err3.Code(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID); + this->gwClient->isGetInstance = false; } TEST_F(InvokeAdaptorTest, UpdateAndSubcribeInsStatusTest) { libruntime::FunctionMeta funcMeta; funcMeta.set_classname("class_name"); - this->fsIntf->isReqNormal = false; + this->gwClient->isReqNormal = false; invokeAdaptor->UpdateAndSubcribeInsStatus("insId", funcMeta); - this->fsIntf->killCallbackFuture.get(); + this->gwClient->killCallbackFuture.get(); ASSERT_EQ(invokeAdaptor->metaMap.size() == 0, true); - this->fsIntf->isReqNormal = true; + this->gwClient->isReqNormal = true; } TEST_F(InvokeAdaptorTest, RemoveInsMetaInfoTest) @@ -1009,5 +1240,82 @@ TEST_F(InvokeAdaptorTest, RemoveInsMetaInfoTest) invokeAdaptor->RemoveInsMetaInfo("insId"); ASSERT_EQ(invokeAdaptor->metaMap.size() == 0, true); } + +TEST_F(InvokeAdaptorTest, CallTimerTest) +{ + libruntime::FunctionMeta funcMeta; + std::string reqId = "reqId"; + std::string insId = "insId"; + invokeAdaptor->CreateCallTimer(reqId, insId, 0); + ASSERT_EQ(invokeAdaptor->callTimeoutTimerMap_.size() == 0, true); + + invokeAdaptor->CreateCallTimer(reqId, insId, 100); + ASSERT_EQ(invokeAdaptor->callTimeoutTimerMap_.size() == 1, true); + invokeAdaptor->EraseCallTimer(reqId); + ASSERT_EQ(invokeAdaptor->callTimeoutTimerMap_.size() == 0, true); + + invokeAdaptor->CreateCallTimer(reqId, insId, 1); + auto called = std::make_shared>(); + EXPECT_CALL(*this->gwClient, ReturnCallResult) + .WillOnce(::testing::Invoke( + [called](const std::shared_ptr result, bool isCreate, CallResultCallBack callback) { + called->set_value(true); + })); + ASSERT_EQ(called->get_future().get(), true); +} + +TEST_F(InvokeAdaptorTest, MetricsTest) +{ + const std::string FileExporterJsonStr = R"( +{ + "backends": [ + { + "immediatelyExport": { + "name": "EDA", + "enable": true, + "exporters": [ + { + "fileExporter": { + "enable": true, + "fileDir": "/tmp/", + "rolling": { + "enable": true, + "maxFiles": 3, + "maxSize": 10000 + }, + "contentType": "STANDARD" + } + } + ] + } + } + ] +} + )"; + // empty metric -> expect init not to be called + invokeAdaptor->InitMetricsAdaptor(true); + ASSERT_EQ(MetricsAdaptor::GetInstance()->userEnable_, false); + Config::Instance().METRICS_CONFIG_ = "invalid json"; + invokeAdaptor->InitMetricsAdaptor(true); + ASSERT_EQ(MetricsAdaptor::GetInstance()->userEnable_, false); + // 创建输出文件流 + std::string file = "./metric.json"; + std::ofstream outFile(file); + if (!outFile) { + return; + } + outFile << FileExporterJsonStr; + outFile.close(); + Config::Instance().METRICS_CONFIG_FILE_ = file; + Config::Instance().METRICS_CONFIG_ = ""; + // valid metric file -> expected successful called + invokeAdaptor->InitMetricsAdaptor(true); + ASSERT_EQ(MetricsAdaptor::GetInstance()->userEnable_, true); + Config::Instance().ENABLE_METRICS_ = true; + invokeAdaptor->ReportMetrics("request", "trace", 1); + std::remove(file.c_str()); + Config::Instance().METRICS_CONFIG_FILE_ = ""; + Config::Instance().ENABLE_METRICS_ = false; +} } // namespace test } // namespace YR diff --git a/test/libruntime/invoke_order_manager_test.cpp b/test/libruntime/invoke_order_manager_test.cpp index 2c7f5e0..b0ef8b9 100644 --- a/test/libruntime/invoke_order_manager_test.cpp +++ b/test/libruntime/invoke_order_manager_test.cpp @@ -113,5 +113,23 @@ TEST_F(InvokeOrderManagerTest, RemoveInstanceTest) invokeOrderMgr->RemoveInstance(spec); ASSERT_EQ(invokeOrderMgr->instances.find("id") == invokeOrderMgr->instances.end(), true); } + +TEST_F(InvokeOrderManagerTest, RegisterInstanceAndUpdateOrderTest) +{ + auto invokeOrderMgr = std::make_shared(); + auto insId = "instanceId"; + invokeOrderMgr->RegisterInstanceAndUpdateOrder(insId); + ASSERT_EQ(invokeOrderMgr->instances.find("instanceId") != invokeOrderMgr->instances.end(), true); + ASSERT_EQ(invokeOrderMgr->instances["instanceId"]->orderingCounter, 1); +} + +TEST_F(InvokeOrderManagerTest, RegisterInstance) +{ + auto invokeOrderMgr = std::make_shared(); + auto insId = "instanceId"; + invokeOrderMgr->RegisterInstance(insId); + ASSERT_EQ(invokeOrderMgr->instances.find("instanceId") != invokeOrderMgr->instances.end(), true); + ASSERT_EQ(invokeOrderMgr->instances["instanceId"]->orderingCounter, 0); +} } // namespace test } // namespace YR \ No newline at end of file diff --git a/test/libruntime/invoke_spec_test.cpp b/test/libruntime/invoke_spec_test.cpp index 2b73afa..89e12ca 100644 --- a/test/libruntime/invoke_spec_test.cpp +++ b/test/libruntime/invoke_spec_test.cpp @@ -82,11 +82,16 @@ TEST_F(InvokeSpecTest, BuildRequestPbOptions) spec->opts.customExtensions["DELEGATE_DIRECTORY_QUOTA"] = "/tmp1"; spec->opts.customExtensions["DELEGATE_DIRECTORY_INFO"] = "1024"; spec->opts.recoverRetryTimes = 3; + spec->opts.debug.enable = true; spec->invokeType = libruntime::InvokeType::CreateInstanceStateless; spec->BuildRequestPbOptions(spec->opts, conf, req); ASSERT_EQ(req.createoptions().at("DELEGATE_DIRECTORY_QUOTA"), "/tmp1"); ASSERT_EQ(req.createoptions().at("DELEGATE_DIRECTORY_INFO"), "1024"); ASSERT_EQ(req.createoptions().at(RECOVER_RETRY_TIMES), "3"); + nlohmann::json debugJson = nlohmann::json::parse(req.createoptions().at("debug_config")); + auto debugMap = debugJson.get>(); + ASSERT_EQ(debugMap["enable"], "true"); + const auto &jsonString = req.createoptions().at(DELEGATE_ENV_VAR); nlohmann::json jsonObj = nlohmann::json::parse(jsonString); auto envsMap = jsonObj.get>(); @@ -129,6 +134,40 @@ TEST_F(InvokeSpecTest, BuildRequestPbArgsStringBuffer) EXPECT_TRUE(buf->GetSize() == 0); // use move semantic } +TEST_F(InvokeSpecTest, GetSchedulerInstanceIds) +{ + spec->opts.schedulerInstanceIds.push_back("00"); + auto vec = spec->GetSchedulerInstanceIds(); + EXPECT_TRUE(vec.size() == 1); + EXPECT_TRUE(vec[0] == "00"); + spec->opts.schedulerInstanceIds.push_back("11"); + EXPECT_TRUE(vec.size() == 1); + EXPECT_TRUE(vec[0] == "00"); + auto vecNew = spec->GetSchedulerInstanceIds(); + EXPECT_TRUE(vecNew.size() == 2); + EXPECT_TRUE(vecNew[1] == "11"); +} + +TEST_F(InvokeSpecTest, GetSchedulerInstanceId) +{ + std::string schedulerId = spec->GetSchedulerInstanceId(); + EXPECT_TRUE(schedulerId.empty()); + spec->opts.schedulerInstanceIds.push_back("00"); + schedulerId = spec->GetSchedulerInstanceId(); + EXPECT_TRUE(schedulerId == "00"); +} + +TEST_F(InvokeSpecTest, SetSchedulerInstanceId) +{ + spec->SetSchedulerInstanceId("00"); + std::string schedulerId = spec->GetSchedulerInstanceId(); + EXPECT_TRUE(schedulerId == "00"); + + spec->SetSchedulerInstanceId("11"); + schedulerId = spec->GetSchedulerInstanceId(); + EXPECT_TRUE(schedulerId == "11"); +} + TEST_F(InvokeSpecTest, BuildInvokeRequestPbOptionsTest) { LibruntimeConfig config; @@ -171,6 +210,23 @@ TEST_F(InvokeSpecTest, RequestResourceEqualTest) r5.opts.invokeLabels = invokelabels2; ASSERT_EQ(r4 == r5, false); ASSERT_EQ(r4 == r6, false); + + RequestResource r7; + RequestResource r8; + r7.opts.debug.enable = true; + ASSERT_EQ(r7 == r8, false); + + RequestResource r9; + RequestResource r10; + r9.functionMeta.languageType = libruntime::LanguageType::Cpp; + r10.functionMeta.languageType = libruntime::LanguageType::Cpp; + r9.opts.instanceSession = + std::make_shared(InstanceSession{.sessionID = "sessionID", .sessionTTL = 1, .concurrency = 1}); + r10.opts.instanceSession = + std::make_shared(InstanceSession{.sessionID = "sessionID", .sessionTTL = 1, .concurrency = 1}); + + ASSERT_EQ(r9 == r10, true); + ASSERT_EQ(r9.opts.instanceSession == r10.opts.instanceSession, false); } TEST_F(InvokeSpecTest, GetInstanceIdTest) diff --git a/test/libruntime/kv_state_store_test.cpp b/test/libruntime/kv_state_store_test.cpp index 02f0aed..25de58b 100644 --- a/test/libruntime/kv_state_store_test.cpp +++ b/test/libruntime/kv_state_store_test.cpp @@ -62,7 +62,7 @@ public: std::shared_ptr stateStore_; }; -TEST_F(KVStateStoreTest, KVWriteReadDel) +TEST_F(KVStateStoreTest, KVWriteReadDelExist) { std::string key = "123"; std::string key2 = "456"; @@ -84,17 +84,20 @@ TEST_F(KVStateStoreTest, KVWriteReadDel) // Read SingleReadResult readResult = stateStore_->Read(key, -1); ASSERT_EQ(readResult.second.Code(), ErrorCode::ERR_OK); - MultipleReadResult multiReadResult = stateStore_->Read({key, key}, 100, false); ASSERT_EQ(multiReadResult.second.Code(), ErrorCode::ERR_GET_OPERATION_FAILED); + // Exist + MultipleExistResult emptyExistResult = stateStore_->Exist({}); + ASSERT_EQ(emptyExistResult.second.Code(), ErrorCode::ERR_PARAM_INVALID); + MultipleExistResult existResult = stateStore_->Exist({key, key2}); + ASSERT_EQ(existResult.second.Code(), ErrorCode::ERR_OK); + // Del err = stateStore_->Del(key); ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); - MultipleDelResult mdResult = stateStore_->Del({key, key2}); ASSERT_EQ(mdResult.second.Code(), ErrorCode::ERR_OK); } - } // namespace test } // namespace YR diff --git a/test/libruntime/libruntime_config_test.cpp b/test/libruntime/libruntime_config_test.cpp index d1205ef..7baa037 100644 --- a/test/libruntime/libruntime_config_test.cpp +++ b/test/libruntime/libruntime_config_test.cpp @@ -79,5 +79,19 @@ TEST_F(LibruntimeConfigTest, InitFunctionGroupRunningInfoTest) ASSERT_EQ(config.groupRunningInfo.serverList.size(), 1); } +TEST_F(LibruntimeConfigTest, GetInstanceIdTest) +{ + LibruntimeConfig config; + libruntime::FunctionMeta meta; + meta.set_name("name"); + config.funcMeta = meta; + auto insId = config.GetInstanceId(); + ASSERT_EQ(insId, "yr_defalut_namespace-name"); + meta.set_ns("ns"); + config.funcMeta = meta; + insId = config.GetInstanceId(); + ASSERT_EQ(insId, "ns-name"); +} + } // namespace test } // namespace YR \ No newline at end of file diff --git a/test/libruntime/libruntime_test.cpp b/test/libruntime/libruntime_test.cpp index 3e8c318..db67dc0 100644 --- a/test/libruntime/libruntime_test.cpp +++ b/test/libruntime/libruntime_test.cpp @@ -73,18 +73,26 @@ public: sec_ = std::make_shared(); auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec_, socketClient); - fsIntf_ = std::make_shared(); - auto fsClient = std::make_shared(fsIntf_); + gwClient_ = std::make_shared(); + auto fsClient = std::make_shared(gwClient_); objectStore_ = std::make_shared(); stateStore_ = std::make_shared(); + streamStore_ = std::make_shared(); heteroStore_ = std::make_shared(); auto finalizeHandler = []() { return; }; - DatasystemClients dsclients{objectStore_, stateStore_, heteroStore_}; + DatasystemClients dsclients{objectStore_, stateStore_, streamStore_, heteroStore_}; lr->Init(fsClient, dsclients, finalizeHandler); } void TearDown() override { + EXPECT_CALL(*this->gwClient_, KillAsync(_, _, _)) + .WillRepeatedly([=](const YR::Libruntime::KillRequest &, YR::Libruntime::KillCallBack cb, int) { + if (cb != nullptr) { + YR::Libruntime::KillResponse resp; + cb(resp, ErrorInfo()); + } + }); CloseGlobalTimer(); if (lr) { lr->Finalize(true); @@ -94,9 +102,10 @@ public: } std::shared_ptr lc; - std::shared_ptr fsIntf_; + std::shared_ptr gwClient_; std::shared_ptr objectStore_; std::shared_ptr stateStore_; + std::shared_ptr streamStore_; std::shared_ptr heteroStore_; std::shared_ptr sec_; std::shared_ptr lr; @@ -137,8 +146,8 @@ TEST_F(LibruntimeTest, PutTest) auto metricsAdaptor = std::make_shared(); auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec_, socketClient); - auto fsClient = std::make_shared(fsIntf_); - DatasystemClients dsclients{objectStore_, stateStore_, heteroStore_}; + auto fsClient = std::make_shared(gwClient_); + DatasystemClients dsclients{objectStore_, stateStore_, streamStore_, heteroStore_}; lr->Init(fsClient, dsclients); std::string str = "Hello, world!"; auto dataObj = std::make_shared(0, str.size()); @@ -197,8 +206,8 @@ TEST_F(LibruntimeTest, When_Not_Driver_Finalize_Should_Kill_Instances) auto metricsAdaptor = std::make_shared(); auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec_, socketClient); - auto fsClient = std::make_shared(fsIntf_); - DatasystemClients dsclients{objectStore_, stateStore_, heteroStore_}; + auto fsClient = std::make_shared(gwClient_); + DatasystemClients dsclients{objectStore_, stateStore_, streamStore_, heteroStore_}; lr->Init(fsClient, dsclients); EXPECT_NO_THROW(lr->Finalize(false)); } @@ -327,8 +336,8 @@ TEST_F(LibruntimeTest, AllocReturnObjectBigTest) auto metricsAdaptor = std::make_shared(); auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec_, socketClient); - auto fsClient = std::make_shared(fsIntf_); - DatasystemClients dsclients{objectStore_, stateStore_, heteroStore_}; + auto fsClient = std::make_shared(gwClient_); + DatasystemClients dsclients{objectStore_, stateStore_, streamStore_, heteroStore_}; lr->Init(fsClient, dsclients); std::string testObjId("fake_id"); @@ -382,9 +391,9 @@ TEST_F(LibruntimeTest, NonDriverSecurityInitTest) auto sec = std::make_shared(); auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec, socketClient); - auto fsClient = std::make_shared(fsIntf_); - DatasystemClients dsclients{objectStore_, stateStore_, heteroStore_}; - ASSERT_NO_THROW(lr->Init(fsClient, dsclients)); + auto fsClient = std::make_shared(gwClient_); + DatasystemClients dsclients{objectStore_, stateStore_, streamStore_, heteroStore_}; + lr->Init(fsClient, dsclients); } TEST_F(LibruntimeTest, DriverSecurityInitTest) @@ -398,8 +407,8 @@ TEST_F(LibruntimeTest, DriverSecurityInitTest) auto sec = std::make_shared(); auto socketClient = std::make_shared("/home/snuser/socket/runtime.sock"); lr = std::make_shared(lc, clientsMgr, metricsAdaptor, sec, socketClient); - auto fsClient = std::make_shared(fsIntf_); - DatasystemClients dsClients{objectStore_, stateStore_, heteroStore_}; + auto fsClient = std::make_shared(gwClient_); + DatasystemClients dsClients{objectStore_, stateStore_, streamStore_, heteroStore_}; ASSERT_NO_THROW(lr->Init(fsClient, dsClients)); } @@ -514,6 +523,9 @@ TEST_F(LibruntimeTest, SetTraceIdTest) std::string traceId = "traceId"; auto err = lr->SetTraceId(traceId); ASSERT_TRUE(err.OK()); + threadLocalTraceId = "threadLocalTraceId"; + err = lr->SetTraceId(); + ASSERT_TRUE(err.OK()); } TEST_F(LibruntimeTest, GenerateKeyByStateStoreTest) @@ -585,6 +597,20 @@ TEST_F(LibruntimeTest, GetArrayByStateStoreTest) ASSERT_TRUE(result.first[1] == nullptr); } +TEST_F(LibruntimeTest, QuerySizeByStateStoreTest) +{ + std::vector keys{"123", "456"}; + std::vector outSizes; + EXPECT_CALL(*this->stateStore_, QuerySize(_, _)) + .WillOnce([=](const std::vector &, std::vector ¶m) { + param = {10, 10}; + return YR::Libruntime::ErrorInfo(); + }); + auto result = lr->QuerySizeByStateStore(stateStore_, keys, outSizes); + ASSERT_TRUE(result.OK()); + ASSERT_TRUE(outSizes.size() == 2); +} + TEST_F(LibruntimeTest, DelByStateStoreTest) { std::string key; @@ -704,7 +730,7 @@ TEST_F(LibruntimeTest, TestCreateInstanceRaw) auto buffer = std::make_shared(body.size()); buffer->MemoryCopy(body.c_str(), body.size()); auto callback = [](const ErrorInfo &err, std::shared_ptr resultRaw) {}; - EXPECT_CALL(*this->fsIntf_, CreateAsync(_, _, _, _)).WillOnce(testing::Return()); + EXPECT_CALL(*this->gwClient_, CreateAsync(_, _, _, _)).WillOnce(testing::Return()); lr->CreateInstanceRaw(buffer, callback); } @@ -717,7 +743,7 @@ TEST_F(LibruntimeTest, TestInvokeByInstanceIdRaw) auto buffer = std::make_shared(body.size()); buffer->MemoryCopy(body.c_str(), body.size()); auto callback = [](const ErrorInfo &err, std::shared_ptr resultRaw) {}; - EXPECT_CALL(*this->fsIntf_, InvokeAsync(_, _, _)).WillOnce(testing::Return()); + EXPECT_CALL(*this->gwClient_, InvokeAsync(_, _, _)).WillOnce(testing::Return()); lr->InvokeByInstanceIdRaw(buffer, callback); } @@ -729,7 +755,7 @@ TEST_F(LibruntimeTest, TestKillRaw) auto buffer = std::make_shared(body.size()); buffer->MemoryCopy(body.c_str(), body.size()); auto callback = [](const ErrorInfo &err, std::shared_ptr resultRaw) {}; - EXPECT_CALL(*this->fsIntf_, KillAsync(_, _, _)).WillRepeatedly(testing::Return()); + EXPECT_CALL(*this->gwClient_, KillAsync(_, _, _)).WillOnce(testing::Return()); lr->KillRaw(buffer, callback); } @@ -799,20 +825,13 @@ TEST_F(LibruntimeTest, TestGetRaw) ASSERT_EQ(lr->GetRaw({"aaa"}, 30, true).first.OK(), false); } -TEST_F(LibruntimeTest, DISABLED_GetResourcesTest) +TEST_F(LibruntimeTest, GetCredentialTest) { - auto result = lr->GetResources(); - ASSERT_FALSE(result.first.OK()); - lc->functionMasters = {"127.0.0.1"}; - result = lr->GetResources(); - result.first.SetIsTimeout(true); - std::vector vec{StackTraceInfo{}}; - result.first.SetStackTraceInfos(vec); - result.first.SetErrorMsg("errmsg"); - auto msg = result.first.CodeAndMsg(); - ASSERT_FALSE(result.first.OK()); - ASSERT_EQ(msg.empty(), false); - ASSERT_EQ(result.first.Finalized(), false); + datasystem::SensitiveValue sk = std::string("sk"); + lr->security_->SetAKSKAndCredential("ak", sk); + auto result = lr->GetCredential(); + ASSERT_EQ(result.ak, "ak"); + ASSERT_EQ(result.sk, "sk"); } TEST_F(LibruntimeTest, FiberEventTest) @@ -835,14 +854,14 @@ TEST_F(LibruntimeTest, HeteroDeleteTest) { std::vector objectIds; std::vector failedObjectIds; - ASSERT_EQ(lr->Delete(objectIds, failedObjectIds).OK(), true); + ASSERT_EQ(lr->DevDelete(objectIds, failedObjectIds).OK(), true); } -TEST_F(LibruntimeTest, HeteroLocalDeleteTest) +TEST_F(LibruntimeTest, HeteroDevLocalDeleteTest) { std::vector objectIds; std::vector failedObjectIds; - ASSERT_EQ(lr->LocalDelete(objectIds, failedObjectIds).OK(), true); + ASSERT_EQ(lr->DevLocalDelete(objectIds, failedObjectIds).OK(), true); } TEST_F(LibruntimeTest, HeteroDevSubscribeTest) @@ -877,6 +896,31 @@ TEST_F(LibruntimeTest, HeteroDevMGetTest) ASSERT_EQ(lr->DevMGet(keys, blob2dList, failedKeys, 1000).OK(), true); } +TEST_F(LibruntimeTest, StreamProducerAndConsumerTest) +{ + std::string streamName = "streamname"; + uint64_t value = 1000; + ProducerConf producerConf; + std::shared_ptr producer = std::make_shared(); + ASSERT_EQ(lr->CreateStreamProducer(streamName, producerConf, producer).OK(), true); + + SubscriptionConfig config; + std::shared_ptr consumer = std::make_shared(); + ASSERT_EQ(lr->CreateStreamConsumer(streamName, config, consumer, true).OK(), true); + ASSERT_EQ(lr->DeleteStream(streamName).OK(), true); + ASSERT_EQ(lr->QueryGlobalProducersNum(streamName, value).OK(), true); + ASSERT_EQ(lr->QueryGlobalConsumersNum(streamName, value).OK(), true); + + lr->dsClients.dsStreamStore = nullptr; + ASSERT_EQ(lr->CreateStreamProducer(streamName, producerConf, producer).Code(), + YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR); + ASSERT_EQ(lr->CreateStreamConsumer(streamName, config, consumer, true).Code(), + YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR); + ASSERT_EQ(lr->DeleteStream(streamName).Code(), YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR); + ASSERT_EQ(lr->QueryGlobalProducersNum(streamName, value).Code(), YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR); + ASSERT_EQ(lr->QueryGlobalConsumersNum(streamName, value).Code(), YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR); +} + TEST_F(LibruntimeTest, SetTenatIdTest) { lr->config->enableAuth = true; @@ -915,6 +959,15 @@ TEST_F(LibruntimeTest, GetInstanceTest) ASSERT_EQ(res.first.name, "name"); } +TEST_F(LibruntimeTest, AcquireAndReleaseInstanceTest) +{ + lr->invokeAdaptor = std::make_shared(); + FunctionMeta meta; + InvokeOptions opts; + ASSERT_EQ(lr->AcquireInstance("stateId", meta, opts).second.OK(), true); + ASSERT_EQ(lr->ReleaseInstance("leaseId", "stateId", true, opts).OK(), true); +} + TEST_F(LibruntimeTest, ExecShutdownCallbackTest) { lr->invokeAdaptor = nullptr; @@ -952,6 +1005,12 @@ TEST_F(LibruntimeTest, ReceiveRequestLoopTest) ASSERT_NO_THROW(lr->ReceiveRequestLoop()); } +TEST_F(LibruntimeTest, UpdateSchdulerInfoTest) +{ + lr->invokeAdaptor = std::make_shared(); + ASSERT_NO_THROW(lr->UpdateSchdulerInfo("scheduleName", "schedulerId", "ADD")); +} + TEST_F(LibruntimeTest, SaveGroupInstanceIdsTest) { const std::string groupInsIds = "aa;bb;cc"; @@ -1001,6 +1060,11 @@ TEST_F(LibruntimeTest, DecreaseReferenceTest) EXPECT_NO_THROW(lr->DecreaseReference(objIds)); } +TEST_F(LibruntimeTest, ReleaseGRefsTest) +{ + ASSERT_EQ(lr->ReleaseGRefs("remoteId").OK(), true); +} + TEST_F(LibruntimeTest, WaitTest) { ASSERT_EQ(lr->Wait({"objId"}, 1, 0)->readyIds.size(), 1); @@ -1014,6 +1078,11 @@ TEST_F(LibruntimeTest, GetBuffersTest) ASSERT_EQ(lr->GetBuffers(ids, 300, false).first.OK(), false); } +TEST_F(LibruntimeTest, PeekObjectRefStreamTest) +{ + ASSERT_EQ(lr->PeekObjectRefStream("generatorId", false).first.OK(), false); +} + TEST_F(LibruntimeTest, GetFunctionGroupRunningInfoTest) { ASSERT_EQ(lr->GetFunctionGroupRunningInfo().instanceRankId, 0); @@ -1036,8 +1105,8 @@ TEST_F(LibruntimeTest, WaitAndGetAsyncTest) YR::Libruntime::WaitAsyncCallback cbWait = [&waitPromise](const std::string &id, const ErrorInfo &err, void *data) { waitPromise.set_value(err); }; - - lr->WaitAsync("objId", cbWait, nullptr); + char *str = "hello"; + lr->WaitAsync("objId", cbWait, str); ASSERT_EQ(waitFut.get().OK(), true); auto getPromise = std::promise(); @@ -1045,7 +1114,7 @@ TEST_F(LibruntimeTest, WaitAndGetAsyncTest) YR::Libruntime::GetAsyncCallback cbGet = [&getPromise](const std::shared_ptr &dataObj, const ErrorInfo &err, void *data) { getPromise.set_value(err); }; - lr->GetAsync("objId", cbGet, nullptr); + lr->GetAsync("objId", cbGet, str); ASSERT_EQ(getFut.get().OK(), true); } @@ -1056,7 +1125,7 @@ TEST_F(LibruntimeTest, GetGroupInstanceIdsTest) TEST_F(LibruntimeTest, ExitTest) { - EXPECT_CALL(*this->fsIntf_, ExitAsync(_, _)) + EXPECT_CALL(*this->gwClient_, ExitAsync(_, _)) .WillOnce([=](const YR::Libruntime::ExitRequest &, YR::Libruntime::ExitCallBack cb) { if (cb != nullptr) { YR::Libruntime::ExitResponse resp; @@ -1115,9 +1184,9 @@ TEST_F(LibruntimeTest, GetThreadPoolSizeTest) ASSERT_EQ(lr->GetLocalThreadPoolSize(), 0); } -TEST_F(LibruntimeTest, DISABLED_resourcegroupTest) +TEST_F(LibruntimeTest, ResourceGroupTest) { - EXPECT_CALL(*this->fsIntf_, CreateRGroupAsync(_, _, _)) + EXPECT_CALL(*this->gwClient_, CreateRGroupAsync(_, _, _)) .WillOnce([=](const YR::Libruntime::CreateResourceGroupRequest &, YR::Libruntime::CreateResourceGroupCallBack cb, int) { if (cb != nullptr) { @@ -1127,7 +1196,7 @@ TEST_F(LibruntimeTest, DISABLED_resourcegroupTest) cb(resp); } }); - EXPECT_CALL(*this->fsIntf_, KillAsync(_, _, _)) + EXPECT_CALL(*this->gwClient_, KillAsync(_, _, _)) .WillOnce([=](const YR::Libruntime::KillRequest &, YR::Libruntime::KillCallBack cb, int) { if (cb != nullptr) { YR::Libruntime::KillResponse resp; @@ -1190,6 +1259,8 @@ TEST_F(LibruntimeTest, KVTest) ASSERT_EQ(err.Code(), 0); auto [res4, err4] = lr->KVDel(keys); ASSERT_EQ(err4.Code(), 0); + auto [res5, err5] = lr->KVExist(keys); + ASSERT_EQ(err5.Code(), 0); } TEST_F(LibruntimeTest, TestAccelerate) @@ -1209,5 +1280,43 @@ TEST_F(LibruntimeTest, TestIsLocalInstances) auto ret = lr->IsLocalInstances(instanceIds); ASSERT_FALSE(ret); } + +TEST_F(LibruntimeTest, TestIsDsHealth) +{ + lr->dsClients.dsStreamStore = nullptr; + lr->dsClients.dsStateStore = nullptr; + ASSERT_TRUE(lr->IsDsHealth()); + auto mockstateStore = std::make_shared(); + lr->dsClients.dsStateStore = mockstateStore; + auto mockStreamStore = std::make_shared(); + lr->dsClients.dsStreamStore = mockStreamStore; + EXPECT_CALL(*mockstateStore, HealthCheck()).WillOnce(Return(ErrorInfo())); + ASSERT_TRUE(lr->IsDsHealth()); +} + +TEST_F(LibruntimeTest, TestIsHealth) +{ + lr->invokeAdaptor = nullptr; + ASSERT_FALSE(lr->IsHealth()); + auto mockInvokeAdaptor = std::make_shared(); + lr->invokeAdaptor = mockInvokeAdaptor; + EXPECT_CALL(*mockInvokeAdaptor, IsHealth()).WillOnce(Return(true)); + ASSERT_TRUE(lr->IsHealth()); +} + +TEST_F(LibruntimeTest, KillAsyncTest) +{ + auto mock_adaptor = std::make_shared(); + lr->invokeAdaptor = mock_adaptor; + EXPECT_CALL(*mock_adaptor, KillAsyncCB(_, _, _, _)) + .WillOnce(Invoke([](const std::string &instanceId, const std::string &payload, int signal, + std::function cb) { cb(YR::Libruntime::ErrorInfo()); })); + auto promise = std::make_shared>(); + auto f = promise->get_future(); + lr->KillAsync("instanceId", 1, [promise](const ErrorInfo &err) { promise->set_value(err); }); + auto status = f.wait_for(std::chrono::milliseconds(100)); + EXPECT_EQ(status, std::future_status::ready); + EXPECT_TRUE(f.get().OK()); +} } // namespace test } // namespace YR diff --git a/test/libruntime/limiter_consistant_hash_test.cpp b/test/libruntime/limiter_consistant_hash_test.cpp new file mode 100644 index 0000000..3f885e3 --- /dev/null +++ b/test/libruntime/limiter_consistant_hash_test.cpp @@ -0,0 +1,252 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "src/libruntime/invoke_spec.h" +#include "src/libruntime/utils/utils.h" +#include "src/utility/logger/logger.h" +#define private public +#include "src/libruntime/invokeadaptor/limiter_consistant_hash.h" + +namespace YR { +namespace test { +using namespace YR::Libruntime; +using namespace YR::utility; + +class LimiterCsHashTest : public ::testing::Test { +public: + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + std::shared_ptr lb(LoadBalancer::Factory(LoadBalancerType::ConsistantRoundRobin)); + limiterHash = std::make_shared(lb); + } + + void TearDown() override + { + limiterHash.reset(); + } + std::shared_ptr limiterHash; +}; + +TEST_F(LimiterCsHashTest, AddTest) +{ + limiterHash->Add("schedulerName", "schedulerId"); + EXPECT_FALSE(limiterHash->Next("funcId", false).empty()); + EXPECT_EQ(limiterHash->Next("funcId"), "schedulerId"); + + limiterHash->Add("schedulerName", "schedulerId1"); + EXPECT_EQ(limiterHash->Next("funcId", false), "schedulerId1"); +} + +TEST_F(LimiterCsHashTest, RemoveTest) +{ + limiterHash->Add("schedulerName", "schedulerId"); + limiterHash->Remove("schedulerName"); + // remove after add, hash has no scheduler, then res is empty + EXPECT_TRUE(limiterHash->Next("funcId").empty()); + + limiterHash->Add("schedulerName", "schedulerId"); + // add schedulername <-> schedulerid, then res is schedulerid + EXPECT_EQ(limiterHash->Next("funcId", false), "schedulerId"); + limiterHash->Add("schedulerName", "anotherSchedulerId"); + // update schedulername <-> anotherSchedulerId, then res is anotherSchedulerId + EXPECT_EQ(limiterHash->Next("funcId", false), "anotherSchedulerId"); + + limiterHash->Add("schedulername_1", "anotherSchedulerId_1"); + // add new schedulername_1 <-> anotherSchedulerId_1, if move flag is false, then res is anotherSchedulerId, + // otherwise need move to next anchorpoint and res is anotherSchedulerId_1 + EXPECT_EQ(limiterHash->Next("funcId", false), "anotherSchedulerId"); + EXPECT_EQ(limiterHash->Next("funcId", true), "anotherSchedulerId_1"); + + limiterHash->Remove("schedulerName"); + // remove schedulername <-> schedulerid, hash has only schedulername_1 <-> anotherSchedulerId_1, then Whether move + // is false or true, the result is always anotherSchedulerId_1 + EXPECT_EQ(limiterHash->Next("funcId", false), "anotherSchedulerId_1"); + EXPECT_EQ(limiterHash->Next("funcId", true), "anotherSchedulerId_1"); +} + +TEST_F(LimiterCsHashTest, RemoveTest2) +{ + limiterHash->Add("schedulerName", "schedulerId"); + limiterHash->Add("schedulerName1", "schedulerId1"); + limiterHash->Add("schedulerName2", "schedulerId2"); + auto schedulerID = limiterHash->Next("funcId", true); + std::cout << "schedulerid is : " << schedulerID << std::endl; + EXPECT_EQ(schedulerID.empty(), false); + std::vector> vec{ + {std::make_shared(SchedulerInstance{ + .InstanceID = "schedulerId", .updateTime = YR::GetCurrentTimestampNs(), .isAvailable = false}), + std::make_shared(SchedulerInstance{ + .InstanceID = "schedulerId1", .updateTime = YR::GetCurrentTimestampNs(), .isAvailable = false}), + std::make_shared(SchedulerInstance{ + .InstanceID = "schedulerId2", .updateTime = YR::GetCurrentTimestampNs(), .isAvailable = false})}}; + auto schedulerInfo = std::make_shared(); + schedulerInfo->schedulerInstanceList = vec; + EXPECT_EQ(limiterHash->Next("funcId", schedulerInfo, true), ALL_SCHEDULER_UNAVAILABLE); +} + +TEST_F(LimiterCsHashTest, RemoveAllTest) +{ + limiterHash->Add("schedulerName0", "schedulerId0"); + limiterHash->Add("schedulerName1", "schedulerId1"); + limiterHash->Add("schedulerName2", "schedulerId2"); + limiterHash->Add("schedulerName3", "schedulerId3"); + limiterHash->RemoveAll(); + EXPECT_TRUE(limiterHash->Next("scheduleId").empty()); +} + +TEST_F(LimiterCsHashTest, NextRetryTest) +{ + limiterHash->Add("schedulerName", ""); + limiterHash->Add("schedulerName1", "schedulerId1"); + auto schedulerId = limiterHash->NextRetry("scheduleId"); + std::cout << "scheduler id is " << schedulerId << std::endl; + EXPECT_TRUE(!schedulerId.empty()); +} + +TEST_F(LimiterCsHashTest, NextTest1) +{ + limiterHash->Add("schedulerName", ""); + auto schedulerId = limiterHash->NextRetry("scheduleId"); + std::cout << "scheduler id1 is : " << schedulerId << std::endl; + EXPECT_TRUE(schedulerId.empty()); + // 添加 "schedulerName1" <-> "schedulerId1", 返回schedulerId1 + limiterHash->Add("schedulerName1", "schedulerId1"); + schedulerId = limiterHash->NextRetry("func1", true); + std::cout << "scheduler id2 is : " << schedulerId << std::endl; + EXPECT_TRUE(!schedulerId.empty()); + // 入参中增加"schedulerId1", 更新时间晚于add时间,返回"AllSchedulerUnavailable" + auto updateTime = YR::GetCurrentTimestampNs(); + auto schedulerInfo = std::make_shared(); + std::vector> vec1{std::make_shared( + SchedulerInstance{.InstanceID = "schedulerId1", .updateTime = updateTime, .isAvailable = false})}; + schedulerInfo->schedulerInstanceList = vec1; + schedulerId = limiterHash->NextRetry("func1", schedulerInfo, true); + std::cout << "scheduler id2 is : " << schedulerId << std::endl; + EXPECT_EQ(schedulerId, "AllSchedulerUnavailable"); + // 重新添加"schedulerName1" <-> "schedulerId1", hash环中add时间更新,返回"schedulerId1" + std::vector> vec2{std::make_shared( + SchedulerInstance{.InstanceID = "schedulerId1", .updateTime = updateTime, .isAvailable = false})}; + limiterHash->Add("schedulerName1", "schedulerId1"); + schedulerInfo->schedulerInstanceList = vec2; + schedulerId = limiterHash->NextRetry("func1", schedulerInfo, true); + std::cout << "scheduler id3 is : " << schedulerId << std::endl; + EXPECT_TRUE(schedulerId != "AllSchedulerUnavailable"); + for (auto scheduler : schedulerInfo->schedulerInstanceList) { + EXPECT_TRUE(scheduler->updateTime > updateTime); + EXPECT_TRUE(scheduler->isAvailable); + } + // 再次Next获取,更新入参vec中的update time,返回"AllSchedulerUnavailable" + std::vector> vec{std::make_shared(SchedulerInstance{ + .InstanceID = "schedulerId1", .updateTime = YR::GetCurrentTimestampNs(), .isAvailable = false})}; + schedulerInfo->schedulerInstanceList = vec; + schedulerId = limiterHash->NextRetry("func1", schedulerInfo, true); + std::cout << "scheduler id3 is : " << schedulerId << std::endl; + EXPECT_TRUE(schedulerId == "AllSchedulerUnavailable"); + limiterHash->RemoveAll(); +} + +TEST_F(LimiterCsHashTest, NextTest2) +{ + std::unordered_map idMap = { + {"schedulerId1", "schedulerName1"}, + {"schedulerId2", "schedulerName2"}, + {"schedulerId3", "schedulerName3"}, + }; + limiterHash->Add(idMap["schedulerId1"], "schedulerId1"); + limiterHash->Add(idMap["schedulerId2"], "schedulerId2"); + limiterHash->Add(idMap["schedulerId3"], "schedulerId3"); + auto spec = std::make_shared(); + auto schedulerId4 = limiterHash->NextRetry("func1", true); + std::cout << "scheduler id4 is : " << schedulerId4 << std::endl; + EXPECT_TRUE(!schedulerId4.empty()); + EXPECT_TRUE(schedulerId4 != "AllSchedulerUnavailable"); + spec->schedulerInfos->schedulerInstanceList.push_back(std::make_shared(SchedulerInstance{ + .InstanceID = schedulerId4, .updateTime = YR::GetCurrentTimestampNs(), .isAvailable = false})); + auto schedulerId5 = limiterHash->NextRetry("func1", spec->schedulerInfos, true); + std::cout << "scheduler id5 is : " << schedulerId5 << std::endl; + EXPECT_TRUE(!schedulerId5.empty()); + EXPECT_TRUE(schedulerId5 != "AllSchedulerUnavailable"); + EXPECT_TRUE(schedulerId5 != schedulerId4); + EXPECT_EQ(spec->schedulerInfos->schedulerInstanceList.size(), 2); + EXPECT_TRUE(spec->schedulerInfos->schedulerInstanceList[1]->isAvailable); + + spec->schedulerInfos->schedulerInstanceList[1]->isAvailable = false; + auto schedulerId6 = limiterHash->NextRetry("func1", spec->schedulerInfos, true); + std::cout << "scheduler id6 is : " << schedulerId6 << std::endl; + EXPECT_TRUE(!schedulerId6.empty()); + EXPECT_TRUE(schedulerId6 != "AllSchedulerUnavailable"); + EXPECT_TRUE(schedulerId5 != schedulerId6); + EXPECT_EQ(spec->schedulerInfos->schedulerInstanceList.size(), 3); + for (auto &scheduler : spec->schedulerInfos->schedulerInstanceList) { + if (scheduler->InstanceID == schedulerId6) { + EXPECT_TRUE(scheduler->isAvailable); + scheduler->isAvailable = false; + } + } + auto schedulerI7 = limiterHash->NextRetry("func1", spec->schedulerInfos, true); + std::cout << "scheduler id7 is : " << schedulerI7 << std::endl; + EXPECT_TRUE(!schedulerI7.empty()); + EXPECT_TRUE(schedulerI7 == "AllSchedulerUnavailable"); + for (auto scheduler : spec->schedulerInfos->schedulerInstanceList) { + if (scheduler->InstanceID == schedulerId5) { + limiterHash->Add(idMap[schedulerId5], schedulerId5); + } + } + auto schedulerI8 = limiterHash->NextRetry("func1", spec->schedulerInfos, true); + std::cout << "scheduler id8 is : " << schedulerI8 << std::endl; + EXPECT_TRUE(!schedulerI8.empty()); + EXPECT_TRUE(schedulerI8 != "AllSchedulerUnavailable"); +} + +TEST_F(LimiterCsHashTest, ResetAllTest) +{ + limiterHash->Add("schedulerName1", "schedulerId1"); + limiterHash->Add("schedulerName2", "schedulerId2"); + limiterHash->Add("schedulerName3", "schedulerId3"); + + auto vec = std::vector{ + SchedulerInstance{.InstanceName = "schedulerName1", .InstanceID = "schedulerId1"}, + SchedulerInstance{.InstanceName = "schedulerName2", .InstanceID = "schedulerId2"}, + SchedulerInstance{.InstanceName = "schedulerName3", .InstanceID = "schedulerId3"}}; + limiterHash->ResetAll(vec, 0); + EXPECT_TRUE(limiterHash->schedulerInfoMap.find("schedulerName3") != limiterHash->schedulerInfoMap.end()); + + auto vec1 = std::vector{ + SchedulerInstance{.InstanceName = "schedulerName1", .InstanceID = "schedulerId1"}, + SchedulerInstance{.InstanceName = "schedulerName2", .InstanceID = "schedulerId2"}}; + limiterHash->ResetAll(vec1, 0); + EXPECT_TRUE(limiterHash->schedulerInfoMap.find("schedulerName3") == limiterHash->schedulerInfoMap.end()); +} +} // namespace test +} // namespace YR diff --git a/test/libruntime/load_balancer_test.cpp b/test/libruntime/load_balancer_test.cpp new file mode 100644 index 0000000..3faeb24 --- /dev/null +++ b/test/libruntime/load_balancer_test.cpp @@ -0,0 +1,177 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "iostream" + +#include +#include + +#include "src/libruntime/invokeadaptor/load_balancer.h" +#include "src/libruntime/utils/utils.h" +#include "src/utility/logger/logger.h" + +using namespace testing; +using namespace YR::Libruntime; +using namespace YR::utility; + +namespace YR { +namespace test { +class LoadBalancerTest : public ::testing::Test { + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + } + + void TearDown() override {} +}; // namespace test + +TEST_F(LoadBalancerTest, WRRTest) +{ + LoadBalancer *lb = LoadBalancer::Factory(LoadBalancerType::WeightedRoundRobin); + EXPECT_TRUE(lb != nullptr); + { + std::vector> nodes = { + {"a", 40}, {"b", 10}, {"c", 10}, {"d", 10}, {"e", 10}, + }; + std::unordered_map counter; + + int total = 0; + for (auto &n : nodes) { + lb->Add(n.first, n.second); + total += n.second; + counter[n.first] = 0; + } + + for (int i = 0; i < total; ++i) { + auto n = lb->Next(""); + counter[n]++; + } + + for (auto &n : nodes) { + ASSERT_EQ(n.second, counter[n.first]); + } + } + + lb->RemoveAll(); + + { + std::vector> nodes = { + {"ax", 40}, + {"bx", 10}, + }; + std::unordered_map counter; + + int total = 0; + for (auto &n : nodes) { + lb->Add(n.first, n.second); + total += n.second; + counter[n.first] = 0; + } + + for (int i = 0; i < total; ++i) { + auto n = lb->Next(""); + counter[n]++; + } + + for (auto &n : nodes) { + ASSERT_EQ(n.second, counter[n.first]); + } + } + + delete lb; +} + +TEST_F(LoadBalancerTest, CsHashRoundRobinTest) +{ + LoadBalancer *lb = LoadBalancer::Factory(LoadBalancerType::ConsistantRoundRobin); + lb->Add("scheduler1", 0); + lb->Add("scheduler2", 0); + lb->Add("scheduler3", 0); + + auto res1 = lb->Next("function1", false); + auto res2 = lb->Next("function1", false); + ASSERT_EQ(res1, res2); + std::cout << "res 1 is " << res1 << std::endl; + + auto res3 = lb->Next("function1", true); + std::cout << "res 3 is " << res3 << std::endl; + EXPECT_NE(res3, res2); + + auto res4 = lb->Next("function1", true); + std::cout << "res 4 is " << res4 << std::endl; + EXPECT_NE(res3, res4); + + lb->RemoveAll(); + auto res6 = lb->Next("function1", false); + std::cout << "res 6 is " << res6 << std::endl; + ASSERT_EQ(res6.find("scheduler3") != std::string::npos, false); + + lb->Add("scheduler1", 0); + lb->Add("scheduler2", 0); + lb->Add("scheduler3", 0); + + auto res7 = lb->Next("function1", false); + std::cout << "res 7 is " << res7 << std::endl; + + auto res8 = lb->Next("function1", false); + std::cout << "res 8 is " << res8 << std::endl; + ASSERT_EQ(res7, res8); + + auto res9 = lb->Next("function1", true); + std::cout << "res 9 is " << res9 << std::endl; + EXPECT_NE(res9, res8); + + auto res10 = lb->Next("function1", true); + std::cout << "res 10 is " << res10 << std::endl; + EXPECT_NE(res9, res10); + + delete lb; +} + +TEST_F(LoadBalancerTest, CsHashRoundRobinRemoveTest) +{ + LoadBalancer *lb = LoadBalancer::Factory(LoadBalancerType::ConsistantRoundRobin); + lb->Add("scheduler1", 0); + auto res1 = lb->Next("function1", true); + ASSERT_EQ(res1.find("scheduler1") != std::string::npos, true); + + lb->Remove("scheduler1"); + lb->Add("scheduler2", 0); + auto res3 = lb->Next("function1", true); + ASSERT_EQ(res3.find("scheduler2") != std::string::npos, true); + + auto res4 = lb->Next("function1", false); + ASSERT_EQ(res4.find("scheduler2") != std::string::npos, true); + + delete lb; +} +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/metrics_adaptor_test.cpp b/test/libruntime/metrics_adaptor_test.cpp index 7f8ec46..9008769 100644 --- a/test/libruntime/metrics_adaptor_test.cpp +++ b/test/libruntime/metrics_adaptor_test.cpp @@ -41,7 +41,7 @@ nlohmann::json GetValidConfig() "backends": [ { "immediatelyExport": { - "name": "Alarm", + "name": "CAAS_Alarm", "enable": true, "custom": { "labels": { @@ -55,14 +55,59 @@ nlohmann::json GetValidConfig() { "fileExporter": { "enable": true, + "fileDir": "/home/sn/metrics/", + "rolling": { + "enable": true, + "maxFiles": 3, + "maxSize": 10000 + }, + "contentType": "LABELS" + } + } + ] + } + }, + { + "immediatelyExport": { + "name": "LingYun", + "enable": true, + "exporters": [ + { + "prometheusPushExporter": { + "enable": true, + "enable": true, + "batchSize": 2, + "batchIntervalSec": 10, + "failureQueueMaxSize": 2, + "failureDataDir": "/home/sn/metrics/failure", + "failureDataFileMaxCapacity": 1, "initConfig": { - "fileDir": "./metrics", - "rolling": { - "enable": true, - "maxFiles": 3, - "maxSize": 10000 - }, - "contentType": "STANDARD" + "ip": "127.0.0.1", + "port": 31061 + } + } + } + ] + } + }, + { + "immediatelyExport": { + "name": "LakeHouse", + "enable": true, + "exporters": [ + { + "aomAlarmExporter": { + "enable": true, + "enable": true, + "batchSize": 2, + "batchIntervalSec": 10, + "failureQueueMaxSize": 2, + "failureDataDir": "/home/sn/metrics/failure", + "failureDataFileMaxCapacity": 1, + "enabledInstruments": ["name"], + "initConfig": { + "ip": "127.0.0.1", + "port": 31061 } } } @@ -81,7 +126,7 @@ nlohmann::json GetUnsupportedConfig() { "backends": [ { - "batchExport": {"name": "Alarm"} + "batchExport": {"name": "CAAS_Alarm"} } ] } @@ -106,7 +151,7 @@ nlohmann::json GetImmedExportNotEnableConfig() "backends": [ { "immediatelyExport": { - "name": "Alarm", + "name": "LingYun", "enable": false, "exporters": [ { @@ -125,6 +170,46 @@ nlohmann::json GetImmedExportNotEnableConfig() return nlohmann::json::parse(str); } +nlohmann::json GetPromExportNotEnableConfig() +{ + const std::string str = R"( +{ + "backends": [ + { + "immediatelyExport": { + "name": "LingYun", + "enable": true, + "exporters": [ + { + "prometheusPushExporter": { + "enable": false, + "ip": "127.0.0.1", + "port": 9091 + } + } + ] + } + } + ] +} + )"; + return nlohmann::json::parse(str); +} + +nlohmann::json GetPrometheusPushExporterConfig() +{ + std::string jsonStr = R"( +{ + "enable": true, + "initConfig": { + "ip": "x", + "port": 0 + } +} + )"; + return nlohmann::json::parse(jsonStr); +} + nlohmann::json GetExportConfigs() { std::string jsonStr = R"( @@ -189,8 +274,9 @@ TEST_F(MetricsAdaptorTest, InitSuccessullyTest) if (path == "") { ASSERT_EQ(1, 2); } - setenv("SNLIB_PATH", path.c_str(), 1); + setenv("SNUSER_LIB_PATH", path.c_str(), 1); setenv("YR_SSL_ENABLE", "true", 1); + setenv("YR_SSL_PASSPHRASE", "YR_SSL_PASSPHRASE", 1); Config::c = Config(); auto nullMeterProvider = MetricsApi::Provider::GetMeterProvider(); auto jsonStr = GetValidConfig(); @@ -199,8 +285,9 @@ TEST_F(MetricsAdaptorTest, InitSuccessullyTest) EXPECT_NE(MetricsApi::Provider::GetMeterProvider(), nullMeterProvider); metricsAdaptor->CleanMetrics(); EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); - unsetenv("SNLIB_PATH"); + unsetenv("SNUSER_LIB_PATH"); unsetenv("YR_SSL_ENABLE"); + unsetenv("YR_SSL_PASSPHRASE"); } TEST_F(MetricsAdaptorTest, UnsupportedInitTest) @@ -209,7 +296,7 @@ TEST_F(MetricsAdaptorTest, UnsupportedInitTest) if (path == "") { ASSERT_EQ(1, 2); } - setenv("SNLIB_PATH", path.c_str(), 1); + setenv("SNUSER_LIB_PATH", path.c_str(), 1); Config::c = Config(); auto nullMeterProvider = MetricsApi::Provider::GetMeterProvider(); auto jsonStr = GetUnsupportedConfig(); @@ -218,7 +305,7 @@ TEST_F(MetricsAdaptorTest, UnsupportedInitTest) EXPECT_NE(MetricsApi::Provider::GetMeterProvider(), nullMeterProvider); metricsAdaptor->CleanMetrics(); EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); - unsetenv("SNLIB_PATH"); + unsetenv("SNUSER_LIB_PATH"); } TEST_F(MetricsAdaptorTest, InvalidInitTest) @@ -227,14 +314,14 @@ TEST_F(MetricsAdaptorTest, InvalidInitTest) if (path == "") { ASSERT_EQ(1, 2); } - setenv("SNLIB_PATH", path.c_str(), 1); + setenv("SNUSER_LIB_PATH", path.c_str(), 1); Config::c = Config(); auto jsonStr = GetInvalidConfig(); auto metricsAdaptor = std::make_shared(); metricsAdaptor->Init(jsonStr, true); metricsAdaptor->CleanMetrics(); EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); - unsetenv("SNLIB_PATH"); + unsetenv("SNUSER_LIB_PATH"); } TEST_F(MetricsAdaptorTest, InitNotEnableTest) @@ -243,7 +330,7 @@ TEST_F(MetricsAdaptorTest, InitNotEnableTest) if (path == "") { ASSERT_EQ(1, 2); } - setenv("SNLIB_PATH", path.c_str(), 1); + setenv("SNUSER_LIB_PATH", path.c_str(), 1); Config::c = Config(); auto nullMeterProvider = MetricsApi::Provider::GetMeterProvider(); auto jsonStr = GetImmedExportNotEnableConfig(); @@ -252,7 +339,14 @@ TEST_F(MetricsAdaptorTest, InitNotEnableTest) EXPECT_NE(MetricsApi::Provider::GetMeterProvider(), nullMeterProvider); metricsAdaptor->CleanMetrics(); EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); - unsetenv("SNLIB_PATH"); + + auto jsonStr2 = GetPromExportNotEnableConfig(); + auto metricsAdaptor2 = std::make_shared(); + metricsAdaptor2->Init(jsonStr2, true); + EXPECT_NE(MetricsApi::Provider::GetMeterProvider(), nullMeterProvider); + metricsAdaptor2->CleanMetrics(); + EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); + unsetenv("SNUSER_LIB_PATH"); } TEST_F(MetricsAdaptorTest, DoubleGaugeTest) @@ -261,7 +355,7 @@ TEST_F(MetricsAdaptorTest, DoubleGaugeTest) if (path == "") { ASSERT_EQ(1, 2); } - setenv("SNLIB_PATH", path.c_str(), 1); + setenv("SNUSER_LIB_PATH", path.c_str(), 1); Config::c = Config(); auto jsonStr = GetValidConfig(); auto metricsAdaptor = std::make_shared(); @@ -277,7 +371,7 @@ TEST_F(MetricsAdaptorTest, DoubleGaugeTest) EXPECT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_OK); metricsAdaptor->CleanMetrics(); EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); - unsetenv("SNLIB_PATH"); + unsetenv("SNUSER_LIB_PATH"); } TEST_F(MetricsAdaptorTest, SetAlarmTest) @@ -286,7 +380,7 @@ TEST_F(MetricsAdaptorTest, SetAlarmTest) if (path == "") { ASSERT_EQ(1, 2); } - setenv("SNLIB_PATH", path.c_str(), 1); + setenv("SNUSER_LIB_PATH", path.c_str(), 1); Config::c = Config(); auto jsonStr = GetValidConfig(); auto metricsAdaptor = std::make_shared(); @@ -301,7 +395,7 @@ TEST_F(MetricsAdaptorTest, SetAlarmTest) EXPECT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_OK); metricsAdaptor->CleanMetrics(); EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); - unsetenv("SNLIB_PATH"); + unsetenv("SNUSER_LIB_PATH"); } TEST_F(MetricsAdaptorTest, DoubleCounterTest) @@ -310,7 +404,7 @@ TEST_F(MetricsAdaptorTest, DoubleCounterTest) if (path == "") { ASSERT_EQ(1, 2); } - setenv("SNLIB_PATH", path.c_str(), 1); + setenv("SNUSER_LIB_PATH", path.c_str(), 1); Config::c = Config(); auto jsonStr = GetValidConfig(); auto metricsAdaptor = std::make_shared(); @@ -334,7 +428,7 @@ TEST_F(MetricsAdaptorTest, DoubleCounterTest) EXPECT_EQ(res.second, 0); metricsAdaptor->CleanMetrics(); EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); - unsetenv("SNLIB_PATH"); + unsetenv("SNUSER_LIB_PATH"); } TEST_F(MetricsAdaptorTest, UInt64CounterTest) @@ -343,7 +437,7 @@ TEST_F(MetricsAdaptorTest, UInt64CounterTest) if (path == "") { ASSERT_EQ(1, 2); } - setenv("SNLIB_PATH", path.c_str(), 1); + setenv("SNUSER_LIB_PATH", path.c_str(), 1); Config::c = Config(); auto jsonStr = GetValidConfig(); auto metricsAdaptor = std::make_shared(); @@ -367,7 +461,7 @@ TEST_F(MetricsAdaptorTest, UInt64CounterTest) EXPECT_EQ(res.second, 0); metricsAdaptor->CleanMetrics(); EXPECT_EQ(MetricsApi::Provider::GetMeterProvider(), nullptr); - unsetenv("SNLIB_PATH"); + unsetenv("SNUSER_LIB_PATH"); } TEST_F(MetricsAdaptorTest, MetricsFailedTest) @@ -417,6 +511,26 @@ TEST_F(MetricsAdaptorTest, contextTest) ASSERT_EQ(result, ""); } +TEST_F(MetricsAdaptorTest, InitHttpExporterWithTLS) +{ + auto metricsAdaptor = std::make_shared(); + + setenv("YR_SSL_ENABLE", "true", 1); + setenv("YR_SSL_ROOT_FILE", "root", 1); + setenv("YR_SSL_CERT_FILE", "cert", 1); + setenv("YR_SSL_KEY_FILE", "key", 1); + setenv("YR_SSL_PASSPHRASE", "123", 1); + + YR::Libruntime::Config::c = YR::Libruntime::Config(); + + auto config = GetPrometheusPushExporterConfig(); + auto ret = metricsAdaptor->InitHttpExporter("prometheusPushExporter", "key", "name", config); + ASSERT_TRUE(ret == nullptr); + auto value = std::getenv("YR_SSL_PASSPHRASE"); + ASSERT_TRUE(value != nullptr); + ASSERT_EQ(std::string(value), ""); +} + TEST_F(MetricsAdaptorTest, BuildExportConfigsTest) { auto metricsAdaptor = std::make_shared(); diff --git a/test/libruntime/mock/mock_datasystem.h b/test/libruntime/mock/mock_datasystem.h index 1600fba..baaef48 100644 --- a/test/libruntime/mock/mock_datasystem.h +++ b/test/libruntime/mock/mock_datasystem.h @@ -23,6 +23,7 @@ #include "src/libruntime/heterostore/hetero_store.h" #include "src/libruntime/objectstore/object_store.h" #include "src/libruntime/statestore/state_store.h" +#include "src/libruntime/streamstore/stream_store.h" namespace YR { namespace Libruntime { @@ -32,7 +33,9 @@ public: MOCK_METHOD(ErrorInfo, Init, (const std::string &addr, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, - const std::string &dsPublicKey, std::int32_t connectTimeout),(override)); + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk, std::int32_t connectTimeout), + (override)); MOCK_METHOD1(Init, ErrorInfo(datasystem::ConnectOptions &options)); MOCK_METHOD(ErrorInfo, Put, (std::shared_ptr data, const std::string &objID, @@ -47,7 +50,9 @@ public: MOCK_METHOD((std::pair>), DecreGlobalReference, (const std::vector &objectIds, const std::string &remoteId), (override)); MOCK_METHOD(std::vector, QueryGlobalReference, (const std::vector &objectIds), (override)); + MOCK_METHOD(ErrorInfo, ReleaseGRefs, (const std::string &remoteId), (override)); MOCK_METHOD(ErrorInfo, GenerateKey, (std::string & key, const std::string &prefix, bool isPut), (override)); + MOCK_METHOD(ErrorInfo, GetPrefix, (const std::string &key, std::string &prefix), (override)); MOCK_METHOD(ErrorInfo, CreateBuffer, (const std::string &objectId, size_t dataSize, std::shared_ptr &dataBuf, const CreateParam &createParam), @@ -59,6 +64,8 @@ public: MOCK_METHOD(void, SetTenantId, (const std::string &tenantId), (override)); MOCK_METHOD(void, Clear, (), (override)); MOCK_METHOD(void, Shutdown, (), (override)); + MOCK_METHOD(ErrorInfo, UpdateToken, (datasystem::SensitiveValue token), (override)); + MOCK_METHOD(ErrorInfo, UpdateAkSk, (std::string ak, datasystem::SensitiveValue sk), (override)); }; class MockStateStore : public StateStore { @@ -68,7 +75,8 @@ public: MOCK_METHOD(ErrorInfo, Init, (const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, - const std::string &dsPublicKey, std::int32_t connectTimeout), + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, const std::string &ak, + const datasystem::SensitiveValue &sk, std::int32_t connectTimeout), (override)); MOCK_METHOD1(Init, ErrorInfo(datasystem::ConnectOptions &options)); MOCK_METHOD1(Init, ErrorInfo(const DsConnectOptions &options)); @@ -80,19 +88,47 @@ public: MOCK_METHOD3(Read, MultipleReadResult(const std::vector &keys, int timeoutMS, bool allowPartial)); MOCK_METHOD3(GetWithParam, MultipleReadResult(const std::vector &keys, const GetParams ¶ms, int timeout)); + MOCK_METHOD2(QuerySize, ErrorInfo(const std::vector &keys, std::vector &outSizes)); MOCK_METHOD1(Del, ErrorInfo(const std::string &key)); MOCK_METHOD1(Del, MultipleDelResult(const std::vector &keys)); + MOCK_METHOD1(Exist, MultipleExistResult(const std::vector &keys)); MOCK_METHOD0(Shutdown, void()); + MOCK_METHOD1(UpdateToken, ErrorInfo(datasystem::SensitiveValue token)); + MOCK_METHOD2(UpdateAkSk, ErrorInfo(std::string ak, datasystem::SensitiveValue sk)); MOCK_METHOD1(GenerateKey, ErrorInfo(std::string &returnKey)); + MOCK_METHOD0(StartHealthCheck, ErrorInfo()); + MOCK_METHOD0(HealthCheck, ErrorInfo()); +}; + +class MockStreamStore : public StreamStore { +public: + MOCK_METHOD2(Init, ErrorInfo(const std::string &ip, int port)); + MOCK_METHOD10(Init, + ErrorInfo(const std::string &ip, int port, bool enableDsAuth, bool encryptEnable, + const std::string &runtimePublicKey, const datasystem::SensitiveValue &runtimePrivateKey, + const std::string &dsPublicKey, const datasystem::SensitiveValue &token, + const std::string &ak, const datasystem::SensitiveValue &sk)); + MOCK_METHOD1(Init, ErrorInfo(datasystem::ConnectOptions &options)); + MOCK_METHOD2(Init, ErrorInfo(datasystem::ConnectOptions &options, std::shared_ptr dsStateStore)); + MOCK_METHOD3(CreateStreamProducer, ErrorInfo(const std::string &streamName, + std::shared_ptr &producer, ProducerConf producerConf)); + MOCK_METHOD4(CreateStreamConsumer, ErrorInfo(const std::string &streamName, const SubscriptionConfig &config, + std::shared_ptr &consumer, bool autoAck)); + MOCK_METHOD1(DeleteStream, ErrorInfo(const std::string &streamName)); + MOCK_METHOD2(QueryGlobalProducersNum, ErrorInfo(const std::string &streamName, uint64_t &gProducerNum)); + MOCK_METHOD2(QueryGlobalConsumersNum, ErrorInfo(const std::string &streamName, uint64_t &gProducerNum)); + MOCK_METHOD0(Shutdown, void()); + MOCK_METHOD1(UpdateToken, ErrorInfo(datasystem::SensitiveValue token)); + MOCK_METHOD2(UpdateAkSk, ErrorInfo(std::string ak, datasystem::SensitiveValue sk)); }; class MockHeretoStore : public HeteroStore { public: MOCK_METHOD1(Init, ErrorInfo(datasystem::ConnectOptions & options)); MOCK_METHOD0(Shutdown, void()); - MOCK_METHOD2(Delete, + MOCK_METHOD2(DevDelete, ErrorInfo(const std::vector &objectIds, std::vector &failedObjectIds)); - MOCK_METHOD2(LocalDelete, + MOCK_METHOD2(DevLocalDelete, ErrorInfo(const std::vector &objectIds, std::vector &failedObjectIds)); MOCK_METHOD3(DevSubscribe, ErrorInfo(const std::vector &keys, const std::vector &blob2dList, @@ -105,5 +141,21 @@ public: MOCK_METHOD4(DevMGet, ErrorInfo(const std::vector &keys, const std::vector &blob2dList, std::vector &failedKeys, int32_t timeoutMs)); }; + +class MockStreamProducer : public StreamProducer { +public: + MOCK_METHOD1(Send, ErrorInfo(const Element &element)); + MOCK_METHOD2(Send, ErrorInfo(const Element &element, int64_t timeoutMs)); + MOCK_METHOD0(Flush, ErrorInfo()); + MOCK_METHOD0(Close, ErrorInfo()); +}; + +class MockStreamConsumer : public StreamConsumer { +public: + MOCK_METHOD3(Receive, ErrorInfo(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements)); + MOCK_METHOD2(Receive, ErrorInfo(uint32_t timeoutMs, std::vector &outElements)); + MOCK_METHOD1(Ack, ErrorInfo(uint64_t elementId)); + MOCK_METHOD0(Close, ErrorInfo()); +}; } // namespace Libruntime } // namespace YR diff --git a/test/libruntime/mock/mock_datasystem_client.cpp b/test/libruntime/mock/mock_datasystem_client.cpp index 1b953b2..f3a4651 100644 --- a/test/libruntime/mock/mock_datasystem_client.cpp +++ b/test/libruntime/mock/mock_datasystem_client.cpp @@ -24,11 +24,108 @@ #include "datasystem/hetero_client.h" #include "datasystem/object_client.h" #include "datasystem/kv_client.h" +#include "datasystem/stream_client.h" namespace datasystem { class ThreadPool { }; +class StreamClientImpl { +}; +StreamClient::StreamClient(std::string ip, int port, const std::string &clientPublicKey, + const SensitiveValue &clientPrivateKey, const std::string &serverPublicKey, + const std::string &accessKey, const SensitiveValue &secretKey) +{ +} + +Status StreamClient::Init(bool reportWorkerLost) +{ + return Status::OK(); +} + +Status StreamClient::CreateProducer(const std::string &streamName, std::shared_ptr &outProducer, + ProducerConf producerConf) +{ + return Status::OK(); +} + +Status StreamClient::Subscribe(const std::string &streamName, const struct SubscriptionConfig &config, + std::shared_ptr &outConsumer, bool autoAck) +{ + // autoAck default is false + return Status::OK(); +} + +Status StreamClient::DeleteStream(const std::string &streamName) +{ + return Status::OK(); +} + +Status StreamClient::QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) +{ + return Status::OK(); +} + +Status StreamClient::QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) +{ + return Status::OK(); +} + +Status StreamClient::ShutDown() +{ + return Status::OK(); +} + +Status Producer::Send(const Element &element) +{ + return Status::OK(); +} + +Status Producer::Send(const Element &element, int64_t timeoutMs) +{ + return Status::OK(); +} + +Status Producer::Flush() +{ + return Status::OK(); +} + +Status Producer::Close() +{ + return Status::OK(); +} + +Status Consumer::Receive(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) +{ + Element element = {.ptr = nullptr, .size = sizeof(int), .id = ULONG_MAX}; + outElements.emplace_back(element); + if (expectNum == 999) { + for (int i = 1; i < 999; ++i) { + Element element = {.ptr = nullptr, .size = 100, .id = ULONG_MAX}; + outElements.emplace_back(element); + } + } + return Status::OK(); +} + +Status Consumer::Receive(uint32_t timeoutMs, std::vector &outElements) +{ + Element element = {.ptr = nullptr, .size = sizeof(int), .id = ULONG_MAX}; + outElements.emplace_back(element); + return Status::OK(); +} + +Status Consumer::Close() +{ + return Status::OK(); +} + +Status Consumer::Ack(uint64_t elementId) +{ + return Status::OK(); +} + class ObjectClientImpl { }; @@ -92,7 +189,12 @@ int ObjectClient::QueryGlobalRefNum(const std::string &id) return 1; } -Status ObjectClient::GenerateObjectKey(const std::string &prefix, std::string &key) +Status ObjectClient::ReleaseGRefs(const std::string &remoteClientId) +{ + return Status::OK(); +} + +Status ObjectClient::GenerateKey(const std::string &prefix, std::string &key) { key = prefix; return Status::OK(); @@ -163,7 +265,7 @@ Status Buffer::Publish(const std::unordered_set &nestedIds) return Status::OK(); } -class StateClientImpl { +class KVClientImpl { }; KVClient::KVClient(const ConnectOptions &connectOptions){}; @@ -192,9 +294,9 @@ std::string KVClient::Set(const StringView &val, const SetParam &setParam) return "returnKey"; } -Status KVClient::GenerateKey(const std::string &prefixKey, std::string &key) +std::string KVClient::GenerateKey(const std::string &prefixKey) { - return Status::OK(); + return "genKey"; } Status KVClient::Get(const std::string &key, std::string &val, int32_t timeoutMs) @@ -258,11 +360,24 @@ Status KVClient::Del(const std::vector &keys, std::vector &keys, std::vector &exists) +{ + if (keys.empty()) { + return Status(StatusCode::K_INVALID, "The keys are empty"); + } + return Status::OK(); +} + Status KVClient::ShutDown() { return Status::OK(); } +Status KVClient::HealthCheck() +{ + return Status::OK(); +} + Status::Status() noexcept : code_(StatusCode::K_OK) {} Status::Status(StatusCode code, std::string msg) @@ -337,13 +452,13 @@ Status HeteroClient::ShutDown() return Status::OK(); } -Status HeteroClient::MGetH2D(const std::vector &keys, const std::vector &devBlobList, - std::vector &failedKeys, int32_t subTimeoutMs) +Status HeteroClient::MGetH2D(const std::vector &objectIds, const std::vector &devBlobList, + std::vector &failList, int32_t timeoutMs) { return Status::OK(); } -Status HeteroClient::Delete(const std::vector &objectIds, std::vector &failedObjectIds) +Status HeteroClient::DevDelete(const std::vector &objectIds, std::vector &failedObjectIds) { return Status::OK(); } @@ -381,8 +496,8 @@ Status HeteroClient::DevMSet(const std::vector &keys, const std::ve return Status::OK(); } -Status HeteroClient::DevMGet(const std::vector &keys, std::vector &devBlobList, - std::vector &failedKeys, int32_t subTimeoutMs) +Status HeteroClient::DevMGet(const std::vector &keys, std::vector &blob2dList, + std::vector &failedKeys, int32_t timeoutMs) { return Status::OK(); } diff --git a/test/libruntime/mock/mock_fs_intf.h b/test/libruntime/mock/mock_fs_intf.h index 091e724..6cc52a8 100644 --- a/test/libruntime/mock/mock_fs_intf.h +++ b/test/libruntime/mock/mock_fs_intf.h @@ -20,7 +20,7 @@ namespace YR { namespace test { -class MockFSIntfClient : public YR::Libruntime::FSIntf { +class MockGwClient : public YR::Libruntime::FSIntf { public: MOCK_METHOD4(CreateAsync, void(const YR::Libruntime::CreateRequest &, YR::Libruntime::CreateRespCallback, YR::Libruntime::CreateCallBack, int)); @@ -29,6 +29,7 @@ public: MOCK_METHOD5(Start, YR::Libruntime::ErrorInfo(const std::string &, const std::string &, const std::string &, const std::string &, const YR::Libruntime::SubscribeFunc &)); MOCK_METHOD0(Stop, void(void)); + MOCK_METHOD0(IsHealth, bool(void)); MOCK_METHOD3(InvokeAsync, void(const std::shared_ptr &, YR::Libruntime::InvokeCallBack, int)); MOCK_METHOD2(CallResultAsync, void(const std::shared_ptr, diff --git a/test/libruntime/mock/mock_fs_intf_rw.h b/test/libruntime/mock/mock_fs_intf_rw.h index e698a82..9c79629 100644 --- a/test/libruntime/mock/mock_fs_intf_rw.h +++ b/test/libruntime/mock/mock_fs_intf_rw.h @@ -23,6 +23,7 @@ public: MOCK_METHOD(void, Stop, (), (override)); MOCK_METHOD(bool, Available, (), (override)); MOCK_METHOD(bool, Abnormal, (), (override)); + MOCK_METHOD(bool, IsHealth, (), (override)); }; } // namespace test } // namespace YR \ No newline at end of file diff --git a/test/libruntime/mock/mock_fs_intf_with_callback.h b/test/libruntime/mock/mock_fs_intf_with_callback.h index 8a50792..355e672 100644 --- a/test/libruntime/mock/mock_fs_intf_with_callback.h +++ b/test/libruntime/mock/mock_fs_intf_with_callback.h @@ -18,8 +18,8 @@ #include #include -#include "src/libruntime/fsclient/fs_intf.h" #include "src/dto/accelerate.h" +#include "src/libruntime/fsclient/fs_intf.h" #include "src/utility/logger/logger.h" #include "gmock/gmock.h" @@ -84,6 +84,7 @@ public: return ErrorInfo(); }; void Stop(void) override{}; + bool IsHealth(void) override { return true; }; void InvokeAsync(const std::shared_ptr &req, InvokeCallBack callback, int timeout) override { NotifyRequest notifyReq; @@ -105,6 +106,13 @@ public: instanceResp["schedulerTime"] = 1.0; auto smallObject = notifyReq.add_smallobjects(); smallObject->set_value(META_PREFIX + instanceResp.dump()); + } else if (isBatchRenew) { + auto smallObject = notifyReq.add_smallobjects(); + if (isReqNormal) { + smallObject->set_value(META_PREFIX + "{\"errorCode\": 6030,\"instanceAllocSucceed\": {},\"instanceAllocFailed\": {\"leaseId\": {\"errorCode\": 111111,\"errorMessage\": \"sssss\"}},\"leaseInterval\": 1000,\"schedulerTime\": 0.000123337}"); + } else { + smallObject->set_value(""); + } } else { if (isReqNormal) { notifyReq.set_code(::common::ErrorCode::ERR_NONE); @@ -151,13 +159,21 @@ public: } else { resp.set_code(::common::ErrorCode::ERR_SCHEDULE_PLUGIN_CONFIG); } - AccelerateMsgQueueHandle handler{.name = "name"}; - resp.set_message(handler.ToJson()); + if (isGetInstance) { + std::string serializedMeta; + libruntime::FunctionMeta meta; + meta.set_classname("classname"); + meta.SerializeToString(&serializedMeta); + resp.set_message(serializedMeta); + } else { + AccelerateMsgQueueHandle handler{.name = "name"}; + resp.set_message(handler.ToJson()); + } callback(resp, ErrorInfo()); try { killCallbackPromise.set_value(1); - } catch (const std::exception &e) { - std::cout << "Promise already satisfied" << std::endl; + } catch (const std::future_error &e) { + YRLOG_DEBUG("killCallbackPromise has already set value"); } }; void ExitAsync(const ExitRequest &req, ExitCallBack callback) override{}; @@ -165,13 +181,15 @@ public: void StateLoadAsync(const StateLoadRequest &req, StateLoadCallBack callback) override{}; void CreateRGroupAsync(const CreateResourceGroupRequest &req, CreateResourceGroupCallBack callback, int timeout) override{}; - + MOCK_METHOD(void, ReturnCallResult, (const std::shared_ptr result, bool isCreate, CallResultCallBack callback), (override)); bool isReqNormal = true; + bool isGetInstance = false; bool isAcquireResponse = false; + bool isBatchRenew = false; bool needCheckArgs = false; std::promise callbackPromise = std::promise(); std::promise killCallbackPromise = std::promise(); diff --git a/test/libruntime/mock/mock_invoke_adaptor.h b/test/libruntime/mock/mock_invoke_adaptor.h index 1065e53..56b7d03 100644 --- a/test/libruntime/mock/mock_invoke_adaptor.h +++ b/test/libruntime/mock/mock_invoke_adaptor.h @@ -39,6 +39,9 @@ public: MOCK_METHOD3(KillAsync, void(const std::string &instanceId, const std::string &payload, int sigNo)); + MOCK_METHOD4(KillAsyncCB, void(const std::string &instanceId, const std::string &payload, int signal, + std::function cb)); + MOCK_METHOD2(GroupCreate, ErrorInfo(const std::string &groupName, GroupOpts &opts)); MOCK_METHOD2(RangeCreate, ErrorInfo(const std::string &groupName, InstanceRange &range)); @@ -56,9 +59,21 @@ public: MOCK_METHOD1(ExecShutdownCallback, ErrorInfo(uint64_t gracePeriodSec)); + MOCK_METHOD3(AcquireInstance, + std::pair(const std::string &stateId, const FunctionMeta &functionMeta, + InvokeOptions &opts)); + + MOCK_METHOD4(ReleaseInstance, + ErrorInfo(const std::string &leaseId, const std::string &stateId, bool abnormal, InvokeOptions &opts)); + MOCK_METHOD3(GetInstance, std::pair(const std::string &name, const std::string &nameSpace, int timeoutSec)); + MOCK_METHOD3(UpdateSchdulerInfo, + void(const std::string &scheduleName, const std::string &schedulerId, const std::string &option)); + + MOCK_METHOD1(EraseFsIntf, void(const std::string &id)); + MOCK_METHOD0(IsHealth, bool()); }; } // namespace Libruntime } // namespace YR diff --git a/test/libruntime/mock/mock_task_submitter.h b/test/libruntime/mock/mock_task_submitter.h index fdc0177..92b0b07 100644 --- a/test/libruntime/mock/mock_task_submitter.h +++ b/test/libruntime/mock/mock_task_submitter.h @@ -23,6 +23,14 @@ namespace test { class MockTaskSubmitter : public YR::Libruntime::TaskSubmitter { public: MOCK_METHOD0(Init, void(void)); + MOCK_METHOD2(AcquireInstance, std::pair(const std::string &stateId, + std::shared_ptr spec)); + MOCK_METHOD4(ReleaseInstance, ErrorInfo(const std::string &leaseId, const std::string &stateId, bool abnormal, + std::shared_ptr spec)); + MOCK_METHOD3(UpdateSchdulerInfo, + void(const std::string &scheduleName, const std::string &schedulerId, const std::string &option)); + MOCK_METHOD2(UpdateFaaSSchedulerInfo, void(std::string schedulerFuncKey, + const std::vector &schedulerInstanceList)); }; } // namespace test diff --git a/test/libruntime/normal_instance_manager_test.cpp b/test/libruntime/normal_instance_manager_test.cpp index fc72893..476141a 100644 --- a/test/libruntime/normal_instance_manager_test.cpp +++ b/test/libruntime/normal_instance_manager_test.cpp @@ -92,6 +92,7 @@ public: } void TearDown() override { + insManager->Stop(); spec.reset(); mockFsIntf.reset(); insManager.reset(); @@ -196,7 +197,7 @@ TEST_F(NormalInsManagerTest, ScaleCancelAll) insManager->ScaleCancel(resource, newQueue->Size(), true); std::dynamic_pointer_cast(mockFsIntf)->killCallbackFuture.get(); - ASSERT_TRUE(insManager->requestResourceInfoMap[resource]->creatingIns.size() == 0); + ASSERT_TRUE(insManager->requestResourceInfoMap.find(resource) == insManager->requestResourceInfoMap.end()); } TEST_F(NormalInsManagerTest, When_StartNormalInsScaleDownTimer_Twice_Should_Be_Ok) diff --git a/test/libruntime/object_store_test.cpp b/test/libruntime/object_store_test.cpp index 0167ff2..33c0cd6 100644 --- a/test/libruntime/object_store_test.cpp +++ b/test/libruntime/object_store_test.cpp @@ -101,6 +101,24 @@ TEST_F(ObjectStoreTest, GetTest) ASSERT_EQ(res.first.Code(), ErrorCode::ERR_OK); } +TEST_F(ObjectStoreTest, UpdateTokenAndAKSKTest) +{ + datasystem::SensitiveValue token("token"); + ErrorInfo err = objectStore_->UpdateToken(token); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + datasystem::SensitiveValue sk("sk"); + err = objectStore_->UpdateAkSk("ak", sk); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + std::shared_ptr dataBuf = std::make_shared(1000); + CreateParam createParam; + err = objectStore_->CreateBuffer("objID", 1000, dataBuf, createParam); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = objectStore_->UpdateToken(token); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = objectStore_->UpdateAkSk("ak", sk); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + TEST_F(ObjectStoreTest, IncreGlobalReferenceTest) { std::vector objectIds = {"objID"}; @@ -129,6 +147,12 @@ TEST_F(ObjectStoreTest, QueryGlobalReferenceTest) ASSERT_EQ(res.at(0), 1); } +TEST_F(ObjectStoreTest, ReleaseGRefsTest) +{ + auto err = objectStore_->ReleaseGRefs("remoteID"); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + TEST_F(ObjectStoreTest, GenerateKeyTest) { std::string key; diff --git a/test/libruntime/request_queue_test.cpp b/test/libruntime/request_queue_test.cpp new file mode 100644 index 0000000..de146fc --- /dev/null +++ b/test/libruntime/request_queue_test.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "src/libruntime/invokeadaptor/request_queue.h" + +using namespace testing; +using namespace YR::Libruntime; + +namespace YR { +namespace test { +class RequestQueueTest : public ::testing::Test { +public: + RequestQueueTest() {} + ~RequestQueueTest() {} +}; + +TEST_F(RequestQueueTest, TestPriorityQueue) +{ + auto queue = PriorityQueue(); + ASSERT_TRUE(queue.Empty()); + ASSERT_EQ(queue.Size(), 0); + auto spec = std::make_shared(); + spec->requestId = "1"; + spec->opts.priority = 1; + queue.Push(spec); + auto spec2 = std::make_shared(); + spec2->requestId = "2"; + spec2->opts.priority = 2; + queue.Push(spec2); + ASSERT_FALSE(queue.Empty()); + ASSERT_EQ(queue.Size(), 2); + ASSERT_EQ(queue.Top()->requestId, "2"); + queue.Pop(); + queue.Pop(); + ASSERT_TRUE(queue.Empty()); + ASSERT_EQ(queue.Size(), 0); +} + +TEST_F(RequestQueueTest, TestQueue) +{ + auto queue = Queue(); + ASSERT_TRUE(queue.Empty()); + ASSERT_EQ(queue.Size(), 0); + auto spec = std::make_shared(); + spec->requestId = "1"; + spec->opts.priority = 1; + queue.Push(spec); + auto spec2 = std::make_shared(); + spec2->requestId = "2"; + spec2->opts.priority = 2; + queue.Push(spec2); + ASSERT_FALSE(queue.Empty()); + ASSERT_EQ(queue.Size(), 2); + ASSERT_EQ(queue.Top()->requestId, "1"); + queue.Pop(); + queue.Pop(); + ASSERT_TRUE(queue.Empty()); + ASSERT_EQ(queue.Size(), 0); +} + +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/libruntime/resource_group_test.cpp b/test/libruntime/resource_group_test.cpp index d8a235a..2d13f5e 100644 --- a/test/libruntime/resource_group_test.cpp +++ b/test/libruntime/resource_group_test.cpp @@ -90,7 +90,7 @@ TEST_F(ResourceGroupTest, ResourceGroupManagerTest) ASSERT_TRUE(!isExist); } -TEST_F(ResourceGroupTest, DISABLED_ResourceGroupTest) +TEST_F(ResourceGroupTest, ResourceGroupTest) { std::vector> bundles = { {{"CPU", 500.0}, {"Memory", 200.0}}, {{"CPU", 300.0}}, {}}; diff --git a/test/libruntime/rt_direct_call_test.cpp b/test/libruntime/rt_direct_call_test.cpp index 7ffd10d..b4db57d 100644 --- a/test/libruntime/rt_direct_call_test.cpp +++ b/test/libruntime/rt_direct_call_test.cpp @@ -31,7 +31,7 @@ namespace YR { namespace test { class FakeCallee { public: - FakeCallee(int port) : port(port) + FakeCallee() { handlers_.call = std::bind(&FakeCallee::EmptyCallHandler, this, _1); handlers_.init = [](const std::shared_ptr &){ return CallResponse(); }; @@ -46,7 +46,6 @@ public: void Start(const std::shared_ptr &functionProxy) { Config::Instance().POD_IP() = Config::Instance().HOST_IP(); - Config::Instance().DERICT_RUNTIME_SERVER_PORT() = port; fsClient_ = std::make_shared(Config::Instance().HOST_IP(), functionProxy->GetPort(), handlers_, false, security_, clientsMgr, true); auto err = fsClient_->Start("12345678", "callee", "callee"); @@ -58,6 +57,11 @@ public: fsClient_->SetInitialized(); } + int GetPort() + { + return fsClient_->selfPort; + } + void Stop() { if (fsClient_ != nullptr) { @@ -78,16 +82,16 @@ public: CallResult res; res.set_requestid(req->Immutable().requestid()); res.set_instanceid(req->Immutable().senderid()); - auto result = std::make_shared(); - result->Mutable() = std::move(res); - fsClient_->ReturnCallResult(result, false, [this](const CallResultAck &resp) { - if (resp.code() != common::ERR_NONE) { - YRLOG_WARN("failed to send CallResult, code: {}, message: {}", static_cast(resp.code()), - resp.message()); - } - return; - }); - }, ""); + auto result = std::make_shared(); + result->Mutable() = std::move(res); + fsClient_->ReturnCallResult(result, false, [this](const CallResultAck &resp) { + if (resp.code() != common::ERR_NONE) { + YRLOG_WARN("failed to send CallResult, code: {}, message: {}", fmt::underlying(resp.code()), + resp.message()); + } + return; + }); + }, ""); return resp; } @@ -96,7 +100,7 @@ public: std::shared_ptr fsClient_; std::shared_ptr clientsMgr; std::thread t; - std::shared_ptr security_ = std::make_shared(); + std::shared_ptr security_ = std::make_shared();; int port; }; @@ -124,7 +128,6 @@ public: .modelName = "test", .maxSize = 100, .maxFiles = 1, - .retentionDays = DEFAULT_RETENTION_DAYS, .logFileWithTime = false, .logBufSecs = 30, .maxAsyncQueueSize = 1048510, @@ -135,9 +138,10 @@ public: Config::Instance().RUNTIME_DIRECT_CONNECTION_ENABLE() = true; clientsMgr = std::make_shared(); functionProxy = std::make_shared(Config::Instance().HOST_IP()); - fakeCallee = std::make_shared(calleePort); + fakeCallee = std::make_shared(); functionProxy->Start(); fakeCallee->Start(functionProxy); + calleePort = fakeCallee->GetPort(); } void TearDown() override @@ -213,7 +217,7 @@ public: auto err = notified.WaitForNotification(); EXPECT_TRUE(err.OK()); } - auto [channel, err] = clientsMgr->GetFsConn(Config::Instance().HOST_IP(), calleePort); + auto [channel, err] = clientsMgr->GetFsConn(Config::Instance().HOST_IP(), calleePort, "callee"); EXPECT_TRUE(channel != nullptr); EXPECT_TRUE(err.OK()); } @@ -273,7 +277,7 @@ TEST_F(RTDirectCallTest, FunctionProxyDisconnectedTest) auto err = notified.WaitForNotification(); EXPECT_TRUE(err.OK()); } - auto [channel, err] = clientsMgr->GetFsConn(Config::Instance().HOST_IP(), calleePort); + auto [channel, err] = clientsMgr->GetFsConn(Config::Instance().HOST_IP(), calleePort, "callee"); EXPECT_TRUE(channel != nullptr); EXPECT_TRUE(err.OK()); } diff --git a/test/libruntime/scheduler_instance_info_test.cpp b/test/libruntime/scheduler_instance_info_test.cpp new file mode 100644 index 0000000..96780bf --- /dev/null +++ b/test/libruntime/scheduler_instance_info_test.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include +#include +#include +#include + +#include "src/libruntime/invokeadaptor/scheduler_instance_info.h" + +namespace YR { +namespace test { +class SchedulerInstanceInfoTest : public ::testing::Test { + void SetUp() override {} + + void TearDown() override {} +}; + +std::string g_scheduler_info = R"( +{ + "schedulerFuncKey": "0/faasscheduler/0", + "schedulerInstanceList": [ + { + "instanceName": "faasscheduler0", + "instanceId": "faasscheduler0-id" + }, + { + "instanceName": "faasscheduler1", + "instanceId": "faasscheduler1-id" + }, + { + "instanceName": "faasscheduler2", + "instanceId": "faasscheduler2-id" + } + ] +} +)"; + +TEST_F(SchedulerInstanceInfoTest, ParseSchedulerInfoTest) +{ + Libruntime::SchedulerInfo schedulerInfo; + ParseSchedulerInfo(g_scheduler_info, schedulerInfo); + + ASSERT_EQ(schedulerInfo.schedulerFuncKey, "0/faasscheduler/0"); + ASSERT_EQ(schedulerInfo.schedulerInstanceList.size(), 3); + ASSERT_EQ(schedulerInfo.schedulerInstanceList[0].InstanceName, "faasscheduler0"); + ASSERT_EQ(schedulerInfo.schedulerInstanceList[0].InstanceID, "faasscheduler0-id"); + + ASSERT_EQ(schedulerInfo.schedulerInstanceList[1].InstanceName, "faasscheduler1"); + ASSERT_EQ(schedulerInfo.schedulerInstanceList[1].InstanceID, "faasscheduler1-id"); + + ASSERT_EQ(schedulerInfo.schedulerInstanceList[2].InstanceName, "faasscheduler2"); + ASSERT_EQ(schedulerInfo.schedulerInstanceList[2].InstanceID, "faasscheduler2-id"); + + auto err = ParseSchedulerInfo("err json", schedulerInfo); + ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID); +} +} +} \ No newline at end of file diff --git a/test/libruntime/security_test.cpp b/test/libruntime/security_test.cpp index 45c3a00..5024793 100644 --- a/test/libruntime/security_test.cpp +++ b/test/libruntime/security_test.cpp @@ -40,6 +40,7 @@ struct TestParam { std::string dssPubKey; bool fsEnable; std::string rootCACert; + std::string token; bool fsServerMode; std::string serverNameOverride; }; @@ -54,6 +55,7 @@ TestParam BuildOneCommonTestParam() .dssPubKey = "ds-ser-pub-key", .fsEnable = true, .rootCACert = "root-ca-cert", + .token = "token0", .fsServerMode = true, .serverNameOverride = "server-name-override", }; @@ -68,6 +70,7 @@ std::string BuildOneCommonTlsConfigStr(TestParam tp) tlsConfig.set_dsserverpublickey(tp.dssPubKey); tlsConfig.set_serverauthenable(tp.fsEnable); tlsConfig.set_rootcertdata(tp.rootCACert); + tlsConfig.set_token(tp.token); tlsConfig.set_enableservermode(tp.fsServerMode); tlsConfig.set_servernameoverride(tp.serverNameOverride); auto m = tlsConfig.SerializeAsString(); @@ -115,6 +118,7 @@ TEST_F(SecurityTest, ParseNormalConfigTest) std::string dssPubKey; bool fsEnable; std::string rootCACert; + std::string token; bool fsServerMode; std::string serverNameOverride; } tps[] = { @@ -130,6 +134,7 @@ TEST_F(SecurityTest, ParseNormalConfigTest) .dssPubKey = "ds-ser-pub-key", .fsEnable = true, .rootCACert = "root-ca-cert", + .token = "token0", .fsServerMode = true, .serverNameOverride = "server-name-override", }, @@ -140,6 +145,7 @@ TEST_F(SecurityTest, ParseNormalConfigTest) .dssPubKey = "ds-ser-pub-keyx", .fsEnable = false, .rootCACert = "root-ca-certx", + .token = "tokenx", .fsServerMode = true, .serverNameOverride = "server-name-overridex", }, @@ -159,6 +165,7 @@ TEST_F(SecurityTest, ParseNormalConfigTest) tlsConfig.set_dsserverpublickey(tp.dssPubKey); tlsConfig.set_serverauthenable(tp.fsEnable); tlsConfig.set_rootcertdata(tp.rootCACert); + tlsConfig.set_token(tp.token); tlsConfig.set_enableservermode(tp.fsServerMode); tlsConfig.set_servernameoverride(tp.serverNameOverride); auto m = tlsConfig.SerializeAsString(); @@ -177,6 +184,8 @@ TEST_F(SecurityTest, ParseNormalConfigTest) std::string certChainData; std::string privateKey; bool fsEnable = s.GetFunctionSystemConfig(rootCACert, certChainData, privateKey); + SensitiveValue token; + s.GetToken(token); std::string sernameOverride; bool connMode = s.GetFunctionSystemConnectionMode(sernameOverride); ASSERT_EQ(dsEnable, tp.dsEnable); @@ -184,6 +193,8 @@ TEST_F(SecurityTest, ParseNormalConfigTest) ASSERT_EQ(dscPriKey, tp.dscPriKey); ASSERT_EQ(dssPubKey, tp.dssPubKey); ASSERT_EQ(fsEnable, tp.fsEnable); + ASSERT_EQ(rootCACert, tp.rootCACert); + ASSERT_EQ(token, tp.token); ASSERT_EQ(connMode, tp.fsServerMode); ASSERT_EQ(sernameOverride, tp.serverNameOverride); } @@ -207,10 +218,132 @@ TEST_F(SecurityTest, UpdateHandlerSizeTest) ssize_t nbytes = write(fildes[1], m.data(), m.size()); ASSERT_EQ(nbytes, m.size()); } + auto s = Security(); + auto err = s.Init(); + s.WhenTokenUpdated([](const SensitiveValue &token) {}); + ASSERT_EQ(s.GetUpdateHandersSize(), 0); + close(fildes[0]); + close(fildes[1]); +} + +TEST_F(SecurityTest, UpdateTokenTest) +{ + struct TestParam { + bool dsEnable; + std::string dscPubKey; + std::string dscPriKey; + std::string dssPubKey; + bool fsEnable; + std::string rootCACert; + std::string token; + bool fsServerMode; + std::string serverNameOverride; + } tps[] = { + { + .dsEnable = true, + .dscPubKey = "ds-cli-pub-key2", + .dscPriKey = "ds-cli-pri-key2", + .dssPubKey = "ds-ser-pub-key2", + .fsEnable = true, + .rootCACert = "root-ca-cert2", + .token = "token2", + .fsServerMode = true, + .serverNameOverride = "server-name-override2", + }, + { + .dsEnable = true, + .dscPubKey = "ds-cli-pub-key3", + .dscPriKey = "ds-cli-pri-key3", + .dssPubKey = "ds-ser-pub-key3", + .fsEnable = true, + .rootCACert = "root-ca-cert3", + .token = "token3", + .fsServerMode = true, + .serverNameOverride = "server-name-override3", + }, + }; + + int fildes[2]; + int status = pipe(fildes); + ASSERT_NE(status, -1); + status = dup2(fildes[0], STDIN_FILENO); + ASSERT_NE(status, -1); + + auto s = Security(); + + auto &tp = tps[0]; + + TLSConfig tlsConfig; + tlsConfig.set_dsauthenable(tp.dsEnable); + tlsConfig.set_dsclientpublickey(tp.dscPubKey); + tlsConfig.set_dsclientprivatekey(tp.dscPriKey); + tlsConfig.set_dsserverpublickey(tp.dssPubKey); + tlsConfig.set_serverauthenable(tp.fsEnable); + tlsConfig.set_rootcertdata(tp.rootCACert); + tlsConfig.set_token(tp.token); + tlsConfig.set_enableservermode(tp.fsServerMode); + tlsConfig.set_servernameoverride(tp.serverNameOverride); + auto m = tlsConfig.SerializeAsString(); + YRLOG_ERROR("1 message size: {}", m.size()); + if (m.size()) { + ssize_t nbytes = write(fildes[1], m.data(), m.size()); + ASSERT_EQ(nbytes, m.size()); + } + auto err = s.Init(); + ASSERT_TRUE(err.OK()); + std::string dscPubKey, dssPubKey; + SensitiveValue dscPriKey; + auto [dsEnable, encryptEnable] = s.GetDataSystemConfig(dscPubKey, dscPriKey, dssPubKey); + std::string rootCACert; + std::string certChainData; + std::string privateKey; + bool fsEnable = s.GetFunctionSystemConfig(rootCACert, certChainData, privateKey); + SensitiveValue token; + s.GetToken(token); + std::string sernameOverride; + bool connMode = s.GetFunctionSystemConnectionMode(sernameOverride); + ASSERT_EQ(dsEnable, tp.dsEnable); + ASSERT_EQ(dscPubKey, tp.dscPubKey); + ASSERT_EQ(dscPriKey, tp.dscPriKey); + ASSERT_EQ(dssPubKey, tp.dssPubKey); + ASSERT_EQ(fsEnable, tp.fsEnable); + ASSERT_EQ(rootCACert, tp.rootCACert); + ASSERT_EQ(token, tp.token); + ASSERT_EQ(connMode, tp.fsServerMode); + ASSERT_EQ(sernameOverride, tp.serverNameOverride); + + SensitiveValue tokenTest; + YR::utility::NotificationUtility n; + s.WhenTokenUpdated([&tokenTest, &n](const SensitiveValue &token) { + tokenTest = token; + n.Notify(); + }); + + tp = tps[1]; + + tlsConfig.set_dsauthenable(tp.dsEnable); + tlsConfig.set_dsclientpublickey(tp.dscPubKey); + tlsConfig.set_dsclientprivatekey(tp.dscPriKey); + tlsConfig.set_dsserverpublickey(tp.dssPubKey); + tlsConfig.set_serverauthenable(tp.fsEnable); + tlsConfig.set_rootcertdata(tp.rootCACert); + tlsConfig.set_token(tp.token); + tlsConfig.set_enableservermode(tp.fsServerMode); + tlsConfig.set_servernameoverride(tp.serverNameOverride); + m = tlsConfig.SerializeAsString(); + YRLOG_ERROR("2 message size: {}", m.size()); + if (m.size()) { + ssize_t nbytes = write(fildes[1], m.data(), m.size()); + ASSERT_EQ(nbytes, m.size()); + } + n.WaitForNotification(); + s.GetToken(token); + ASSERT_EQ(token, tp.token); close(fildes[0]); close(fildes[1]); } +// DTS2024071820905 // ENABLE_DS_AUTH stdin Delay Write // there may be a problem if the file descriptor (fd) hasn't been created after start runtime. TEST_F(SecurityTest, DelayStdinShouldSuccessTest) @@ -235,10 +368,12 @@ TEST_F(SecurityTest, DelayStdinShouldSuccessTest) }); auto s = Security(STDIN_FILENO, 1000); auto err = s.Init(); - YRLOG_INFO("security init: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); promise->set_value(true); t.join(); ASSERT_TRUE(err.OK()); + SensitiveValue token; + s.GetToken(token); + ASSERT_EQ(token, tp.token); unsetenv("ENABLE_DS_AUTH"); Libruntime::Config::Instance() = Libruntime::Config(); close(fds[0]); @@ -254,7 +389,6 @@ TEST_F(SecurityTest, NoStdinShouldTimeoutFailedTest) auto m = BuildOneCommonTlsConfigStr(tp); auto s = Security(STDIN_FILENO, 1000); auto err = s.Init(); - YRLOG_INFO("security init: Code:{}, MCode:{}, Msg:{}", err.Code(), err.MCode(), err.Msg()); ASSERT_FALSE(err.OK()); unsetenv("ENABLE_DS_AUTH"); Libruntime::Config::Instance() = Libruntime::Config(); diff --git a/test/libruntime/stream_store_test.cpp b/test/libruntime/stream_store_test.cpp new file mode 100644 index 0000000..7f1d00d --- /dev/null +++ b/test/libruntime/stream_store_test.cpp @@ -0,0 +1,165 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "mock/mock_datasystem_client.h" +#include "src/dto/stream_conf.h" +#include "src/libruntime/libruntime.h" +#include "src/libruntime/streamstore/datasystem_stream_store.h" +#include "src/libruntime/streamstore/stream_producer_consumer.h" +#include "src/proto/libruntime.pb.h" +#include "src/utility/logger/logger.h" + +using namespace testing; +using namespace YR::Libruntime; +using namespace YR::utility; + +namespace YR { +namespace test { +class StreamStoreTest : public testing::Test { +public: + StreamStoreTest(){}; + ~StreamStoreTest(){}; + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + streamStore_ = std::make_shared(); + streamStore_->Init("127,0,0,1", 11111); + InitGlobalTimer(); + } + void TearDown() override + { + CloseGlobalTimer(); + streamStore_->Shutdown(); + } + + std::shared_ptr streamStore_; +}; + +TEST_F(StreamStoreTest, CreateStreamProducer) +{ + std::string streamName = "streamName"; + Libruntime::ProducerConf libProducerConf{}; + libProducerConf.extendConfig.insert({"STREAM_MODE", "MPMC"}); + std::shared_ptr emptyProducer = nullptr; + ErrorInfo err = streamStore_->CreateStreamProducer(streamName, emptyProducer, libProducerConf); + ASSERT_EQ(err.Code(), ErrorCode::ERR_PARAM_INVALID); + auto producer = std::make_shared(); + err = streamStore_->CreateStreamProducer(streamName, producer, libProducerConf); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + libProducerConf.extendConfig.insert({"XXXXX", "XXXXX"}); + err = streamStore_->CreateStreamProducer(streamName, producer, libProducerConf); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + libProducerConf.extendConfig["STREAM_MODE"] = "XXXXX"; + err = streamStore_->CreateStreamProducer(streamName, producer, libProducerConf); + ASSERT_EQ(err.Code(), ErrorCode::ERR_PARAM_INVALID); +} + +TEST_F(StreamStoreTest, CreateStreamConsumer) +{ + std::string streamName = "streamName"; + Libruntime::SubscriptionConfig subscriptionConfig("subName", libruntime::SubscriptionType::STREAM); + std::shared_ptr emptyConsumer = nullptr; + ErrorInfo err = streamStore_->CreateStreamConsumer(streamName, subscriptionConfig, emptyConsumer); + ASSERT_EQ(err.Code(), ErrorCode::ERR_PARAM_INVALID); + auto consumer = std::make_shared(); + err = streamStore_->CreateStreamConsumer(streamName, subscriptionConfig, consumer); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + +TEST_F(StreamStoreTest, DeleteStream) +{ + std::string streamName = "streamName"; + ErrorInfo err = streamStore_->DeleteStream(streamName); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + +TEST_F(StreamStoreTest, QueryGlobalProducersNum) +{ + std::string streamName = "streamName"; + uint64_t gProducerNum = 0; + ErrorInfo err = streamStore_->QueryGlobalProducersNum(streamName, gProducerNum); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + +TEST_F(StreamStoreTest, QueryGlobalConsumersNum) +{ + std::string streamName = "streamName"; + uint64_t gConsumerNum = 0; + ErrorInfo err = streamStore_->QueryGlobalConsumersNum(streamName, gConsumerNum); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + +TEST_F(StreamStoreTest, TestProducer) +{ + auto streamProducer = std::make_shared(); + std::string str = "hello"; + Libruntime::Element element((uint8_t *)(str.c_str()), str.size()); + ErrorInfo err = streamProducer->Send(element); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = streamProducer->Send(element, 1000); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = streamProducer->Flush(); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = streamProducer->Close(); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + +TEST_F(StreamStoreTest, TestConsumer) +{ + auto streamConsumer = std::make_shared(); + std::vector elements; + ErrorInfo err = streamConsumer->Receive(1, 1000, elements); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = streamConsumer->Receive(1000, elements); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = streamConsumer->Ack(11); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = streamConsumer->Close(); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + setenv("STREAM_RECEIVE_LIMIT", "1000", 1); + Libruntime::Config::Instance() = Libruntime::Config(); + err = streamConsumer->Receive(999, 1, elements); + ASSERT_EQ(err.Code(), ErrorCode::ERR_INNER_SYSTEM_ERROR); + unsetenv("STREAM_RECEIVE_LIMIT"); +} + +TEST_F(StreamStoreTest, TestUpdateTokenAndAksk) +{ + ErrorInfo err = streamStore_->UpdateAkSk("ak", datasystem::SensitiveValue("sk")); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); + err = streamStore_->UpdateToken(datasystem::SensitiveValue("token")); + ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); +} + +} // namespace test +} // namespace YR diff --git a/test/libruntime/task_submitter_test.cpp b/test/libruntime/task_submitter_test.cpp index bdc64b0..752147b 100644 --- a/test/libruntime/task_submitter_test.cpp +++ b/test/libruntime/task_submitter_test.cpp @@ -32,6 +32,7 @@ #define protected public #include "src/libruntime/invokeadaptor/instance_manager.h" #define private public +#include "src/libruntime/invokeadaptor/faas_instance_manager.h" #include "src/libruntime/invokeadaptor/normal_instance_manager.h" #include "src/libruntime/invokeadaptor/task_submitter.h" @@ -48,8 +49,8 @@ class TaskSubmitterTest : public testing::Test { public: TaskSubmitterTest(){}; ~TaskSubmitterTest(){}; - static void SetUpTestCase(){}; - static void TearDownTestCase(){}; + static void SetUpTestCase() {}; + static void TearDownTestCase() {}; void SetUp() override { Mkdir("/tmp/log"); @@ -79,7 +80,13 @@ public: KillFunc f = [](const std::string &instanceId, const std::string &payload, int signal) -> ErrorInfo { return ErrorInfo(); }; - taskSubmitter = std::make_shared(librtCfg, memoryStore, fsClient, reqMgr, f); + std::string functionId = ""; + auto clientsMgr = std::make_shared(); + auto security = std::make_shared(); + auto downgrade = std::make_shared(functionId, clientsMgr, security); + downgrade->Init(); + taskSubmitter = + std::make_shared(librtCfg, memoryStore, fsClient, reqMgr, f, nullptr, nullptr, downgrade); } void TearDown() override { @@ -93,11 +100,12 @@ public: protected: bool NeedRetryWrapper(std::shared_ptr &spec, ErrorCode errcode, bool &consume); - std::shared_ptr SetMaxConcurrencyInstanceNum(int concurrencyCreateNum = 100000); + std::shared_ptr SetMaxConcurrencyInstanceNum(int concurrencyCreateNum = 100000); void SubmitFunction(int total, bool differentResource = false); void CommonAssert(std::shared_ptr timerWorker, std::mutex &timersMtx, - std::shared_ptr mockFsIntf, std::vector> &timers, - bool differentResource = false); + std::shared_ptr mockFsIntf, + std::vector> &timers, bool differentResource = false, + int total = 10000); std::shared_ptr taskSubmitter; std::shared_ptr mockFsIntf; }; @@ -115,7 +123,7 @@ TEST_F(TaskSubmitterTest, ScheduleFunction) auto resource = GetRequestResource(spec); taskSubmitter->SubmitFunction(spec); absl::ReaderMutexLock lock(&taskSubmitter->reqMtx_); - ASSERT_EQ(taskSubmitter->waitScheduleReqMap_[resource]->Empty(), false); + ASSERT_EQ(taskSubmitter->taskSchedulerMap_[resource]->queue->Empty(), false); sleep(3); } @@ -142,7 +150,7 @@ TEST_F(TaskSubmitterTest, HandleInvokeNotify) taskSubmitter->HandleInvokeNotify(req, ErrorInfo()); auto resource = GetRequestResource(spec); absl::ReaderMutexLock lock(&taskSubmitter->reqMtx_); - ASSERT_TRUE(taskSubmitter->waitScheduleReqMap_[resource]->Size() <= 1); + ASSERT_TRUE(taskSubmitter->taskSchedulerMap_[resource]->queue->Size() <= 1); } TEST_F(TaskSubmitterTest, HandleFailInvokeNotify) @@ -160,7 +168,7 @@ TEST_F(TaskSubmitterTest, HandleFailInvokeNotify) spec->invokeInstanceId = "insId"; spec->invokeLeaseId = "leaseId"; spec->functionMeta = { - "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Function}; + "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Faas}; spec->invokeType = libruntime::InvokeType::InvokeFunctionStateless; auto resource = GetRequestResource(spec); taskSubmitter->HandleFailInvokeNotify(req, spec, resource, ErrorInfo()); @@ -170,10 +178,7 @@ TEST_F(TaskSubmitterTest, HandleFailInvokeNotify) req.set_code(common::ErrorCode::ERR_INSTANCE_EVICTED); { absl::WriterMutexLock lock(&taskSubmitter->reqMtx_); - taskSubmitter->waitScheduleReqMap_[resource] = std::make_shared(); - auto cb = []() {}; - auto taskScheduler = std::make_shared(cb); - taskSubmitter->taskSchedulerMap_[resource] = taskScheduler; + taskSubmitter->taskSchedulerMap_[resource] = std::make_shared(false); } taskSubmitter->HandleFailInvokeNotify(req, spec, resource, ErrorInfo()); auto [rawRequestId, seq] = YR::utility::IDGenerator::DecodeRawRequestId(spec->requestInvoke->Mutable().requestid()); @@ -303,7 +308,7 @@ TEST_F(TaskSubmitterTest, CancelStatelessRequest) auto spec = std::make_shared(); spec->jobId = "jobId"; spec->requestId = requestId; - spec->functionMeta = {.apiType = libruntime::ApiType::Function}; + spec->functionMeta = {.apiType = libruntime::ApiType::Faas}; taskSubmitter->requestManager->PushRequest(spec); taskSubmitter->CancelStatelessRequest(objids, f, true, true); @@ -358,6 +363,7 @@ std::list> GetMockAffinity() resourceRequiredAnti->SetLabelOperators(GetMockLabelOperators()); std::shared_ptr instanceRequiredAnti = std::make_shared(); instanceRequiredAnti->SetLabelOperators(GetMockLabelOperators()); + instanceRequiredAnti->SetAffinityScope(AFFINITYSCOPE_NODE); affinities.push_back(resourcePreferred); affinities.push_back(instancePreferred); affinities.push_back(resourcePreferredAnti); @@ -387,6 +393,8 @@ TEST_F(TaskSubmitterTest, TestAffinity) ASSERT_EQ(inOperator->GetOperatorType(), "LabelIn"); ASSERT_EQ(inOperator->GetKey(), "k1"); ASSERT_EQ(inOperator->GetValues().size(), 2); + std::shared_ptr last = affinities.back(); + ASSERT_EQ(last->GetAffinityScope(), "NODE"); opts.scheduleAffinities = affinities; spec->opts = opts; spec->returnIds = std::vector({DataObject{"obj-id"}}); @@ -394,7 +402,7 @@ TEST_F(TaskSubmitterTest, TestAffinity) auto resource = GetRequestResource(spec); taskSubmitter->SubmitFunction(spec); absl::ReaderMutexLock lock(&taskSubmitter->reqMtx_); - ASSERT_EQ(taskSubmitter->waitScheduleReqMap_[resource]->Empty(), false); + ASSERT_EQ(taskSubmitter->taskSchedulerMap_[resource]->queue->Empty(), false); sleep(3); } @@ -433,18 +441,15 @@ TEST_F(TaskSubmitterTest, ScheduleInsTest) auto resource = GetRequestResource(spec); { absl::WriterMutexLock lock(&taskSubmitter->reqMtx_); - taskSubmitter->waitScheduleReqMap_[resource] = std::make_shared(); - taskSubmitter->waitScheduleReqMap_[resource]->Push(spec); - auto cb = []() {}; - auto taskScheduler = std::make_shared(cb); - taskSubmitter->taskSchedulerMap_[resource] = taskScheduler; + taskSubmitter->taskSchedulerMap_[resource] = std::make_shared(false); + taskSubmitter->taskSchedulerMap_[resource]->queue->Push(spec); } taskSubmitter->ScheduleIns(resource, err, false); absl::ReaderMutexLock lock(&taskSubmitter->reqMtx_); - EXPECT_EQ(taskSubmitter->waitScheduleReqMap_[resource]->Empty(), true); + EXPECT_EQ(taskSubmitter->taskSchedulerMap_[resource]->queue->Empty(), true); } -std::shared_ptr TaskSubmitterTest::SetMaxConcurrencyInstanceNum(int concurrencyCreateNum) +std::shared_ptr TaskSubmitterTest::SetMaxConcurrencyInstanceNum(int concurrencyCreateNum) { // construct taskSubmitter KillFunc f = [](const std::string &instanceId, const std::string &payload, int signal) -> ErrorInfo { @@ -453,28 +458,36 @@ std::shared_ptr TaskSubmitterTest::SetMaxConcurrencyInstanceNu auto reqMgr = std::make_shared(); auto librtCfg = std::make_shared(); librtCfg->maxConcurrencyCreateNum = concurrencyCreateNum; - auto mockFsIntf = std::make_shared(); + auto mockFsIntf = std::make_shared(); auto fsClient = std::make_shared(mockFsIntf); std::shared_ptr memoryStore = std::make_shared(); auto dsObjectStore = std::make_shared(); dsObjectStore->Init("127.0.0.1", 8080); auto wom = std::make_shared(); memoryStore->Init(dsObjectStore, wom); - taskSubmitter = std::make_shared(librtCfg, memoryStore, fsClient, reqMgr, f); + std::string functionId = ""; + auto clientsMgr = std::make_shared(); + auto security = std::make_shared(); + auto downgrade = std::make_shared(functionId, clientsMgr, security); + downgrade->Init(); + taskSubmitter = + std::make_shared(librtCfg, memoryStore, fsClient, reqMgr, f, nullptr, nullptr, downgrade); return mockFsIntf; } void TaskSubmitterTest::SubmitFunction(int total, bool differentResource) { - if (differentResource) { - total = total / 2; - } for (int i = 0; i < total; i++) { auto spec = std::make_shared(); spec->jobId = "job-7c8e6fab"; spec->functionMeta = { "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Function}; - spec->opts = {}; + InvokeOptions options = {}; + if (differentResource) { + options.instanceSession = std::make_shared(); + options.instanceSession->sessionID = YR::utility::IDGenerator::GenRequestId(); + } + spec->opts = options; spec->returnIds = std::vector({DataObject{"obj-id"}}); spec->invokeArgs = std::vector(); spec->requestId = YR::utility::IDGenerator::GenRequestId(); @@ -483,32 +496,13 @@ void TaskSubmitterTest::SubmitFunction(int total, bool differentResource) taskSubmitter->requestManager->PushRequest(spec); taskSubmitter->SubmitFunction(spec); } - if (!differentResource) { - return; - } - for (int i = 0; i < total; i++) { - auto spec = std::make_shared(); - spec->jobId = "job-7c8e6fab"; - spec->functionMeta = { - "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Function}; - spec->opts = {}; - spec->opts.cpu = 1000; - spec->opts.memory = 2000; - spec->opts.customExtensions["Concurrency"] = "3"; - spec->returnIds = std::vector({DataObject{"obj-id"}}); - spec->invokeArgs = std::vector(); - spec->requestId = YR::utility::IDGenerator::GenRequestId(); - LibruntimeConfig config; - spec->BuildInstanceInvokeRequest(config); - taskSubmitter->SubmitFunction(spec); - } } void TaskSubmitterTest::CommonAssert(std::shared_ptr timerWorker, std::mutex &timersMtx, - std::shared_ptr mockFsIntf, - std::vector> &timers, bool differentResource) + std::shared_ptr mockFsIntf, + std::vector> &timers, bool differentResource, + int total) { - int total = 10000; std::mutex mtx; int get = 0; auto promise = std::promise(); @@ -532,7 +526,7 @@ void TaskSubmitterTest::CommonAssert(std::shared_ptr timerWorker, s timers.push_back(timer); }); EXPECT_CALL(*mockFsIntf, KillAsync(_, _, _)).WillRepeatedly(Return()); - SubmitFunction(total); + SubmitFunction(total, differentResource); ASSERT_TRUE(future.get()); auto instanceIds = taskSubmitter->GetInstanceIds(); std::cout << " create " << instanceIds.size() << std::endl; @@ -576,12 +570,14 @@ TEST_F(TaskSubmitterTest, ScheduleFunction_Benchmark) EXPECT_CALL(*mockFsIntf, InvokeAsync(_, _, _)) .WillRepeatedly( [](const std::shared_ptr &req, InvokeCallBack callback, int timeoutSec) { return; }); + EXPECT_CALL(*mockFsIntf, KillAsync(_, _, _)).WillRepeatedly(Return()); // benchmark auto start = std::chrono::high_resolution_clock::now(); size_t total = 50000; if (const char *env = std::getenv("YR_BENCHMARK_SCALE")) { total = std::stoi(env); } + SubmitFunction(total); auto submitEnd = std::chrono::high_resolution_clock::now(); std::chrono::duration submitCostMs = submitEnd - start; @@ -669,7 +665,7 @@ TEST_F(TaskSubmitterTest, ScheduleFunction_CreateRandomAbnormal) CommonAssert(timerWorker, timersMtx, mockFsIntf, timers); } -TEST_F(TaskSubmitterTest, ScheduleFunction_DifferentResource) +TEST_F(TaskSubmitterTest, ScheduleFunction_DifferentSessionResource) { auto mockFsIntf = SetMaxConcurrencyInstanceNum(10000); auto timerWorker = std::make_shared(); @@ -694,7 +690,151 @@ TEST_F(TaskSubmitterTest, ScheduleFunction_DifferentResource) std::lock_guard lockGuard(timersMtx); timers.push_back(timer); }); - CommonAssert(timerWorker, timersMtx, mockFsIntf, timers, true); + CommonAssert(timerWorker, timersMtx, mockFsIntf, timers, true, 3000); +} + +TEST_F(TaskSubmitterTest, CancelFaasScheduleTimeoutReqTest) +{ + ASSERT_NO_THROW(taskSubmitter->CancelFaasScheduleTimeoutReq("reqId", 100 * 1000)); + + auto spec = std::make_shared(); + spec->requestId = "reqId"; + spec->returnIds = {DataObject("objId")}; + auto resource = GetRequestResource(spec); + taskSubmitter->taskSchedulerMap_[resource] = std::make_shared(false); + taskSubmitter->taskSchedulerMap_[resource]->queue->Push(spec); + taskSubmitter->requestManager->PushRequest(spec); + taskSubmitter->CancelFaasScheduleTimeoutReq("reqId", 100 * 1000); + ASSERT_EQ(taskSubmitter->requestManager->GetRequest("reqId"), nullptr); + + taskSubmitter->requestManager->PushRequest(spec); + resource = GetRequestResource(spec); + taskSubmitter->taskSchedulerMap_[resource] = std::make_shared(false); + taskSubmitter->taskSchedulerMap_[resource]->SetLastError( + ErrorInfo(ErrorCode::ERR_PARAM_INVALID, ModuleCode::RUNTIME, "error")); + taskSubmitter->CancelFaasScheduleTimeoutReq("reqId", 100 * 1000); + ASSERT_EQ(taskSubmitter->requestManager->GetRequest("reqId"), nullptr); +} + +TEST_F(TaskSubmitterTest, EraseFaasCancelTimerTest) +{ + taskSubmitter->faasCancelTimerWorkers["reqId"] = YR::utility::ExecuteByGlobalTimer([]() -> void {}, 1, 1); + ASSERT_EQ(taskSubmitter->faasCancelTimerWorkers.size(), 1); + taskSubmitter->EraseFaasCancelTimer("reqId"); + ASSERT_EQ(taskSubmitter->faasCancelTimerWorkers.size(), 0); +} + +TEST_F(TaskSubmitterTest, UpdateSchdulerInfoTest) +{ + ASSERT_NO_THROW(taskSubmitter->UpdateSchdulerInfo("schedulerName", "schedulerId", "ADD")); + ASSERT_NO_THROW(taskSubmitter->UpdateSchdulerInfo("schedulerName", "schedulerId", "REMOVE")); + ASSERT_NO_THROW(taskSubmitter->UpdateSchdulerInfo("schedulerName", "schedulerId", "UNKOWN")); +} + +TEST_F(TaskSubmitterTest, RecordFaasInvokeDataAndUpdateTest) +{ + taskSubmitter->metricsEnable_ = true; + ASSERT_NO_THROW(taskSubmitter->UpdateFaasInvokeLog("reqId", ErrorInfo())); + auto spec = std::make_shared(); + spec->requestId = "reqId"; + spec->functionMeta = YR::Libruntime::FunctionMeta{.funcName = "funcName"}; + taskSubmitter->RecordFaasInvokeData(spec); + ASSERT_TRUE(!taskSubmitter->faasInvokeDataMap_.empty()); + ASSERT_NO_THROW(taskSubmitter->UpdateFaasInvokeLog("reqId", ErrorInfo())); + taskSubmitter->faasInvokeDataMap_.erase("reqId"); + ASSERT_NO_THROW(taskSubmitter->UpdateFaasInvokeLog( + "reqId", ErrorInfo(ErrorCode::ERR_PARAM_INVALID, ModuleCode::RUNTIME, "err msg"))); +} + +TEST_F(TaskSubmitterTest, InvokeExceedMaxRetryTimeShouldNotStuck) +{ + auto mockFsIntf = SetMaxConcurrencyInstanceNum(100000); + // mock function + EXPECT_CALL(*mockFsIntf, CreateAsync(_, _, _, _)) + .WillRepeatedly( + [](const CreateRequest &req, CreateRespCallback respCallback, CreateCallBack callback, int timeoutSec) { + CreateResponse response; + auto instanceId = YR::utility::IDGenerator::GenRequestId(); + response.set_instanceid(instanceId); + response.set_code(::common::ErrorCode::ERR_NONE); + respCallback(response); + NotifyRequest notifyReq; + notifyReq.set_requestid(req.requestid()); + notifyReq.set_code(::common::ErrorCode::ERR_NONE); + callback(notifyReq); + }); + std::mutex mtx; + int get = 0; + int retry = 5; + int want = retry + 1; + auto promise = std::promise(); + auto future = promise.get_future(); + EXPECT_CALL(*mockFsIntf, InvokeAsync(_, _, _)) + .WillRepeatedly([&](const std::shared_ptr &req, InvokeCallBack callback, int timeoutSec) { + NotifyRequest notifyReq; + notifyReq.set_requestid(req->Immutable().requestid()); + notifyReq.set_code(::common::ErrorCode::ERR_INSTANCE_EVICTED); + callback(notifyReq, ErrorInfo()); + std::lock_guard lockGuard(mtx); + get++; + if (get == want) { + promise.set_value(true); + } + }); + EXPECT_CALL(*mockFsIntf, KillAsync(_, _, _)).WillRepeatedly(Return()); + auto spec = std::make_shared(); + spec->jobId = "job-7c8e6fab"; + spec->functionMeta = { + "", "", "funcname", "classname", libruntime::LanguageType::Cpp, "", "", "", libruntime::ApiType::Function}; + InvokeOptions options = {}; + options.maxRetryTime = retry; + spec->opts = options; + spec->invokeType = libruntime::InvokeType::InvokeFunctionStateless; + spec->returnIds = std::vector({DataObject{"obj-id"}}); + spec->invokeArgs = std::vector(); + spec->requestId = YR::utility::IDGenerator::GenRequestId(); + LibruntimeConfig config; + spec->BuildInstanceInvokeRequest(config); + taskSubmitter->requestManager->PushRequest(spec); + taskSubmitter->SubmitFunction(spec); + ASSERT_TRUE(future.get()); + ASSERT_EQ(get, want); +} + +TEST_F(TaskSubmitterTest, ConvertToGaugeDataTest) +{ + auto data = std::make_shared( + YR::Libruntime::FaasInvokeData("tenantId", "functionName", "inputAlias", "traceId", 0)); + auto res = taskSubmitter->ConvertToGaugeData(data, "reqId"); + ASSERT_EQ(res.name, "report_faas_invoke_data"); +} + +TEST_F(TaskSubmitterTest, AddFaasCancelTimerTest) +{ + auto spec = std::make_shared(); + spec->opts.acquireTimeout = 300; + spec->functionMeta.apiType = libruntime::ApiType::Faas; + ASSERT_NO_THROW(taskSubmitter->AddFaasCancelTimer(spec)); +} + +TEST_F(TaskSubmitterTest, SendInvokeReqTest) +{ + auto spec = std::make_shared(); + spec->opts.device = YR::Libruntime::Device{.name = "deviceName", .batch_size = 1}; + spec->functionMeta.apiType = libruntime::ApiType::Faas; + auto resource = GetRequestResource(spec); + ASSERT_NO_THROW(taskSubmitter->SendInvokeReq(resource, spec)); +} + +TEST_F(TaskSubmitterTest, UpdateFaasInvokeSendTimeTest) +{ + taskSubmitter->metricsEnable_ = true; + auto currentTime = GetCurrentTimestampNs(); + auto data = std::make_shared(); + data->sendTime = currentTime; + taskSubmitter->faasInvokeDataMap_["reqId"] = data; + taskSubmitter->UpdateFaasInvokeSendTime("reqId"); + ASSERT_TRUE(taskSubmitter->faasInvokeDataMap_["reqId"]->sendTime > currentTime); } } // namespace test } // namespace YR diff --git a/test/libruntime/trace_adapter_test.cpp b/test/libruntime/trace_adapter_test.cpp new file mode 100644 index 0000000..41852b6 --- /dev/null +++ b/test/libruntime/trace_adapter_test.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "src/utility/logger/fileutils.h" +#include "src/utility/logger/logger.h" +#define private public + +#include "src/libruntime/traceadaptor/trace_adapter.h" + +#include "src/libruntime/traceadaptor/exporter/log_file_exporter_factory.h" + +#include + +using namespace testing; +using namespace YR::utility; +using namespace YR::Libruntime; +namespace trace_api = opentelemetry::trace; +namespace trace_sdk = opentelemetry::sdk::trace; +namespace nostd = opentelemetry::nostd; +namespace common_sdk = opentelemetry::sdk::common; + +namespace YR { +namespace test { + +class TraceAdapterTest : public ::testing::Test { +protected: + static void SetUpTestCase(){} + + static void TearDownTestCase(){} + + void SetUp(){} + + void TearDown(){} + +}; + +TEST_F(TraceAdapterTest, InitTrace) +{ + const std::string configStr = "{\"otlpGrpcExporter\":{\"enable\":true,\"endpoint\":\"127.0.0.1:4317\"},\"logFileExporter\":{\"enable\":true}}"; + const std::string traceServiceName = "testService"; + // enable: false + TraceAdapter::GetInstance().InitTrace(traceServiceName, false, configStr); + ASSERT_FALSE(TraceAdapter::GetInstance().enableTrace_); + // empty config + TraceAdapter::GetInstance().InitTrace(traceServiceName, true, ""); + ASSERT_FALSE(TraceAdapter::GetInstance().enableTrace_); + // invalid json string + const std::string invalidConfigStr1 = "\"otlpGrpcExporter\":{\"enable\":false,\"endpoint\":\"\"},\"logFileExporter\":{\"enable\":false}}"; + TraceAdapter::GetInstance().InitTrace(traceServiceName, true, invalidConfigStr1); + ASSERT_FALSE(TraceAdapter::GetInstance().enableTrace_); + // invalid exporter config + const std::string invalidConfigStr2 = "{\"otlpGrpcExporter\":{\"enable\":true,\"endpoint\":\"\"},\"logFileExporter\":{\"enable\":false}}"; + TraceAdapter::GetInstance().InitTrace(traceServiceName, true, invalidConfigStr2); + ASSERT_FALSE(TraceAdapter::GetInstance().enableTrace_); + // valid exporter config + TraceAdapter::GetInstance().InitTrace(traceServiceName, true, configStr); + ASSERT_TRUE(TraceAdapter::GetInstance().enableTrace_); + + // set attribute + TraceAdapter::GetInstance().SetAttr("component", "proxy"); + ASSERT_TRUE(TraceAdapter::GetInstance().attribute_.find("component") != TraceAdapter::GetInstance().attribute_.end()); + ASSERT_EQ(TraceAdapter::GetInstance().attribute_.find("component")->second, "proxy"); +} + +TEST_F(TraceAdapterTest, StartSpan) +{ + const std::string configStr = "{\"otlpGrpcExporter\":{\"enable\":true,\"endpoint\":\"127.0.0.1:4317\"}}"; + const std::string traceServiceName = "testService"; + TraceAdapter::GetInstance().InitTrace(traceServiceName, false, configStr); + EXPECT_FALSE(TraceAdapter::GetInstance().enableTrace_); + auto disableSpan = TraceAdapter::GetInstance().StartSpan("span"); + EXPECT_FALSE(disableSpan->GetContext().trace_id().IsValid()); + + TraceAdapter::GetInstance().InitTrace(traceServiceName, true, configStr); + EXPECT_TRUE(TraceAdapter::GetInstance().enableTrace_); + + auto span1 = TraceAdapter::GetInstance().StartSpan("span1"); + EXPECT_TRUE(span1->GetContext().trace_id().IsValid()); + + auto span2 = TraceAdapter::GetInstance().StartSpan("span2",{{"attr1",123},{"attr2", "value2"}}); + EXPECT_TRUE(span2->GetContext().trace_id().IsValid()); + +} + +TEST_F(TraceAdapterTest, TestLogFileExporter) +{ + auto logFileExporter = std::move(LogFileExporterFactory::Create()); + auto record = logFileExporter->MakeRecordable(); + static_cast(record.get())->SetAttribute("requestID", "abc"); + static_cast(record.get())->SetName("Create"); + ASSERT_EQ(logFileExporter->Export(nostd::span>(&record, 1)), common_sdk::ExportResult::kSuccess); + + EXPECT_TRUE(logFileExporter->Shutdown()); + ASSERT_EQ(logFileExporter->Export(nostd::span>(&record, 1)), common_sdk::ExportResult::kFailure); +} + +}} \ No newline at end of file diff --git a/test/libruntime/utils_test.cpp b/test/libruntime/utils_test.cpp index 788eff9..f1d3bc6 100644 --- a/test/libruntime/utils_test.cpp +++ b/test/libruntime/utils_test.cpp @@ -140,14 +140,50 @@ TEST_F(UtilsTest, InitWithDriverTest) librtConfig->verifyFilePath = "test"; librtConfig->certificateFilePath = "test"; librtConfig->privateKeyPath = "test"; + std::strcpy(librtConfig->privateKeyPaaswd, "paaswd"); librtConfig->serverName = "test"; librtConfig->encryptEnable = "test"; librtConfig->encryptEnable = "test"; librtConfig->runtimePublicKey = "test"; librtConfig->runtimePrivateKey = "test"; + librtConfig->ak_ = "ak"; + librtConfig->sk_ = "sk"; auto err = security->InitWithDriver(librtConfig); ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); } +TEST_F(UtilsTest, GetAuthConnectOptsTest) +{ + ConnectOptions connnections; + std::string ak; + datasystem::SensitiveValue sk; + datasystem::SensitiveValue token = "token"; + GetAuthConnectOpts(connnections, ak, sk, token); + ASSERT_EQ(connnections.accessKey, ak); +} + +TEST_F(UtilsTest, unhexlifyTest) +{ + char ascii[2]; + auto res = unhexlify("00", ascii); + ASSERT_EQ(res, 0); + ASSERT_EQ(ascii[0], '\0'); + res = unhexlify("FF", ascii); + ASSERT_EQ(res, 0); + ASSERT_EQ(ascii[0], '\xFF'); + res = unhexlify("1a", ascii); + ASSERT_EQ(res, 0); + ASSERT_EQ(ascii[0], '\x1a'); + res = unhexlify("AB", ascii); + ASSERT_EQ(res, 0); + ASSERT_EQ(ascii[0], '\xAB'); + res = unhexlify("G1", ascii); + ASSERT_EQ(res, -1); + res = unhexlify("1G", ascii); + ASSERT_EQ(res, -1); + res = unhexlify("1", ascii); + ASSERT_EQ(res, -1); +} + } // namespace test } // namespace YR \ No newline at end of file diff --git a/test/scene/downgrade_test.cpp b/test/scene/downgrade_test.cpp new file mode 100644 index 0000000..69fce83 --- /dev/null +++ b/test/scene/downgrade_test.cpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#define private public +#include "src/scene/downgrade.h" +namespace YR { +namespace test { +using namespace YR::utility; +using namespace YR::scene; +using namespace YR::Libruntime; +using namespace testing; +class MockClientManager : public ClientManager { +public: + explicit MockClientManager(const std::shared_ptr &librtCfg) : ClientManager(librtCfg) {} + MOCK_METHOD1(Init, ErrorInfo(const ConnectionParam ¶m)); + MOCK_METHOD6(SubmitInvokeRequest, + void(const http::verb &method, const std::string &target, + const std::unordered_map &headers, const std::string &body, + const std::shared_ptr requestId, const HttpCallbackFunction &receiver)); +}; + +class DowngradeControllerTest : public ::testing::Test { +public: + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + std::string address = "127.0.0.1:11111"; + std::string urlPrefix = "/serverless/v2/functions/wisefunction:cn"; + clientsMgr = std::make_shared(); + auto config = std::make_shared(); + clientMgr = std::make_shared(config); + clientsMgr->httpClients[address] = clientMgr; + + security = std::make_shared(); + std::string accessKey = "access_key"; + SensitiveValue secretKey = std::string("secret_key"); + security->SetAKSKAndCredential(accessKey, secretKey); + setenv(FRONTEND_ADDRESS_ENV.c_str(), address.c_str(), 1); + setenv(INVOCATION_URL_PREFIX_ENV.c_str(), urlPrefix.c_str(), 1); + spec = std::make_shared(); + spec->traceId = "traceId"; + spec->requestId = "requestId"; + spec->functionMeta.functionId = "tenantId/hello/version"; + std::string arg = "{\"key\":\"invoke\"}"; + InvokeArg libArg; + libArg.dataObj = std::make_shared(0, arg.size()); + libArg.dataObj->data->MemoryCopy(arg.data(), arg.size()); + spec->invokeArgs.emplace_back(std::move(libArg)); + } + + void TearDown() override {} + std::shared_ptr clientMgr; + std::shared_ptr clientsMgr; + std::shared_ptr security; + std::shared_ptr spec; +}; + +TEST_F(DowngradeControllerTest, ApiClient) +{ + EXPECT_CALL(*this->clientMgr, SubmitInvokeRequest(_, _, _, _, _, _)).WillOnce(testing::Return()); + std::string functionId = "tenantId/hello/version"; + auto apiClient = std::make_shared(functionId, this->clientsMgr, this->security); + ASSERT_TRUE(apiClient->Init().OK()); + ASSERT_NO_THROW(apiClient->InvocationAsync(this->spec, nullptr)); +} + +TEST_F(DowngradeControllerTest, DowngradeControllerTest) +{ + std::string functionId = "tenantId/hello/version"; + EXPECT_CALL(*this->clientMgr, SubmitInvokeRequest(_, _, _, _, _, _)).WillOnce(testing::Return()); + auto controller = std::make_shared(functionId, this->clientsMgr, this->security); + ASSERT_TRUE(controller->Init().OK()); + ASSERT_FALSE(controller->ShouldDowngrade(this->spec)); + controller->ParseDowngrade("/tmp/log/a.json"); // test not exit + controller->Downgrade(this->spec, nullptr); + controller->Stop(); +} + +TEST_F(DowngradeControllerTest, DowngradeControllerInitTest) +{ + setenv(DOWNGRADE_FILE_ENV.c_str(), "/tmp/log/downgrade.json", 1); + std::string functionId = "tenantId/hello/version"; + auto controller = std::make_shared(functionId, this->clientsMgr, this->security); + ASSERT_TRUE(controller->Init().OK()); + unsetenv(DOWNGRADE_FILE_ENV.c_str()); + controller->Stop(); +} +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/st/cpp/src/base/actor_test.cpp b/test/st/cpp/src/base/actor_test.cpp index 0ff7f7e..fab9a88 100644 --- a/test/st/cpp/src/base/actor_test.cpp +++ b/test/st/cpp/src/base/actor_test.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include #include "base/utils.h" @@ -347,7 +359,8 @@ TEST_F(ActorTest, InvalidResource) } catch (YR::Exception &e) { printf("Exception:%s,\n", e.what()); std::string errorCode = "ErrCode: 1006"; - std::string errorMsg = "Required CPU resource size 299 millicores is invalid. Valid value range is [300,16000] millicores"; + std::string errorMsg = + "Required CPU resource size 299 millicores is invalid. Valid value range is [0,16000] millicores"; std::string excepMsg = e.what(); ErrorMsgCheck(errorCode, errorMsg, excepMsg); } @@ -361,7 +374,8 @@ TEST_F(ActorTest, InvalidResource) } catch (YR::Exception &e) { printf("Exception:%s,\n", e.what()); std::string errorCode = "ErrCode: 1006"; - std::string errorMsg = "Required CPU resource size 16001 millicores is invalid. Valid value range is [300,16000] millicores"; + std::string errorMsg = + "Required CPU resource size 16001 millicores is invalid. Valid value range is [0,16000] millicores"; std::string excepMsg = e.what(); ErrorMsgCheck(errorCode, errorMsg, excepMsg); } diff --git a/test/st/cpp/src/base/always_local_mode.cpp b/test/st/cpp/src/base/always_local_mode.cpp index 95b1b58..ccebfe4 100644 --- a/test/st/cpp/src/base/always_local_mode.cpp +++ b/test/st/cpp/src/base/always_local_mode.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include #include "base/utils.h" diff --git a/test/st/cpp/src/base/ds_test.cpp b/test/st/cpp/src/base/ds_test.cpp index cfa8a87..b09b385 100644 --- a/test/st/cpp/src/base/ds_test.cpp +++ b/test/st/cpp/src/base/ds_test.cpp @@ -1,6 +1,18 @@ /* -* Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -*/ + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include diff --git a/test/st/cpp/src/base/init_test.cpp b/test/st/cpp/src/base/init_test.cpp index 2443028..f31a511 100644 --- a/test/st/cpp/src/base/init_test.cpp +++ b/test/st/cpp/src/base/init_test.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "gtest/gtest.h" diff --git a/test/st/cpp/src/base/task_test.cpp b/test/st/cpp/src/base/task_test.cpp index 7f53332..89247e8 100644 --- a/test/st/cpp/src/base/task_test.cpp +++ b/test/st/cpp/src/base/task_test.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include #include diff --git a/test/st/cpp/src/base/utils.h b/test/st/cpp/src/base/utils.h index 4b6eac0..34dceb2 100644 --- a/test/st/cpp/src/base/utils.h +++ b/test/st/cpp/src/base/utils.h @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #ifndef TEST_ST_CASES_SRC_BASE_UTILS_H diff --git a/test/st/cpp/src/main.cpp b/test/st/cpp/src/main.cpp index 3f3e72b..0c8edc8 100644 --- a/test/st/cpp/src/main.cpp +++ b/test/st/cpp/src/main.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "gtest/gtest.h" diff --git a/test/st/cpp/src/user_common_func.cpp b/test/st/cpp/src/user_common_func.cpp index 2d534b2..209526a 100644 --- a/test/st/cpp/src/user_common_func.cpp +++ b/test/st/cpp/src/user_common_func.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "user_common_func.h" diff --git a/test/st/cpp/src/user_common_func.h b/test/st/cpp/src/user_common_func.h index 12140b4..1cb04d1 100644 --- a/test/st/cpp/src/user_common_func.h +++ b/test/st/cpp/src/user_common_func.h @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/test/st/cpp/src/utils.cpp b/test/st/cpp/src/utils.cpp index efc2492..9b71860 100644 --- a/test/st/cpp/src/utils.cpp +++ b/test/st/cpp/src/utils.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "utils.h" diff --git a/test/st/cpp/src/utils.h b/test/st/cpp/src/utils.h index dec2dd4..e1ff152 100644 --- a/test/st/cpp/src/utils.h +++ b/test/st/cpp/src/utils.h @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/test/st/others/rpc_retry_test/src/main.cpp b/test/st/others/rpc_retry_test/src/main.cpp index 5f0069a..12ecd83 100644 --- a/test/st/others/rpc_retry_test/src/main.cpp +++ b/test/st/others/rpc_retry_test/src/main.cpp @@ -1,5 +1,17 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "yr/yr.h" diff --git a/test/st/python/test_yr_api.py b/test/st/python/test_yr_api.py index 8d70222..ac2402e 100644 --- a/test/st/python/test_yr_api.py +++ b/test/st/python/test_yr_api.py @@ -810,7 +810,8 @@ def test_get_async_instance_in_actor_proxy(init_yr_config): assert ret1 == ret2 -@pytest.mark.smoke +@pytest.mark.skip( + reason="This use case will now fail; let's temporarily disable it and resolve it later through an issue ID64WF.") def test_get_order_instance_after_teminate(init_yr_config): conf = init_yr_config yr.init(conf) diff --git a/test/test_goruntime_start.sh b/test/test_goruntime_start.sh new file mode 100644 index 0000000..c4ea577 --- /dev/null +++ b/test/test_goruntime_start.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + +BASE_DIR=$( + cd "$(dirname "$0")" + pwd +) + +SO_NAME="faasmanager" + +[ -n "$1" ] && SO_NAME=$1 + +export DATASYSTEM_ADDR=127.0.0.1:31501 +export POSIX_LISTEN_ADDR=127.0.0.1:55555 +export YR_FUNCTION_LIB_PATH=$(readlink -f $BASE_DIR/../../functionsystem/build/_output/$SO_NAME) +export INIT_HANDLER=$SO_NAME.InitHandler +export CALL_HANDLER=$SO_NAME.CallHandler +export CHECKPOINT_HANDLER=$SO_NAME.CheckpointHandler +export RECOVER_HANDLER=$SO_NAME.RecoverHandler +export SHUTDOWN_HANDLER=$SO_NAME.ShutdownHandler +export SIGNAL_HANDLER=$SO_NAME.SignalHandler + +export LD_LIBRARY_PATH=$BASE_DIR/../build/output/runtime/service/go/bin/ +$BASE_DIR/../build/output/runtime/service/go/bin/goruntime \ + -runtimeId=12345678 \ + -instanceId=23456789 \ + -logLevel=DEBUG \ + -grpcAddress=127.0.0.1:55555 + diff --git a/test/utility/file_watcher_test.cpp b/test/utility/file_watcher_test.cpp new file mode 100644 index 0000000..47675f2 --- /dev/null +++ b/test/utility/file_watcher_test.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/utility/file_watcher.h" +#include +#include +#include +#include "json.hpp" +namespace YR { +namespace test { +using namespace YR::utility; +class FileWatcherTest : public ::testing::Test { + void SetUp() override + { + Mkdir("/tmp/log"); + LogParam g_logParam = { + .logLevel = "DEBUG", + .logDir = "/tmp/log", + .nodeName = "test-runtime", + .modelName = "test", + .maxSize = 100, + .maxFiles = 1, + .logFileWithTime = false, + .logBufSecs = 30, + .maxAsyncQueueSize = 1048510, + .asyncThreadCount = 1, + .alsoLog2Stderr = true, + }; + InitLog(g_logParam); + } + + void TearDown() override {} +}; + +bool CreateFile(const std::string &path) +{ + std::ofstream ofs(path, std::ios::out); + if (!ofs.is_open()) { + std::cerr << "Failed to create file: " << path << std::endl; + return false; + } + ofs.close(); + return true; +} + +bool DeleteFile(const std::string &path) +{ + std::error_code ec; // 避免抛异常 + bool removed = std::filesystem::remove(path, ec); + if (!removed || ec) { + std::cerr << "Failed to delete file: " << path << " , error: " << ec.message() << std::endl; + return false; + } + return true; +} + +bool WriteStringToFile(const std::string &path, const std::string &content) +{ + std::ofstream ofs(path, std::ios::out); + if (!ofs.is_open()) { + std::cerr << "Failed to open file: " << path << std::endl; + return false; + } + ofs << content; + ofs.close(); + return true; +} + +TEST_F(FileWatcherTest, FileWatcher) +{ + bool isDowngradeEnabled = false; + std::string fileName = "/tmp/log/a.json"; + FileWatcher f(fileName, [&isDowngradeEnabled](const std::string &fileName) { + std::ifstream file(fileName); + nlohmann::json j; + try { + file >> j; + isDowngradeEnabled = j.value("downgrade", false); + } catch (const nlohmann::json::parse_error &e) { + std::stringstream buffer; + buffer << file.rdbuf(); + YRLOG_WARN("{} json parse error: {}", buffer.str(), e.what()); + isDowngradeEnabled = false; + } + }); + f.Start(); + CreateFile(fileName); + WriteStringToFile(fileName, "{\"downgrade\": true}"); + while (true) { + if (isDowngradeEnabled) { + break; + } + std::this_thread::yield(); + } + ASSERT_TRUE(isDowngradeEnabled); + WriteStringToFile(fileName, "{\"downgrade\": false}"); + while (true) { + if (!isDowngradeEnabled) { + break; + } + std::this_thread::yield(); + } + DeleteFile(fileName); + ASSERT_FALSE(isDowngradeEnabled); + f.Stop(); +} +} // namespace test +} // namespace YR \ No newline at end of file diff --git a/test/utility/logger/logger_test.cpp b/test/utility/logger/logger_test.cpp index d67f802..81b6ad7 100644 --- a/test/utility/logger/logger_test.cpp +++ b/test/utility/logger/logger_test.cpp @@ -108,15 +108,15 @@ TEST_F(Logger, GetLogLevel) { struct testArgs { std::string lvStr; - spdlog::level::level_enum spdLv; + yr_spdlog::level::level_enum spdLv; }; std::vector ta = { - {"DEBUG", spdlog::level::debug}, {"INFO", spdlog::level::info}, {"WARN", spdlog::level::warn}, - {"ERR", spdlog::level::err}, {"FATAL", spdlog::level::critical}, + {"DEBUG", yr_spdlog::level::debug}, {"INFO", yr_spdlog::level::info}, {"WARN", yr_spdlog::level::warn}, + {"ERR", yr_spdlog::level::err}, {"FATAL", yr_spdlog::level::critical}, }; for (auto &a : ta) { - spdlog::level::level_enum level = GetLogLevel(a.lvStr); + yr_spdlog::level::level_enum level = GetLogLevel(a.lvStr); EXPECT_EQ(level, a.spdLv); } } diff --git a/tools/download_dependency.sh b/tools/download_dependency.sh index 750865b..a42ffb6 100644 --- a/tools/download_dependency.sh +++ b/tools/download_dependency.sh @@ -127,6 +127,9 @@ function compile_all(){ chmod -R 700 "${THIRD_PARTY_DIR}/openssl/" ./config enable-ssl3 enable-ssl3-method --prefix="${THIRD_PARTY_DIR}/openssl/install" make -j build_libs install_dev + if [[ -d ${THIRD_PARTY_DIR}/openssl/install/lib64 && ! -d ${THIRD_PARTY_DIR}/openssl/install/lib ]];then + cp -fr ${THIRD_PARTY_DIR}/openssl/install/lib64 ${THIRD_PARTY_DIR}/openssl/install/lib + fi popd fi } diff --git a/tools/openSource.txt b/tools/openSource.txt index 79deccd..41c17b8 100644 --- a/tools/openSource.txt +++ b/tools/openSource.txt @@ -1,5 +1,5 @@ abseil-cpp,20240722.0,functionsystem,https://gitee.com/mirrors/abseil-cpp/repository/archive/20240722.0.zip,104dead3edd7b67ddeb70c37578245130d6118efad5dad4b618d7e26a5331f55, -boost,1.82.0,runtime,https://sourceforge.net/projects/boost/files/boost/1.82.0/boost_1_82_0.tar.gz,66a469b6e608a51f8347236f4912e27dc5c60c60d7d53ae9bfe4683316c6f04c, +boost,1.87.0,runtime,https://sourceforge.net/projects/boost/files/boost/1.87.0/boost_1_87_0.tar.gz,f55c340aa49763b1925ccf02b2e83f35fdcf634c9d5164a2acb87540173c741d, c-ares,1.19.1,functionsystem,https://gitee.com/mirrors/c-ares/repository/archive/cares-1_19_1.zip,edcaac184aff0e6b6eb7b9ede7a55f36c7fc04085d67fecff2434779155dd8ce, etcd,v3.5.11,functionsystem,https://gitee.com/mirrors/etcd/repository/archive/v3.5.11.zip,7af745a2bf75c7b04f9b168b85deb35fb73d30d60f751046564b54dbd483277a, gogo-protobuf,v1.3.2,functionsystem,https://gitee.com/mirrors_gogo/protobuf/repository/archive/v1.3.2.zip,11e863712289bf1e6ba86239b5b896704bb63815bee8710e5032daf63fd9efe7, diff --git a/yuanrong/build/build.sh b/yuanrong/build/build.sh new file mode 100644 index 0000000..6bf2269 --- /dev/null +++ b/yuanrong/build/build.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +BASE_DIR=$(cd "$(dirname "$0")"; pwd) +PROJECT_DIR=$(cd "$(dirname "$0")"/..; pwd) +OUTPUT_DIR="${PROJECT_DIR}/output" +RUNTIME_OUTPUT_DIR="${PROJECT_DIR}/../output" +POSIX_DIR="${PROJECT_DIR}/proto/posix" +BUILD_TAGS="" +FLAGS='-extldflags "-fPIC -fstack-protector-strong -Wl,-z,now,-z,relro,-z,noexecstack,-s -Wall -Werror"' + +# go module prepare +export GO111MODULE=on +export GONOSUMDB=* +export CGO_ENABLED=1 + +# resolve missing go.sum entry +go env -w "GOFLAGS"="-mod=mod" + +echo "generating fs proto pb objects" +protoc --proto_path=${POSIX_DIR} --go_out=${PROJECT_DIR}/../ --go-grpc_out=${PROJECT_DIR}/../ ${POSIX_DIR}/*.proto + +echo "start to compile dashboard -s ${SCC_BUILD_ENABLED}" +mkdir -p "${OUTPUT_DIR}/bin/" +rm -rf "${OUTPUT_DIR}/bin/dashboard" + +CC='gcc -fstack-protector-strong -D_FORTIFY_SOURCE=2 -O2' go build -tags="${BUILD_TAGS}" -buildmode=pie -ldflags "${FLAGS}" -o \ +"${OUTPUT_DIR}"/bin/dashboard "${PROJECT_DIR}"/cmd/dashboard/main.go + +mkdir -p "${OUTPUT_DIR}/config/" +rm -rf "${OUTPUT_DIR}/config/dashboard*" +cp -ar "${BASE_DIR}/dashboard/config/" "${OUTPUT_DIR}/" + +npm config set strict-ssl false +echo "start to compile dashboard client" +mkdir -p "${OUTPUT_DIR}/bin/" +rm -rf "${OUTPUT_DIR}/bin/client" +cd "${PROJECT_DIR}/pkg/dashboard/client" +npm install || die "dashboard client install failed" +npm run build || die "dashboard client build failed" +mkdir -p "${OUTPUT_DIR}/bin/client" +cp -ar ./dist "${OUTPUT_DIR}/bin/client/" +cd "${BASE_DIR}" + +echo "start to compile collector" +mkdir -p "${OUTPUT_DIR}/bin/" +rm -rf "${OUTPUT_DIR}/bin/collector" +echo LD_LIBRARY_PATH=$LD_LIBRARY_PATH +CC='gcc -fstack-protector-strong -D_FORTIFY_SOURCE=2 -O2' go build -tags="${BUILD_TAGS}" -buildmode=pie -ldflags "${FLAGS}" -o \ +"${OUTPUT_DIR}"/bin/collector "${PROJECT_DIR}"/cmd/collector/main.go + +cd "${OUTPUT_DIR}" +tar -czvf yr-dashboard-v0.0.1.tar.gz ./* +mkdir -p "${RUNTIME_OUTPUT_DIR}" +rm -rf "${RUNTIME_OUTPUT_DIR}/yr-dashboard-v0.0.1.tar.gz" +cp yr-dashboard-v0.0.1.tar.gz "${RUNTIME_OUTPUT_DIR}" +cd "${BASE_DIR}" \ No newline at end of file diff --git a/yuanrong/build/build_function.sh b/yuanrong/build/build_function.sh new file mode 100644 index 0000000..0238937 --- /dev/null +++ b/yuanrong/build/build_function.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +# ---------------------------------------------------------------------- +# funcname: log_info. +# description: Print build info log. +# parameters: NA +# return value: NA +# ---------------------------------------------------------------------- +log_info() +{ + echo "[BUILD_INFO][$(date -u +%b\ %d\ %H:%M:%S)]$*" +} + +# ---------------------------------------------------------------------- +# funcname: log_warning. +# description: Print build warning log. +# parameters: NA +# return value: NA +# ---------------------------------------------------------------------- +log_warning() +{ + echo "[BUILD_WARNING][$(date -u +%b\ %d\ %H:%M:%S)]$*" +} + +# ---------------------------------------------------------------------- +# funcname: log_error. +# description: Print build error log. +# parameters: NA +# return value: NA +# ---------------------------------------------------------------------- +log_error() +{ + echo "[BUILD_ERROR][$(date -u +%b\ %d\ %H:%M:%S)]$*" +} + +# ---------------------------------------------------------------------- +# funcname: die. +# description: Print build error log. +# parameters: NA +# return value: NA +# ---------------------------------------------------------------------- +die() +{ + log_error "$*" + stty echo + exit 1 +} + +# ---------------------------------------------------------------------- +# funcname: usage +# description: the build Instructions +# parameters: void +# return value: void +# ---------------------------------------------------------------------- +usage() +{ + echo -e "Usage: ./build.sh [-m subsystem_name] [-h help]" + echo -e "Usage: ./docker_build.sh [-m subsystem_name -v version] [-h help]" + echo -e "Options:" + echo -e " -m subsystem_name, such as nodemanager workermanager crontrigger frontend worker" + echo -e " functionrepo adminservice cli" + echo -e " functionstate functiontask" + echo -e " -m functioncore build functioncore containing all functioncore modules" + echo -e " -m all contains runtime and worker-flow in addition to functioncore" + echo -e "Notes: If the parameter of -m is all, you must download the runtime project first," + echo -e " and keep the the projects' path at the same level as the functioncore." + echo -e " -h usage help" + echo -e " -l compile with local runtime code" + echo -e " " + echo -e "Example:" + echo -e " sh build.sh -m \"worker\"" + echo -e " sh build.sh -m \"functioncore\"" + echo -e " sh build.sh -m \"all\"" + echo -e " sh build.sh -l -m \"runtimemanager\"" + echo -e " sh docker_build.sh -v \"version\" -m \"worker\"" + echo -e " sh docker_build.sh -v \"version\" -m \"functioncore\"" + echo -e " sh docker_build.sh -v \"version\" -m \"all\"" + echo -e " sh docker_build.sh -v \"version\" -m \"runtimemanager\"" + echo -e "" +} + +# ---------------------------------------------------------------------- +# funcname: get_base_image. +# description: check the system +# parameters: base images path +# return value: NA +# ---------------------------------------------------------------------- +function get_base_image() { + local local_os=$(head -1 /etc/os-release | tail -1 | awk -F "\"" '{print $2}')_$(uname -m) + log_info "The operating system is ${local_os}." + + case "${local_os}" in + Ubuntu_x86_64) + BASE_IMAGE="ubuntu:18.04" + ;; + EulerOS_x86_64) + BASE_IMAGE="euleros:v2r9" + ;; + EulerOS_aarch64) + BASE_IMAGE="euleros:v2r8" + ;; + openEuler_x86_64) + BASE_IMAGE="openeuler:20.03" + ;; + "CentOS Linux_x86_64") + BASE_IMAGE="centos:7.9" + ;; + esac +} diff --git a/yuanrong/build/compile_functions.sh b/yuanrong/build/compile_functions.sh new file mode 100644 index 0000000..ae69258 --- /dev/null +++ b/yuanrong/build/compile_functions.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +BASE_DIR=$(cd "$(dirname "$0")"; pwd) +PROJECT_DIR=$(cd "$(dirname "$0")"/..; pwd) +POSIX_DIR="${PROJECT_DIR}/proto/posix" + +function generate_pb() { + # generate pb files + if [ -z "${GOPATH}" ] || [ ! -d "${GOPATH}" ]; then + log_error "GOPATH ${GOPATH} not exist!" + return 1 + fi + cd "${PROJECT_DIR}"/pkg + [ -d "${GOPATH}/src/dashboard" ] && rm -rf "${GOPATH}/src/dashboard" + mkdir -p "${GOPATH}"/src/ + ln -s "${PROJECT_DIR}" "${GOPATH}"/src/dashboard + protoc --proto_path=${POSIX_DIR} --go_out=${PROJECT_DIR}/../ --go-grpc_out=${PROJECT_DIR}/../ ${POSIX_DIR}/*.proto +} + +go mod tidy +generate_pb \ No newline at end of file diff --git a/yuanrong/build/dashboard/config/dashboard_config.json b/yuanrong/build/dashboard/config/dashboard_config.json new file mode 100644 index 0000000..7a2b098 --- /dev/null +++ b/yuanrong/build/dashboard/config/dashboard_config.json @@ -0,0 +1,22 @@ +{ + "ip": "{ip}", + "port": {port}, + "grpcIP": "{grpcIP}", + "grpcPort": {grpcPort}, + "staticPath": "{staticPath}", + "functionMasterAddr": "{functionMasterAddr}", + "frontendAddr": "{frontendAddr}", + "prometheusAddr": "{prometheusAddr}", + "routerEtcdConfig": { + "servers": ["{etcdAddr}"], + "sslEnable": {etcdSsl}, + "authType": "{etcdAuthType}", + "azPrefix": "{azPrefix}" + }, + "metaEtcdConfig": { + "servers": ["{etcdAddr}"], + "sslEnable": {etcdSsl}, + "authType": "{etcdAuthType}", + "azPrefix": "{azPrefix}" + } +} \ No newline at end of file diff --git a/yuanrong/build/dashboard/config/dashboard_log.json b/yuanrong/build/dashboard/config/dashboard_log.json new file mode 100644 index 0000000..abdece3 --- /dev/null +++ b/yuanrong/build/dashboard/config/dashboard_log.json @@ -0,0 +1,13 @@ +{ + "filepath": "{logConfigPath}", + "level": "{logLevel}", + "rolling": { + "maxsize": 400, + "maxbackups": 1, + "maxage": 1, + "compress": true + }, + "tick": 10, + "first": 10, + "thereafter": 5 +} \ No newline at end of file diff --git a/yuanrong/cmd/collector/main.go b/yuanrong/cmd/collector/main.go new file mode 100644 index 0000000..7af9f42 --- /dev/null +++ b/yuanrong/cmd/collector/main.go @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package main +package main + +import ( + "yuanrong/cmd/collector/process" +) + +func main() { + process.StartCollector() +} diff --git a/yuanrong/cmd/collector/process/process.go b/yuanrong/cmd/collector/process/process.go new file mode 100644 index 0000000..c54beed --- /dev/null +++ b/yuanrong/cmd/collector/process/process.go @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package process to run collector +package process + +import ( + "yuanrong/pkg/collector/common" + "yuanrong/pkg/collector/logcollector" + "yuanrong/pkg/common/faas_common/logger/log" +) + +// StartCollector - +func StartCollector() { + common.InitCmd() + if err := common.InitEtcdClient(); err != nil { + log.GetLogger().Errorf("failed to init etcd client %s", err.Error()) + return + } + ready := make(chan bool) + go func() { + <-ready + err := logcollector.Register() + if err != nil { + log.GetLogger().Errorf("failed to register %s", err.Error()) + return + } + logcollector.StartLogReporter() + }() + err := logcollector.StartReadLogService(ready) + if err != nil { + log.GetLogger().Errorf("failed to start log service %s", err.Error()) + return + } +} diff --git a/yuanrong/cmd/dashboard/main.go b/yuanrong/cmd/dashboard/main.go new file mode 100644 index 0000000..80aa323 --- /dev/null +++ b/yuanrong/cmd/dashboard/main.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "yuanrong/cmd/dashboard/process" +) + +// @title Dashboard API +// @version 1.0 +func main() { + process.StartDashboard() +} diff --git a/yuanrong/cmd/dashboard/process/process.go b/yuanrong/cmd/dashboard/process/process.go new file mode 100644 index 0000000..533fd7f --- /dev/null +++ b/yuanrong/cmd/dashboard/process/process.go @@ -0,0 +1,85 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package process for run dashboard server +package process + +import ( + "fmt" + "net" + "net/http" + + "github.com/gin-gonic/gin" + "google.golang.org/grpc" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/etcdcache" + "yuanrong/pkg/dashboard/flags" + "yuanrong/pkg/dashboard/logmanager" + "yuanrong/pkg/dashboard/routers" +) + +// StartGrpcServices of dashboard +func StartGrpcServices() { + // start grpc service + lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", flags.DashboardConfig.Ip, flags.DashboardConfig.GrpcPort)) + if err != nil { + log.GetLogger().Fatalf("failed to listen: %v", err) + } + + grpcServer := grpc.NewServer() + logservice.RegisterLogManagerServiceServer(grpcServer, &logmanager.Server{}) + + if err := grpcServer.Serve(lis); err != nil { + log.GetLogger().Fatalf("failed to serve grpc: %s", err.Error()) + } +} + +// StartDashboard - function for run dashboard server +func StartDashboard() { + gin.SetMode(gin.ReleaseMode) + stopCh := make(chan struct{}) + + // init the etcd config first + err := flags.InitEtcdClient() + if err != nil { + log.GetLogger().Fatalf("failed to init etcd, err: %s", err) + } + + // register self + err = flags.RegisterSelfToEtcd(stopCh) + if err != nil { + log.GetLogger().Fatalf("failed to register self to etcd, err: %s", err) + } + + // start watcher + etcdcache.StartWatchInstance(stopCh) + + // start grpc at background + go StartGrpcServices() + + // start http, and use http as main thread + r := routers.SetRouter() + srv := &http.Server{ + Addr: flags.DashboardConfig.ServerAddr, + Handler: r, + } + log.GetLogger().Debugf("http://%s is running...", flags.DashboardConfig.ServerAddr) + if err := srv.ListenAndServe(); err != nil { + log.GetLogger().Fatalf("srv.ListenAndServe: %s", err.Error()) + } +} diff --git a/yuanrong/cmd/faas/faascontroller/main.go b/yuanrong/cmd/faas/faascontroller/main.go new file mode 100644 index 0000000..9e543b9 --- /dev/null +++ b/yuanrong/cmd/faas/faascontroller/main.go @@ -0,0 +1,163 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package main - +package main + +import ( + "errors" + "fmt" + "sync" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/faascontroller" + "yuanrong/pkg/system_function_controller/state" +) + +var ( + // faasController handles instance management for faasscheduler + faasController *faascontroller.FaaSController + stopCh = make(chan struct{}) + shutdownOnce sync.Once + + signalHandlerMap = map[int32]func([]byte) error{} + frontendUpdateSignal = int32(65) + schedulerUpdateSignal = int32(66) + managerUpdateSignal = int32(68) +) + +// InitHandlerLibruntime is the init handler called by runtime based on multi libruntime +func InitHandlerLibruntime(args []api.Arg, libruntimeAPI api.LibruntimeAPI) ([]byte, error) { + log.SetupLoggerLibruntime(libruntimeAPI.GetFormatLogger()) + var err error + defer func() { + if err != nil { + fmt.Printf("panic, module: faascontroller, err: %s\n", err.Error()) + log.GetLogger().Errorf("panic, module: faascontroller, err: %s", err.Error()) + } + log.GetLogger().Sync() + }() + log.GetLogger().Infof("trigger: faascontroller.InitHandler") + if err = checkArgsLibruntime(args); err != nil { + return []byte(""), err + } + if err = config.InitConfig(args[0].Data); err != nil { + log.GetLogger().Errorf("failed to init config, err:%s", err.Error()) + return []byte(""), err + } + if err = config.InitEtcd(stopCh); err != nil { + log.GetLogger().Errorf("failed to init etcd ,err:%s", err.Error()) + return []byte(""), err + } + state.InitState(config.GetFaaSControllerConfig().SchedulerExclusivity) + if err = setupFaaSControllerLibruntime(libruntimeAPI); err != nil { + return []byte(""), err + } + log.GetLogger().Infof("exit: faascontroller.InitHandler") + return []byte(""), nil +} + +func checkArgsLibruntime(args []api.Arg) error { + if len(args) == 0 { + log.GetLogger().Errorf("init args empty") + return errors.New("init args empty") + } + if args[0].Type != api.Value { + log.GetLogger().Errorf("arg type error") + return errors.New("arg type error") + } + return nil +} + +func setupFaaSControllerLibruntime(libruntimeAPI api.LibruntimeAPI) error { + var err error + // create Faas controller instance + faasController, err = faascontroller.NewFaaSControllerLibruntime(libruntimeAPI, stopCh) + return PrepareFaasController(err) +} + +// PrepareFaasController - +func PrepareFaasController(err error) error { + if err != nil { + log.GetLogger().Errorf("failed to create faas controller instance") + return err + } + signalHandlerMap[frontendUpdateSignal] = faasController.FrontendSignalHandler + signalHandlerMap[schedulerUpdateSignal] = faasController.SchedulerSignalHandler + signalHandlerMap[managerUpdateSignal] = faasController.ManagerSignalHandler + return nil +} + +// CallHandlerLibruntime is the call handler called by runtime +func CallHandlerLibruntime(args []api.Arg) ([]byte, error) { + return nil, nil +} + +// CheckpointHandlerLibruntime is the checkpoint handler called by runtime +func CheckpointHandlerLibruntime(checkpointID string) ([]byte, error) { + return state.GetStateByte() +} + +// RecoverHandlerLibruntime is the recover handler called by runtime based on multi libruntime +func RecoverHandlerLibruntime(stateData []byte, libruntimeAPI api.LibruntimeAPI) error { + var err error + log.SetupLoggerLibruntime(libruntimeAPI.GetFormatLogger()) + log.GetLogger().Infof("trigger: faascontroller.RecoverHandler") + if err = state.SetState(stateData); err != nil { + return fmt.Errorf("recover faaS controller error:%s", err.Error()) + } + state.InitState(config.GetFaaSControllerConfig().SchedulerExclusivity) + if err = config.RecoverConfig(); err != nil { + return fmt.Errorf("recover config error:%s", err.Error()) + } + if err = config.InitEtcd(stopCh); err != nil { + log.GetLogger().Errorf("failed to init etcd ,err:%s", err.Error()) + return err + } + if err = setupFaaSControllerLibruntime(libruntimeAPI); err != nil { + return fmt.Errorf("restart faaS controller error:%s", err.Error()) + } + state.Update(config.GetFaaSControllerConfig()) + return nil +} + +// ShutdownHandlerLibruntime is the shutdown handler called by runtime based on multi libruntime +func ShutdownHandlerLibruntime(gracePeriodSecond uint64) error { + log.GetLogger().Infof("trigger: faascontroller.ShutdownHandlerLibruntime") + utils.SafeCloseChannel(stopCh) + log.GetLogger().Infof("faascontrolerLibruntime exit") + log.GetLogger().Sync() + return nil +} + +// SignalHandlerLibruntime is the signal handler called by runtime +func SignalHandlerLibruntime(signal int, payload []byte) error { + log.GetLogger().Infof("trigger: faascontroller.SignalHandlerLibruntime signal:%d", signal) + handler, ok := signalHandlerMap[int32(signal)] + if !ok { + log.GetLogger().Errorf("signal: %d, not found handler", signal) + return fmt.Errorf("not found signal handler libruntime") + } + err := handler(payload) + if err != nil { + return err + } + return nil +} diff --git a/yuanrong/cmd/faas/faascontroller/main_test.go b/yuanrong/cmd/faas/faascontroller/main_test.go new file mode 100644 index 0000000..b1195d8 --- /dev/null +++ b/yuanrong/cmd/faas/faascontroller/main_test.go @@ -0,0 +1,323 @@ +//go:build function + +package main + +import ( + "encoding/json" + "errors" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/faascontroller" + "yuanrong/pkg/system_function_controller/state" +) + +func TestCallHandler(t *testing.T) { + type args struct { + args []*api.Arg + createOpt map[string]string + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + { + name: "nil response", + args: args{ + args: nil, + createOpt: make(map[string]string), + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CallHandler(tt.args.args, tt.args.createOpt) + if (err != nil) != tt.wantErr { + t.Errorf("CallHandler() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CallHandler() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCheckpointHandler(t *testing.T) { + type args struct { + checkpointID string + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + name: "nil response", + args: args{checkpointID: "123456"}, + wantErr: false, + want: []byte(`{"FaaSControllerConfig":null,"FaasInstance":null}`), + }, + } + defer gomonkey.ApplyFunc(state.GetStateByte, func() ([]byte, error) { + return json.Marshal(&state.ControllerState{}) + }).Reset() + state.InitState() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CheckpointHandler(tt.args.checkpointID) + if (err != nil) != tt.wantErr { + t.Errorf("CheckpointHandler() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CheckpointHandler() got = %v, want %v", got, tt.want) + } + }) + } +} + +var configString = `{ + "frontendInstanceNum": 100, + "schedulerInstanceNum": 100, + "faasschedulerConfig": { + "cpu": 777, + "memory": 777, + "slaQuota": 1000, + "scaleDownTime": 60000, + "burstScaleNum": 1000, + "leaseSpan": 600000 + }, + "faasfrontendConfig": { + "cpu": 777, + "memory": 777, + "slaQuota": 1000, + "functionCapability": 1, + "authenticationEnable": false, + "trafficLimitDisable": true, + "http": { + "resptimeout": 5, + "workerInstanceReadTimeOut": 5, + "maxRequestBodySize": 6 + } + }, + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + } + } + ` + +func TestInitHandler(t *testing.T) { + type args struct { + args []*api.Arg + fsClient api.FunctionSystemClient + dsClient api.DataSystemClient + formatLogger api.FormatLogger + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + { + name: "init_args_error", + args: args{ + args: []*api.Arg{}, + fsClient: nil, + dsClient: nil, + formatLogger: nil, + }, + want: "", + wantErr: true, + }, + { + name: "init_args_type_error", + args: args{ + args: []*api.Arg{&api.Arg{ + ArgType: 1, + Data: nil, + }}, + fsClient: nil, + dsClient: nil, + }, + want: "", + wantErr: true, + }, + { + name: "init_config_error", + args: args{ + args: []*api.Arg{&api.Arg{ + ArgType: api.Value, + Data: nil, + }}, + fsClient: nil, + dsClient: nil, + }, + want: "", + wantErr: true, + }, + { + name: "init_etcd_error", + args: args{ + args: []*api.Arg{&api.Arg{ + ArgType: api.Value, + Data: []byte(configString), + }}, + fsClient: nil, + dsClient: nil, + }, + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(utils.DelayExit, func(module string, err error) { return }), + } + got, err := InitHandler(tt.args.args, tt.args.fsClient, tt.args.dsClient, tt.args.formatLogger) + if (err != nil) != tt.wantErr { + t.Errorf("InitHandler() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("InitHandler() got = %v, want %v", got, tt.want) + } + for _, patch := range patches { + patch.Reset() + } + }) + } + + convey.Convey("init success", t, func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(faascontroller.NewFaaSController, + func(sdkClient api.FunctionSystemClient, stopCh chan struct{}) (*faascontroller.FaaSController, error) { + return &faascontroller.FaaSController{}, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&faascontroller.FaaSController{}), "RegistryList", + func(_ *faascontroller.FaaSController, stopCh <-chan struct{}) error { + return nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&faascontroller.FaaSController{}), "CreateExpectedAllInstanceCount", + func(_ *faascontroller.FaaSController) error { + return nil + }), + gomonkey.ApplyFunc(state.InitState, func() { + return + }), + gomonkey.ApplyFunc(config.InitEtcd, func(stopCh <-chan struct{}) error { + return nil + }), + } + defer func() { + close(stopCh) + for _, patch := range patches { + patch.Reset() + } + }() + initArgs := []*api.Arg{ + { + ArgType: 0, + Data: []byte(configString), + }, + } + resp, err := InitHandler(initArgs, nil, nil, nil) + convey.So(resp, convey.ShouldEqual, "") + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestRecoverHandler(t *testing.T) { + convey.Convey("recover success", t, func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(faascontroller.NewFaaSController, + func(sdkClient api.FunctionSystemClient, stopCh chan struct{}) (*faascontroller.FaaSController, error) { + return &faascontroller.FaaSController{}, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&faascontroller.FaaSController{}), "RegistryList", + func(_ *faascontroller.FaaSController, stopCh <-chan struct{}) error { + return nil + }), + gomonkey.ApplyFunc(config.RecoverConfig, func() error { + return nil + }), + gomonkey.ApplyFunc(state.InitState, func() { + return + }), + gomonkey.ApplyFunc(config.InitEtcd, func(stopCh <-chan struct{}) error { + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + convey.Convey("success", func() { + patches = append(patches, gomonkey.ApplyMethod(reflect.TypeOf(&faascontroller.FaaSController{}), "CreateExpectedAllInstanceCount", + func(_ *faascontroller.FaaSController) error { + return nil + })) + err := RecoverHandler([]byte(configString), nil, nil, nil) + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("failed", func() { + patches = append(patches, gomonkey.ApplyMethod(reflect.TypeOf(&faascontroller.FaaSController{}), "CreateExpectedAllInstanceCount", + func(_ *faascontroller.FaaSController) error { + return errors.New("fail to CreateExpectedAllInstanceCount") + })) + err := RecoverHandler([]byte(configString), nil, nil, nil) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestShutdownHandler(t *testing.T) { + stopCh = make(chan struct{}) + convey.Convey("ShutdownHandler", t, func() { + faasController = &faascontroller.FaaSController{} + convey.Convey("success", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(faasController), "KillAllInstances", + func(_ *faascontroller.FaaSController) { + + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + err := ShutdownHandler(30) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestSignalHandler(t *testing.T) { + convey.Convey("SignalHandler", t, func() { + err := SignalHandler(0, []byte{}) + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/cmd/faas/faasmanager/main.go b/yuanrong/cmd/faas/faasmanager/main.go new file mode 100644 index 0000000..2050a7d --- /dev/null +++ b/yuanrong/cmd/faas/faasmanager/main.go @@ -0,0 +1,139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package main - +package main + +import ( + "encoding/json" + "errors" + "fmt" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionmanager" + "yuanrong/pkg/functionmanager/config" + "yuanrong/pkg/functionmanager/state" +) + +var ( + // faasManager handles functions management for faas pattern + faasManager *functionmanager.Manager + stopCh = make(chan struct{}) +) + +// InitHandlerLibruntime is the init handler called by runtime +func InitHandlerLibruntime(args []api.Arg, libruntimeAPI api.LibruntimeAPI) ([]byte, error) { + log.SetupLoggerLibruntime(libruntimeAPI.GetFormatLogger()) + log.GetLogger().Infof("trigger: faasmanager.InitHandler") + if len(args) == 0 || libruntimeAPI == nil { + return []byte(""), errors.New("init args empty") + } + if args[0].Type != api.Value { + return []byte(""), errors.New("arg type error") + } + err := config.InitConfig(args[0].Data) + if err != nil { + return []byte(""), err + } + if err = config.InitEtcd(stopCh); err != nil { + log.GetLogger().Errorf("failed to init etcd ,err:%s", err.Error()) + return []byte(""), err + } + state.InitState() + stateByte, err := state.GetStateByte() + if err == nil && len(stateByte) != 0 { + return []byte(""), RecoverHandlerLibruntime(stateByte, libruntimeAPI) + } + if _, err = setupFaaSManagerLibruntime(libruntimeAPI); err != nil { + return []byte(""), err + } + cfg := config.GetConfig() + state.Update(&cfg) + if faasManager != nil { + go faasManager.WatchLeaseEvent() + } + return []byte(""), nil +} + +// CallHandlerLibruntime is the call handler called by runtime +func CallHandlerLibruntime(args []api.Arg) ([]byte, error) { + traceID := string(args[len(args)-1].Data) + if faasManager == nil { + return nil, fmt.Errorf("faas manager is not initialized, traceID: %s", traceID) + } + response := faasManager.ProcessSchedulerRequestLibruntime(args, traceID) + if response == nil { + return nil, fmt.Errorf("failed to process scheduler request, traceID: %s", traceID) + } + rspData, err := json.Marshal(response) + if err != nil { + return nil, err + } + return rspData, nil +} + +// CheckpointHandlerLibruntime is the checkpoint handler called by libruntime +func CheckpointHandlerLibruntime(checkpointID string) ([]byte, error) { + return state.GetStateByte() +} + +// RecoverHandlerLibruntime is the recover handler called by runtime +func RecoverHandlerLibruntime(stateData []byte, libruntimeAPI api.LibruntimeAPI) error { + var err error + log.SetupLoggerLibruntime(libruntimeAPI.GetFormatLogger()) + log.GetLogger().Infof("trigger: faasmanager.RecoverHandler") + state.InitState() + if err = state.SetState(stateData); err != nil { + return fmt.Errorf("faaS manager recover error:%s", err.Error()) + } + if _, err = setupFaaSManagerLibruntime(libruntimeAPI); err != nil { + return fmt.Errorf("restart faaS controller libruntime error:%s", err.Error()) + } + if faasManager != nil { + faasManager.RecoverData() + } + cfg := config.GetConfig() + state.Update(&cfg) + if faasManager != nil { + go faasManager.WatchLeaseEvent() + } + return nil +} + +// ShutdownHandlerLibruntime is the shutdown handler called by libruntime +func ShutdownHandlerLibruntime(gracePeriodSecond uint64) error { + log.GetLogger().Infof("trigger: faasmanager.ShutdownHandler") + utils.SafeCloseChannel(stopCh) + log.GetLogger().Sync() + return nil +} + +// SignalHandlerLibruntime is the signal handler called by libruntime +func SignalHandlerLibruntime(signal int, payload []byte) error { + return nil +} + +func setupFaaSManagerLibruntime(libruntimeAPI api.LibruntimeAPI) (interface{}, error) { + var err error + faasManager, err = functionmanager.NewFaaSManagerLibruntime(libruntimeAPI, stopCh) + if err != nil { + return "", err + } + return "", nil +} diff --git a/yuanrong/cmd/faas/faasscheduler/function_main.go b/yuanrong/cmd/faas/faasscheduler/function_main.go new file mode 100644 index 0000000..02d395c --- /dev/null +++ b/yuanrong/cmd/faas/faasscheduler/function_main.go @@ -0,0 +1,162 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package main - +package main + +import ( + "errors" + "fmt" + "time" + + _ "go.uber.org/automaxprocs" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/trafficlimit" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/healthcheck" + "yuanrong/pkg/functionscaler/instancepool" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/signalmanager" + "yuanrong/pkg/functionscaler/state" +) + +const ( + shutdownWaitTime = 5 * time.Second +) + +var ( + stopCh = make(chan struct{}) + errCh = make(chan error) +) + +// InitHandlerLibruntime is the init handler called by runtime based on multi libruntime +func InitHandlerLibruntime(args []api.Arg, libruntimeAPI api.LibruntimeAPI) ([]byte, error) { + log.SetupLoggerLibruntime(libruntimeAPI.GetFormatLogger()) + var err error + defer func() { + if err != nil { + fmt.Printf("panic, module: faasscheduler, err: %s\n", err.Error()) + log.GetLogger().Errorf("panic, module: faasscheduler, err: %s", err.Error()) + } + log.GetLogger().Sync() + }() + if len(args) == 0 || libruntimeAPI == nil { + return []byte(""), errors.New("init args empty") + } + if args[0].Type != api.Value { + return []byte(""), errors.New("arg type error") + } + if err = config.InitConfig(args[0].Data); err != nil { + return []byte(""), err + } + if err = config.InitEtcd(stopCh); err != nil { + log.GetLogger().Errorf("failed to init etcd ,err:%s", err.Error()) + return nil, err + } + state.InitState() + if err = setupFunctionSchedulerLibruntime(libruntimeAPI); err != nil { + return []byte(""), err + } + registry.StartRegistry() + if err = healthcheck.StartHealthCheck(errCh); err != nil { + return []byte(""), err + } + config.ClearSensitiveInfo() + return []byte(""), nil +} + +// CallHandlerLibruntime is the call handler called by runtime based on multi libruntime +func CallHandlerLibruntime(args []api.Arg) ([]byte, error) { + traceID := string(args[len(args)-1].Data) + + if functionscaler.GetGlobalScheduler() == nil { + return nil, fmt.Errorf("faas scheduler is not initialized, traceID: %s", traceID) + } + return functionscaler.GetGlobalScheduler().ProcessInstanceRequestLibruntime(args, traceID) +} + +// CheckpointHandlerLibruntime is the checkpoint handler called by runtime based on multi libruntime +func CheckpointHandlerLibruntime(checkpointID string) ([]byte, error) { + return state.GetStateByte() +} + +// RecoverHandlerLibruntime is the recover handler called by runtime based on multi libruntime +func RecoverHandlerLibruntime(stateData []byte, libruntimeAPI api.LibruntimeAPI) error { + var err error + log.SetupLoggerLibruntime(libruntimeAPI.GetFormatLogger()) + log.GetLogger().Infof("trigger: libruntime faasscheduler.RecoverHandler") + if err = state.SetState(stateData); err != nil { + return fmt.Errorf("libruntime faaS scheduler recover error is :%s", err.Error()) + } + state.InitState() + if err = state.RecoverConfig(); err != nil { + return fmt.Errorf("libruntime recover config error:%s", err.Error()) + } + if err = config.InitEtcd(stopCh); err != nil { + log.GetLogger().Errorf("failed to init etcd ,err:%s", err.Error()) + return err + } + if err = setupFunctionSchedulerLibruntime(libruntimeAPI); err != nil { + log.GetLogger().Errorf("libruntime recover initHandler error:%s", err.Error()) + return fmt.Errorf("faaS frontend recover initHandler error of libruntime is :%s", err.Error()) + } + if functionscaler.GetGlobalScheduler() != nil { + functionscaler.GetGlobalScheduler().Recover() + } + state.Update(config.GlobalConfig) + registry.StartRegistry() + config.ClearSensitiveInfo() + return nil +} + +// ShutdownHandlerLibruntime is the shutdown handler called by runtime based on multi libruntime +func ShutdownHandlerLibruntime(gracePeriodSecond uint64) error { + log.GetLogger().Infof("trigger: faasscheduler.ShutdownHandler") + utils.SafeCloseChannel(stopCh) + time.Sleep(shutdownWaitTime) + log.GetLogger().Infof("faasschedulerLibruntime exit") + log.GetLogger().Sync() + return nil +} + +// SignalHandlerLibruntime is the signal handler called by runtime based on multi libruntime +func SignalHandlerLibruntime(signal int, payload []byte) error { + return nil +} + +func setupFunctionSchedulerLibruntime(fsClient api.LibruntimeAPI) error { + rollout.SetRolloutSdkClient(fsClient) + if err := registry.InitRegistry(stopCh); err != nil { + return err + } + if err := selfregister.RegisterToEtcd(stopCh); err != nil { + return err + } + + signalmanager.GetSignalManager().SetKillFunc(fsClient.Kill) + + instancepool.SetGlobalSdkClient(fsClient) + functionscaler.InitGlobalScheduler(stopCh) + registry.ProcessETCDList() + trafficlimit.SetFunctionLimitRate(config.GlobalConfig.FunctionLimitRate) + return nil +} diff --git a/yuanrong/cmd/faas/faasscheduler/function_main_test.go b/yuanrong/cmd/faas/faasscheduler/function_main_test.go new file mode 100644 index 0000000..f256151 --- /dev/null +++ b/yuanrong/cmd/faas/faasscheduler/function_main_test.go @@ -0,0 +1,286 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package main - +package main + +import ( + "context" + "errors" + "reflect" + "syscall" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/etcd3" + commonTypes "yuanrong/pkg/common/faas_common/types" + mockUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/types" +) + +func TestMain(m *testing.M) { + patches := []*Patches{ + ApplyFunc((*etcd3.EtcdWatcher).StartList, func(ew *etcd3.EtcdWatcher) { + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.SYNCED, + Key: "", + Value: nil, + PrevValue: nil, + Rev: 0, + ETCDType: "", + } + }), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + registry.InitRegistry(stopCh) + m.Run() +} + +func TestInitHandler(t *testing.T) { + var err error + configJson := `{ + "cpu": 999, + "memory": 999, + "autoScaleConfig": { + "slaQuota": 1000, + "scaleDownTime": 20000, + "burstScaleNum": 1000 + }, + "leaseSpan": 1000, + "routerEtcd": { + "servers": ["1.1.1.1:32379"], + "username": "root", + "password": "" + }, + "metaEtcd": { + "servers": ["1.1.1.1:32380"], + "username": "root", + "password": "" + } + }` + testArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(configJson), + }, + } + patches := []*Patches{ + ApplyFunc(registry.InitRegistry, func(stopCh <-chan struct{}) error { + return nil + }), + ApplyFunc(registry.StartRegistry, func() { + return + }), + ApplyFunc((*registry.FunctionRegistry).WaitForETCDList, func() {}), + ApplyFunc(functionscaler.NewFaaSScheduler, + func(stopCh <-chan struct{}) *functionscaler.FaaSScheduler { + return &functionscaler.FaaSScheduler{} + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "EtcdHeatBeat", func(e *etcd3.EtcdClient) error { + return nil + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + _, err = InitHandlerLibruntime(nil, &mockUtils.FakeLibruntimeSdkClient{}) + assert.Equal(t, "init args empty", err.Error()) + _, err = InitHandlerLibruntime(testArgs, &mockUtils.FakeLibruntimeSdkClient{}) + assert.Equal(t, true, err == nil) + + defer ApplyFunc(config.InitEtcd, func(stopCh <-chan struct{}) error { + return errors.New("init etcd error") + }).Reset() + _, err = InitHandlerLibruntime(testArgs, &mockUtils.FakeLibruntimeSdkClient{}) + assert.Equal(t, true, err != nil) +} + +func TestCallHandler(t *testing.T) { + // _, err := CallHandlerLibruntime(nil) + configJson := `{ + "slaQuota": 1000, + "scaleDownTime": 20000, + "burstScaleNum": 1000, + "leaseSpan": 1000, + "routerEtcd": { + "servers": ["1.1.1.1:32379"], + "username": "root", + "password": "" + }, + "metaEtcd": { + "servers": ["1.1.1.1:32380"], + "username": "root", + "password": "" + } + }` + testFuncSpec := &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncKey: "TestFuncKey", + FuncMetaData: commonTypes.FuncMetaData{}, + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + MinInstance: 0, + MaxInstance: 1000, + ConcurrentNum: 100, + }, + } + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "GetFuncSpec", func(_ *registry.Registry, + funcKey string) *types.FunctionSpecification { + return testFuncSpec + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", func(_ *etcd3.EtcdWatcher) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + + config.InitConfig([]byte(configJson)) + functionscaler.InitGlobalScheduler(stopCh) + functionscaler.GetGlobalScheduler().PoolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, testFuncSpec) + testArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte("acquire#TestFuncKey"), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + rsp, err := CallHandlerLibruntime(testArgs) + assert.Equal(t, true, err == nil) + assert.Equal(t, true, rsp != nil) +} + +func BenchmarkCallHandler(b *testing.B) { + configJson := `{ + "slaQuota": 1000, + "scaleDownTime": 20000, + "burstScaleNum": 1000, + "leaseSpan": 1000, + "etcd": { + "url": ["1.1.1.1:32379"], + "username": "root", + "password": "" + } + }` + testFuncSpec := &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncKey: "TestFuncKey", + FuncMetaData: commonTypes.FuncMetaData{}, + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + MinInstance: 0, + MaxInstance: 1000, + ConcurrentNum: 100, + }, + } + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "GetFuncSpec", func(_ *registry.Registry, + funcKey string) *types.FunctionSpecification { + return testFuncSpec + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + config.InitConfig([]byte(configJson)) + functionscaler.InitGlobalScheduler(make(chan struct{})) + testArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte("acquire#TestFuncKey"), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + for i := 0; i < b.N; i++ { + _, err := CallHandlerLibruntime(testArgs) + if err != nil { + b.Errorf("acquire instance thread error %s", err.Error()) + } + } +} + +func TestCheckpointHandler(t *testing.T) { + convey.Convey("CheckpointHandler", t, func() { + defer ApplyFunc(state.GetStateByte, func() ([]byte, error) { + return []byte{}, nil + }).Reset() + handler, err := CheckpointHandlerLibruntime("checkpointID") + convey.So(err, convey.ShouldBeNil) + convey.So(len(handler), convey.ShouldEqual, 0) + }) +} + +func TestShutdownHandler(t *testing.T) { + convey.Convey("ShutdownHandler", t, func() { + stopCh = make(chan struct{}) + err := ShutdownHandlerLibruntime(uint64(30)) + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestShutdownHandlerLibruntime(t *testing.T) { + convey.Convey("ShutdownHandlerLibruntime", t, func() { + stopCh = make(chan struct{}) + err := ShutdownHandlerLibruntime(uint64(30)) + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestSignalHandlerLibruntime(t *testing.T) { + convey.Convey("SignalHandlerLibruntime", t, func() { + err := SignalHandlerLibruntime(int(syscall.SIGTERM), nil) + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/cmd/faas/faasscheduler/module_main.go b/yuanrong/cmd/faas/faasscheduler/module_main.go new file mode 100644 index 0000000..729ff09 --- /dev/null +++ b/yuanrong/cmd/faas/faasscheduler/module_main.go @@ -0,0 +1,210 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "fmt" + + "github.com/valyala/fasthttp" + + "yuanrong/pkg/common/faas_common/autogc" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/signals" + "yuanrong/pkg/common/faas_common/trafficlimit" + "yuanrong/pkg/functionscaler" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/healthcheck" + "yuanrong/pkg/functionscaler/httpserver" + "yuanrong/pkg/functionscaler/instancequeue" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/workermanager" +) + +const ( + logFileName = "faas-scheduler" + filePath = "/home/sn/config/config.json" + errChanSize = 2 +) + +func main() { + defer func() { + log.GetLogger().Sync() + }() + // init logger config + err := log.InitRunLog(logFileName, true) + if err != nil { + fmt.Print("init logger error: " + err.Error()) + return + } + + defer func() { + if err != nil { + fmt.Printf("panic, module: faasscheduler, err: %s\n", err.Error()) + log.GetLogger().Errorf("panic, module: scheduler, err: %s", err.Error()) + } + log.GetLogger().Sync() + }() + err = config.InitModuleConfig() + if err != nil { + errMessage := fmt.Sprintf("init module config error: %s", err.Error()) + logAndPrintError(errMessage) + return + } + autogc.InitAutoGOGC() + stopCh := signals.WaitForSignal() + if stopCh == nil { + errMessage := "stopCh is nil" + logAndPrintError(errMessage) + return + } + if err = config.InitEtcd(stopCh); err != nil { + errMessage := fmt.Sprintf("init etcd error: %s", err.Error()) + logAndPrintError(errMessage) + return + } + if err = workermanager.InitLeaseInformer(stopCh); err != nil { + errMessage := fmt.Sprintf("init lease informer error: %s", err.Error()) + logAndPrintError(errMessage) + return + } + instancequeue.DisableCreateRetry() + state.InitState() + var stateByte []byte + stateByte, err = state.GetStateByte() + if err == nil && len(stateByte) != 0 { + err = RecoverModuleScheduler(stateByte, stopCh) + if err != nil { + errMessage := fmt.Sprintf("failed to recover module scheduler ,err:%s", err.Error()) + logAndPrintError(errMessage) + return + } + } + if err = setupModuleScheduler(stopCh); err != nil { + errMessage := fmt.Sprintf("failed to setup module scheduler,err:%s", err.Error()) + logAndPrintError(errMessage) + return + } + registry.StartRegistry() + config.ClearSensitiveInfo() + errChan := make(chan error, errChanSize) + httpServer, err := httpserver.StartHTTPServer(errChan) + if err != nil { + errMessage := fmt.Sprintf("failed to start http server, err: %s", err.Error()) + logAndPrintError(errMessage) + return + } + err = healthcheck.StartHealthCheck(errChan) + if err != nil { + errMessage := fmt.Sprintf("failed to start health check, err: %s", err.Error()) + logAndPrintError(errMessage) + return + } + if err = selfregister.RegisterToEtcd(stopCh); err != nil { + errMessage := fmt.Sprintf("register to etcd error: %s", err.Error()) + logAndPrintError(errMessage) + } + waitShutdown(httpServer, stopCh, errChan) +} + +func logAndPrintError(errMessage string) { + log.GetLogger().Errorf(errMessage) + fmt.Println(errMessage) +} + +func setupModuleScheduler(stopCh <-chan struct{}) error { + err := registry.InitRegistry(stopCh) + if err != nil { + return err + } + // WatchConfig failed do not return, just config hot load not enable + if err := config.WatchConfig(filePath, stopCh, nil); err != nil { + log.GetLogger().Warnf("WatchConfig %s failed, err %s", filePath, err.Error()) + } + functionscaler.InitGlobalScheduler(stopCh) + registry.ProcessETCDList() + trafficlimit.SetFunctionLimitRate(config.GlobalConfig.FunctionLimitRate) + return nil +} + +// RecoverModuleScheduler - +func RecoverModuleScheduler(stateData []byte, stopCh <-chan struct{}) error { + var err error + log.GetLogger().Infof("trigger: RecoverModuleScheduler") + if err = state.SetState(stateData); err != nil { + return fmt.Errorf("module scheduler recover error:%s", err.Error()) + } + state.RecoverStateRev() + state.InitState() + if err = setupModuleScheduler(stopCh); err != nil { + log.GetLogger().Errorf("recover module scheduler error:%s", err.Error()) + return fmt.Errorf("module scheduler recover error:%s", err.Error()) + } + if functionscaler.GetGlobalScheduler() != nil { + functionscaler.GetGlobalScheduler().Recover() + } + registry.StartRegistry() + config.ClearSensitiveInfo() + errChan := make(chan error, errChanSize) + httpServer, err := startServer(errChan) + if err != nil { + return err + } + if err = selfregister.RegisterToEtcd(stopCh); err != nil { + return err + } + waitShutdown(httpServer, stopCh, errChan) + return nil +} + +func startServer(errChan chan error) (*fasthttp.Server, error) { + var httpServer *fasthttp.Server + var err error + if config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeFG { + httpServer, err = httpserver.StartHTTPServer(errChan) + if err != nil { + return nil, fmt.Errorf("start fast http server error:%s", err.Error()) + } + } + err = healthcheck.StartHealthCheck(errChan) + if err != nil { + return nil, fmt.Errorf("failed to start health check, err: %s", err.Error()) + } + return httpServer, nil +} + +func waitShutdown(server *fasthttp.Server, stopCh <-chan struct{}, errChan <-chan error) { + if stopCh == nil || errChan == nil { + errMessage := "input channel is nil" + logAndPrintError(errMessage) + } + select { + case <-stopCh: + log.GetLogger().Infof("received termination signal") + if server != nil { + if err := server.Shutdown(); err != nil { + errMessage := fmt.Sprintf("http server shutdowm error:%s", err.Error()) + logAndPrintError(errMessage) + } + } + case err := <-errChan: + errMessage := fmt.Sprintf("http server error:%s", err.Error()) + logAndPrintError(errMessage) + } +} diff --git a/yuanrong/go.mod b/yuanrong/go.mod new file mode 100644 index 0000000..f5d179f --- /dev/null +++ b/yuanrong/go.mod @@ -0,0 +1,56 @@ +module yuanrong + +go 1.24.1 + +require ( + yuanrong/pkg/common v1.0.0 + github.com/agiledragon/gomonkey/v2 v2.11.0 + github.com/stretchr/testify v1.10.0 + go.uber.org/automaxprocs v1.6.0 + yuanrong.org/kernel/runtime v1.0.0 +) + +replace ( + yuanrong/pkg/common => ./pkg/common + github.com/agiledragon/gomonkey => github.com/agiledragon/gomonkey v2.0.1+incompatible + github.com/asaskevich/govalidator/v11 => github.com/asaskevich/govalidator/v11 v11.0.1-0.20250122183457-e11347878e23 + github.com/fsnotify/fsnotify => github.com/fsnotify/fsnotify v1.7.0 + // for test or internal use + github.com/gin-gonic/gin => github.com/gin-gonic/gin v1.10.0 + github.com/golang/mock => github.com/golang/mock v1.3.1 + github.com/google/uuid => github.com/google/uuid v1.6.0 + github.com/olekukonko/tablewriter => github.com/olekukonko/tablewriter v0.0.5 + github.com/operator-framework/operator-lib => github.com/operator-framework/operator-lib v0.4.0 + github.com/prashantv/gostub => github.com/prashantv/gostub v1.0.0 + github.com/robfig/cron/v3 => github.com/robfig/cron/v3 v3.0.1 + github.com/smartystreets/goconvey => github.com/smartystreets/goconvey v1.6.4 + github.com/spf13/cobra => github.com/spf13/cobra v1.8.1 + github.com/stretchr/testify => github.com/stretchr/testify v1.5.1 + github.com/valyala/fasthttp => github.com/valyala/fasthttp v1.58.0 + go.etcd.io/etcd/api/v3 => go.etcd.io/etcd/api/v3 v3.5.11 + go.etcd.io/etcd/client/v3 => go.etcd.io/etcd/client/v3 v3.5.11 + go.opentelemetry.io/otel => go.opentelemetry.io/otel v1.24.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace => go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc => go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0 + go.opentelemetry.io/otel/metric => go.opentelemetry.io/otel/metric v1.24.0 + go.opentelemetry.io/otel/sdk => go.opentelemetry.io/otel/sdk v1.24.0 + go.opentelemetry.io/otel/trace => go.opentelemetry.io/otel/trace v1.24.0 + go.uber.org/automaxprocs => go.uber.org/automaxprocs v1.6.0 + go.uber.org/zap => go.uber.org/zap v1.27.0 + golang.org/x/crypto => golang.org/x/crypto v0.24.0 + // affects VPC plugin building, will cause error if not pinned + golang.org/x/net => golang.org/x/net v0.26.0 + golang.org/x/sync => golang.org/x/sync v0.0.0-20190423024810-112230192c58 + golang.org/x/sys => golang.org/x/sys v0.21.0 + golang.org/x/text => golang.org/x/text v0.16.0 + golang.org/x/time => golang.org/x/time v0.10.0 + google.golang.org/genproto => google.golang.org/genproto v0.0.0-20230526203410-71b5a4ffd15e + google.golang.org/genproto/googleapis/rpc => google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d + google.golang.org/grpc => google.golang.org/grpc v1.67.0 + google.golang.org/protobuf => google.golang.org/protobuf v1.36.6 + gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1 + yuanrong.org/kernel/runtime => ../api/go + k8s.io/api => k8s.io/api v0.31.2 + k8s.io/apimachinery => k8s.io/apimachinery v0.31.2 + k8s.io/client-go => k8s.io/client-go v0.31.2 +) diff --git a/yuanrong/pkg/collector/common/connection.go b/yuanrong/pkg/collector/common/connection.go new file mode 100644 index 0000000..994321f --- /dev/null +++ b/yuanrong/pkg/collector/common/connection.go @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package common prepares common constants, utils and structs for collector +package common + +import ( + "context" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" +) + +var ( + connection *grpc.ClientConn + once sync.Once +) + +const ( + // DefaultGrpcTimeoutS - + DefaultGrpcTimeoutS = 5 * time.Second +) + +// GetConnection get grpc connection +func GetConnection() *grpc.ClientConn { + once.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), DefaultGrpcTimeoutS) + defer cancel() + log.GetLogger().Infof("start connect to log manager grpc server: %s", CollectorConfigs.ManagerAddress) + conn, err := grpc.DialContext(ctx, CollectorConfigs.ManagerAddress, + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + log.GetLogger().Errorf("failed to connect to log manager grpc server, error: %v", err) + return + } + log.GetLogger().Infof("success to connect to log manager: %s", CollectorConfigs.ManagerAddress) + connection = conn + }) + return connection +} + +// InitEtcdClient will +func InitEtcdClient() error { + return etcd3.InitParam(). + WithRouteEtcdConfig(CollectorConfigs.EtcdConfig). + WithStopCh(make(chan struct{})). + InitClient() +} diff --git a/yuanrong/pkg/collector/common/connection_test.go b/yuanrong/pkg/collector/common/connection_test.go new file mode 100644 index 0000000..0c2a133 --- /dev/null +++ b/yuanrong/pkg/collector/common/connection_test.go @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package common + +import ( + "net" + "sync" + "testing" + + "google.golang.org/grpc" +) + +func TestGetConnection(t *testing.T) { + once = sync.Once{} + connection = nil + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer lis.Close() + + go func() { + grpc.NewServer().Serve(lis) + }() + + CollectorConfigs.ManagerAddress = lis.Addr().String() + + conn := GetConnection() + + if conn == nil { + t.Errorf("Expected a valid connection, but got nil") + } + + conn.Close() +} + +func TestFailedGetConnection(t *testing.T) { + once = sync.Once{} + connection = nil + + CollectorConfigs.ManagerAddress = "" + + conn := GetConnection() + + if conn != nil { + t.Errorf("Expected a failed connection, but got: %#v", conn) + } +} diff --git a/yuanrong/pkg/collector/common/flags.go b/yuanrong/pkg/collector/common/flags.go new file mode 100644 index 0000000..27fdc97 --- /dev/null +++ b/yuanrong/pkg/collector/common/flags.go @@ -0,0 +1,177 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package common + +import ( + "fmt" + "net" + "os" + "path/filepath" + "regexp" + "strconv" + + "github.com/spf13/cobra" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" +) + +// CollectorConfig - +type CollectorConfig struct { + CollectorID string + IP string + Port string + Address string + ManagerAddress string + DatasystemPort int + LogRoot string + UserLogPath string + EtcdConfig etcd3.EtcdConfig +} + +var ( + // CollectorConfigs - + CollectorConfigs CollectorConfig +) + +const ( + collectorIDPattern string = `^[a-zA-Z0-9~\.\-\/_!@#\%\^&\*\(\)\+\=\:;]{1,256}$` + portLowerBound int = 0 + portUpperBound int = 65536 +) + +var rootCmd = &cobra.Command{ + Short: "Collector is to collect user logs.", + Long: "Collector is to collect user-level logs by interactively streaming or specified range of lines.", + Run: func(cmd *cobra.Command, args []string) {}, +} + +func init() { + registerCmdArgs(rootCmd) +} + +// InitCmd init commandline +func InitCmd() { + if err := rootCmd.Execute(); err != nil { + log.GetLogger().Fatal(err.Error()) + } + if err := validateCmdArgs(); err != nil { + log.GetLogger().Fatal(err.Error()) + } + log.GetLogger().Infof("collector start args: %+v", CollectorConfigs) +} + +func registerCmdArgs(rootCmd *cobra.Command) { + rootCmd.Flags().StringVarP(&CollectorConfigs.CollectorID, "collect_id", "", "", + "the identifier; of length less than 256") + rootCmd.Flags().StringVarP(&CollectorConfigs.IP, "ip", "", "", "the ip of collector for manager to access") + rootCmd.Flags().StringVarP(&CollectorConfigs.Port, "port", "", "", "the port of collector for manager to access") + rootCmd.Flags().StringVarP(&CollectorConfigs.ManagerAddress, "manager_address", "", "", + "manager address to register collector") + rootCmd.Flags().IntVarP(&CollectorConfigs.DatasystemPort, "datasystem_port", "", 0, + "datasystem port to publish stream logs") + rootCmd.Flags().StringVarP(&CollectorConfigs.LogRoot, "log_root", "", "", "the default root path of all logs") + rootCmd.Flags().StringVarP(&CollectorConfigs.UserLogPath, "user_log_path", "", "", + "optional; specified only if user log is in other directory than log root path") + + rootCmd.Flags().StringSliceVarP(&CollectorConfigs.EtcdConfig.Servers, "etcd_config_servers", "", []string{}, + "etcd server addresses") + rootCmd.Flags().StringVarP(&CollectorConfigs.EtcdConfig.User, "etcd_config_user", "", "", + "etcd config about user") + rootCmd.Flags().StringVarP(&CollectorConfigs.EtcdConfig.Password, "etcd_config_password", "", "", + "etcd config about password") + rootCmd.Flags().BoolVarP(&CollectorConfigs.EtcdConfig.SslEnable, "etcd_config_ssl_enable", "", false, + "etcd config about ssl_enable") + rootCmd.Flags().StringVarP(&CollectorConfigs.EtcdConfig.AuthType, "etcd_config_auth_type", "", "", + "etcd config about auth_type") + rootCmd.Flags().BoolVarP(&CollectorConfigs.EtcdConfig.UseSecret, "etcd_config_use_secret", "", false, + "etcd config about use_secret") + rootCmd.Flags().StringVarP(&CollectorConfigs.EtcdConfig.SecretName, "etcd_config_secret_name", "", "", + "etcd config about secret_name") + rootCmd.Flags().IntVarP(&CollectorConfigs.EtcdConfig.LimitRate, "etcd_config_limit_rate", "", 0, + "etcd config about limit_rate") + rootCmd.Flags().IntVarP(&CollectorConfigs.EtcdConfig.LimitBurst, "etcd_config_limit_burst", "", 0, + "etcd config about limit_burst") + rootCmd.Flags().IntVarP(&CollectorConfigs.EtcdConfig.LimitTimeout, "etcd_config_limit_timeout", "", 0, + "etcd config about limit_timeout") + rootCmd.Flags().StringVarP(&CollectorConfigs.EtcdConfig.CaFile, "etcd_config_ca_file", "", "", + "etcd config about ca_file") + rootCmd.Flags().StringVarP(&CollectorConfigs.EtcdConfig.CertFile, "etcd_config_cert_file", "", "", + "etcd config about cert_file") + rootCmd.Flags().StringVarP(&CollectorConfigs.EtcdConfig.KeyFile, "etcd_config_key_file", "", "", + "etcd config about key_file") + rootCmd.Flags().StringVarP(&CollectorConfigs.EtcdConfig.PassphraseFile, "etcd_config_passphrase_file", "", "", + "etcd config about passphrase_file") +} + +func checkPath(path string) error { + if !filepath.IsAbs(path) { + return fmt.Errorf("%s should be absolute path", path) + } + if _, err := os.Stat(path); err != nil { + return fmt.Errorf("no such directory: %s", path) + } + return nil +} + +func validateCmdArgs() error { + matched, merr := regexp.MatchString(collectorIDPattern, CollectorConfigs.CollectorID) + if merr != nil { + return fmt.Errorf("collector ID failed to match regex: %s, collector id: %s", collectorIDPattern, + CollectorConfigs.CollectorID) + } + if !matched { + return fmt.Errorf( + "collector ID %s does not match %s; please check characters and length. The valid length range is [1, 256)", + CollectorConfigs.CollectorID, collectorIDPattern) + } + if net.ParseIP(CollectorConfigs.IP) == nil { + return fmt.Errorf("collector IP %s is invalid", CollectorConfigs.IP) + } + if p, err := strconv.Atoi(CollectorConfigs.Port); err != nil || p < portLowerBound || p > portUpperBound { + return fmt.Errorf("collector port %s is invalid", CollectorConfigs.Port) + } + CollectorConfigs.Address = CollectorConfigs.IP + ":" + CollectorConfigs.Port + + host, port, err := net.SplitHostPort(CollectorConfigs.ManagerAddress) + if err != nil { + return fmt.Errorf("manager address %s is invalid, error: %s", CollectorConfigs.ManagerAddress, err) + } + if net.ParseIP(host) == nil { + return fmt.Errorf("manager address %s has invalid IP", CollectorConfigs.ManagerAddress) + } + if p, err := strconv.Atoi(port); err != nil || p < portLowerBound || p > portUpperBound { + return fmt.Errorf("manager address %s has invalid port", CollectorConfigs.ManagerAddress) + } + + if CollectorConfigs.DatasystemPort < portLowerBound || CollectorConfigs.DatasystemPort > portUpperBound { + return fmt.Errorf("datasystem has invalid port %d", CollectorConfigs.DatasystemPort) + } + + if err := checkPath(CollectorConfigs.LogRoot); err != nil { + return err + } + if CollectorConfigs.UserLogPath == "" { + CollectorConfigs.UserLogPath = CollectorConfigs.LogRoot + log.GetLogger().Infof("user log path inherits log root %s", CollectorConfigs.UserLogPath) + } else { + if err := checkPath(CollectorConfigs.UserLogPath); err != nil { + return err + } + } + return nil +} diff --git a/yuanrong/pkg/collector/common/flags_test.go b/yuanrong/pkg/collector/common/flags_test.go new file mode 100644 index 0000000..b18aea0 --- /dev/null +++ b/yuanrong/pkg/collector/common/flags_test.go @@ -0,0 +1,141 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCommandLineArguments(t *testing.T) { + testCases := []struct { + args []string + expectedID string + expectedIP string + expectedPort string + expectedManagerAddress string + expectedDatasystemPort int + expectedLogRoot string + expectedUserLogPath string + }{ + { + args: []string{"--collect_id", "12345", "--ip", "192.168.1.1", "--port", "8080", "--manager_address", "10.10.10.10:5678", "--datasystem_port", "9090", "--log_root", "/var/log", "--user_log_path", "/tmp"}, + expectedID: "12345", + expectedIP: "192.168.1.1", + expectedPort: "8080", + expectedManagerAddress: "10.10.10.10:5678", + expectedDatasystemPort: 9090, + expectedLogRoot: "/var/log", + expectedUserLogPath: "/tmp", + }, + { + args: []string{"--collect_id", "abcde12345", "--ip", "192.168.1.1", "--port", "8080", "--manager_address", "10.10.10.10:5678", "--datasystem_port", "9090", "--log_root", "/var/log"}, + expectedID: "abcde12345", + expectedIP: "192.168.1.1", + expectedPort: "8080", + expectedManagerAddress: "10.10.10.10:5678", + expectedDatasystemPort: 9090, + expectedLogRoot: "/var/log", + expectedUserLogPath: "/var/log", + }, + } + + for _, tc := range testCases { + CollectorConfigs.UserLogPath = "" + rootCmd.SetArgs(tc.args) + + if err := rootCmd.Execute(); err != nil { + t.Errorf("Failed to execute command: %v", err) + } + assert.Equal(t, validateCmdArgs(), nil, "failed to validate cmd args") + + assert.Equal(t, tc.expectedID, CollectorConfigs.CollectorID, "Collector ID mismatch") + assert.Equal(t, tc.expectedIP, CollectorConfigs.IP, "IP mismatch") + assert.Equal(t, tc.expectedPort, CollectorConfigs.Port, "Port mismatch") + assert.Equal(t, tc.expectedManagerAddress, CollectorConfigs.ManagerAddress, "Manager Address mismatch") + assert.Equal(t, tc.expectedDatasystemPort, CollectorConfigs.DatasystemPort, "Datasystem Port mismatch") + assert.Equal(t, tc.expectedLogRoot, CollectorConfigs.LogRoot, "Log Root mismatch") + assert.Equal(t, tc.expectedUserLogPath, CollectorConfigs.UserLogPath, "User Log Path mismatch") + } +} + +func TestInvalidCommandLineArguments(t *testing.T) { + testCases := []struct { + args []string + errorMsg string + }{ + { + args: []string{"--collect_id", ""}, + errorMsg: "please check characters and length. The valid length range is", + }, + { + args: []string{"--collect_id", "[]"}, + errorMsg: "please check characters and length. The valid length range is", + }, + { + args: []string{"--collect_id", "01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"}, + errorMsg: "please check characters and length. The valid length range is", + }, + { + args: []string{"--collect_id", "1", "--ip", "255.255.255.256"}, + errorMsg: "collector IP 255.255.255.256 is invalid", + }, + { + args: []string{"--collect_id", "1", "--ip", "10.10.10.10", "--port", "65537"}, + errorMsg: "collector port 65537 is invalid", + }, + { + args: []string{"--collect_id", "1", "--ip", "10.10.10.10", "--port", "8080", "--manager_address", "abcde"}, + errorMsg: "manager address abcde is invalid", + }, + { + args: []string{"--collect_id", "1", "--ip", "10.10.10.10", "--port", "8080", "--manager_address", "abc:def"}, + errorMsg: "manager address abc:def has invalid IP", + }, + { + args: []string{"--collect_id", "1", "--ip", "10.10.10.10", "--port", "8080", "--manager_address", "10.10.10.10:-3"}, + errorMsg: "manager address 10.10.10.10:-3 has invalid port", + }, + { + args: []string{"--collect_id", "1", "--ip", "10.10.10.10", "--port", "8080", "--manager_address", "10.10.10.10:90", "--datasystem_port", "-2"}, + errorMsg: "datasystem has invalid port -2", + }, + { + args: []string{"--collect_id", "1", "--ip", "10.10.10.10", "--port", "8080", "--manager_address", "10.10.10.10:90", "--datasystem_port", "91", "--log_root", "./here"}, + errorMsg: "./here should be absolute path", + }, + { + args: []string{"--collect_id", "1", "--ip", "10.10.10.10", "--port", "8080", "--manager_address", "10.10.10.10:90", "--datasystem_port", "91", "--log_root", "/var/log", "--user_log_path", "./here"}, + errorMsg: "./here should be absolute path", + }, + { + args: []string{"--collect_id", "1", "--ip", "10.10.10.10", "--port", "8080", "--manager_address", "10.10.10.10:90", "--datasystem_port", "91", "--log_root", "/var/log", "--user_log_path", "/here"}, + errorMsg: "no such directory: /here", + }, + } + + for _, tc := range testCases { + rootCmd.SetArgs(tc.args) + + if err := rootCmd.Execute(); err != nil { + t.Errorf("Failed to execute command: %v", err) + } + + assert.Contains(t, validateCmdArgs().Error(), tc.errorMsg, "unexpected error message") + } +} diff --git a/yuanrong/pkg/collector/logcollector/common.go b/yuanrong/pkg/collector/logcollector/common.go new file mode 100644 index 0000000..1a31fbd --- /dev/null +++ b/yuanrong/pkg/collector/logcollector/common.go @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logcollector collects and publish logs +package logcollector + +import ( + "fmt" + "path/filepath" + "sync" + "time" + + "yuanrong/pkg/collector/common" + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + realRetryInterval = 100 * time.Millisecond + realMaxRetryTimes = 30 + + maxTimeForUnfinishedLine time.Duration = 1 * time.Second +) + +type constantInterface interface { + GetRetryInterval() time.Duration + GetMaxRetryTimes() int + GetMaxTimeForUnfinishedLine() time.Duration +} + +type constImpl struct{} + +func (c *constImpl) GetRetryInterval() time.Duration { + return realRetryInterval +} + +func (c *constImpl) GetMaxRetryTimes() int { + return realMaxRetryTimes +} + +func (c *constImpl) GetMaxTimeForUnfinishedLine() time.Duration { + return maxTimeForUnfinishedLine +} + +var constant constantInterface = &constImpl{} + +var ( + // LogServiceClient is grpc log service interface + LogServiceClient logservice.LogManagerServiceClient + logServiceClientOnce sync.Once +) + +// streamControlChans maps streamName to stream done channel for each file +var streamControlChans = struct { + sync.Mutex + hashmap map[string]map[string]chan bool +}{ + hashmap: make(map[string]map[string]chan bool), +} + +// GetAbsoluteFilePath returns absolute path based on the target type +func GetAbsoluteFilePath(item *logservice.LogItem) (string, error) { + switch item.Target { + case logservice.LogTarget_USER_STD: + if filepath.IsAbs(item.Filename) { + return item.Filename, nil + } + return filepath.Join(common.CollectorConfigs.UserLogPath, item.Filename), nil + default: + break + } + log.GetLogger().Warnf("undefined log item target: %v", item.Target) + return "", fmt.Errorf("undefined log item target: %v", item.Target) +} + +// GetLogServiceClient returns log service client +func GetLogServiceClient() logservice.LogManagerServiceClient { + logServiceClientOnce.Do(func() { + conn := common.GetConnection() + if conn == nil { + log.GetLogger().Errorf("failed to get connection to log manager") + return + } + LogServiceClient = logservice.NewLogManagerServiceClient(conn) + }) + + return LogServiceClient +} diff --git a/yuanrong/pkg/collector/logcollector/common_test.go b/yuanrong/pkg/collector/logcollector/common_test.go new file mode 100644 index 0000000..638bfef --- /dev/null +++ b/yuanrong/pkg/collector/logcollector/common_test.go @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logcollector + +import ( + "net" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + + "yuanrong/pkg/collector/common" + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" +) + +type constTestImpl struct{} + +func (c *constTestImpl) GetRetryInterval() time.Duration { + return 100 * time.Microsecond +} + +func (c *constTestImpl) GetMaxRetryTimes() int { + return 3 +} + +func (c *constTestImpl) GetMaxTimeForUnfinishedLine() time.Duration { + return 1 * time.Millisecond +} + +func TestGetAbsoluteFilePath(t *testing.T) { + { + item := &logservice.LogItem{ + Target: logservice.LogTarget_USER_STD, + Filename: "example.log", + } + expectedPath := filepath.Join(common.CollectorConfigs.UserLogPath, item.Filename) + path, err := GetAbsoluteFilePath(item) + assert.NoError(t, err) + assert.Equal(t, expectedPath, path) + } + + { + item := &logservice.LogItem{ + Target: logservice.LogTarget_LIB_RUNTIME, + Filename: "example.log", + } + path, err := GetAbsoluteFilePath(item) + assert.Error(t, err) + assert.Empty(t, path) + } +} + +func TestGetLogServiceClient(t *testing.T) { + logServiceClientOnce = sync.Once{} + LogServiceClient = nil + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer lis.Close() + + go func() { + grpc.NewServer().Serve(lis) + }() + + common.CollectorConfigs.ManagerAddress = lis.Addr().String() + + client := GetLogServiceClient() + + if client == nil { + t.Errorf("Expected a valid log manager grpc client, but got nil") + } +} diff --git a/yuanrong/pkg/collector/logcollector/log_reporter.go b/yuanrong/pkg/collector/logcollector/log_reporter.go new file mode 100644 index 0000000..5379f24 --- /dev/null +++ b/yuanrong/pkg/collector/logcollector/log_reporter.go @@ -0,0 +1,225 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logcollector + +import ( + "context" + "fmt" + "io/fs" + "path/filepath" + "regexp" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + + "yuanrong/pkg/collector/common" + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + maxNewFileChanSize = 100 +) + +// reportedLogFiles will not remove elements even if the target process exits +var reportedLogFiles = struct { + sync.Mutex + hashmap map[string]struct{} +}{ + hashmap: make(map[string]struct{}), +} + +var logFileRegexMap = map[logservice.LogTarget]*regexp.Regexp{ + logservice.LogTarget_USER_STD: regexp.MustCompile(`^runtime(-[a-z0-9]+)+.(out|err)`), +} + +var runtimeRegex = regexp.MustCompile(`runtime(-[a-z0-9]+)+`) + +func tryReportLog(name string) bool { + if item, ok := parseLogFileName(name); ok { + reportedLogFiles.Lock() + if _, exists := reportedLogFiles.hashmap[item.Filename]; exists { + log.GetLogger().Debugf("%s is already reported", item.Filename) + reportedLogFiles.Unlock() + return false + } + + log.GetLogger().Infof("find log file to report: %s", item.Filename) + log.GetLogger().Debugf("log item details to report: %+v", item) + reportedLogFiles.hashmap[item.Filename] = struct{}{} + reportedLogFiles.Unlock() + if err := reportLog(item); err != nil { + log.GetLogger().Errorf("failed to report log file %s, error: %v", item.Filename, err) + return false + } + return true + } + return false +} + +func getRuntimeID(filename string) string { + return runtimeRegex.FindString(filename) +} + +func parseLogFileName(filePath string) (*logservice.LogItem, bool) { + filename := filepath.Base(filePath) + for target, regex := range logFileRegexMap { + if regex.MatchString(filename) { + runtimeID := getRuntimeID(filename) + log.GetLogger().Debugf("%s matches %v, runtimeID: %s", filePath, target, runtimeID) + return &logservice.LogItem{ + Filename: filePath, + CollectorID: common.CollectorConfigs.CollectorID, + Target: target, + RuntimeID: runtimeID, + }, true + } + } + return nil, false +} + +func reportLog(item *logservice.LogItem) error { + retryInterval := constant.GetRetryInterval() + maxRetryTimes := constant.GetMaxRetryTimes() + client := GetLogServiceClient() + if client == nil { + log.GetLogger().Errorf("failed to get log service client") + return fmt.Errorf("failed to get log service client") + } + ctx, cancel := context.WithTimeout(context.Background(), common.DefaultGrpcTimeoutS) + defer cancel() + + for i := 0; i < maxRetryTimes; i++ { + log.GetLogger().Infof("start to report log %s, attempt: %d", item.Filename, i) + response, err := client.ReportLog(ctx, &logservice.ReportLogRequest{ + Items: []*logservice.LogItem{item}, + }) + if err != nil { + log.GetLogger().Errorf("failed to report log %s, error: %v", item.Filename, err) + time.Sleep(retryInterval) + continue + } + if response.Code != 0 { + log.GetLogger().Errorf("failed to report log %s, error: %d, message: %s", item.Filename, response.Code, + response.Message) + time.Sleep(retryInterval) + continue + } + log.GetLogger().Infof("success to report log %s", item.Filename) + return nil + } + return fmt.Errorf("failed to report log: exceeds max retry time: %d", maxRetryTimes) +} + +func handleNewFile(watcher *fsnotify.Watcher, newFileChan chan string, directory string) { + for { + select { + case file, ok := <-newFileChan: + if !ok { + log.GetLogger().Warnf("new file event chan is closed") + return + } + log.GetLogger().Debugf("find new file %s", file) + if relPath, err := filepath.Rel(directory, file); err == nil { + tryReportLog(relPath) + } + case err, ok := <-watcher.Errors: + if !ok { + log.GetLogger().Warnf("new file event chan is closed") + return + } + log.GetLogger().Warnf("new file event chan error: %v", err) + } + } +} + +func monitorNewFile(watcher *fsnotify.Watcher, newFileChan chan string, directory string) { + defer close(newFileChan) + defer watcher.Close() + for { + select { + case event, ok := <-watcher.Events: + if !ok { + log.GetLogger().Warnf("watch event chan for %s is closed ", directory) + return + } + if event.Op&fsnotify.Create == fsnotify.Create { + log.GetLogger().Debugf("find a new file is created: %s", event.Name) + newFileChan <- event.Name + } + case err, ok := <-watcher.Errors: + if !ok { + log.GetLogger().Warnf("watch event chan for %s is closed ", directory) + return + } + log.GetLogger().Warnf("watch event for %s error: %v", directory, err) + } + } +} + +// createLogReporter starts two go routines that never end +func createLogReporter(directory string) error { + log.GetLogger().Infof("create log report for %s", directory) + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.GetLogger().Errorf("failed to create file watcher, error: %v", err) + return err + } + + err = watcher.Add(directory) + if err != nil { + log.GetLogger().Errorf("failed to create file watcher for %s, error: %v", directory, err) + return err + } + + newFileChan := make(chan string, maxNewFileChanSize) + go handleNewFile(watcher, newFileChan, directory) + go monitorNewFile(watcher, newFileChan, directory) + return nil +} + +func scanUserLog(directory string) { + err := filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error { + if err != nil { + log.GetLogger().Warnf("failed to access file %s under %s", path, directory) + return err + } + relPath, err := filepath.Rel(directory, path) + if err != nil { + return err + } + log.GetLogger().Debugf("find file %s under %s", relPath, directory) + tryReportLog(relPath) + return nil + }) + + if err != nil { + log.GetLogger().Errorf("failed to iterate files under %s, error: %v", directory, err) + } +} + +// StartLogReporter starts file watcher to report logs +func StartLogReporter() { + err := createLogReporter(common.CollectorConfigs.UserLogPath) + if err != nil { + log.GetLogger().Errorf("failed to create log reporter for %s, error: %s", common.CollectorConfigs.UserLogPath, + err) + return + } + scanUserLog(common.CollectorConfigs.UserLogPath) +} diff --git a/yuanrong/pkg/collector/logcollector/log_reporter_test.go b/yuanrong/pkg/collector/logcollector/log_reporter_test.go new file mode 100644 index 0000000..6d0f876 --- /dev/null +++ b/yuanrong/pkg/collector/logcollector/log_reporter_test.go @@ -0,0 +1,300 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logcollector + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + + "yuanrong/pkg/collector/common" + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" +) + +func TestGetRuntimeID(t *testing.T) { + tests := []struct { + filename string + want string + }{ + {"runtime-12345.out", "runtime-12345"}, + {"runtime-123e4567-e89b-12d3-a456-426614174000.err", "runtime-123e4567-e89b-12d3-a456-426614174000"}, + {"no-runtime.log", ""}, + } + + for _, tt := range tests { + t.Run(tt.filename, func(t *testing.T) { + got := getRuntimeID(tt.filename) + if got != tt.want { + t.Errorf("getRuntimeID(%v) = %v, want %v", tt.filename, got, tt.want) + } + }) + } +} + +func TestParseLogFileName(t *testing.T) { + tests := []struct { + filePath string + want *logservice.LogItem + wantOk bool + }{ + { + filePath: "path/to/runtime-abc123-456.out", + want: &logservice.LogItem{ + Filename: "path/to/runtime-abc123-456.out", + CollectorID: common.CollectorConfigs.CollectorID, + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-abc123-456", + }, + wantOk: true, + }, + { + filePath: "/path/to/runtime-1c1.err", + want: &logservice.LogItem{ + Filename: "/path/to/runtime-1c1.err", + CollectorID: common.CollectorConfigs.CollectorID, + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-1c1", + }, + wantOk: true, + }, + { + filePath: "path/to/unknown.log", + want: nil, + wantOk: false, + }, + { + filePath: "/path/to/function-master.log", + want: nil, + wantOk: false, + }, + } + + for _, tt := range tests { + t.Run(tt.filePath, func(t *testing.T) { + got, ok := parseLogFileName(tt.filePath) + assert.Equal(t, tt.wantOk, ok) + if ok { + assert.Equal(t, tt.want.Filename, got.Filename) + assert.Equal(t, tt.want.CollectorID, got.CollectorID) + assert.Equal(t, tt.want.Target, got.Target) + assert.Equal(t, tt.want.RuntimeID, got.RuntimeID) + } + }) + } +} + +type MockLogServiceClient struct { + mock.Mock +} + +func (m *MockLogServiceClient) ReportLog(ctx context.Context, req *logservice.ReportLogRequest, opts ...grpc.CallOption) (*logservice.ReportLogResponse, error) { + args := m.Called(ctx, req, opts) + return args.Get(0).(*logservice.ReportLogResponse), args.Error(1) +} + +func (m *MockLogServiceClient) Register(ctx context.Context, req *logservice.RegisterRequest, opts ...grpc.CallOption) (*logservice.RegisterResponse, error) { + args := m.Called(ctx, req, opts) + return args.Get(0).(*logservice.RegisterResponse), args.Error(1) +} + +func TestReportLog(t *testing.T) { + constant = &constTestImpl{} + tests := []struct { + mockResponse *logservice.ReportLogResponse + mockError error + expectedError error + }{ + { + mockResponse: &logservice.ReportLogResponse{Code: 0, Message: "success"}, + mockError: nil, + expectedError: nil, + }, + { + mockResponse: nil, + mockError: fmt.Errorf("network error"), + expectedError: fmt.Errorf("failed to report log: exceeds max retry time: %d", constant.GetMaxRetryTimes()), + }, + { + mockResponse: &logservice.ReportLogResponse{Code: -1, Message: "failure"}, + mockError: nil, + expectedError: fmt.Errorf("failed to report log: exceeds max retry time: %d", constant.GetMaxRetryTimes()), + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + mockClient := new(MockLogServiceClient) + LogServiceClient = mockClient + mockClient.On("ReportLog", mock.Anything, mock.Anything, mock.Anything).Return(tt.mockResponse, tt.mockError) + err := reportLog(&logservice.LogItem{}) + mockClient.AssertExpectations(t) + assert.Equal(t, tt.expectedError, err) + }) + } +} + +func TestFailedReportLog(t *testing.T) { + LogServiceClient = nil + + tests := []struct { + mockResponse *logservice.ReportLogResponse + mockError error + expectedError error + }{ + { + mockResponse: nil, + mockError: nil, + expectedError: fmt.Errorf("failed to get log service client"), + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + item := &logservice.LogItem{} + err := reportLog(item) + assert.Equal(t, tt.expectedError, err) + }) + } +} + +func TestTryReportLog(t *testing.T) { + constant = &constTestImpl{} + reportedLogFiles.Lock() + reportedLogFiles.hashmap["runtime-123.out"] = struct{}{} + reportedLogFiles.Unlock() + tests := []struct { + name string + mockResponse *logservice.ReportLogResponse + mockError error + result bool + }{ + { + name: "runtime-123e4567-e89b-12d3-a456-426614174000.err", + mockResponse: &logservice.ReportLogResponse{Code: 0, Message: "success"}, + mockError: nil, + result: true, + }, + { + name: "runtime-456.out", + mockResponse: &logservice.ReportLogResponse{Code: -1, Message: "failure"}, + mockError: nil, + result: false, + }, + { + name: "runtime-123.out", + mockResponse: &logservice.ReportLogResponse{Code: 0, Message: "success"}, + mockError: nil, + result: false, + }, + { + name: "function-master.out", + mockResponse: &logservice.ReportLogResponse{Code: 0, Message: "success"}, + mockError: nil, + result: false, + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + mockClient := new(MockLogServiceClient) + LogServiceClient = mockClient + mockClient.On("ReportLog", mock.Anything, mock.Anything, mock.Anything).Return(tt.mockResponse, tt.mockError) + res := tryReportLog(tt.name) + assert.Equal(t, tt.result, res) + }) + } +} + +func prepareDir(t *testing.T) string { + tempDir := t.TempDir() + subDir := filepath.Join(tempDir, "subdir") + os.Mkdir(subDir, 0755) + return subDir +} + +func prepareFiles(subDir string) { + reportedLogFiles.Lock() + reportedLogFiles.hashmap = make(map[string]struct{}) + reportedLogFiles.Unlock() + file1 := filepath.Join(subDir, "runtime-456-abc.err") + file2 := filepath.Join(subDir, "runtime-abc.out") + os.WriteFile(file1, []byte("log content"), 0644) + os.WriteFile(file2, []byte("log content"), 0644) +} + +func prepareClient() *MockLogServiceClient { + mockClient := new(MockLogServiceClient) + LogServiceClient = mockClient + mockResponse := &logservice.ReportLogResponse{Code: 0, Message: "success"} + mockClient.On("ReportLog", mock.Anything, mock.Anything, mock.Anything).Return(mockResponse, nil) + return mockClient +} + +func checkReportMsg(t *testing.T, mockClient *MockLogServiceClient) { + { + r := mockClient.Calls[0].Arguments[1].(*logservice.ReportLogRequest) + assert.Equal(t, len(r.Items), 1) + assert.Equal(t, r.Items[0].Filename, "runtime-456-abc.err") + assert.Equal(t, r.Items[0].CollectorID, common.CollectorConfigs.CollectorID) + assert.Equal(t, r.Items[0].Target, logservice.LogTarget_USER_STD) + assert.Equal(t, r.Items[0].RuntimeID, "runtime-456-abc") + } + { + r := mockClient.Calls[1].Arguments[1].(*logservice.ReportLogRequest) + assert.Equal(t, len(r.Items), 1) + assert.Equal(t, r.Items[0].Filename, "runtime-abc.out") + assert.Equal(t, r.Items[0].CollectorID, common.CollectorConfigs.CollectorID) + assert.Equal(t, r.Items[0].Target, logservice.LogTarget_USER_STD) + assert.Equal(t, r.Items[0].RuntimeID, "runtime-abc") + } + + mockClient.AssertExpectations(t) +} + +func TestCreateLogReporter(t *testing.T) { + mockClient := prepareClient() + subDir := prepareDir(t) + createLogReporter(subDir) + time.Sleep(10 * time.Millisecond) + prepareFiles(subDir) + time.Sleep(10 * time.Millisecond) + checkReportMsg(t, mockClient) +} + +func TestScanUserLog(t *testing.T) { + mockClient := prepareClient() + subDir := prepareDir(t) + prepareFiles(subDir) + scanUserLog(subDir) + checkReportMsg(t, mockClient) +} + +func TestStartLogReporter(t *testing.T) { + mockClient := prepareClient() + subDir := prepareDir(t) + prepareFiles(subDir) + common.CollectorConfigs.UserLogPath = subDir + StartLogReporter() + checkReportMsg(t, mockClient) +} diff --git a/yuanrong/pkg/collector/logcollector/register.go b/yuanrong/pkg/collector/logcollector/register.go new file mode 100644 index 0000000..b2892f0 --- /dev/null +++ b/yuanrong/pkg/collector/logcollector/register.go @@ -0,0 +1,102 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logcollector + +import ( + "context" + "errors" + "fmt" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/collector/common" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + // wait in total 60s + waitDashboardRegisterMaxTimes = 60 + waitDashboardRegisterInterval = 1 * time.Second + + dashboardRegisterKey = "/yr/dashboard" + etcdGetTimeout = 30 * time.Second +) + +// GetManagerAddress from etcd +func GetManagerAddress() (string, error) { + rsp, err := etcd3.GetRouterEtcdClient().Get(etcd3.CreateEtcdCtxInfoWithTimeout(context.TODO(), etcdGetTimeout), + dashboardRegisterKey, clientv3.WithPrefix()) + if err != nil { + return "", err + } + if len(rsp.Kvs) == 0 { + return "", errors.New("failed to get dashboard address, get 0 responses") + } + return string(rsp.Kvs[0].Value), nil +} + +// Register itself to manager +func Register() error { + getDashboardAddrRetryLeftCnt := waitDashboardRegisterMaxTimes + for ; getDashboardAddrRetryLeftCnt > 0; getDashboardAddrRetryLeftCnt -= 1 { + addr, err := GetManagerAddress() + if err != nil { + log.GetLogger().Warnf("failed to get master address: %s, try again later", err.Error()) + time.Sleep(waitDashboardRegisterInterval) + continue + } + common.CollectorConfigs.ManagerAddress = addr + break + } + if getDashboardAddrRetryLeftCnt == 0 { + return errors.New("failed to get dashboard address from etcd") + } + + retryInterval := constant.GetRetryInterval() + maxRetryTimes := constant.GetMaxRetryTimes() + client := GetLogServiceClient() + if client == nil { + log.GetLogger().Errorf("failed to get log service client") + return fmt.Errorf("failed to get log service client") + } + ctx, cancel := context.WithTimeout(context.Background(), common.DefaultGrpcTimeoutS) + defer cancel() + + for i := 0; i < maxRetryTimes; i++ { + log.GetLogger().Infof("start to register %s to %s", common.CollectorConfigs.CollectorID, + common.CollectorConfigs.ManagerAddress) + response, err := client.Register(ctx, &logservice.RegisterRequest{ + CollectorID: common.CollectorConfigs.CollectorID, + Address: common.CollectorConfigs.Address, + }) + if err != nil { + log.GetLogger().Errorf("failed to send register, error: %v", err) + time.Sleep(retryInterval) + continue + } + if response.Code != 0 { + log.GetLogger().Errorf("failed to register, error: %d, message: %s", response.Code, response.Message) + time.Sleep(retryInterval) + continue + } + return nil + } + return fmt.Errorf("failed to register: exceeds max retry time: %d", maxRetryTimes) +} diff --git a/yuanrong/pkg/collector/logcollector/register_test.go b/yuanrong/pkg/collector/logcollector/register_test.go new file mode 100644 index 0000000..2283e17 --- /dev/null +++ b/yuanrong/pkg/collector/logcollector/register_test.go @@ -0,0 +1,70 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logcollector + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" +) + +func TestRegister(t *testing.T) { + constant = &constTestImpl{} + tests := []struct { + mockResponse *logservice.RegisterResponse + mockError error + expectedError error + }{ + { + mockResponse: &logservice.RegisterResponse{Code: 0, Message: "success"}, + mockError: nil, + expectedError: nil, + }, + { + mockResponse: &logservice.RegisterResponse{Code: -1, Message: "failure"}, + mockError: nil, + expectedError: fmt.Errorf("failed to register: exceeds max retry time: %d", constant.GetMaxRetryTimes()), + }, + { + mockResponse: &logservice.RegisterResponse{Code: -1, Message: "failure"}, + mockError: fmt.Errorf("failed"), + expectedError: fmt.Errorf("failed to register: exceeds max retry time: %d", constant.GetMaxRetryTimes()), + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + mockClient := new(MockLogServiceClient) + LogServiceClient = mockClient + mockClient.On("Register", mock.Anything, mock.Anything, mock.Anything).Return(tt.mockResponse, tt.mockError) + err := Register() + mockClient.AssertExpectations(t) + assert.Equal(t, tt.expectedError, err) + }) + } +} + +func TestFailedClientRegister(t *testing.T) { + LogServiceClient = nil + err := Register() + expectedError := fmt.Errorf("failed to get log service client") + assert.Equal(t, expectedError, err) +} diff --git a/yuanrong/pkg/collector/logcollector/service.go b/yuanrong/pkg/collector/logcollector/service.go new file mode 100644 index 0000000..7090b92 --- /dev/null +++ b/yuanrong/pkg/collector/logcollector/service.go @@ -0,0 +1,174 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logcollector + +import ( + "bufio" + "context" + "net" + "os" + "strings" + + "google.golang.org/grpc" + + "yuanrong/pkg/collector/common" + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + mb uint32 = 1024 * 1024 +) + +var ( + readLogChunkThreshold uint32 = 1024 * 1024 + redundantBytes uint32 = 1024 +) + +type server struct { + logservice.UnimplementedLogCollectorServiceServer +} + +// ReadLog deal with grpc read log request +func (s *server) ReadLog(request *logservice.ReadLogRequest, + stream logservice.LogCollectorService_ReadLogServer) error { + filename, err := GetAbsoluteFilePath(request.Item) + if err != nil { + stream.Send(&logservice.ReadLogResponse{ + Code: -1, + Message: err.Error(), + }) + return err + } + if _, err := os.Stat(filename); err != nil { + stream.Send(&logservice.ReadLogResponse{ + Code: -1, + Message: err.Error(), + }) + return err + } + return s.sendStreamResponse(request, stream, filename) +} + +func appendToStringBuilder(builder *strings.Builder, line string) error { + if _, err := builder.WriteString(line); err != nil { + return err + } + if _, err := builder.WriteString("\n"); err != nil { + return err + } + return nil +} + +func (s *server) sendStreamResponse(request *logservice.ReadLogRequest, + stream logservice.LogCollectorService_ReadLogServer, absoluteFilename string) error { + file, err := os.OpenFile(absoluteFilename, os.O_RDONLY, 0) + if err != nil { + log.GetLogger().Errorf("failed to open file: %s, error: %v", absoluteFilename, err) + return err + } + defer file.Close() + + scanner := bufio.NewScanner(file) + builder := &strings.Builder{} + builder.Grow(int(readLogChunkThreshold + redundantBytes)) + var totalBytes uint32 = 0 + var lineCount uint32 = 0 + for scanner.Scan() { + if lineCount < request.StartLine { + lineCount++ + continue + } + if lineCount >= request.EndLine { + break + } + lineCount++ + line := scanner.Text() + if err := appendToStringBuilder(builder, line); err != nil { + log.GetLogger().Errorf("failed to write line to buffer, error: %v", err) + return err + } + totalBytes += uint32(len(line) + 1) + if totalBytes >= readLogChunkThreshold { + log.GetLogger().Infof("total bytes %d reaches threshold %d", totalBytes, readLogChunkThreshold) + if err := s.send(stream, request.Item.Filename, totalBytes, builder); err != nil { + return err + } + builder.Reset() + builder.Grow(int(readLogChunkThreshold + redundantBytes)) + totalBytes = 0 + } + } + if err := scanner.Err(); err != nil { + stream.Send(&logservice.ReadLogResponse{Code: -1, Message: err.Error()}) + return err + } + if builder.Len() > 0 { + if err := s.send(stream, request.Item.Filename, totalBytes, builder); err != nil { + return err + } + } + return nil +} + +func (s *server) send(stream logservice.LogCollectorService_ReadLogServer, filename string, totalBytes uint32, + builder *strings.Builder) error { + log.GetLogger().Infof("send read log response for %s, size: %f MB", filename, float64(totalBytes)/float64(mb)) + err := stream.Send(&logservice.ReadLogResponse{ + Code: 0, + Content: []byte(builder.String()), + }) + if err != nil { + log.GetLogger().Errorf("failed to send read log response for %s, size: %f MB", filename, + float64(totalBytes)/float64(mb)) + return err + } + return nil +} + +// QueryLogStream - +func (s *server) QueryLogStream(ctx context.Context, request *logservice.QueryLogStreamRequest) ( + *logservice.QueryLogStreamResponse, error) { + streamControlChans.Lock() + defer streamControlChans.Unlock() + streams := make([]string, 0, len(streamControlChans.hashmap)) + for streamName := range streamControlChans.hashmap { + streams = append(streams, streamName) + } + return &logservice.QueryLogStreamResponse{Code: 0, Streams: streams}, nil +} + +// StartReadLogService starts grpc server and then set ready channel +func StartReadLogService(ready chan<- bool) error { + grpcServer := grpc.NewServer() + logservice.RegisterLogCollectorServiceServer(grpcServer, &server{}) + + lis, err := net.Listen("tcp", common.CollectorConfigs.Address) + if err != nil { + log.GetLogger().Errorf("failed to listen to address %s, error: %v", common.CollectorConfigs.Address, err) + return err + } + + ready <- true + log.GetLogger().Infof("start serve log service on address %s", common.CollectorConfigs.Address) + if err = grpcServer.Serve(lis); err != nil { + log.GetLogger().Errorf("failed to serve on address %s, error: %v", common.CollectorConfigs.Address, err) + return err + } + log.GetLogger().Infof("stop serve log service on address %s", common.CollectorConfigs.Address) + return nil +} diff --git a/yuanrong/pkg/collector/logcollector/service_test.go b/yuanrong/pkg/collector/logcollector/service_test.go new file mode 100644 index 0000000..06885d2 --- /dev/null +++ b/yuanrong/pkg/collector/logcollector/service_test.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logcollector + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + + "yuanrong/pkg/collector/common" + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" +) + +type MockLogCollectorService_ReadLogServer struct { + grpc.ServerStream + SentResponses []*logservice.ReadLogResponse +} + +func (m *MockLogCollectorService_ReadLogServer) Send(response *logservice.ReadLogResponse) error { + m.SentResponses = append(m.SentResponses, response) + return nil +} + +func TestFailedReadLog(t *testing.T) { + s := &server{} + stream := &MockLogCollectorService_ReadLogServer{} + + testCases := []struct { + request *logservice.ReadLogRequest + expectedCode int32 + errorMsg string + }{ + { + request: &logservice.ReadLogRequest{Item: &logservice.LogItem{Filename: "", Target: logservice.LogTarget_LIB_RUNTIME}}, + expectedCode: -1, + errorMsg: fmt.Sprintf("undefined log item target"), + }, + { + request: &logservice.ReadLogRequest{Item: &logservice.LogItem{Filename: "123.txt", Target: logservice.LogTarget_USER_STD}}, + expectedCode: -1, + errorMsg: fmt.Sprintf("no such file or directory"), + }, + } + + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + stream.SentResponses = nil + err := s.ReadLog(tc.request, stream) + assert.Contains(t, err.Error(), tc.errorMsg, "unexpected error message") + + if len(stream.SentResponses) > 0 { + actualResponse := stream.SentResponses[0] + assert.Equal(t, actualResponse.Code, tc.expectedCode, "unexpected error code") + } + }) + } +} + +func TestSuccessReadLog(t *testing.T) { + s := &server{} + stream := &MockLogCollectorService_ReadLogServer{} + stream.SentResponses = nil + tempDir := t.TempDir() + file := filepath.Join(tempDir, "runtime-111.err") + f, _ := os.OpenFile(file, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + for i := 0; i < 50; i++ { + content := fmt.Sprintf("timestamp: %d\n", i) + f.WriteString(content) + } + + request := &logservice.ReadLogRequest{ + Item: &logservice.LogItem{Filename: file, Target: logservice.LogTarget_USER_STD}, + StartLine: 1, + EndLine: 5, + } + s.ReadLog(request, stream) + assert.Equal(t, len(stream.SentResponses), 1, "wrong count of responses") + response := stream.SentResponses[0] + assert.Equal(t, response.Code, int32(0), "wrong code") + assert.Equal(t, response.Content, []byte("timestamp: 1\ntimestamp: 2\ntimestamp: 3\ntimestamp: 4\n"), "wrong content") +} + +func TestMultipleSuccessReadLog(t *testing.T) { + readLogChunkThreshold = 30 + redundantBytes = 20 + s := &server{} + stream := &MockLogCollectorService_ReadLogServer{} + stream.SentResponses = nil + tempDir := t.TempDir() + file := filepath.Join(tempDir, "runtime-112.err") + f, _ := os.OpenFile(file, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + for i := 0; i < 50; i++ { + content := fmt.Sprintf("0123456789\n") + f.WriteString(content) + } + + request := &logservice.ReadLogRequest{ + Item: &logservice.LogItem{Filename: file, Target: logservice.LogTarget_USER_STD}, + StartLine: 10, + EndLine: 40, + } + s.ReadLog(request, stream) + assert.Equal(t, 10, len(stream.SentResponses), "wrong count of responses") + for _, response := range stream.SentResponses { + assert.Equal(t, int32(0), response.Code, "wrong code") + assert.Equal(t, []byte("0123456789\n0123456789\n0123456789\n"), response.Content, "wrong content") + } +} + +func TestQueryLogStream(t *testing.T) { + for k := range streamControlChans.hashmap { + delete(streamControlChans.hashmap, k) + } + s := &server{} + AddDone("123", "456") + AddDone("567", "89") + ctx := context.Background() + request := &logservice.QueryLogStreamRequest{} + response, err := s.QueryLogStream(ctx, request) + assert.Equal(t, response.Code, int32(0)) + assert.Equal(t, 2, len(response.Streams)) + assert.Equal(t, err, nil) +} + +func TestFailedStartReadLogService(t *testing.T) { + common.CollectorConfigs.Address = "xxx" + ready := make(chan bool) + go func() { + <-ready + }() + err := StartReadLogService(ready) + assert.NotEqual(t, nil, err) +} diff --git a/yuanrong/pkg/common/constants/constant_test.go b/yuanrong/pkg/common/constants/constant_test.go new file mode 100644 index 0000000..1255fc8 --- /dev/null +++ b/yuanrong/pkg/common/constants/constant_test.go @@ -0,0 +1 @@ +package constants diff --git a/yuanrong/pkg/common/constants/constants.go b/yuanrong/pkg/common/constants/constants.go new file mode 100644 index 0000000..cd49c74 --- /dev/null +++ b/yuanrong/pkg/common/constants/constants.go @@ -0,0 +1,395 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constants implements vars of all +package constants + +import ( + "os" + "strconv" + "time" +) + +const ( + // ZoneKey zone key + ZoneKey = "KUBERNETES_IO_AVAILABLEZONE" + // ZoneNameLen define zone length + ZoneNameLen = 255 + // DefaultAZ default az + DefaultAZ = "defaultaz" + + // PodIPEnvKey define pod ip env key + PodIPEnvKey = "POD_IP" + + // HostNameEnvKey defines the hostname env key + HostNameEnvKey = "HOSTNAME" + + // NodeID defines the node name env key + NodeID = "NODE_ID" + + // HostIPEnvKey defines the host ip env key + HostIPEnvKey = "HOST_IP" + + // PodNamespaceEnvKey define pod namespace env key + PodNamespaceEnvKey = "POD_NAMESPACE" + + // ResourceLimitsMemory Memory limit, in bytes + ResourceLimitsMemory = "MEMORY_LIMIT_BYTES" + + // ResourceLimitsCPU CPU limit, in m(1/1000) + ResourceLimitsCPU = "CPU_LIMIT" + + // FuncBranchEnvKey is branch env key + FuncBranchEnvKey = "FUNC_BRANCH" + + // DataSystemBranchEnvKey is branch env key + DataSystemBranchEnvKey = "DATASYSTEM_CAPABILITY" + + // HTTPort busproxy httpserver listen port + HTTPort = "22423" + // GRPCPort busproxy gRPCserver listen port + GRPCPort = "22769" + // WorkerAgentPort is the listen port of worker agent grpc server + WorkerAgentPort = "22888" + // DataSystemPort is the port of data system + DataSystemPort = "31501" + // LocalSchedulerPort is the listen port string of local scheduler grpc server + LocalSchedulerPort = GRPCPort + // DomainSchedulerPort is the listen port of domain scheduler grpc server + DomainSchedulerPort = 22771 + // MaxPort maximum number of ports + MaxPort = 65535 + // SchedulerAddressSeparator is the separator of domain scheduler address + SchedulerAddressSeparator = ":" + // PlatformTenantID is tenant ID of platform function + PlatformTenantID = "0" + + // RuntimeLogOptTail - + RuntimeLogOptTail = "Tail" + // RuntimeLayerDirName - + RuntimeLayerDirName = "layer" + // RuntimeFuncDirName - + RuntimeFuncDirName = "func" + + // FunctionTaskAppID - + FunctionTaskAppID = "function-task" + + // TenantID config from function task + TenantID = "0" + + // BackpressureCode indicate that frontend should choose another proxy/worker and retry + BackpressureCode = 211429 + // HeaderBackpressure indicate that proxy can backpressure this request + HeaderBackpressure = "X-Backpressure" + + // SrcInstanceID gRPC context of metadata + SrcInstanceID = "src_instance_id" + // ReturnObjID gRPC context of metadata + ReturnObjID = "return_obj_id" + + // DelWorkerAgentEvent delete workerAgent + DelWorkerAgentEvent = "WorkerAgent-Del" + // UpdWorkerAgentEvent update workerAgent + UpdWorkerAgentEvent = "WorkerAgent-Upd" + + // DefaultLatestVersion is default function name + DefaultLatestVersion = "$latest" + // DefaultLatestFaaSVersion is default faas function name + DefaultLatestFaaSVersion = "latest" + // DefaultJavaRuntimeName is default java runtime name + DefaultJavaRuntimeName = "java1.8" + // DefaultJavaRuntimeNameForFaas is defualt + DefaultJavaRuntimeNameForFaas = "java8" +) + +// grpc parameters +const ( + // MaxMsgSize grpc client max message size(bit) + MaxMsgSize = 1024 * 1024 * 2 + // MaxWindowSize grpc flow control window size(bit) + MaxWindowSize = 1024 * 1024 * 2 + // MaxBufferSize grpc read/write buffer size(bit) + MaxBufferSize = 1024 * 1024 * 2 +) + +// functionBus userData key flag +const ( + // FrontendCallFlag invoke from task + FrontendCallFlag = "FrontendCallFlag" +) + +const ( + // DynamicRouterParamPrefix 动态路由参数前缀 + DynamicRouterParamPrefix = "/:" +) + +// HTTP invoke request header key +const ( + // HeaderExecutedDuration - + HeaderExecutedDuration = "X-Executed-Duration" + // HeaderTraceID - + HeaderTraceID = "X-Trace-Id" + // HeaderEventSourceID - + HeaderEventSourceID = "X-Event-Source-Id" + // HeaderBusinessID - + HeaderBusinessID = "X-Business-ID" + // HeaderTenantID - + HeaderTenantID = "X-Tenant-ID" + // HeaderTenantId - + HeaderTenantId = "X-Tenant-Id" + // HeaderPoolLabel - + HeaderPoolLabel = "X-Pool-Label" + // HeaderLogType - + HeaderLogType = "X-Log-Type" + // HeaderLogResult - + HeaderLogResult = "X-Log-Result" + // HeaderTriggerFlag - + HeaderTriggerFlag = "X-Trigger-Flag" + // HeaderInnerCode - + HeaderInnerCode = "X-Inner-Code" + // HeaderInvokeURN - + HeaderInvokeURN = "X-Tag-VersionUrn" + // HeaderStateKey - + HeaderStateKey = "X-State-Key" + // HeaderCallType is the request type + HeaderCallType = "X-Call-Type" + // HeaderLoadDuration duration of loading function + HeaderLoadDuration = "X-Load-Duration" + // HeaderNodeLabel is node label + HeaderNodeLabel = "X-Node-Label" + // HeaderForceDeploy is Force Deploy + HeaderForceDeploy = "X-Force-Deploy" + // HeaderAuthorization is authorization + HeaderAuthorization = "authorization" + // HeaderFutureID is futureID of invocation + HeaderFutureID = "X-Future-ID" + // HeaderAsync indicate whether it is an async request + HeaderAsync = "X-ASYNC" + // HeaderRuntimeID represents runtime instance identification + HeaderRuntimeID = "X-Runtime-ID" + // HeaderRuntimePort represents runtime rpc port + HeaderRuntimePort = "X-Runtime-Port" + // HeaderCPUSize is cpu size specified by invoke + HeaderCPUSize = "X-Instance-CPU" + // HeaderMemorySize is cpu memory specified by invoke + HeaderMemorySize = "X-Instance-Memory" + HeaderFileDigest = "X-File-Digest" + HeaderProductID = "X-Product-Id" + HeaderPrivilege = "X-Privilege" + HeaderUserID = "X-User-Id" + HeaderVersion = "X-Version" + HeaderKind = "X-Kind" + // HeaderCompatibleRuntimes - + HeaderCompatibleRuntimes = "X-Header-Compatible-Runtimes" + // HeaderDescription - + HeaderDescription = "X-Description" + // HeaderLicenseInfo - + HeaderLicenseInfo = "X-License-Info" + // HeaderGroupID is group id + HeaderGroupID = "X-Group-ID" + // ApplicationJSON - + ApplicationJSON = "application/json" + // ContentType - + ContentType = "Content-Type" + // PriorityHeader - + PriorityHeader = "priority" + // HeaderDataContentType - + HeaderDataContentType = "X-Content-Type" + // ErrorDuration duration when error happened, + // used with key $LoadDuration + ErrorDuration = -1 +) + +// Extra Request Header +const ( + // HeaderRequestID - + HeaderRequestID = "x-request-id" + // HeaderAccessKey - + HeaderAccessKey = "x-access-key" + // HeaderSecretKey - + HeaderSecretKey = "x-secret-key" + // HeaderAuthToken - + HeaderAuthToken = "x-auth-token" + // HeaderSecurityToken - + HeaderSecurityToken = "x-security-token" + // HeaderStorageType code storage type + HeaderStorageType = "x-storage-type" +) + +const ( + // FunctionStatusUnavailable function status is unavailable + FunctionStatusUnavailable = "unavailable" + + // FunctionStatusAvailable function status is available + FunctionStatusAvailable = "available" +) + +const ( + // OndemandKey is used in ondemand scenario + OndemandKey = "ondemand" +) + +// stage +const ( + InitializeStage = "initialize" +) + +// default UIDs and GIDs +const ( + DefaultWorkerGID = 1002 + DefaultRuntimeUID = 1003 + DefaultRuntimeUName = "snuser" + DefaultRuntimeGID = 1003 +) + +const ( + // WorkerManagerApplier mark the instance is created by minInstance + WorkerManagerApplier = "worker-manager" +) + +const ( + DialBaseDelay = 300 * time.Millisecond + DialMultiplier = 1.2 + DialJitter = 0.1 + DialMaxDelay = 15 * time.Second + RuntimeDialMaxDelay = 100 * time.Second +) + +// constants of network connection +const ( + // DefaultConnectInterval is the default connect interval + DefaultConnectInterval = 3 * time.Second + // DefaultDialInterval is the default grpc dial request interval + DefaultDialInterval = 3 * time.Second + // DefaultRetryTimes is the default request retry times + DefaultRetryTimes = 3 + ConnectIntervalTime = 1 * time.Second +) + +// request message +const ( + // RequestCPU - + RequestCPU = "CPU" + // RequestMemory - + RequestMemory = "Memory" + // MinCustomResourcesSize is min gpu size of invoke + MinCustomResourcesSize = 0 + + // CpuUnitConvert - + CpuUnitConvert = 1000 + // MemoryUnitConvert - + MemoryUnitConvert = 1024 + + // minInvokeCPUSize is default min cpu size of invoke (One CPU core corresponds to 1000) + minInvokeCPUSize = 300 + // MaxInvokeCPUSize is max cpu size of invoke (One CPU core corresponds to 1000) + MaxInvokeCPUSize = 16000 + // minInvokeMemorySize is default min memory size of invoke (MB) + minInvokeMemorySize = 128 + // MaxInvokeMemorySize is max memory size of invoke (MB) + MaxInvokeMemorySize = 1024 * 1024 * 1024 + // InstanceConcurrency - + InstanceConcurrency = "Concurrency" + // DefaultMapSize default map size + DefaultMapSize = 2 + // DefaultSliceSize default slice size + DefaultSliceSize = 16 + // MaxUploadMemorySize is max memory size of upload (MB) + MaxUploadMemorySize = 10 * 1024 * 1024 + // S3StorageType the code is stored in the minio + S3StorageType = "s3" + // LocalStorageType the code is stored in the disk + LocalStorageType = "local" + // CopyStorageType the code is stored in the disk and need to copy to container path + CopyStorageType = "copy" + // Faas kind of function creation + Faas = "faas" +) + +// prefixes of ETCD keys +const ( + WorkerETCDKeyPrefix = "/sn/workeragent" + NodeETCDKeyPrefix = "/sn/node" + // InstanceETCDKeyPrefix is the prefix of etcd key for instance + InstanceETCDKeyPrefix = "/sn/instance" + // ResourceGroupETCDKeyPrefix is the prefix of etcd key for resource group + ResourceGroupETCDKeyPrefix = "/sn/resourcegroup" + // WorkersEtcdKeyPrefix is the prefix of etcd key for workers + WorkersEtcdKeyPrefix = "/sn/workers" + // AliasEtcdKeyPrefix is the key prefix of aliases in etcd + AliasEtcdKeyPrefix = "/sn/aliases" +) + +// constants of posix custom runtime +const ( + PosixCustomRuntime = "posix-custom-runtime" + GORuntime = "go" + JavaRuntime = "java" + _ +) + +const ( + // OriginSchedulePolicy use origin scheduler policy + OriginSchedulePolicy = 0 + // NewSchedulePolicy use new scheduler policy + NewSchedulePolicy = 1 +) + +const ( + // LocalSchedulerLevel local scheduler level is 0 + LocalSchedulerLevel = iota + // LowDomainSchedulerLevel low domain scheduler level is 0 + LowDomainSchedulerLevel +) + +const ( + // Base10 is the decimal base number when use FormatInt + Base10 = 10 +) + +// MinInvokeCPUSize is min cpu size of invoke (One CPU core corresponds to 1000) +// Return default minInvokeCPUSize or system env[MinInvokeCPUSize] +var MinInvokeCPUSize = func() float64 { + minInvokeCPUSizeStr := os.Getenv("MinInvokeCPUSize") + if minInvokeCPUSizeStr != "" { + value, err := strconv.Atoi(minInvokeCPUSizeStr) + if err != nil { + return minInvokeCPUSize + } + return float64(value) + } + return minInvokeCPUSize +}() + +// MinInvokeMemorySize is min memory size of invoke (MB) +// Return default minInvokeMemorySize or system env[MinInvokeMemorySize] +var MinInvokeMemorySize = func() float64 { + minInvokeMemorySizeStr := os.Getenv("MinInvokeMemorySize") + if minInvokeMemorySizeStr != "" { + value, err := strconv.Atoi(minInvokeMemorySizeStr) + if err != nil { + return minInvokeMemorySize + } + return float64(value) + } + return minInvokeMemorySize +}() + +// SelfNodeIP - node IP +var SelfNodeIP = os.Getenv(HostIPEnvKey) + +// SelfNodeID - node ID +var SelfNodeID = os.Getenv(NodeID) diff --git a/yuanrong/pkg/common/crypto/crypto.go b/yuanrong/pkg/common/crypto/crypto.go new file mode 100644 index 0000000..343eb31 --- /dev/null +++ b/yuanrong/pkg/common/crypto/crypto.go @@ -0,0 +1,267 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crypto for auth +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "strings" + "sync" + + "golang.org/x/crypto/pbkdf2" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +const ( + randomNumberMaxLength = 16 + randomNumberMinLength = 12 + defaultSliceLen = 1024 + cipherTextsLen = 2 +) + +var ( + decryptAlgorithm string = "GCM" + decryptAlgorithmMutex sync.RWMutex +) + +// SetDecryptAlgorithm - +func SetDecryptAlgorithm(algorithm string) { + decryptAlgorithmMutex.Lock() + defer decryptAlgorithmMutex.Unlock() + decryptAlgorithm = algorithm +} + +// GetDecryptAlgorithm returns global decryptAlgorithm +func GetDecryptAlgorithm() string { + decryptAlgorithmMutex.RLock() + defer decryptAlgorithmMutex.RUnlock() + return decryptAlgorithm +} + +// Encrypt encrypts data by GCM algorithm +func Encrypt(content string, secret []byte) ([]byte, error) { + if GetDecryptAlgorithm() == "NO_CRYPTO" { + log.GetLogger().Debug("decrypt algorithm is NO_CRYPTO, return plain text directly") + return []byte(content), nil + } + textByte := []byte(content) + cipherByte, salt, err := encryptGcmDataFromBody(textByte, secret) + if err != nil { + return nil, err + } + ciperText := fmt.Sprintf("%s:%s", salt, hex.EncodeToString(cipherByte)) + return []byte(ciperText), nil +} + +func encryptPBKDF2WithSHA256(f *RootKeyFactor) *RootKey { + minLen := math.Min(float64(len(f.k1Data)), math.Min(float64(len(f.k2Data)), float64(len(f.component3byte)))) + bytePsd := make([]byte, int(minLen), int(minLen)) + + for i := 0; i < int(minLen); i++ { + bytePsd[i] = f.k1Data[i] ^ f.k2Data[i] ^ f.component3[i] + } + + rootKeyByte := pbkdf2.Key(bytePsd, f.saltData, f.iterCount, byteSize, sha256.New) + sliceLen := len(rootKeyByte) + if sliceLen <= 0 || sliceLen > defaultSliceLen { + sliceLen = defaultSliceLen + } + + byteMac := make([]byte, sliceLen) + macSecretKeyByte := pbkdf2.Key(byteMac, f.macData, f.iterCount, byteSize, sha256.New) + + rootKey := &RootKey{} + rootKey.RootKey = rootKeyByte + rootKey.MacSecretKey = macSecretKeyByte + + return rootKey +} + +func hmacHash(data []byte, key []byte) string { + hm := hmac.New(sha256.New, key) + _, err := hm.Write(data) + if err != nil { + log.GetLogger().Errorf("failed to hmacHash write data: %s ", err.Error()) + return "" + } + return hex.EncodeToString(hm.Sum([]byte{})) +} + +// encryptGcmDataFromBody encrypts data +func encryptGcmDataFromBody(body []byte, secret []byte) ([]byte, string, error) { + if len(body) == 0 { + return nil, "", fmt.Errorf("body is empty") + } + secretBytes, err := hex.DecodeString(string(secret)) + if err != nil { + return nil, "", err + } + + aesBlock, err := aes.NewCipher(secretBytes) + if err != nil { + return nil, "", err + } + aesgcm, err := cipher.NewGCM(aesBlock) + if err != nil { + return nil, "", err + } + + // generate salt value + nonceSize := aesgcm.NonceSize() + if nonceSize > randomNumberMaxLength || nonceSize < randomNumberMinLength { + err = errors.New("nonceSize out of bound") + return nil, "", err + } + salt := make([]byte, nonceSize, nonceSize) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + return nil, "", err + } + + cipherByte := aesgcm.Seal(nil, salt, body, nil) + return cipherByte, hex.EncodeToString(salt), nil +} + +// Decrypt returns string cipher bytes by AES and GCM algorithms +func Decrypt(cipherText []byte, secret []byte) (string, error) { + if GetDecryptAlgorithm() == "NO_CRYPTO" { + log.GetLogger().Debug("decrypt algorithm is NO_CRYPTO, return plain text directly") + return string(cipherText), nil + } + cipherTexts := strings.Split(string(cipherText), ":") + if len(cipherTexts) != cipherTextsLen { + return "", fmt.Errorf("wrong cipher text") + } + + saltStr := cipherTexts[0] + encryptStr := cipherTexts[1] + + salt, err := hex.DecodeString(saltStr) + if err != nil { + return "", err + } + + encrypt, err := hex.DecodeString(encryptStr) + if err != nil { + return "", err + } + + secretData := secret + if utils.IsHexString(string(secret)) { + var err error + secretData, err = hex.DecodeString(string(secret)) + if err != nil { + return "", err + } + } + + cipherBytes, err := decryptGcmData(encrypt, secretData, salt) + if err != nil { + return "", err + } + + if cipherBytes == nil { + return "", fmt.Errorf("decrypt error") + } + + return string(cipherBytes), nil +} + +// DecryptByte returns string cipher bytes by AES and GCM algorithms +func DecryptByte(cipherText []byte, secret []byte) ([]byte, error) { + cipherTexts := strings.Split(string(cipherText), ":") + if len(cipherTexts) != cipherTextsLen { + return nil, fmt.Errorf("wrong cipher text") + } + + saltStr := cipherTexts[0] + encryptStr := cipherTexts[1] + salt, err := hex.DecodeString(saltStr) + if err != nil { + return nil, err + } + + encryptByte, err := hex.DecodeString(encryptStr) + if err != nil { + return nil, err + } + + secretData := secret + if utils.IsHexString(string(secret)) { + var err error + secretData, err = hex.DecodeString(string(secret)) + if err != nil { + return nil, err + } + } + + cipherBytes, err := decryptGcmData(encryptByte, secretData, salt) + if err != nil { + return nil, err + } + + if cipherBytes == nil { + return nil, fmt.Errorf("decrypt error") + } + + return cipherBytes, nil +} + +// decryptGcmData decrypt data with aes gcm mode +func decryptGcmData(encrypt []byte, secret []byte, salt []byte) ([]byte, error) { + block, err := aes.NewCipher(secret) + if err != nil { + return nil, err + } + + aesGcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + decrypted, err := aesGcm.Open(nil, salt, encrypt, nil) + if err != nil { + return nil, err + } + + return decrypted, nil +} + +// decryptWorkKey Decrypt Work Key +func decryptWorkKey(workKey string, workMac string, rootKey *RootKey) (string, error) { + workKeyDecrypt, err := Decrypt([]byte(workKey), rootKey.RootKey) + if err != nil { + return "", err + } + + workKeyMac := hmacHash([]byte(workKeyDecrypt), rootKey.MacSecretKey) + if workKeyMac == workMac { + return workKeyDecrypt, nil + } + + return "", fmt.Errorf("workKey is changed") +} diff --git a/yuanrong/pkg/common/crypto/crypto_test.go b/yuanrong/pkg/common/crypto/crypto_test.go new file mode 100644 index 0000000..81bdc77 --- /dev/null +++ b/yuanrong/pkg/common/crypto/crypto_test.go @@ -0,0 +1,143 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This test file can also be used as a tool to create, encrypt and decrypt our secrets and cipher texts +package crypto + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestAll tests all processes, including creating random numbers, encryption and decryption +func TestAll(t *testing.T) { + rootKey := RootKey{} + randNum := hex.EncodeToString(createRandNum()) + fmt.Println(randNum) + rootKey.RootKey = []byte(randNum) + content := "abcd" + secret := hex.EncodeToString(createRandNum()) + fmt.Println(secret) + cipherText, err := Encrypt(content, []byte(secret)) + if err != nil { + t.Fatal(err) + } + fmt.Println(cipherText) + result, err := Decrypt(cipherText, []byte(secret)) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, content, result) +} + +// TestRandNum is also a tool to create the random number for encryption +func TestRandNum(t *testing.T) { + randNum := hex.EncodeToString(createRandNum()) + fmt.Println("randNum: " + randNum) +} + +// TestEncrypt is also a tool to generate a cipher text from a plain text and a secret +func TestEncrypt(t *testing.T) { + content := "7b83a1e330ccb177048671182f5ce1fde59c4c1c8167e8cf56190c4a5dd2c434" + secret := "f7de29fa800605cd7f490ff1d1607fffc1387f05ad8ca059868ab605d6bb6b6b" + cipherText, err := Encrypt(content, []byte(secret)) + if err != nil { + t.Fatal(err) + } + fmt.Println(string(cipherText)) +} + +// TestDecrypt is also a tool to decrypt a cipher text with a secret and get the plain text +func TestDecrypt(t *testing.T) { + cipherText := "b53df10229eead59476ae034:1feb0793e5b021511f064681827dbb8660594b31dfd90e665fa9664fdf02f1aa64304b1db66328e0b87f19c188d9e0d6487049b19a3b3aab25e3c3dcdcd22d390e020dce27af51b94ac154d137a9ce19" + secret := "f7de29fa800605cd7f490ff1d1607fffc1387f05ad8ca059868ab605d6bb6b6b" + content, err := Decrypt([]byte(cipherText), []byte(secret)) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, content, "7b83a1e330ccb177048671182f5ce1fde59c4c1c8167e8cf56190c4a5dd2c434") + fmt.Println(content) +} + +func TestDecryptError(t *testing.T) { + _, err := Decrypt([]byte("1A"), []byte("1C")) + assert.NotNil(t, err) + + _, err = Decrypt([]byte("1A:1B"), []byte("1C")) + assert.NotNil(t, err) + + _, err = Decrypt([]byte("1Z:1B"), []byte("1C")) + assert.NotNil(t, err) + + _, err = Decrypt([]byte("1A:1Z"), []byte("1C")) + assert.NotNil(t, err) + + _, err = Decrypt([]byte("1A:1B"), []byte("1Z")) + assert.NotNil(t, err) +} + +func createRandNum() []byte { + var keyLengthAES256 = 32 + initNum := make([]byte, keyLengthAES256) + _, err := rand.Read(initNum) + if err != nil { + return nil + } + return initNum +} + +func TestDecryptByte(t *testing.T) { + _, err := DecryptByte([]byte("1A"), []byte("1C")) + assert.NotNil(t, err) + + _, err = DecryptByte([]byte("1A:1B"), []byte("1C")) + assert.NotNil(t, err) +} + +func Test_encryptPBKDF2WithSHA256(t *testing.T) { + data3 := "0B6AA66FADD74F59F019109582E1AAED1EEEEA14CEDFAFCA6DB384D8C3360D5E34087FD513B16929A2567E5E184" + + "AE2B49A71B9E25E6371C91227D8CE114957D3D383EBC4899DBA7C43F6D80273E57F60B8FC918C2474CA687F1C5DBD7A71" + + "B1DC0A1EA455C7F2304A4846FD05FFD9FDD96B606546C51241A190EF8B70382ABE55" + + f := &RootKeyFactor{ + iterCount: IterKeyFactoryIter, + component3: data3, + component3byte: []byte(data3), + } + rKey := encryptPBKDF2WithSHA256(f) + assert.NotNil(t, rKey.RootKey) + + rootKey = rKey + + s := &SecretWorkKey{ + Key: "123", + Mac: "abc", + } + + data, err := s.MarshalJSON() + assert.Nil(t, err) + + err = s.UnmarshalJSON(data) + assert.Nil(t, err) +} + +func Test_EncryptGcmDataFromBody(t *testing.T) { + encryptGcmDataFromBody([]byte{}, []byte{}) +} diff --git a/yuanrong/pkg/common/crypto/pem_crypto.go b/yuanrong/pkg/common/crypto/pem_crypto.go new file mode 100644 index 0000000..df6c180 --- /dev/null +++ b/yuanrong/pkg/common/crypto/pem_crypto.go @@ -0,0 +1,180 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crypto for auth +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "encoding/hex" + "encoding/pem" + "errors" + "strings" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// PEMCipher - +type PEMCipher int + +// Possible values for the EncryptPEMBlock encryption algorithm. +const ( + _ PEMCipher = iota + PEMCipherAES128 + PEMCipherAES192 + PEMCipherAES256 +) + +const ( + saltLength = 8 + aes128Cbc = "AES-128-CBC" + aes192Cbc = "AES-192-CBC" + aes256Cbc = "AES-256-CBC" +) + +// cipherUnit holds a method for enciphering a PEM block. +type cipherUnit struct { + cipher PEMCipher + name string + cipherFunc func(key []byte) (cipher.Block, error) + keySize int + blockSize int +} + +// cipherUnits holds a slice of cipherUnit. +var cipherUnits = []cipherUnit{{ + name: aes256Cbc, + cipher: PEMCipherAES256, + cipherFunc: aes.NewCipher, + keySize: 32, + blockSize: aes.BlockSize, +}, { + name: aes192Cbc, + cipher: PEMCipherAES192, + cipherFunc: aes.NewCipher, + keySize: 24, + blockSize: aes.BlockSize, +}, { + name: aes128Cbc, + cipher: PEMCipherAES128, + cipherFunc: aes.NewCipher, + keySize: 16, + blockSize: aes.BlockSize, +}, +} + +// deriveKey uses a key derivation function to stretch the password into a key with +// the number of bits our cipher requires. +func (c cipherUnit) deriveKey(password, salt []byte) []byte { + hash := md5.New() + out := make([]byte, c.keySize) + var digest []byte + + for i := 0; i < len(out); i += len(digest) { + hash.Reset() + _, err := hash.Write(digest) + if err != nil { + log.GetLogger().Warnf("write digest failed, err: %s", err) + } + _, err = hash.Write(password) + if err != nil { + log.GetLogger().Warnf("write password failed, err: %s", err) + } + _, err = hash.Write(salt) + if err != nil { + log.GetLogger().Warnf("write salt failed, err: %s", err) + } + digest = hash.Sum(digest[:0]) + copy(out[i:], digest) + } + return out +} + +func cipherByName(name string) *cipherUnit { + for i := range cipherUnits { + alg := &cipherUnits[i] + if alg.name == name { + return alg + } + } + return nil +} + +// IsEncryptedPEMBlock returns whether the PEM block is password encrypted according to RFC 1423. +func IsEncryptedPEMBlock(b *pem.Block) bool { + _, ok := b.Headers["DEK-Info"] + return ok +} + +// DecryptPEMBlock takes a PEM block encrypted according to RFC 1423 and the password used to encrypt +// it and returns a slice of decrypted DER encoded bytes. +func DecryptPEMBlock(b *pem.Block, pwd []byte) ([]byte, error) { + dekInfo, ok := b.Headers["DEK-Info"] + if !ok { + return nil, errors.New("crypto: no DEK-Info header in block") + } + + mode, hexIV, ok := strings.Cut(dekInfo, ",") + if !ok { + return nil, errors.New("crypto: malformed DEK-Info header") + } + + ciph := cipherByName(mode) + if ciph == nil { + return nil, errors.New("crypto: unknown encryption mode") + } + iv, err := hex.DecodeString(hexIV) + if err != nil { + return nil, err + } + if len(iv) != ciph.blockSize { + return nil, errors.New("crypto: incorrect IV size") + } + + key := ciph.deriveKey(pwd, iv[:saltLength]) + block, err := ciph.cipherFunc(key) + if err != nil { + return nil, err + } + + if len(b.Bytes)%block.BlockSize() != 0 { + return nil, errors.New("crypto: encrypted PEM data is not a multiple of the block size") + } + + data := make([]byte, len(b.Bytes)) + dec := cipher.NewCBCDecrypter(block, iv) + dec.CryptBlocks(data, b.Bytes) + + dataLen := len(data) + if dataLen == 0 || dataLen%ciph.blockSize != 0 { + return nil, errors.New("crypto: invalid padding") + } + last := int(data[dataLen-1]) + if dataLen < last { + return nil, errors.New("crypto: decryption password incorrect") + } + if last == 0 || last > ciph.blockSize { + return nil, errors.New("crypto: decryption password incorrect") + } + for _, val := range data[dataLen-last:] { + if int(val) != last { + return nil, errors.New("crypto: decryption password incorrect") + } + } + return data[:dataLen-last], nil +} diff --git a/yuanrong/pkg/common/crypto/scc_constants.go b/yuanrong/pkg/common/crypto/scc_constants.go new file mode 100644 index 0000000..d5e5586 --- /dev/null +++ b/yuanrong/pkg/common/crypto/scc_constants.go @@ -0,0 +1,47 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crypto for auth +package crypto + +import ( + "sync" +) + +var ( + sccInitialized bool = false + m sync.RWMutex +) + +const ( + // Aes128Gcm - + Aes128Gcm = "AES128_GCM" + // Aes256Gcm - + Aes256Gcm = "AES256_GCM" + // Aes256Cbc - + Aes256Cbc = "AES256_CBC" + // Sm4Cbc - + Sm4Cbc = "SM4_CBC" + // Sm4Ctr - + Sm4Ctr = "SM4_CTR" +) + +// SccConfig - +type SccConfig struct { + Enable bool `json:"enable" valid:"optional"` + Algorithm string `json:"algorithm" valid:"optional"` + SccConfigPath string `json:"sccConfigPath" valid:"optional"` +} diff --git a/yuanrong/pkg/common/crypto/scc_crypto.go b/yuanrong/pkg/common/crypto/scc_crypto.go new file mode 100644 index 0000000..417e238 --- /dev/null +++ b/yuanrong/pkg/common/crypto/scc_crypto.go @@ -0,0 +1,115 @@ +//go:build cryptoapi +// +build cryptoapi + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crypto for auth +package crypto + +import ( + "cryptoapi" + "fmt" + "path" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// SCCInitialized - +func SCCInitialized() bool { + m.RLock() + defer m.RUnlock() + return sccInitialized +} + +// GetSCCAlgorithm - +func GetSCCAlgorithm(algorithm string) int { + switch algorithm { + case Aes128Gcm: + return cryptoapi.ALG_AES128_GCM + case Aes256Gcm: + return cryptoapi.ALG_AES256_GCM + case Aes256Cbc: + return cryptoapi.ALG_AES256_CBC + case Sm4Cbc: + return cryptoapi.ALG_SM4_CBC + case Sm4Ctr: + return cryptoapi.ALG_SM4_CTR + default: + return cryptoapi.ALG_AES256_GCM + } +} + +// InitializeSCC - +func InitializeSCC(config SccConfig) bool { + m.Lock() + defer m.Unlock() + if !config.Enable { + return true + } + options := cryptoapi.NewSccOptions() + const configPath = "/home/sn/resource/scc" + sccConfigPath := config.SccConfigPath + if sccConfigPath == "" { + sccConfigPath = configPath + } + options.PrimaryKeyFile = path.Join(sccConfigPath, "primary.ks") + options.StandbyKeyFile = path.Join(sccConfigPath, "standby.ks") + options.LogPath = "/tmp/log/" + options.LogFile = "scc" + options.DefaultAlgorithm = GetSCCAlgorithm(config.Algorithm) + options.RandomDevice = "/dev/random" + options.EnableChangeFilePermission = 0 + cryptoapi.Finalize() + err := cryptoapi.InitializeWithConfig(options) + if err != nil { + fmt.Printf("failed to initialize crypto, Error = [%s]\n", err.Error()) + log.GetLogger().Errorf("Initialize SCC Error = [%s]", err.Error()) + return false + } + sccInitialized = true + return true +} + +// FinalizeSCC - +func FinalizeSCC() { + m.Lock() + defer m.Unlock() + sccInitialized = false + cryptoapi.Finalize() +} + +// SCCDecrypt - +func SCCDecrypt(cipher []byte) (string, error) { + plain, err := cryptoapi.Decrypt(string(cipher)) + if err != nil { + log.GetLogger().Errorf("SCC Decrypt Error = [%s]", err.Error()) + return "", err + } + + return plain, nil +} + +// SCCEncrypt - +func SCCEncrypt(plainInput string) ([]byte, error) { + cipher, err := cryptoapi.Encrypt(plainInput) + if err != nil { + log.GetLogger().Errorf("SCC Encrypt Error = [%s]", err.Error()) + return nil, err + } + + return []byte(cipher), nil +} diff --git a/yuanrong/pkg/common/crypto/scc_crypto_fake.go b/yuanrong/pkg/common/crypto/scc_crypto_fake.go new file mode 100644 index 0000000..c334713 --- /dev/null +++ b/yuanrong/pkg/common/crypto/scc_crypto_fake.go @@ -0,0 +1,50 @@ +//go:build !cryptoapi +// +build !cryptoapi + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crypto for auth +package crypto + +// SCCInitialized - +func SCCInitialized() bool { + return false +} + +// GetSCCAlgorithm - +func GetSCCAlgorithm(algorithm string) int { + return 0 +} + +// InitializeSCC - +func InitializeSCC(config SccConfig) bool { + return false +} + +// FinalizeSCC - +func FinalizeSCC() { +} + +// SCCDecrypt - +func SCCDecrypt(cipher []byte) (string, error) { + return "", nil +} + +// SCCEncrypt - +func SCCEncrypt(plainInput string) ([]byte, error) { + return []byte{}, nil +} diff --git a/yuanrong/pkg/common/crypto/scc_crypto_test.go b/yuanrong/pkg/common/crypto/scc_crypto_test.go new file mode 100644 index 0000000..d37d693 --- /dev/null +++ b/yuanrong/pkg/common/crypto/scc_crypto_test.go @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This test file can also be used as a tool to create, encrypt and decrypt our secrets and cipher texts +package crypto + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSCCEncryptDecryptInitialized(t *testing.T) { + var c = SccConfig{ + Enable: true, + Algorithm: "AES256_GCM", + } + ret := InitializeSCC(c) + assert.True(t, ret) + input := "text to encrypt" + encrypted, err := SCCEncrypt(input) + fmt.Printf("encrypted : %s\n", string(encrypted)) + assert.Nil(t, err) + decrypt, err := SCCDecrypt(encrypted) + fmt.Printf("decrypt : %s\n", decrypt) + assert.Nil(t, err) + assert.Equal(t, input, decrypt) + assert.NotEqual(t, encrypted, input) + FinalizeSCC() +} + +func TestSCCEncryptDecryptNotInitialized(t *testing.T) { + var c = SccConfig{ + Enable: false, + Algorithm: "AES256_GCM", + } + ret := InitializeSCC(c) + assert.True(t, ret) + input := "text to encrypt" + encrypted, _ := SCCEncrypt(input) + fmt.Printf("encrypted : %s\n", string(encrypted)) + decrypt, _ := SCCDecrypt(encrypted) + fmt.Printf("decrypt : %s\n", decrypt) + FinalizeSCC() +} + +func TestSCCEncryptDecryptAlgorithms(t *testing.T) { + var c = SccConfig{ + Enable: true, + Algorithm: "AES256_GCM", + } + + algorithms := []string{"AES256_CBC", "AES128_GCM", "AES256_GCM", "SM4_CBC", "SM4_CTR", "DEFAULT"} + for _, algo := range algorithms { + FinalizeSCC() + c.Algorithm = algo + ret := InitializeSCC(c) + assert.True(t, ret) + input := "text to encrypt" + encrypted, err := SCCEncrypt(input) + fmt.Printf("encrypted : %s\n", string(encrypted)) + assert.Nil(t, err) + decrypt, err := SCCDecrypt(encrypted) + fmt.Printf("decrypt : %s\n", decrypt) + assert.Nil(t, err) + assert.Equal(t, input, decrypt) + assert.NotEqual(t, encrypted, input) + } +} diff --git a/yuanrong/pkg/common/crypto/types.go b/yuanrong/pkg/common/crypto/types.go new file mode 100644 index 0000000..36471f3 --- /dev/null +++ b/yuanrong/pkg/common/crypto/types.go @@ -0,0 +1,300 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crypto for auth +package crypto + +import ( + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "path" + "sync" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/reader" +) + +const ( + // byteSize defines the key length + byteSize = 32 + queryElementLem = 2 + saltDataMinLength = 16 + // IterKeyFactoryIter is the iter Count of Root Key Factor + IterKeyFactoryIter = 10000 + apple = "apple" + boy = "boy" + cat = "cat" + dog = "dog" + egg = "egg" + fish = "fish" + wdo = "wdo" + KeyFactorNums = 5 +) + +var ( + rootKeyOnce sync.Once + workKeyOnce sync.Once + + workKey []byte + rootKey *RootKey +) + +// set root key factor +func buildRootKeyFactor(f *RootKeyFactor) error { + resourcePath := utils.GetResourcePath() + k1Path := path.Join(resourcePath, "rdo", "v1", apple, "a.txt") + k2Path := path.Join(resourcePath, "rdo", "v1", boy, "b.txt") + macPath := path.Join(resourcePath, "rdo", "v1", cat, "c.txt") + saltPath := path.Join(resourcePath, "rdo", "v1", dog, "d.txt") + + // k1Data + k1Data, err := reader.ReadFileWithTimeout(k1Path) + if err != nil { + return err + } + f.k1Data, err = hex.DecodeString(string(k1Data)) + if err != nil { + return err + } + + // k2Data + k2Data, err := reader.ReadFileWithTimeout(k2Path) + if err != nil { + return err + } + f.k2Data, err = hex.DecodeString(string(k2Data)) + if err != nil { + return err + } + + // macData + macData, err := reader.ReadFileWithTimeout(macPath) + if err != nil { + return err + } + f.macData, err = hex.DecodeString(string(macData)) + if err != nil { + return err + } + + // saltData + saltData, err := reader.ReadFileWithTimeout(saltPath) + if len(saltData) < saltDataMinLength { + return fmt.Errorf("invalid salt data length of %d", len(saltData)) + } + if err != nil { + return err + } + if f.saltData, err = hex.DecodeString(string(saltData)); err != nil { + return err + } + return nil +} + +// LoadRootKey Load Root Key +func LoadRootKey() (*RootKey, error) { + // k3 + resourcePath := utils.GetResourcePath() + k3Path := path.Join(resourcePath, "rdo", "v1", egg, "e.txt") + // k1Data + k3Data, err := reader.ReadFileWithTimeout(k3Path) + k3DataDecode, err := hex.DecodeString(string(k3Data)) + if err != nil { + return nil, err + } + f := &RootKeyFactor{ + // 10000 is the iter Count of Root Key Factor + iterCount: IterKeyFactoryIter, + component3: string(k3DataDecode), + component3byte: k3DataDecode, + } + err = buildRootKeyFactor(f) + if err != nil { + return nil, err + } + rootKey := encryptPBKDF2WithSHA256(f) + return rootKey, nil +} + +// LoadRootKeyWithKeyFactor Load Root Key With Key Factor +func LoadRootKeyWithKeyFactor(keyFactor []string) (*RootKey, error) { + if len(keyFactor) < KeyFactorNums { + return nil, errors.New("short key factors") + } + var err error + k3Data := keyFactor[2] + k3DataDecode, err := hex.DecodeString(k3Data) + f := &RootKeyFactor{ + // 10000 is the iter Count of Root Key Factor + iterCount: IterKeyFactoryIter, + component3: string(k3DataDecode), + component3byte: k3DataDecode, + } + f.k1Data, err = hex.DecodeString(keyFactor[0]) + if err != nil { + return nil, err + } + f.k2Data, err = hex.DecodeString(keyFactor[1]) + if err != nil { + return nil, err + } + f.macData, err = hex.DecodeString(keyFactor[3]) + if err != nil { + return nil, err + } + if f.saltData, err = hex.DecodeString(keyFactor[4]); err != nil { + return nil, err + } + rootKey := encryptPBKDF2WithSHA256(f) + return rootKey, nil +} + +// RootKey include RootKey and MacSecretKey +type RootKey struct { + RootKey []byte + MacSecretKey []byte +} + +// RootKeyFactor include Root Key Factor +type RootKeyFactor struct { + k1Data []byte + k2Data []byte + macData []byte + saltData []byte + iterCount int + component3 string + component3byte []byte +} + +// WorkKeys define Work Keys +type WorkKeys map[string]*SecretNamedWorkKeys + +// GetKeyByName Get Key By Name +func (k *WorkKeys) GetKeyByName(name string) *SecretWorkKey { + namedKey, exist := (*k)[name] + if !exist { + return nil + } + + return namedKey.Keys +} + +// SecretNamedWorkKeys include Keys and Description +type SecretNamedWorkKeys struct { + Keys *SecretWorkKey `json:"keys"` + Description string `json:"description"` +} + +// SecretWorkKey include Key and Mac +type SecretWorkKey struct { + Key string `json:"key"` + Mac string `json:"mac"` +} + +// MarshalJSON Marshal JSON +func (s *SecretWorkKey) MarshalJSON() ([]byte, error) { + if rootKey == nil || rootKey.RootKey == nil || rootKey.MacSecretKey == nil { + return nil, fmt.Errorf("rootKey is nil") + } + + key, err := Encrypt(s.Key, []byte(hex.EncodeToString(rootKey.RootKey))) + if err != nil { + return nil, err + } + mac := hmacHash([]byte(s.Key), rootKey.MacSecretKey) + + type SecretWorkKeyJSON SecretWorkKey + + return json.Marshal(SecretWorkKeyJSON(SecretWorkKey{ + Key: string(key), Mac: mac})) +} + +// UnmarshalJSON Unmarshal JSON +func (s *SecretWorkKey) UnmarshalJSON(data []byte) error { + + type SecretWorkKeyJSON SecretWorkKey + + err := json.Unmarshal(data, (*SecretWorkKeyJSON)(s)) + if err != nil { + return err + } + + key, err := decryptWorkKey(s.Key, s.Mac, rootKey) + if err != nil { + return err + } + + s.Key = key + + return nil +} + +// Signature define Signature +type Signature struct { + Method []byte + Path []byte + QueryStr string + Body []byte + AppID []byte + CurTimeTamp []byte +} + +// GetRootKey Get Root Key +func GetRootKey() []byte { + rootKeyOnce.Do(func() { + rk, err := LoadRootKey() + if err != nil { + log.GetLogger().Errorf("failed to load rootKey, err: %s", err.Error()) + return + } + rootKey = rk + }) + + if rootKey == nil { + log.GetLogger().Errorf("root key is nil") + return []byte{} + } + return []byte(hex.EncodeToString(rootKey.RootKey)) +} + +// LoadWorkKey Load work Key +func LoadWorkKey() ([]byte, error) { + resourcePath := utils.GetResourcePath() + workKeyPath := path.Join(resourcePath, "rdo", "v1", fish, "f.txt") + workKey, err := reader.ReadFileWithTimeout(workKeyPath) + return workKey, err +} + +// GetWorkKey Get Work Key +func GetWorkKey() []byte { + workKeyOnce.Do(func() { + wk, err := LoadWorkKey() + if err != nil { + log.GetLogger().Errorf("failed to load workKey, err: %s", err.Error()) + return + } + workKey = wk + }) + + if workKey == nil { + log.GetLogger().Errorf("work key is nil") + return []byte{} + } + return workKey +} diff --git a/yuanrong/pkg/common/crypto/types_test.go b/yuanrong/pkg/common/crypto/types_test.go new file mode 100644 index 0000000..b8482fa --- /dev/null +++ b/yuanrong/pkg/common/crypto/types_test.go @@ -0,0 +1,47 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package crypto + +import ( + "testing" + + "github.com/agiledragon/gomonkey" +) + +func TestGetKeyByName(t *testing.T) { + k := &WorkKeys{} + k.GetKeyByName("") + + (*k)[""] = &SecretNamedWorkKeys{} + k.GetKeyByName("") +} + +func TestLoadRootKeyWithKeyFactor(t *testing.T) { + LoadRootKeyWithKeyFactor([]string{""}) + LoadRootKeyWithKeyFactor([]string{"", "", "", "", ""}) +} + +// TestGetWorkKey is also a tool to get the work key from the pre-set resource path +func TestGetWorkKey(t *testing.T) { + GetRootKey() + + patch := gomonkey.ApplyFunc(LoadRootKey, func() (*RootKey, error) { + return nil, nil + }) + GetRootKey() + patch.Reset() +} diff --git a/yuanrong/pkg/common/engine/etcd/etcd.go b/yuanrong/pkg/common/engine/etcd/etcd.go new file mode 100644 index 0000000..34e2bee --- /dev/null +++ b/yuanrong/pkg/common/engine/etcd/etcd.go @@ -0,0 +1,219 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package etcd + +import ( + "context" + "fmt" + "time" + "yuanrong/pkg/common/engine" + + commonetcd "yuanrong/pkg/common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" +) + +// Config is the configuration with etcd +type Config struct { + Timeout time.Duration + TransactionTimeout time.Duration +} + +const ( + // DefaultTimeout is the default etcd execution timeout + DefaultTimeout = 40 * time.Second + // DefaultTransactionTimeout is the default etcd transaction timeout + DefaultTransactionTimeout = 40 * time.Second +) + +// DefaultConfig is default etcd engine configuration +var DefaultConfig = Config{ + Timeout: DefaultTimeout, + TransactionTimeout: DefaultTransactionTimeout, +} + +type etcdE struct { + cli *commonetcd.EtcdClient + cfg Config +} + +// NewEtcdEngine creates a new etcd engine +func NewEtcdEngine(cli *commonetcd.EtcdClient, cfg Config) engine.Engine { + return &etcdE{cli: cli, cfg: cfg} +} + +// Get implements engine.Engine +func (e *etcdE) Get(ctx context.Context, etcdKey string) (string, error) { + etcdCtxInfo := commonetcd.CreateEtcdCtxInfoWithTimeout(ctx, e.cfg.Timeout) + values, err := e.cli.GetValues(etcdCtxInfo, etcdKey) + if err != nil { + return "", err + } + if len(values) == 0 { + log.GetLogger().Debugf("get key: %s, key not found", etcdKey) + return "", engine.ErrKeyNotFound + } + + log.GetLogger().Debugf("get key: %s", etcdKey) + return values[0], nil +} + +// Count implements engine.Engine +func (e *etcdE) Count(ctx context.Context, prefix string) (int64, error) { + etcdCtxInfo := commonetcd.CreateEtcdCtxInfoWithTimeout(ctx, e.cfg.Timeout) + resp, err := e.cli.GetResponse(etcdCtxInfo, prefix, clientv3.WithPrefix(), clientv3.WithCountOnly()) + if err != nil { + return 0, err + } + log.GetLogger().Debugf("count prefix: %s, count: %s", prefix, resp.Count) + return resp.Count, nil +} + +func (e *etcdE) firstInRange(ctx context.Context, prefix string, last bool) (string, string, error) { + etcdCtxInfo := commonetcd.CreateEtcdCtxInfoWithTimeout(ctx, e.cfg.Timeout) + var opts []clientv3.OpOption + if last { + opts = clientv3.WithLastKey() + } else { + opts = clientv3.WithFirstKey() + } + + resp, err := e.cli.GetResponse(etcdCtxInfo, prefix, opts...) + if err != nil { + return "", "", err + } + if len(resp.Kvs) == 0 { + if last { + log.GetLogger().Debugf("last in range prefix: %s, key not found", prefix) + } else { + log.GetLogger().Debugf("first in range prefix: %s, key not found", prefix) + } + return "", "", engine.ErrKeyNotFound + } + + if last { + log.GetLogger().Debugf("last in range prefix: %s, key: %s", prefix, string(resp.Kvs[0].Key)) + } else { + log.GetLogger().Debugf("first in range prefix: %s, key: %s", prefix, string(resp.Kvs[0].Key)) + } + return e.cli.DetachAZPrefix(string(resp.Kvs[0].Key)), string(resp.Kvs[0].Value), nil +} + +// FirstInRange implements engine.Engine +func (e *etcdE) FirstInRange(ctx context.Context, prefix string) (string, string, error) { + return e.firstInRange(ctx, prefix, false) +} + +// LastInRange implements engine.Engine +func (e *etcdE) LastInRange(ctx context.Context, prefix string) (string, string, error) { + return e.firstInRange(ctx, prefix, true) +} + +// PrepareStream implements engine.Engine +func (e *etcdE) PrepareStream( + ctx context.Context, prefix string, decode engine.DecodeHandleFunc, by engine.SortBy, +) engine.PrepareStmt { + sortOp, err := e.genSortOpOption(by) + if err != nil { + return &etcdStmt{ + fn: func() ([]interface{}, error) { + return nil, err + }, + } + } + + fn := func() ([]interface{}, error) { + etcdCtxInfo := commonetcd.CreateEtcdCtxInfoWithTimeout(ctx, e.cfg.Timeout) + resp, err := e.cli.GetResponse(etcdCtxInfo, prefix, clientv3.WithPrefix(), sortOp) + if err != nil { + return nil, err + } + + log.GetLogger().Debugf("stream prefix: %s, resp num: %v", prefix, len(resp.Kvs)) + + var res []interface{} + for _, kv := range resp.Kvs { + key := e.cli.DetachAZPrefix(string(kv.Key)) + i, err := decode(key, string(kv.Value)) + if err != nil { + return nil, err + } + res = append(res, i) + } + return res, nil + } + + return &etcdStmt{ + fn: fn, + } +} + +func (e *etcdE) genSortOpOption(by engine.SortBy) (clientv3.OpOption, error) { + var ( + order clientv3.SortOrder + target clientv3.SortTarget + ) + switch by.Order { + case engine.Ascend: + order = clientv3.SortAscend + case engine.Descend: + order = clientv3.SortDescend + default: + return nil, fmt.Errorf("invalid sort order: %v", by.Order) + } + switch by.Target { + case engine.SortName: + target = clientv3.SortByKey + case engine.SortCreate: + target = clientv3.SortByCreateRevision + case engine.SortModify: + target = clientv3.SortByModRevision + default: + return nil, fmt.Errorf("invalid sort target: %v", by.Target) + } + sortOp := clientv3.WithSort(target, order) + return sortOp, nil +} + +// Put implements engine.Engine +func (e *etcdE) Put(ctx context.Context, etcdKey string, value string) error { + etcdCtxInfo := commonetcd.CreateEtcdCtxInfoWithTimeout(ctx, e.cfg.Timeout) + err := e.cli.Put(etcdCtxInfo, etcdKey, value) + if err != nil { + log.GetLogger().Debugf("put key: %s", etcdKey) + } + return err +} + +// Delete implements engine.Engine +func (e *etcdE) Delete(ctx context.Context, etcdKey string) error { + etcdCtxInfo := commonetcd.CreateEtcdCtxInfoWithTimeout(ctx, e.cfg.Timeout) + err := e.cli.Delete(etcdCtxInfo, etcdKey) + if err != nil { + log.GetLogger().Debugf("delete etcd key: %s", etcdKey) + } + return err +} + +// BeginTx implements engine.Engine +func (e *etcdE) BeginTx(ctx context.Context) engine.Transaction { + return newTransaction(ctx, e.cli, e.cfg.TransactionTimeout) +} + +// Close implements engine.Engine +func (e *etcdE) Close() error { + return e.cli.Client.Close() +} diff --git a/yuanrong/pkg/common/engine/etcd/stream.go b/yuanrong/pkg/common/engine/etcd/stream.go new file mode 100644 index 0000000..a56aacd --- /dev/null +++ b/yuanrong/pkg/common/engine/etcd/stream.go @@ -0,0 +1,71 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd implements engine.Engine +package etcd + +import ( + "io" + "yuanrong/pkg/common/engine" +) + +type etcdStmt struct { + fn func() ([]interface{}, error) + filters []engine.FilterFunc +} + +// Filter implements engine.PrepareStmt +func (s *etcdStmt) Filter(filter engine.FilterFunc) engine.PrepareStmt { + s.filters = append(s.filters, filter) + return s +} + +// Execute implements engine.PrepareStmt +func (s *etcdStmt) Execute() (engine.Stream, error) { + vs, err := s.fn() + if err != nil { + return nil, err + } + + var res []interface{} +outer: + for _, v := range vs { + for _, filter := range s.filters { + if !filter(v) { + continue outer + } + } + res = append(res, v) + } + return &etcdStream{vs: res}, nil +} + +type etcdStream struct { + vs []interface{} + pos int +} + +// Next implements engine.Stream +func (s *etcdStream) Next() (interface{}, error) { + defer func() { + s.pos++ + }() + + if s.pos == len(s.vs) { + return nil, io.EOF + } + return s.vs[s.pos], nil +} diff --git a/yuanrong/pkg/common/engine/etcd/transaction.go b/yuanrong/pkg/common/engine/etcd/transaction.go new file mode 100644 index 0000000..21bedc3 --- /dev/null +++ b/yuanrong/pkg/common/engine/etcd/transaction.go @@ -0,0 +1,309 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package etcd + +import ( + "context" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/constants" + "yuanrong/pkg/common/engine" + "yuanrong/pkg/common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" +) + +type reads struct { + resp *clientv3.GetResponse + modRev int64 + withPrefix bool +} + +const ( + writeOp = iota + delOp +) + +type writes struct { + value string + op int + withPrefix bool +} + +// Transaction utilities etcd v3 transaction to perform transaction with separated expressions +type Transaction struct { + etcdClient *etcd3.EtcdClient + + rds map[string]*reads + wrs map[string]*writes + + ctx context.Context + cancel func() +} + +// newTransaction creates a new transaction object to records all the method that will be performed in it. +func newTransaction(ctx context.Context, client *etcd3.EtcdClient, timeout time.Duration) *Transaction { + ctx, cancel := context.WithTimeout(ctx, timeout) + t := Transaction{etcdClient: client, ctx: ctx, cancel: cancel} + t.rds = make(map[string]*reads, constants.DefaultMapSize) + t.wrs = make(map[string]*writes, constants.DefaultMapSize) + return &t +} + +// Put caches a key-value pair, it will be replaced if the same key has been called within the transaction. +func (t *Transaction) Put(key string, value string) { + log.GetLogger().Debugf("transaction put key: %s, value: %s", key, value) + t.wrs[key] = &writes{value, writeOp, false} +} + +func getRespMaxModRev(resp *clientv3.GetResponse) int64 { + var rev int64 = 0 + for _, kv := range resp.Kvs { + if kv.ModRevision > rev { + rev = kv.ModRevision + } + } + return rev +} + +// Get returns value of the key from etcd, and cached it to perform 'If' statement in the transaction. +// The method will return cached value if the same key has been called within the transaction. +func (t *Transaction) Get(key string) (string, error) { + if v, ok := t.wrs[key]; ok { + if v.op == writeOp { + log.GetLogger().Debugf("cached: transaction get key: %s, value: %s", key, v.value) + return v.value, nil + } + log.GetLogger().Debugf("cached: transaction get key: %s, value: %s, key not found", key) + return "", engine.ErrKeyNotFound + } + if v, ok := t.rds[key]; ok { + if len(v.resp.Kvs) == 1 { + return string(v.resp.Kvs[0].Value), nil + } + return "", engine.ErrKeyNotFound + } + + etcdCtxInfo := etcd3.CreateEtcdCtxInfo(t.ctx) + resp, err := t.etcdClient.GetResponse(etcdCtxInfo, key) + if err != nil { + log.GetLogger().Errorf("failed to get from etcd, error: %s", err.Error()) + return "", err + } + t.rds[key] = &reads{resp, getRespMaxModRev(resp), false} + if len(resp.Kvs) == 1 { + log.GetLogger().Debugf("transaction get key: %s, value: %s, modRevision: %d", + key, string(resp.Kvs[0].Value), resp.Kvs[0].ModRevision) + return string(resp.Kvs[0].Value), nil + } + log.GetLogger().Debugf("transaction get key: %s, value: %s, key not found", key) + return "", engine.ErrKeyNotFound +} + +func (t *Transaction) getRespKVs(resp *clientv3.GetResponse) (keys, values []string) { + for _, kv := range resp.Kvs { + key := t.etcdClient.DetachAZPrefix(string(kv.Key)) + keys = append(keys, key) + values = append(values, string(kv.Value)) + log.GetLogger().Debugf("transaction get prefix key: %s, modRevision: %d", key, kv.ModRevision) + } + return +} + +// GetPrefix returns values of the key as a prefix from etcd, and cached them to perform 'If' statement in the +// transaction. The method will return cached values if the same key has been called within the transaction. +func (t *Transaction) GetPrefix(key string) (keys, values []string, err error) { + if v, ok := t.rds[key]; ok { + keys, values = t.getRespKVs(v.resp) + log.GetLogger().Debugf("cached: transaction get prefix: %s, resp num: %v", key, len(keys)) + return keys, values, nil + } + + etcdCtxInfo := etcd3.CreateEtcdCtxInfo(t.ctx) + resp, err := t.etcdClient.GetResponse(etcdCtxInfo, key, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("failed to get with prefix from etcd, error: %s", err.Error()) + return nil, nil, err + } + t.rds[key] = &reads{resp, getRespMaxModRev(resp), true} + keys, values = t.getRespKVs(resp) + log.GetLogger().Debugf("transaction get prefix: %s, resp num: %v", key, len(keys)) + return keys, values, nil +} + +// Del caches a key, it will be replaced if the same key has been called within the transaction. +func (t *Transaction) Del(key string) { + log.GetLogger().Debugf("transaction delete key: %s", key) + t.wrs[key] = &writes{"", delOp, false} +} + +// DelPrefix caches a key, it will be replaced if the same key has been called within the transaction. +func (t *Transaction) DelPrefix(key string) { + log.GetLogger().Debugf("transaction delete prefix: %s", key) + t.wrs[key] = &writes{"", delOp, true} +} + +func (t *Transaction) genCmp() []clientv3.Cmp { + cs := make([]clientv3.Cmp, 0, len(t.rds)) + for k, v := range t.rds { + k = t.etcdClient.AttachAZPrefix(k) + result := "=" + rev := v.modRev + if v.withPrefix && rev != 0 { + result = "<" + rev++ + } + + c := clientv3.Compare(clientv3.ModRevision(k), result, rev) + if v.withPrefix { + c = c.WithPrefix() + } + cs = append(cs, c) + } + return cs +} + +func (t *Transaction) genOp() []clientv3.Op { + ops := make([]clientv3.Op, 0, len(t.wrs)) + for k, v := range t.wrs { + k = t.etcdClient.AttachAZPrefix(k) + var op clientv3.Op + if v.op == writeOp { + op = clientv3.OpPut(k, v.value) + } else { + if v.withPrefix { + op = clientv3.OpDelete(k, clientv3.WithPrefix()) + } else { + op = clientv3.OpDelete(k) + } + } + ops = append(ops, op) + } + return ops +} + +// Commit do the 'If' statement with all Get and GetPrefix methods, +// and do the 'Then' statement with all Put, Del, DelPrefix methods. +func (t *Transaction) Commit() error { + resp, err := t.etcdClient.Client.KV.Txn(t.ctx).If(t.genCmp()...).Then(t.genOp()...).Commit() + if err != nil { + log.GetLogger().Errorf("failed to commit transaction, error: %s", err.Error()) + return err + } + if !resp.Succeeded { + log.GetLogger().Errorf("transaction commit not succeeded") + t.printError() + return engine.ErrTransaction + } + log.GetLogger().Debugf("transaction get revision, revision: %d", resp.Header.Revision) + return nil +} + +// Cancel undoes the commit. +func (t *Transaction) Cancel() { + t.cancel() +} + +func (t *Transaction) printError() { + for k, v := range t.rds { + rev := v.modRev + if v.withPrefix { + if t.rereadPrefix(k, rev) != nil { + continue + } + } else { + if t.rereadKey(k, rev) != nil { + continue + } + } + } +} + +func (t *Transaction) rereadPrefix(k string, revision int64) error { + etcdCtxInfo := etcd3.CreateEtcdCtxInfo(t.ctx) + resp, err := t.etcdClient.GetResponse(etcdCtxInfo, k, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("failed to reread data with prefix from etcd, error: %s", err.Error()) + return err + } + t.printPrefixValue(resp, k, revision) + return nil +} + +func (t *Transaction) printPrefixValue(resp *clientv3.GetResponse, k string, revision int64) { + for index, kv := range resp.Kvs { + if kv.ModRevision > revision { + log.GetLogger().Errorf("reread etcd data by prefix, Key : %s", string(resp.Kvs[index].Key)) + log.GetLogger().Errorf("reread etcd data by prefix, Value: %s", string(resp.Kvs[index].Value)) + log.GetLogger().Errorf("reread etcd data by prefix, CreateRevision: %d, ModRevision: %d, Version: %d", + resp.Kvs[index].CreateRevision, resp.Kvs[index].ModRevision, resp.Kvs[index].Version) + t.printCachedPrefixValue(k, string(kv.Key)) + } + } +} + +func (t *Transaction) printCachedPrefixValue(prefix string, key string) { + if v, ok := t.rds[prefix]; ok { + for i, value := range v.resp.Kvs { + if string(value.Key) == string(key) { + log.GetLogger().Errorf("get cached data by prefix, Key : %s", string(v.resp.Kvs[i].Key)) + log.GetLogger().Errorf("get cached data by prefix, Value: %s", string(v.resp.Kvs[i].Value)) + log.GetLogger().Errorf("get cached data by prefix, CreateRevision: %d, ModRevision: %d, Version: %d", + v.resp.Kvs[i].CreateRevision, v.resp.Kvs[i].ModRevision, v.resp.Kvs[i].Version) + } + } + } else { + log.GetLogger().Errorf("invalid prefix, prefix : %s", prefix) + } +} + +func (t *Transaction) rereadKey(k string, revision int64) error { + etcdCtxInfo := etcd3.CreateEtcdCtxInfo(t.ctx) + resp, err := t.etcdClient.GetResponse(etcdCtxInfo, k) + if err != nil { + log.GetLogger().Errorf("invalid key, k: %s, error: %s", k, err.Error()) + return err + } + t.printKeyValue(resp, k, revision) + return nil +} + +func (t *Transaction) printKeyValue(resp *clientv3.GetResponse, k string, revision int64) { + for index, kv := range resp.Kvs { + if kv.ModRevision != revision { + log.GetLogger().Errorf("reread etcd data, Key : %s", string(resp.Kvs[index].Key)) + log.GetLogger().Errorf("reread etcd data, Value: %s", string(resp.Kvs[index].Value)) + log.GetLogger().Errorf("reread etcd data, CreateRevision: %d, ModRevision: %d, Version: %d", + resp.Kvs[index].CreateRevision, resp.Kvs[index].ModRevision, resp.Kvs[index].Version) + t.printCachedKeyValue(k) + } + } +} + +func (t *Transaction) printCachedKeyValue(k string) { + if v, ok := t.rds[k]; ok { + if len(v.resp.Kvs) == 1 { + log.GetLogger().Errorf("get cached data, Key : %s", string(v.resp.Kvs[0].Key)) + log.GetLogger().Errorf("get cached data, Value: %s", string(v.resp.Kvs[0].Value)) + log.GetLogger().Errorf("get cached data, CreateRevision: %d, ModRevision: %d, Version: %d", + v.resp.Kvs[0].CreateRevision, v.resp.Kvs[0].ModRevision, v.resp.Kvs[0].Version) + } + } else { + log.GetLogger().Errorf("invalid key, k: %s", k) + } +} diff --git a/yuanrong/pkg/common/engine/etcd/transaction_test.go b/yuanrong/pkg/common/engine/etcd/transaction_test.go new file mode 100644 index 0000000..0014dd2 --- /dev/null +++ b/yuanrong/pkg/common/engine/etcd/transaction_test.go @@ -0,0 +1,175 @@ +package etcd + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey" + . "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/etcd3" +) + +// 给一个prefix,从rds里找出对应的reads +func TestPrintCachedPrefixValue(t *testing.T) { + kv1 := &mvccpb.KeyValue{ + Key: []byte("mock-key"), + Value: []byte("mock-key"), + ModRevision: 1, + Version: 1, + } + tr := &Transaction{ + rds: map[string]*reads{ + "mock-prefix": {resp: &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{kv1}, + }}, + }, + } + Convey("Test PrintCachedPrefixValue", t, func() { + tr.printCachedPrefixValue("mock-prefix", "mock-key") + tr.printCachedPrefixValue("mock-prefix00", "mock-key") + }) +} + +func TestPrintPrefixValue(t *testing.T) { + Convey("Test printPrefixValue", t, func() { + kv1 := &mvccpb.KeyValue{ + Key: []byte("mock-key"), + Value: []byte("mock-key"), + CreateRevision: 1, + ModRevision: 1, + Version: 1, + } + resp := &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{kv1}, + } + tr := &Transaction{ + rds: map[string]*reads{ + "mock-prefix": {resp: resp}, + }, + } + tr.printPrefixValue(resp, "mock-prefix", 0) + }) +} + +func TestRereadPrefix(t *testing.T) { + kv1 := &mvccpb.KeyValue{ + Key: []byte("mock-key"), + Value: []byte("mock-key"), + } + resp := &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{kv1}, + } + tr := &Transaction{ + rds: map[string]*reads{ + "mock-prefix": {resp: resp}, + }, + ctx: context.TODO(), + etcdClient: &etcd3.EtcdClient{}, + } + Convey("Test RereadPrefix", t, func() { + Convey("with err", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "GetResponse", + func(w *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, errors.New("mock err") + }) + defer patch.Reset() + err := tr.rereadPrefix("mock-prefix", 0) + So(err, ShouldNotBeNil) + }) + Convey("without err", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "GetResponse", + func(w *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return resp, nil + }) + defer patch.Reset() + err := tr.rereadPrefix("mock-prefix", 0) + So(err, ShouldBeNil) + }) + }) +} + +func TestRereadKey(t *testing.T) { + kv1 := &mvccpb.KeyValue{ + Key: []byte("mock-key"), + Value: []byte("mock-key"), + CreateRevision: 1, + ModRevision: 1, + Version: 1, + } + resp := &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{kv1}, + } + tr := &Transaction{ + rds: map[string]*reads{ + "mock-prefix": {resp: resp}, + }, + ctx: context.TODO(), + etcdClient: &etcd3.EtcdClient{}, + } + Convey("Test rereadKey", t, func() { + Convey("with err", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "GetResponse", + func(w *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, errors.New("mock err") + }) + defer patch.Reset() + err := tr.rereadKey("mock-prefix", 0) + So(err, ShouldNotBeNil) + }) + Convey("without err", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "GetResponse", + func(w *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return resp, nil + }) + defer patch.Reset() + err := tr.rereadKey("mock-prefix", 0) + So(err, ShouldBeNil) + }) + Convey("with printCachedKeyValue err", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "GetResponse", + func(w *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return resp, nil + }) + defer patch.Reset() + delete(tr.rds, "mock-prefix") + err := tr.rereadKey("mock-prefix", 0) + So(err, ShouldBeNil) + }) + }) +} + +func TestPrintError(t *testing.T) { + kv1 := &mvccpb.KeyValue{Key: []byte("mock-key"), Value: []byte("mock-key"), CreateRevision: 1, + ModRevision: 1, Version: 1, + } + kv2 := &mvccpb.KeyValue{Key: []byte("mock-key2"), Value: []byte("mock-key2"), CreateRevision: 1, + ModRevision: 1, Version: 1, + } + resp := &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{kv1, kv2}} + tr := &Transaction{ + rds: map[string]*reads{ + "mock-prefix": {resp: resp, withPrefix: true}, + "mock-prefix2": {resp: resp, withPrefix: false}}, + ctx: context.TODO(), + etcdClient: &etcd3.EtcdClient{}, + } + Convey("Test printError", t, func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "GetResponse", + func(w *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return resp, nil + }) + defer patch.Reset() + tr.printError() + }) +} diff --git a/yuanrong/pkg/common/engine/interface.go b/yuanrong/pkg/common/engine/interface.go new file mode 100644 index 0000000..0b23d12 --- /dev/null +++ b/yuanrong/pkg/common/engine/interface.go @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package engine defines interfaces for a storage kv engine +package engine + +import ( + "context" + "errors" +) + +var ( + // ErrKeyNotFound means key does not exist + ErrKeyNotFound = errors.New("key not found") + // ErrTransaction means a transaction request has failed. User should retry. + ErrTransaction = errors.New("failed to execute a transaction") +) + +// Engine defines a storage kv engine +type Engine interface { + // Get retrieves the value for a key. It returns ErrKeyNotFound if key does not exist. + Get(ctx context.Context, key string) (val string, err error) + + // Count retries the number of key value pairs from a prefix key. + Count(ctx context.Context, prefix string) (int64, error) + + // FirstInRange retries the first key value pairs sort by keys from a prefix. It returns ErrKeyNotFound if the + // range is empty. + FirstInRange(ctx context.Context, prefix string) (key, val string, err error) + + // LastInRange is similar to FirstInRange by searches backwards. It returns ErrKeyNotFound if the range is empty. + LastInRange(ctx context.Context, prefix string) (key, val string, err error) + + // PrepareStream creates a stream that users can combine it with "Filter" and "Execute" to fuzzy search from a + // range of key value pairs. + PrepareStream(ctx context.Context, prefix string, decode DecodeHandleFunc, by SortBy) PrepareStmt + + // Put writes a key value pairs. + Put(ctx context.Context, key string, value string) error + + // Delete removes a key value pair. + Delete(ctx context.Context, key string) error + + // BeginTx starts a new transaction. + BeginTx(ctx context.Context) Transaction + + // Close cleans up any resources holds by the engine. + Close() error +} + +// FilterFunc filters a range stream +type FilterFunc func(interface{}) bool + +// PrepareStmt allows fuzzy searching of a range of key value pairs. +type PrepareStmt interface { + Filter(filter FilterFunc) PrepareStmt + Execute() (Stream, error) +} + +// Stream defines interface to range a filtered key value pairs. +type Stream interface { + // Next returns the next value from a stream. It returns io.EOF when stream ends + Next() (interface{}, error) +} + +// DecodeHandleFunc decodes key value pairs to a user defined type. +type DecodeHandleFunc func(key, value string) (interface{}, error) + +// SortOrder is the sorting order of a stream. +type SortOrder int + +const ( + // Ascend starts from small to large. + Ascend SortOrder = iota + + // Descend starts from large to small. + Descend +) + +// SortTarget is the sorting target of a stream. +type SortTarget int + +const ( + // SortName tells a stream to sort by name. + SortName SortTarget = iota + + // SortCreate tells a stream to sort by create time. + SortCreate + + // SortModify tells a stream to sort by update time. + SortModify +) + +// SortBy is the sorting order and target of a stream. +type SortBy struct { + Order SortOrder + Target SortTarget +} + +// Transaction ensures consistent view and change of the data. +type Transaction interface { + Get(key string) (val string, err error) + GetPrefix(key string) (keys, values []string, err error) + Put(key string, value string) + Del(key string) + DelPrefix(key string) + Commit() error + Cancel() +} diff --git a/yuanrong/pkg/common/etcd3/config.go b/yuanrong/pkg/common/etcd3/config.go new file mode 100644 index 0000000..1201d82 --- /dev/null +++ b/yuanrong/pkg/common/etcd3/config.go @@ -0,0 +1,169 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "context" + "crypto/tls" + "crypto/x509" + + clientv3 "go.etcd.io/etcd/client/v3" + "k8s.io/apimachinery/pkg/util/wait" + + "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/logger/log" + commontls "yuanrong/pkg/common/tls" +) + +// EtcdConfig the info to get function instance +type EtcdConfig struct { + Servers []string `json:"servers" yaml:"servers" valid:"required"` + User string `json:"user" yaml:"user" valid:"optional"` + Passwd string `json:"password" yaml:"password" valid:"optional"` + AuthType string `json:"authType" yaml:"authType" valid:"optional"` + SslEnable bool `json:"sslEnable,omitempty" yaml:"sslEnable,omitempty" valid:"optional"` + LimitRate int `json:"limitRate,omitempty" yaml:"limitRate,omitempty" valid:"optional"` + LimitBurst int `json:"limitBurst,omitempty" yaml:"limitBurst,omitempty" valid:"optional"` + LimitTimeout int `json:"limitTimeout,omitempty" yaml:"limitTimeout,omitempty" valid:"optional"` + LeaseTTL int64 `json:"leaseTTL,omitempty" yaml:"leaseTTL,omitempty" valid:"optional"` + RenewTTL int64 `json:"renewTTL,omitempty" yaml:"renewTTL,omitempty" valid:"optional"` + CaFile string `json:"cafile,omitempty" yaml:"cafile,omitempty" valid:"optional"` + CertFile string `json:"certfile,omitempty" yaml:"certfile,omitempty" valid:"optional"` + KeyFile string `json:"keyfile,omitempty" yaml:"keyfile,omitempty" valid:"optional"` + PassphraseFile string `json:"passphraseFile,omitempty" yaml:"passphraseFile,omitempty" valid:"optional"` + AZPrefix string `json:"azPrefix,omitempty" yaml:"azPrefix,omitempty" valid:"optional"` + // DisableSync will not run the sync method to avoid endpoints being replaced by the domain name, default is FALSE + DisableSync bool +} + +// GetETCDCertificatePath get the certificate path from tlsConfig. +func GetETCDCertificatePath(config EtcdConfig, tlsConfig commontls.MutualTLSConfig) EtcdConfig { + if !config.SslEnable { + return config + } + config.CaFile = tlsConfig.RootCAFile + config.CertFile = tlsConfig.ModuleCertFile + config.KeyFile = tlsConfig.ModuleKeyFile + return config +} + +type etcdAuth interface { + getEtcdConfig() (*clientv3.Config, error) + renewToken(client *clientv3.Client, stop chan struct{}) +} + +type noAuth struct { +} + +type tlsAuth struct { + cerfile string + keyfile string + cafile string +} + +type pwdAuth struct { + user string + passWd []byte +} + +// this support no tls so we can leave scc dependencies behind +func (e *EtcdConfig) getEtcdAuthTypeNoTLS() etcdAuth { + if e.SslEnable { + return &tlsAuth{ + cerfile: e.CertFile, + keyfile: e.KeyFile, + cafile: e.CaFile, + } + } + if e.User == "" || e.Passwd == "" { + return &noAuth{} + } + return &pwdAuth{ + user: e.User, + passWd: []byte(e.Passwd), + } +} + +func (n *noAuth) getEtcdConfig() (*clientv3.Config, error) { + return &clientv3.Config{}, nil +} + +func (n *noAuth) renewToken(client *clientv3.Client, stop chan struct{}) { + return +} + +func (t *tlsAuth) getEtcdConfig() (*clientv3.Config, error) { + var pool *x509.CertPool + pool, err := commontls.GetX509CACertPool(t.cafile) + if err != nil { + log.GetLogger().Errorf("failed to getX509CACertPool: %s", err.Error()) + return nil, err + } + + var certs []tls.Certificate + certs, err = commontls.LoadServerTLSCertificate(t.cerfile, t.keyfile) + if err != nil { + log.GetLogger().Errorf("failed to loadServerTLSCertificate: %s", err.Error()) + return nil, err + } + + clientAuthMode := tls.NoClientCert + cfg := &clientv3.Config{ + TLS: &tls.Config{ + RootCAs: pool, + Certificates: certs, + ClientAuth: clientAuthMode, + }, + } + return cfg, nil +} + +func (t *tlsAuth) renewToken(client *clientv3.Client, stop chan struct{}) { + return +} + +func (p *pwdAuth) getEtcdConfig() (*clientv3.Config, error) { + passWd, err := crypto.Decrypt(p.passWd, crypto.GetRootKey()) + if err != nil { + log.GetLogger().Errorf("failed to decrypt, error:%s", err.Error()) + return nil, err + } + cfg := &clientv3.Config{ + Username: p.user, + Password: passWd, + } + return cfg, nil +} + +// renewToken can keep client token not expired, because +// etcd server will renew simple token TTL if token is not expired. +func (p *pwdAuth) renewToken(client *clientv3.Client, stop chan struct{}) { + if stop == nil { + log.GetLogger().Errorf("stop chan is nil") + return + } + wait.Until(func() { + ctx, cancel := context.WithTimeout(context.Background(), DurationContextTimeout) + _, err := client.Get(ctx, renewKey) + cancel() + if err != nil { + log.GetLogger().Warnf("renew token error: %s", err.Error()) + } + }, tokenRenewTTL, stop) + log.GetLogger().Infof("stopped to renew token") +} diff --git a/yuanrong/pkg/common/etcd3/config_test.go b/yuanrong/pkg/common/etcd3/config_test.go new file mode 100644 index 0000000..45963d9 --- /dev/null +++ b/yuanrong/pkg/common/etcd3/config_test.go @@ -0,0 +1,162 @@ +package etcd3 + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "testing" + + "github.com/agiledragon/gomonkey" + . "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/crypto" + commontls "yuanrong/pkg/common/tls" +) + +func TestGetETCDCertificatePath(t *testing.T) { + Convey("Test ssl enable", t, func() { + config := EtcdConfig{ + SslEnable: true, + CaFile: "", + CertFile: "", + KeyFile: "", + } + tlsConfig := commontls.MutualTLSConfig{ + RootCAFile: "xxx.ca", + ModuleCertFile: "xxx.cert", + ModuleKeyFile: "xxx.key", + } + etcdConfig := GetETCDCertificatePath(config, tlsConfig) + So(etcdConfig.CaFile, ShouldEqual, tlsConfig.RootCAFile) + So(etcdConfig.CertFile, ShouldEqual, tlsConfig.ModuleCertFile) + So(etcdConfig.KeyFile, ShouldEqual, tlsConfig.ModuleKeyFile) + }) + Convey("Test ssl disable", t, func() { + config := EtcdConfig{ + SslEnable: false, + CaFile: "", + CertFile: "", + KeyFile: "", + } + tlsConfig := commontls.MutualTLSConfig{ + RootCAFile: "xxx.ca", + ModuleCertFile: "xxx.cert", + ModuleKeyFile: "xxx.key", + } + etcdConfig := GetETCDCertificatePath(config, tlsConfig) + So(etcdConfig.CaFile, ShouldNotEqual, tlsConfig.RootCAFile) + So(etcdConfig.CertFile, ShouldNotEqual, tlsConfig.ModuleCertFile) + So(etcdConfig.KeyFile, ShouldNotEqual, tlsConfig.ModuleKeyFile) + }) +} + +func TestGetEtcdAuthType(t *testing.T) { + Convey("Test getEtcdAuthType, tlsAuth", t, func() { + config := EtcdConfig{ + SslEnable: true, + CaFile: "", + CertFile: "", + KeyFile: "", + } + etcdAuth := config.getEtcdAuthType() + So(etcdAuth, ShouldNotBeNil) + }) + Convey("Test getEtcdAuthType, noAuth", t, func() { + config := EtcdConfig{ + SslEnable: false, + CaFile: "", + CertFile: "", + KeyFile: "", + User: "test", + Passwd: "", + } + etcdAuth := config.getEtcdAuthType() + So(etcdAuth, ShouldNotBeNil) + }) + Convey("Test getEtcdAuthType, pwdAuth", t, func() { + config := EtcdConfig{ + SslEnable: false, + CaFile: "", + CertFile: "", + KeyFile: "", + User: "test", + Passwd: "test", + } + etcdAuth := config.getEtcdAuthType() + So(etcdAuth, ShouldNotBeNil) + }) + Convey("Test getEtcdAuthType, clientTlsAuth", t, func() { + config := EtcdConfig{ + AuthType: "TLS", + SslEnable: false, + CaFile: "", + CertFile: "", + KeyFile: "", + User: "test", + Passwd: "test", + } + etcdAuth := config.getEtcdAuthType() + So(etcdAuth, ShouldNotBeNil) + }) +} + +func TestGetEtcdConfigTlsAuth(t *testing.T) { + tlsAuth := &tlsAuth{} + Convey("GetX509CACertPool success", t, func() { + patch := gomonkey.ApplyFunc(commontls.GetX509CACertPool, func(string) (*x509.CertPool, error) { + return nil, nil + }) + defer patch.Reset() + Convey("LoadServerTLSCertificate success", func() { + patch := gomonkey.ApplyFunc(commontls.LoadServerTLSCertificate, func(string, string) ([]tls.Certificate, error) { + return nil, nil + }) + defer patch.Reset() + _, err := tlsAuth.getEtcdConfig() + So(err, ShouldBeNil) + }) + Convey("LoadServerTLSCertificate fail", func() { + patch := gomonkey.ApplyFunc(commontls.LoadServerTLSCertificate, func(string, string) ([]tls.Certificate, error) { + return nil, errors.New("LoadServerTLSCertificate fail") + }) + defer patch.Reset() + _, err := tlsAuth.getEtcdConfig() + So(err, ShouldNotBeNil) + }) + }) +} + +func TestRenewTokenTlsAuth(t *testing.T) { + tlsAuth := &tlsAuth{} + cli := &EtcdClient{} + stopCh := make(chan struct{}) + + tlsAuth.renewToken(cli.Client, stopCh) +} + +func TestRenewTokenPwdAuth(t *testing.T) { + pwdAuth := &pwdAuth{} + cli := &EtcdClient{} + + pwdAuth.renewToken(cli.Client, nil) +} + +func TestGetEtcdConfigPwdAuth(t *testing.T) { + pwdAuth := &pwdAuth{} + Convey("Decrypt success", t, func() { + patch := gomonkey.ApplyFunc(crypto.Decrypt, func([]byte, []byte) (string, error) { + return "", nil + }) + defer patch.Reset() + _, err := pwdAuth.getEtcdConfig() + So(err, ShouldBeNil) + }) + Convey("Decrypt fail", t, func() { + patch := gomonkey.ApplyFunc(crypto.Decrypt, func([]byte, []byte) (string, error) { + return "", errors.New("decrypt fail") + }) + defer patch.Reset() + _, err := pwdAuth.getEtcdConfig() + So(err, ShouldNotBeNil) + }) +} diff --git a/yuanrong/pkg/common/etcd3/event.go b/yuanrong/pkg/common/etcd3/event.go new file mode 100644 index 0000000..3d0853c --- /dev/null +++ b/yuanrong/pkg/common/etcd3/event.go @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "time" + + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" +) + +const ( + // PUT event + PUT = iota + // DELETE event + DELETE + // ERROR unexpected event + ERROR + // SYNCED synced event + SYNCED +) + +// Event of databases +type Event struct { + Type int + Key string + Value []byte + PrevValue []byte + Rev int64 +} + +// parseKV converts a KeyValue retrieved from an initial sync() listing to a synthetic isCreated event. +func parseKV(kv *mvccpb.KeyValue) *Event { + return &Event{ + Type: PUT, + Key: string(kv.Key), + Value: kv.Value, + PrevValue: nil, + Rev: kv.ModRevision, + } +} + +func parseEvent(e *clientv3.Event) *Event { + eType := 0 + if e.Type == clientv3.EventTypeDelete { + eType = DELETE + } + ret := &Event{ + Type: eType, + Key: string(e.Kv.Key), + Value: e.Kv.Value, + Rev: e.Kv.ModRevision, + } + if e.PrevKv != nil { + ret.PrevValue = e.PrevKv.Value + } + return ret +} + +func parseErr(err error) *Event { + return &Event{Type: ERROR, Value: []byte(err.Error())} +} + +func parseSync(t time.Time) *Event { + return &Event{Type: SYNCED, Value: []byte(t.String())} +} diff --git a/yuanrong/pkg/common/etcd3/event_test.go b/yuanrong/pkg/common/etcd3/event_test.go new file mode 100644 index 0000000..e4f32cb --- /dev/null +++ b/yuanrong/pkg/common/etcd3/event_test.go @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "fmt" + "testing" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/api/v3/mvccpb" +) + +func TestParseKV(t *testing.T) { + kv := &mvccpb.KeyValue{ + Key: []byte("/sn/workeragent/abc"), + Value: []byte("value_abc"), + ModRevision: 1, + } + res := parseKV(kv) + if res == nil { + t.Errorf("failed to parse kv") + } +} + +func TestParseEvent(t *testing.T) { + e := &clientv3.Event{ + Type: clientv3.EventTypeDelete, + Kv: &mvccpb.KeyValue{ + Key: []byte("/sn/workeragent/abc"), + Value: []byte("value_abc"), + ModRevision: 1, + }, + } + + res := parseEvent(e) + if res == nil { + t.Errorf("failed to parse event") + } + e = &clientv3.Event{ + Type: clientv3.EventTypePut, + Kv: &mvccpb.KeyValue{ + Key: []byte("/sn/workeragent/abc"), + Value: []byte("value_abc"), + ModRevision: 1, + }, + PrevKv: &mvccpb.KeyValue{ + Key: []byte("/sn/workeragent/def"), + Value: []byte("value_def"), + ModRevision: 1, + }, + } + res = parseEvent(e) + convey.ShouldNotBeNil(res) +} +func TestParseErr(t *testing.T) { + res := parseErr(fmt.Errorf("test")) + convey.ShouldNotBeNil(res) +} + +func TestParseSync(t *testing.T) { + res := parseSync(time.Time{}) + convey.ShouldNotBeNil(res) +} diff --git a/yuanrong/pkg/common/etcd3/scc_config.go b/yuanrong/pkg/common/etcd3/scc_config.go new file mode 100644 index 0000000..358207d --- /dev/null +++ b/yuanrong/pkg/common/etcd3/scc_config.go @@ -0,0 +1,147 @@ +//go:build cryptoapi +// +build cryptoapi + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "os" + + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/logger/log" +) + +type clientTlsAuth struct { + cerfile string + keyfile string + cafile string + passphrasefile string +} + +func (e *EtcdConfig) getEtcdAuthType() etcdAuth { + if e.AuthType == "TLS" { + return &clientTlsAuth{ + cerfile: e.CertFile, + keyfile: e.KeyFile, + cafile: e.CaFile, + passphrasefile: e.PassphraseFile, + } + } + if e.SslEnable { + return &tlsAuth{ + cerfile: e.CertFile, + keyfile: e.KeyFile, + cafile: e.CaFile, + } + } + if e.User == "" || e.Passwd == "" { + return &noAuth{} + } + return &pwdAuth{ + user: e.User, + passWd: []byte(e.Passwd), + } +} + +func (c *clientTlsAuth) getEtcdConfig() (*clientv3.Config, error) { + var keyPwd []byte + if _, err := os.Stat(c.passphrasefile); err == nil { + keyPwd, err = ioutil.ReadFile(c.passphrasefile) + if err != nil { + log.GetLogger().Errorf("failed to read passphrasefile, %s", err.Error()) + return nil, err + } + if crypto.SCCInitialized() { + pwd, err := crypto.SCCDecrypt(keyPwd) + if err != nil { + log.GetLogger().Errorf("failed to decrypt passphrasefile, %s", err.Error()) + return nil, err + } + keyPwd = []byte(pwd) + } + } + + certPEM, err := ioutil.ReadFile(c.cerfile) + if err != nil { + log.GetLogger().Errorf("failed to read cert file: %s", err.Error()) + return nil, err + } + + caCertPEM, err := ioutil.ReadFile(c.cafile) + if err != nil { + log.GetLogger().Errorf("failed to read ca file: %s", err.Error()) + return nil, err + } + + encryptedKeyPEM, err := ioutil.ReadFile(c.keyfile) + if err != nil { + log.GetLogger().Errorf("failed to read key file: %s", err.Error()) + return nil, err + } + + keyBlock, _ := pem.Decode(encryptedKeyPEM) + if keyBlock == nil { + log.GetLogger().Errorf("failed to decode key PEM block") + return nil, err + } + + keyDER, err := crypto.DecryptPEMBlock(keyBlock, keyPwd) + if err != nil { + log.GetLogger().Errorf("failed to decrypt key: %s", err.Error()) + return nil, err + } + + key, err := x509.ParsePKCS1PrivateKey(keyDER) + if err != nil { + log.GetLogger().Errorf("failed to parse private key: %s", err.Error()) + return nil, err + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(caCertPEM) { + log.GetLogger().Errorf("failed to append CA certificate") + return nil, fmt.Errorf("failed to append CA certificate") + } + + clientCert, err := tls.X509KeyPair(certPEM, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})) + if err != nil { + log.GetLogger().Errorf("failed to create client certificate: %s", err.Error()) + return nil, err + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: certPool, + } + + return &clientv3.Config{ + TLS: tlsConfig, + }, nil +} + +func (n *clientTlsAuth) renewToken(client *clientv3.Client, stop chan struct{}) { + return +} diff --git a/yuanrong/pkg/common/etcd3/scc_watcher.go b/yuanrong/pkg/common/etcd3/scc_watcher.go new file mode 100644 index 0000000..a7d4d7a --- /dev/null +++ b/yuanrong/pkg/common/etcd3/scc_watcher.go @@ -0,0 +1,55 @@ +//go:build cryptoapi +// +build cryptoapi + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + clientv3 "go.etcd.io/etcd/client/v3" + "golang.org/x/net/context" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// NewEtcdWatcher new a etcd watcher +func NewEtcdWatcher(config EtcdConfig) (*EtcdClient, error) { + cfg, err := config.getEtcdAuthType().getEtcdConfig() + if err != nil { + return nil, err + } + + cfg.DialTimeout = dialTimeout + cfg.DialKeepAliveTime = keepaliveTime + cfg.DialKeepAliveTimeout = keepaliveTimeout + + cfg.Endpoints = config.Servers + client, err := clientv3.New(*cfg) + if err != nil { + return nil, err + } + stopCh := make(chan struct{}) + go config.getEtcdAuthType().renewToken(client, stopCh) + + // fetch registered grpc-proxy endpoints + if err = client.Sync(context.Background()); err != nil { + log.GetLogger().Warnf("Sync endpoints: %s", err.Error()) + } + log.GetLogger().Infof("Etcd discovered endpoints: %v", client.Endpoints()) + return &EtcdClient{Client: client, AZPrefix: config.AZPrefix}, nil +} diff --git a/yuanrong/pkg/common/etcd3/scc_watcher_no_scc.go b/yuanrong/pkg/common/etcd3/scc_watcher_no_scc.go new file mode 100644 index 0000000..f460edb --- /dev/null +++ b/yuanrong/pkg/common/etcd3/scc_watcher_no_scc.go @@ -0,0 +1,59 @@ +//go:build !cryptoapi +// +build !cryptoapi + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "context" + + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// NewEtcdWatcher new a etcd watcher (this function provide no scc config) +func NewEtcdWatcher(config EtcdConfig) (*EtcdClient, error) { + cfg, err := config.getEtcdAuthTypeNoTLS().getEtcdConfig() + if err != nil { + return nil, err + } + + cfg.DialTimeout = dialTimeout + cfg.DialKeepAliveTime = keepaliveTime + cfg.DialKeepAliveTimeout = keepaliveTimeout + + cfg.Endpoints = config.Servers + client, err := clientv3.New(*cfg) + if err != nil { + return nil, err + } + stopCh := make(chan struct{}) + go config.getEtcdAuthTypeNoTLS().renewToken(client, stopCh) + + // fetch registered grpc-proxy endpoints + if config.DisableSync { + return &EtcdClient{Client: client}, nil + } + if err = client.Sync(context.Background()); err != nil { + log.GetLogger().Warnf("Sync endpoints: %s", err.Error()) + } + log.GetLogger().Infof("Etcd discovered endpoints: %v", client.Endpoints()) + return &EtcdClient{Client: client, AZPrefix: config.AZPrefix}, nil +} diff --git a/yuanrong/pkg/common/etcd3/watcher.go b/yuanrong/pkg/common/etcd3/watcher.go new file mode 100644 index 0000000..871227d --- /dev/null +++ b/yuanrong/pkg/common/etcd3/watcher.go @@ -0,0 +1,215 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "fmt" + "strings" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + "golang.org/x/net/context" + + "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + // We have set a buffer in order to reduce times of context switches. + incomingBufSize = 2000 + outgoingBufSize = 2000 +) + +// EtcdClientInterface is the interface of ETCD client +type EtcdClientInterface interface { + GetResponse(ctxInfo EtcdCtxInfo, etcdKey string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) + Put(ctxInfo EtcdCtxInfo, etcdKey string, value string, opts ...clientv3.OpOption) error + Delete(ctxInfo EtcdCtxInfo, etcdKey string, opts ...clientv3.OpOption) error +} + +// EtcdClient etcd client struct +type EtcdClient struct { + Client *clientv3.Client + AZPrefix string +} + +// EtcdWatchChan implements watch.Interface. +type EtcdWatchChan struct { + incomingEventChan chan *Event + ResultChan chan *Event + errChan chan error + watcher *EtcdClient + key string + initialRev int64 + recursive bool + ctx context.Context + cancel context.CancelFunc +} + +// EtcdCtxInfo etcd context info +type EtcdCtxInfo struct { + Ctx context.Context + Cancel context.CancelFunc +} + +const ( + // keepaliveTime is the time after which client pings the server to see if + // transport is alive. + keepaliveTime = 30 * time.Second + + // keepaliveTimeout is the time that the client waits for a response for the + // keep-alive attempt. + keepaliveTimeout = 10 * time.Second + + // dialTimeout is the timeout for establishing a connection. + // 20 seconds as times should be set shorter than that will cause TLS connections to fail + dialTimeout = 20 * time.Second + + // tokenRenewTTL the default TTL of etcd simple token is 300s, so the renew TTL should be smaller than 300s + tokenRenewTTL = 30 * time.Second + + // renewKey etcd server will renew simple token TTL if token is not expired. + // so use an random key path for querying in order to renew token. + renewKey = "/keyforrenew" + + // DurationContextTimeout etcd request timeout, default context duration timeout + DurationContextTimeout = 5 * time.Second +) + +// AttachAZPrefix - +func (w *EtcdClient) AttachAZPrefix(key string) string { + if len(w.AZPrefix) != 0 { + return fmt.Sprintf("/%s%s", w.AZPrefix, key) + } + return key +} + +// DetachAZPrefix - +func (w *EtcdClient) DetachAZPrefix(key string) string { + if len(w.AZPrefix) != 0 { + return strings.TrimPrefix(key, fmt.Sprintf("/%s", w.AZPrefix)) + } + return key +} + +// GetResponse get etcd value and return pointer of GetResponse struct +func (w *EtcdClient) GetResponse(ctxInfo EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + etcdKey = w.AttachAZPrefix(etcdKey) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + + kv := clientv3.NewKV(w.Client) + leaderCtx := clientv3.WithRequireLeader(ctx) + getResp, err := kv.Get(leaderCtx, etcdKey, opts...) + + return getResp, err +} + +// Put put context key and value +func (w *EtcdClient) Put(ctxInfo EtcdCtxInfo, etcdKey string, value string, opts ...clientv3.OpOption) error { + etcdKey = w.AttachAZPrefix(etcdKey) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + + kv := clientv3.NewKV(w.Client) + leaderCtx := clientv3.WithRequireLeader(ctx) + _, err := kv.Put(leaderCtx, etcdKey, value, opts...) + return err +} + +// Delete delete key +func (w *EtcdClient) Delete(ctxInfo EtcdCtxInfo, etcdKey string, opts ...clientv3.OpOption) error { + etcdKey = w.AttachAZPrefix(etcdKey) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + + kv := clientv3.NewKV(w.Client) + leaderCtx := clientv3.WithRequireLeader(ctx) + _, err := kv.Delete(leaderCtx, etcdKey, opts...) + return err +} + +// TxnPut transaction put operation with if key existed put failed +func (w *EtcdClient) TxnPut(ctxInfo EtcdCtxInfo, etcdKey string, value string) error { + etcdKey = w.AttachAZPrefix(etcdKey) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + kv := clientv3.NewKV(w.Client) + leaderCtx := clientv3.WithRequireLeader(ctx) + txnRsp, err := kv.Txn(leaderCtx). + If(clientv3.Compare(clientv3.CreateRevision(etcdKey), "=", 0)). + Then(clientv3.OpPut(etcdKey, value)). + Else(clientv3.OpGet(etcdKey)).Commit() + if err != nil { + return err + } + if !txnRsp.Succeeded { + log.GetLogger().Warnf("the key has already exist: %s", etcdKey) + return fmt.Errorf("duplicated key") + } + return nil +} + +// GetValues return list of object for value +func (w *EtcdClient) GetValues(ctxInfo EtcdCtxInfo, etcdKey string, opts ...clientv3.OpOption) ([]string, error) { + etcdKey = w.AttachAZPrefix(etcdKey) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + + kv := clientv3.NewKV(w.Client) + leaderCtx := clientv3.WithRequireLeader(ctx) + response, err := kv.Get(leaderCtx, etcdKey, opts...) + + if err != nil { + return nil, err + } + values := make([]string, len(response.Kvs), len(response.Kvs)) + + for index, v := range response.Kvs { + values[index] = string(v.Value) + } + return values, err +} + +// CreateEtcdCtxInfo return context with cancle function +func CreateEtcdCtxInfo(ctx context.Context) EtcdCtxInfo { + ctx, cancel := context.WithCancel(ctx) + leaderCtx := clientv3.WithRequireLeader(ctx) + return EtcdCtxInfo{leaderCtx, cancel} +} + +// CreateEtcdCtxInfoWithTimeout create a context with timeout, default timeout is DurationContextTimeout +func CreateEtcdCtxInfoWithTimeout(ctx context.Context, duration time.Duration) EtcdCtxInfo { + ctx, cancel := context.WithTimeout(ctx, duration) + leaderCtx := clientv3.WithRequireLeader(ctx) + return EtcdCtxInfo{leaderCtx, cancel} +} + +// DecryptEtcdPassword decrypts the password of etcd +func DecryptEtcdPassword(plainPwd, workKey []byte) ([]byte, error) { + var ( + etcdPwd []byte + err error + ) + etcdPwd, err = crypto.DecryptByte(plainPwd, workKey) + if err != nil { + return nil, err + } + return etcdPwd, nil +} diff --git a/yuanrong/pkg/common/etcd3/watcher_test.go b/yuanrong/pkg/common/etcd3/watcher_test.go new file mode 100644 index 0000000..5c1944b --- /dev/null +++ b/yuanrong/pkg/common/etcd3/watcher_test.go @@ -0,0 +1,203 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package etcd3 + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + clientv3 "go.etcd.io/etcd/client/v3" + "golang.org/x/net/context" + "k8s.io/apimachinery/pkg/util/wait" + + "yuanrong/pkg/common/crypto" +) + +type EtcdTestSuite struct { + suite.Suite + defaultEtcdCtx EtcdCtxInfo + etcdClient *EtcdClient +} + +func (es *EtcdTestSuite) setupEtcdClient() { + var err error + patches := gomonkey.NewPatches() + cli := clientv3.Client{} + patches.ApplyFunc(clientv3.New, func(_ clientv3.Config) (*clientv3.Client, error) { + return &cli, nil + }) + patches.ApplyMethod(reflect.TypeOf(&cli), "Sync", func(_ *clientv3.Client, _ context.Context) error { + return nil + }) + patches.ApplyMethod(reflect.TypeOf(&cli), "Endpoints", func(_ *clientv3.Client) []string { + return []string{"localhost:0"} + }) + defer patches.Reset() + // create a new etcd watcher for the following tests + auth := EtcdConfig{ + Servers: []string{"localhost:0"}, + User: "", + Passwd: "", + SslEnable: false, + } + es.etcdClient, err = NewEtcdWatcher(auth) + if err != nil { + err = fmt.Errorf("failed to create etcd watcher; err: %v", err) + } +} + +func (es *EtcdTestSuite) SetupSuite() { + es.defaultEtcdCtx = CreateEtcdCtxInfo(context.Background()) + es.setupEtcdClient() +} +func (es *EtcdTestSuite) TearDownSuite() { + es.etcdClient = nil +} + +func (es *EtcdTestSuite) TestNewEtcdWatcher() { + var ( + serverList []string + auth EtcdConfig + err error + ) + + patches := gomonkey.NewPatches() + cli := clientv3.Client{} + patches.ApplyFunc(clientv3.New, func(_ clientv3.Config) (*clientv3.Client, error) { + return &cli, nil + }) + patches.ApplyMethod(reflect.TypeOf(&cli), "Sync", func(_ *clientv3.Client, _ context.Context) error { + return nil + }) + patches.ApplyMethod(reflect.TypeOf(&cli), "Endpoints", func(_ *clientv3.Client) []string { + return []string{"localhost:0"} + }) + patches.ApplyFunc(crypto.Decrypt, func(cipherText []byte, secret []byte) (string, error) { + return "key", nil + }) + patches.ApplyFunc(wait.Until, func(f func(), period time.Duration, stopCh <-chan struct{}) { + return + }) + serverList = []string{"localhost:0"} + auth = EtcdConfig{ + Servers: serverList, + User: "", + Passwd: "", + SslEnable: false, + CaFile: "", + CertFile: "", + KeyFile: "", + } + _, err = NewEtcdWatcher(auth) + assert.Nil(es.T(), err) + + serverList = []string{"localhost:0"} + auth = EtcdConfig{ + Servers: serverList, + User: "", + Passwd: "", + SslEnable: true, + CaFile: "xxx.ca", + CertFile: "xxx.cert", + KeyFile: "xxx.key", + } + _, err = NewEtcdWatcher(auth) + assert.NotNil(es.T(), err) + + serverList = []string{"localhost:0"} + auth = EtcdConfig{ + Servers: serverList, + User: "user", + Passwd: "", + SslEnable: false, + } + _, err = NewEtcdWatcher(auth) + assert.Nil(es.T(), err) + patches.Reset() +} + +func (es *EtcdTestSuite) TestCRUD() { + cli := es.etcdClient + ctxInfo := es.defaultEtcdCtx + etcdKey := "test_key" + etcdValue := "test_value" + + patch := gomonkey.ApplyFunc(clientv3.NewKV, func(*clientv3.Client) clientv3.KV { + return KV{} + }) + err := cli.Put(ctxInfo, etcdKey, etcdValue) + assert.Nil(es.T(), err) + + _, err = cli.GetKeys(ctxInfo, etcdKey) + assert.Nil(es.T(), err) + + _, err = cli.GetValues(ctxInfo, etcdKey) + assert.Nil(es.T(), err) + + _, err = cli.GetResponse(ctxInfo, etcdKey) + assert.Nil(es.T(), err) + + err = cli.Delete(ctxInfo, etcdKey) + assert.Nil(es.T(), err) + + patch.Reset() +} + +func (es *EtcdTestSuite) TestCreateEtcdCtxInfoWithTimeout() { + ctxInfo := es.defaultEtcdCtx.Ctx + etcdCtxInfo := CreateEtcdCtxInfoWithTimeout(ctxInfo, 1) + assert.NotNil(es.T(), etcdCtxInfo) +} + +func TestEtcdTestSuite(t *testing.T) { + suite.Run(t, new(EtcdTestSuite)) +} + +func TestDecryptEtcdPassword(t *testing.T) { + DecryptEtcdPassword([]byte{}, []byte{}) +} + +type KV struct { +} + +func (k KV) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{}, nil +} + +func (k KV) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + return &clientv3.PutResponse{}, nil +} + +func (k KV) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + return &clientv3.DeleteResponse{}, nil +} + +func (k KV) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, error) { + return &clientv3.CompactResponse{}, nil +} + +func (k KV) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + return clientv3.OpResponse{}, nil +} + +func (k KV) Txn(ctx context.Context) clientv3.Txn { + return nil +} diff --git a/yuanrong/pkg/common/etcdkey/etcdkey.go b/yuanrong/pkg/common/etcdkey/etcdkey.go new file mode 100644 index 0000000..6ba0ca9 --- /dev/null +++ b/yuanrong/pkg/common/etcdkey/etcdkey.go @@ -0,0 +1,196 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcdkey contains etcd key definition and tools +package etcdkey + +import ( + "fmt" + "strings" + + "yuanrong/pkg/common/faas_common/urnutils" +) + +const ( + keyFormat = "/sn/%s/business/%s/tenant/%s/function/%s/version/%s/%s/%s" + keySeparator = "/" + stateWorkerLen = 13 + + cronTriggerTenantIndex = 8 + functionMetadataTenantIndex = 4 +) + +// Index of element in etcd key +const ( + prefixIndex = iota + 1 + typeIndex + businessIDKey + businessIDIndex + tenantIDKey + tenantIDIndex + functionKey + functionIndex + versionKey + versionIndex + zoneIndex + keyIndex +) + +// Index of instance element in etcd key +const ( + instancePrefixIndex = iota + 1 + instanceTypeIndex + instanceBusinessIDKey + instanceBusinessIDIndex + instanceTenantIDKey + instanceTenantIDIndex + instanceZoneIndex + instanceFunctionNameIndex + instanceUrnVersionIndex = 10 + instanceInstanceIDIndex = 13 + instanceKeyLen = 14 +) + +// EtcdKey etcd key interface definition +type EtcdKey interface { + String() string + ParseFrom(key string) error +} + +// StateWorkersKey state workers key +type StateWorkersKey struct { + TypeKey string + BusinessID string + TenantID string + Function string + Version string + Zone string + StateID string +} + +// String serialize state workers key struct to string +func (s *StateWorkersKey) String() string { + return fmt.Sprintf(keyFormat, s.TypeKey, s.BusinessID, s.TenantID, s.Function, s.Version, s.Zone, s.StateID) +} + +// ParseFrom parse string to state workers key struct +func (s *StateWorkersKey) ParseFrom(etcdKey string) error { + elements := strings.Split(etcdKey, keySeparator) + if len(elements) != stateWorkerLen { + return fmt.Errorf("failed to parse etcd key from %s", etcdKey) + } + s.TypeKey = elements[keyIndex] + s.BusinessID = elements[businessIDIndex] + s.TenantID = elements[tenantIDIndex] + s.Function = elements[functionIndex] + s.Version = elements[versionIndex] + s.Zone = elements[zoneIndex] + s.StateID = elements[keyIndex] + + return nil +} + +// WorkerInstanceKey is the etcd key path of worker instance +type WorkerInstanceKey struct { + TypeKey string + BusinessID string + TenantID string + Function string + Version string + Zone string + Instance string +} + +// String serialize worker instance key struct to string +func (w *WorkerInstanceKey) String() string { + return fmt.Sprintf(keyFormat, w.TypeKey, w.BusinessID, w.TenantID, w.Function, w.Version, w.Zone, w.Instance) +} + +// ParseFrom parse string to worker instance key struct +func (w *WorkerInstanceKey) ParseFrom(etcdKey string) error { + elements := strings.Split(etcdKey, keySeparator) + if len(elements) != stateWorkerLen { + return fmt.Errorf("failed to parse etcd key from %s", etcdKey) + } + w.TypeKey = elements[typeIndex] + w.BusinessID = elements[businessIDIndex] + w.TenantID = elements[tenantIDIndex] + w.Function = elements[functionIndex] + w.Version = elements[versionIndex] + w.Zone = elements[zoneIndex] + w.Instance = elements[keyIndex] + return nil +} + +// AnonymizeTenantCommonEtcdKey Anonymize tenant info in common etcd key +// /yr/functions/business/yrk/tenant/8e08d5cc0ad34032bba8d636040a278c/function/0-test1-addone/version/$latest +func AnonymizeTenantCommonEtcdKey(etcdKey string) string { + elements := strings.Split(etcdKey, keySeparator) + if len(elements) <= tenantIDIndex { + return etcdKey + } + elements[tenantIDIndex] = urnutils.Anonymize(elements[tenantIDIndex]) + return strings.Join(elements, keySeparator) +} + +// AnonymizeTenantCronTriggerEtcdKey Anonymize tenant info in cron trigger etcd key +// /sn/triggers/triggerType/CRON/business/yrk/tenant/i1fe539427b24702acc11fbb4e134e17/function/pytzip/version/$latest/398e2ca2-a160-4c22-bd05-94a90a5326e2 +func AnonymizeTenantCronTriggerEtcdKey(etcdKey string) string { + elements := strings.Split(etcdKey, keySeparator) + if len(elements) <= cronTriggerTenantIndex { + return etcdKey + } + elements[cronTriggerTenantIndex] = urnutils.Anonymize(elements[cronTriggerTenantIndex]) + return strings.Join(elements, keySeparator) +} + +// AnonymizeTenantFunctionMetadataEtcdKey Anonymize tenant info in function metadata etcd key +// /repo/FunctionVersion/business/tenant/funcName/version/ +func AnonymizeTenantFunctionMetadataEtcdKey(etcdKey string) string { + elements := strings.Split(etcdKey, keySeparator) + if len(elements) <= functionMetadataTenantIndex { + return etcdKey + } + elements[functionMetadataTenantIndex] = urnutils.Anonymize(elements[functionMetadataTenantIndex]) + return strings.Join(elements, keySeparator) +} + +// FunctionInstanceKey is the etcd key path of function instance +type FunctionInstanceKey struct { + TypeKey string + BusinessID string + TenantID string + FunctionName string + Version string + Zone string + InstanceID string +} + +// ParseFrom parse string to function instance key struct +// /sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-defaultservice-py/version/$latest +// /defaultaz/8c9fa45600e5f44f00/10000000-0000-4000-b653-c11128589d17 +func (f *FunctionInstanceKey) ParseFrom(etcdKey string) error { + elements := strings.Split(etcdKey, keySeparator) + if len(elements) != instanceKeyLen { + return fmt.Errorf("failed to parse etcd key from %s: invalid key length", etcdKey) + } + f.TypeKey = elements[instanceTypeIndex] + f.BusinessID = elements[instanceBusinessIDIndex] + f.TenantID = elements[instanceTenantIDIndex] + f.Zone = elements[instanceZoneIndex] + f.InstanceID = elements[instanceInstanceIDIndex] + return nil +} diff --git a/yuanrong/pkg/common/etcdkey/etcdkey_test.go b/yuanrong/pkg/common/etcdkey/etcdkey_test.go new file mode 100644 index 0000000..eeb3b38 --- /dev/null +++ b/yuanrong/pkg/common/etcdkey/etcdkey_test.go @@ -0,0 +1,264 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcdkey contains etcd key definition and tools +package etcdkey + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStateWorkersKey_ParseFrom(t *testing.T) { + type fields struct { + KeyType string + BusinessID string + TenantID string + Function string + Version string + Zone string + StateID string + } + type args struct { + key string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "test001", + fields: fields{ + KeyType: "stateworkers", + BusinessID: "yrk", + TenantID: "tenantID", + Function: "function", + Version: "$latest", + Zone: "defaultaz", + StateID: "stateID", + }, + args: args{ + key: "/sn/stateworkers/business/yrk/tenant/tenantID" + + "/function/function/version/$latest/defaultaz/stateID", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &StateWorkersKey{ + TypeKey: tt.fields.KeyType, + BusinessID: tt.fields.BusinessID, + TenantID: tt.fields.TenantID, + Function: tt.fields.Function, + Version: tt.fields.Version, + Zone: tt.fields.Zone, + StateID: tt.fields.StateID, + } + if err := s.ParseFrom(tt.args.key); (err != nil) != tt.wantErr { + t.Errorf("ParseFrom() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestStateWorkersKey_String(t *testing.T) { + type fields struct { + KeyType string + BusinessID string + TenantID string + Function string + Version string + Zone string + StateID string + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "test002", + fields: fields{ + KeyType: "stateworkers", + BusinessID: "yrk", + TenantID: "tenantID", + Function: "function", + Version: "$latest", + Zone: "defaultaz", + StateID: "stateID", + }, + want: "/sn/stateworkers/business/yrk/tenant/tenantID" + + "/function/function/version/$latest/defaultaz/stateID", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &StateWorkersKey{ + TypeKey: tt.fields.KeyType, + BusinessID: tt.fields.BusinessID, + TenantID: tt.fields.TenantID, + Function: tt.fields.Function, + Version: tt.fields.Version, + Zone: tt.fields.Zone, + StateID: tt.fields.StateID, + } + if got := s.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWorkerInstanceKey_ParseFrom(t *testing.T) { + tests := []struct { + arg string + want WorkerInstanceKey + err bool + }{ + { + arg: "/sn/workers/business/busid/tenant/abcd/function/" + + "0-counter-addone/version/$latest/defaultaz/defaultaz-#-pool7-500-500-python3.7-58f588848d-smss8", + want: WorkerInstanceKey{ + TypeKey: "workers", + BusinessID: "busid", + TenantID: "abcd", + Function: "0-counter-addone", + Version: "$latest", + Zone: "defaultaz", + Instance: "defaultaz-#-pool7-500-500-python3.7-58f588848d-smss8", + }, + err: false, + }, + { + arg: "/sn/workers/business/busid/tenant/abcd/function/" + + "0-counter-addone/version/$latest/defaultaz", + want: WorkerInstanceKey{}, + err: true, + }, + { + arg: "/sn/workers/business/yrk/tenant/0/function/function-task/version/$latest/defaultaz/dggphis36581", + want: WorkerInstanceKey{ + TypeKey: "workers", + BusinessID: "yrk", + TenantID: "0", + Function: "function-task", + Version: "$latest", + Zone: "defaultaz", + Instance: "dggphis36581", + }, + err: false, + }, + } + for _, tt := range tests { + worker := WorkerInstanceKey{} + err := worker.ParseFrom(tt.arg) + if tt.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, worker) + } + } +} + +func TestWorkerInstanceKey_String(t *testing.T) { + tests := []struct { + arg WorkerInstanceKey + want string + }{ + { + want: "/sn/workers/business/busid/tenant/abcd/function/" + + "0-counter-addone/version/$latest/defaultaz/defaultaz-#-pool7-500-500-python3.7-58f588848d-smss8", + arg: WorkerInstanceKey{ + TypeKey: "workers", + BusinessID: "busid", + TenantID: "abcd", + Function: "0-counter-addone", + Version: "$latest", + Zone: "defaultaz", + Instance: "defaultaz-#-pool7-500-500-python3.7-58f588848d-smss8", + }, + }, + } + for _, tt := range tests { + assert.Equal(t, tt.want, tt.arg.String()) + } +} + +func TestAnonymizeTenantCommonEtcdKey(t *testing.T) { + keyA := "/yr/functions/business/yrk" + anonymizeKeyA := AnonymizeTenantCommonEtcdKey(keyA) + assert.Equal(t, keyA, anonymizeKeyA) + keyB := "/yr/functions/business/yrk/tenant/8e08d5cc0ad34032bba8d636040a278c/function/0-test1-addone/version/$latest" + anonymizeKeyB := AnonymizeTenantCommonEtcdKey(keyB) + assert.NotEqual(t, keyB, anonymizeKeyB) +} +func TestAnonymizeTenantCronTriggerEtcdKey(t *testing.T) { + keyA := "/sn/triggers/triggerType/CRON/business/yrk/tenant" + anonymizeKeyA := AnonymizeTenantCronTriggerEtcdKey(keyA) + assert.Equal(t, keyA, anonymizeKeyA) + keyB := "/sn/triggers/triggerType/CRON/business/yrk/tenant/i1fe539427b2/function/pytzip/version/$latest/398e2ca2" + anonymizeKeyB := AnonymizeTenantCronTriggerEtcdKey(keyB) + assert.NotEqual(t, keyB, anonymizeKeyB) +} +func TestAnonymizeTenantFunctionMetadataEtcdKey(t *testing.T) { + keyA := "/repo/FunctionVersion/business" + anonymizeKeyA := AnonymizeTenantFunctionMetadataEtcdKey(keyA) + assert.Equal(t, keyA, anonymizeKeyA) + keyB := "/repo/FunctionVersion/business/tenant/funcName/version/" + anonymizeKeyB := AnonymizeTenantFunctionMetadataEtcdKey(keyB) + assert.NotEqual(t, keyB, anonymizeKeyB) +} + +func TestFunctionInstanceKey_ParseFrom(t *testing.T) { + tests := []struct { + arg string + want FunctionInstanceKey + err bool + }{ + { + arg: "/sn/instance/business/yrk/tenant/tenantID/defaultaz/instanceID", + want: FunctionInstanceKey{ + TypeKey: "instance", + BusinessID: "yrk", + TenantID: "tenantID", + InstanceID: "instanceID", + Version: "", + Zone: "defaultaz", + }, + err: false, + }, + { + arg: "/sn/instance/business/b1/tenant/t1/function/0-s1-test/version/1/defaultaz", + want: FunctionInstanceKey{}, + err: true, + }, + } + for _, tt := range tests { + instance := FunctionInstanceKey{} + err := instance.ParseFrom(tt.arg) + if tt.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, instance) + } + } +} diff --git a/yuanrong/pkg/common/faas_common/alarm/config.go b/yuanrong/pkg/common/faas_common/alarm/config.go new file mode 100644 index 0000000..a934d78 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/alarm/config.go @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package alarm alarm log by filebeat +package alarm + +import ( + "yuanrong/pkg/common/faas_common/logger/config" + "yuanrong/pkg/common/faas_common/types" +) + +// Config - +type Config struct { + EnableAlarm bool `json:"enableAlarm"` + AlarmLogConfig config.CoreInfo `json:"alarmLogConfig" valid:"optional"` + XiangYunFourConfig types.XiangYunFourConfig `json:"xiangYunFourConfig" valid:"optional"` + MinInsStartInterval int `json:"minInsStartInterval"` + MinInsCheckInterval int `json:"minInsCheckInterval"` +} diff --git a/yuanrong/pkg/common/faas_common/alarm/logalarm.go b/yuanrong/pkg/common/faas_common/alarm/logalarm.go new file mode 100644 index 0000000..f36b643 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/alarm/logalarm.go @@ -0,0 +1,242 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package alarm alarm log by filebeat +package alarm + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strconv" + "sync" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger" + "yuanrong/pkg/common/faas_common/logger/config" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/urnutils" +) + +const ( + // ConfigKey environment variable key of alarm config + ConfigKey = "ALARM_CONFIG" + + cacheLimit = 10 * 1 << 20 // 10 mb + + // Level3 - + Level3 = "critical" + // Level2 - + Level2 = "major" + // Level1 - + Level1 = "minor" + // Level0 - + Level0 = "notice" + + // GenerateAlarmLog - + GenerateAlarmLog = "firing" + // ClearAlarmLog - + ClearAlarmLog = "resolved" + + // InsufficientMinInstance00001 alarm id + InsufficientMinInstance00001 = "InsufficientMinInstance00001" + // MetadataEtcdConnection00001 alarm id + MetadataEtcdConnection00001 = "MetadataEtcdConnection00001" + // RouterEtcdConnection00001 alarm id + RouterEtcdConnection00001 = "RouterEtcdConnection00001" + // InitStsSdkErr00001 alarm id + InitStsSdkErr00001 = "InitStsSdkErr00001" + // PullStsConfiguration00001 alarm id + PullStsConfiguration00001 = "PullStsConfiguration00001" + // ReportToXPUManageFailed00001 alarm id + ReportToXPUManageFailed00001 = "ReportToXPUManageFailed00001" + // FaaSSchedulerRemovedFromHashRing00001 alarm id + FaaSSchedulerRemovedFromHashRing00001 = "FaaSSchedulerRemovedFromHashRing00001" + // FaaSFrontendReceiptDMQMessage00001 - + FaaSFrontendReceiptDMQMessage00001 = "FaaSFrontendReceiptDMQMessage00001" + // FaaSFrontendDequeueDMQMessage00001 - + FaaSFrontendDequeueDMQMessage00001 = "FaaSFrontendDequeueDMQMessage00001" + // NoAvailableSchedulerInstance00001 没有可用的scheduler实例的告警id + NoAvailableSchedulerInstance00001 = "NoAvailableSchedulerInstance00001" +) + +var ( + alarmLogger *zap.Logger + createLoggerErr error + createLoggerOnce sync.Once +) + +// LogAlarmInfo Custom alarm info +type LogAlarmInfo struct { + AlarmID string + AlarmName string + AlarmLevel string +} + +// Detail alarm detail +type Detail struct { + SourceTag string // 告警来源 + OpType string // 告警操作类型 + Details string // 告警详情 + StartTimestamp int // 产生时间 + EndTimestamp int // 清除时间 +} + +// GetAlarmLogger - +func GetAlarmLogger() (*zap.Logger, error) { + createLoggerOnce.Do(func() { + alarmLogger, createLoggerErr = newAlarmLogger() + if createLoggerErr != nil { + return + } + if alarmLogger == nil { + createLoggerErr = errors.New("failed to new alarmLogger") + return + } + // 祥云四元组 - 站点/租户ID/产品ID/服务ID + alarmLogger = alarmLogger.With(zapcore.Field{ + Key: "site", Type: zapcore.StringType, + String: os.Getenv(constant.WiseCloudSite), + }, zapcore.Field{ + Key: "tenant_id", Type: zapcore.StringType, + String: os.Getenv(constant.TenantID), + }, zapcore.Field{ + Key: "application_id", Type: zapcore.StringType, + String: os.Getenv(constant.ApplicationID), + }, zapcore.Field{ + Key: "service_id", Type: zapcore.StringType, + String: os.Getenv(constant.ServiceID), + }) + }) + return alarmLogger, createLoggerErr +} + +func newAlarmLogger() (*zap.Logger, error) { + coreInfo, err := config.ExtractCoreInfoFromEnv(ConfigKey) + log.GetLogger().Infof("ALARM_CONFIG is: %v", coreInfo) + if err != nil { + log.GetLogger().Errorf("failed to valid log path, err: %s", err.Error()) + return nil, err + } + + coreInfo.FilePath = filepath.Join(coreInfo.FilePath, "alarm.dat") + + sink, err := logger.CreateSink(coreInfo) + if err != nil { + log.GetLogger().Errorf("failed to create sink: %s", err.Error()) + return nil, err + } + + ws := zapcore.AddSync(sink) + priority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= zapcore.DebugLevel + }) + encoderConfig := zapcore.EncoderConfig{} + rollingFileEncoder := zapcore.NewJSONEncoder(encoderConfig) + + return zap.New(zapcore.NewCore(rollingFileEncoder, ws, priority)), nil +} + +func addAlarmLogger(rollingLogger *zap.Logger, alarmInfo *LogAlarmInfo, detail *Detail) *zap.Logger { + return rollingLogger.With(zapcore.Field{ + Key: "id", Type: zapcore.StringType, + String: alarmInfo.AlarmID, + }, zapcore.Field{ + Key: "name", Type: zapcore.StringType, + String: alarmInfo.AlarmName, + }, zapcore.Field{ + Key: "level", Type: zapcore.StringType, + String: alarmInfo.AlarmLevel, + }, zapcore.Field{ + Key: "source_tag", Type: zapcore.StringType, + String: detail.SourceTag, + }, zapcore.Field{ + Key: "op_type", Type: zapcore.StringType, + String: detail.OpType, + }, zapcore.Field{ + Key: "details", Type: zapcore.StringType, + String: detail.Details, + }, zapcore.Field{ + Key: "clear_type", Type: zapcore.StringType, + String: "ADAC", + }, zapcore.Field{ + Key: "start_timestamp", Type: zapcore.StringType, + String: strconv.Itoa(detail.StartTimestamp), + }, zapcore.Field{ + Key: "end_timestamp", Type: zapcore.StringType, + String: strconv.Itoa(detail.EndTimestamp), + }) +} + +// ReportOrClearAlarm - +func ReportOrClearAlarm(alarmInfo *LogAlarmInfo, detail *Detail) { + alarmLog, err := GetAlarmLogger() + if err != nil { + log.GetLogger().Errorf("GetAlarmLogger err %v", err) + return + } + logger := addAlarmLogger(alarmLog, alarmInfo, detail) + logger.Info("") +} + +// SetAlarmEnv - +func SetAlarmEnv(alarmConfigInfo config.CoreInfo) { + alarmConfigBytes, err := json.Marshal(alarmConfigInfo) + if err != nil { + log.GetLogger().Errorf("json marshal alarmConfigInfo err %v", err) + } + if err := os.Setenv(ConfigKey, string(alarmConfigBytes)); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", ConfigKey, err.Error()) + } + log.GetLogger().Debugf("succeeded to set env of %s, value: %s", ConfigKey, string(alarmConfigBytes)) +} + +// SetXiangYunFourConfigEnv - +func SetXiangYunFourConfigEnv(xiangYunFourConfig types.XiangYunFourConfig) { + if err := os.Setenv(constant.WiseCloudSite, xiangYunFourConfig.Site); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constant.WiseCloudSite, err.Error()) + } + if err := os.Setenv(constant.TenantID, xiangYunFourConfig.TenantID); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constant.TenantID, err.Error()) + } + if err := os.Setenv(constant.ApplicationID, xiangYunFourConfig.ApplicationID); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constant.ApplicationID, err.Error()) + } + if err := os.Setenv(constant.ServiceID, xiangYunFourConfig.ServiceID); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constant.ServiceID, err.Error()) + } + log.GetLogger().Debugf("succeeded to set env, value: %v", xiangYunFourConfig) +} + +// SetPodIP - +func SetPodIP() error { + ip, err := urnutils.GetServerIP() + if err != nil { + log.GetLogger().Errorf("failed to get pod ip, err: %s", err.Error()) + return err + } + err = os.Setenv(constant.PodIPEnvKey, ip) + if err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", constant.PodIPEnvKey, err.Error()) + return err + } + return nil +} diff --git a/yuanrong/pkg/common/faas_common/alarm/logalarm_test.go b/yuanrong/pkg/common/faas_common/alarm/logalarm_test.go new file mode 100644 index 0000000..033b403 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/alarm/logalarm_test.go @@ -0,0 +1,85 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package alarm +package alarm + +import ( + "encoding/json" + "github.com/smartystreets/goconvey/convey" + "os" + "sync" + "testing" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/config" + "yuanrong/pkg/common/faas_common/urnutils" +) + +func TestGetAlarmLogger(t *testing.T) { + convey.Convey("TestGetAlarmLogger", t, func() { + convey.Convey("failed to new alarmLogger", func() { + logger, err := GetAlarmLogger() + convey.So(err, convey.ShouldBeError) + convey.So(logger, convey.ShouldBeNil) + }) + + convey.Convey("success", func() { + dir, _ := os.Getwd() + defer gomonkey.ApplyFunc(config.ExtractCoreInfoFromEnv, func(env string) (config.CoreInfo, error) { + return config.CoreInfo{FilePath: dir}, nil + }).Reset() + createLoggerOnce = sync.Once{} + logger, err := GetAlarmLogger() + convey.So(err, convey.ShouldBeNil) + convey.So(logger, convey.ShouldNotBeNil) + }) + }) +} + +func TestReportOrClearAlarm(t *testing.T) { + convey.Convey("ReportOrClearAlarm", t, func() { + convey.Convey("no test assert", func() { + ReportOrClearAlarm(&LogAlarmInfo{}, &Detail{}) + }) + }) +} + +func TestSetAlarmEnv(t *testing.T) { + convey.Convey("SetAlarmEnv", t, func() { + convey.Convey("set env", func() { + dir, _ := os.Getwd() + SetAlarmEnv(config.CoreInfo{FilePath: dir}) + getenv := os.Getenv(ConfigKey) + var cfg *config.CoreInfo + err := json.Unmarshal([]byte(getenv), &cfg) + convey.So(err, convey.ShouldBeNil) + convey.So(cfg.FilePath, convey.ShouldEqual, dir) + os.Unsetenv(ConfigKey) + }) + }) +} + +func TestSetPodIP(t *testing.T) { + convey.Convey("SetPodIP", t, func() { + convey.Convey("", func() { + ip, _ := urnutils.GetServerIP() + SetPodIP() + convey.So(os.Getenv(constant.PodIPEnvKey), convey.ShouldEqual, ip) + os.Unsetenv(constant.PodIPEnvKey) + }) + }) + +} diff --git a/yuanrong/pkg/common/faas_common/aliasroute/alias.go b/yuanrong/pkg/common/faas_common/aliasroute/alias.go new file mode 100644 index 0000000..fdc9485 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/aliasroute/alias.go @@ -0,0 +1,503 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package aliasroute alias routing in busclient +package aliasroute + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/loadbalance" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/urnutils" +) + +const ( + weightRatio = 100 // max weight of a node + routingTypeRule = "rule" + // AliasKeySeparator is the separator in an alias key + AliasKeySeparator = "/" +) + +const ( + defaultVersion = "latest" + defaultBusinessID = "yrk" +) + +// example of an aliasKey: +// ///////// +const ( + ProductIDIndex = iota + 1 + AliasSignIndex + BusinessSignIndex + BusinessIDIndex + TenantSignIndex + TenantIDIndex + FunctionSignIndex + FunctionIDIndex + AliasNameIndex + aliasKeyLength +) + +// Aliases map for stateless function alias +type Aliases struct { + AliasMap *sync.Map // Key: aliasURN -- Value: *AliasElement +} + +// aliases for alias routing +var ( + aliases = &Aliases{ + AliasMap: &sync.Map{}, + } +) + +// GetAliases - +func GetAliases() *Aliases { + return aliases +} + +// AddAlias add alias to Aliases map from etcd +func (a *Aliases) AddAlias(alias *AliasElement) { + existAliasIf, exist := a.AliasMap.Load(alias.AliasURN) + var existAlias *AliasElement + var ok bool + if !exist { + // new alias, initialize RR and Mutex + existAlias = &AliasElement{ + AliasURN: alias.AliasURN, + FunctionURN: alias.FunctionURN, + FunctionVersionURN: alias.FunctionVersionURN, + Name: alias.Name, + Description: alias.Description, + FunctionVersion: alias.FunctionVersion, + RevisionID: alias.RevisionID, + RoutingConfigs: alias.RoutingConfigs, + RoutingRules: alias.RoutingRules, + RoutingType: alias.RoutingType, + + lb: loadbalance.LBFactory(loadbalance.RoundRobinNginx), + aliasLock: &sync.RWMutex{}, + } + existAlias.resetRR() + a.AliasMap.Store(alias.AliasURN, existAlias) + return + } + existAlias, ok = existAliasIf.(*AliasElement) + if ok { + aliasUpdate(existAlias, alias) + existAlias.resetRR() + } +} + +// RemoveAlias remove alias to aliases map +func (a *Aliases) RemoveAlias(aliasURN string) { + a.AliasMap.Delete(aliasURN) +} + +// GetFuncURNFromAlias If the alias exists, the weighted route version is returned. +// If the alias does not exist, the original URN is returned. +func (a *Aliases) GetFuncURNFromAlias(urn string) string { + existAliasIf, exist := a.AliasMap.Load(urn) + if !exist { + return urn + } + existAlias, ok := existAliasIf.(*AliasElement) + if !ok { + log.GetLogger().Warnf("Failed to convert the alias urn %s", urn) + return "" + } + return existAlias.getFuncVersionURN() +} + +// GetFuncVersionURNWithParams gets the routing version URN of stateless functionName with parmas for rules +func (a *Aliases) GetFuncVersionURNWithParams(aliasURN string, params map[string]string) string { + existAliasIf, exist := a.AliasMap.Load(aliasURN) + if !exist { + return aliasURN + } + existAlias, ok := existAliasIf.(*AliasElement) + if !ok { + log.GetLogger().Warnf("Failed to convert the alias urn %s", aliasURN) + return "" + } + return existAlias.GetFuncVersionURNWithParams(params) +} + +// CheckAliasRoutingChange - return false means oldURN is not equal to newURN or alise is not exist +func (a *Aliases) CheckAliasRoutingChange(aliasURN, oldURN string, params map[string]string) bool { + existAliasIf, exist := a.AliasMap.Load(aliasURN) + if !exist { + return true + } + existAlias, ok := existAliasIf.(*AliasElement) + if ok && existAlias.RoutingType == routingTypeRule { + return oldURN != existAlias.getFuncVersionURNByRule(params) + } + // routingType is weight + for _, config := range existAlias.RoutingConfigs { + if config.FunctionVersionURN == oldURN && config.Weight > 0.0 { + return false + } + } + return true +} + +// GetAliasRoutingType - +func (a *Aliases) GetAliasRoutingType(aliasURN string) string { + existAliasIf, exist := a.AliasMap.Load(aliasURN) + if !exist { + return "" + } + if existAlias, ok := existAliasIf.(*AliasElement); ok { + return existAlias.RoutingType + } + return "" +} + +// change means the following 3 conditions +func isAliasWeightTypeChange(originAlias, srcAlias *AliasElement) map[string]int { + changedURNMap := make(map[string]int) + if originAlias.RoutingType == routingTypeRule || srcAlias.RoutingType == routingTypeRule { + return map[string]int{"": NoneUpdate} + } + // 1、delete weight alias + if len(srcAlias.RoutingConfigs) == 0 { + return map[string]int{"": UpdateAllURN} + } + srcAliasMap := make(map[string]float64, len(srcAlias.RoutingConfigs)) + for _, config := range srcAlias.RoutingConfigs { + if config.Weight <= 0 { + // 2、weight decrease to 0 + changedURNMap[config.FunctionVersionURN] = UpdateWeightGreyURN + } + srcAliasMap[config.FunctionVersionURN] = config.Weight + } + + for _, config := range originAlias.RoutingConfigs { + // 3、grey functionURN in originAlias is not in srcAlias + // old device still follow old urn, new device follow the weight + if _, ok := srcAliasMap[config.FunctionVersionURN]; !ok { + changedURNMap[config.FunctionVersionURN] = UpdateWeightGreyURN + } + } + if originAlias.FunctionVersion != srcAlias.FunctionVersion { + changedURNMap[originAlias.FunctionVersion] = UpdateMainURN + } + return changedURNMap +} + +type routingRules struct { + RuleLogic string `json:"ruleLogic"` + Rules []string `json:"rules"` + GrayVersion string `json:"grayVersion"` +} + +// AliasElement struct stores an alias configs of stateless function +type AliasElement struct { + aliasLock *sync.RWMutex + lb loadbalance.LoadBalance + AliasURN string `json:"aliasUrn"` + FunctionURN string `json:"functionUrn"` + FunctionVersionURN string `json:"functionVersionUrn"` + Name string `json:"name"` + FunctionVersion string `json:"functionVersion"` + RevisionID string `json:"revisionId"` + Description string `json:"description"` + RoutingType string `json:"routingType"` + RoutingRules routingRules `json:"routingRules"` + RoutingConfigs []*routingConfig `json:"routingconfig"` +} + +type routingConfig struct { + FunctionVersionURN string `json:"functionVersionUrn"` + Weight float64 `json:"weight"` +} + +func (a *AliasElement) getFuncVersionURN() string { + a.aliasLock.RLock() + defer a.aliasLock.RUnlock() + funcVersion := a.lb.Next("", true) + if funcVersion == nil { + return "" + } + res, ok := funcVersion.(string) + if !ok { + return "" + } + return res +} + +func (a *AliasElement) resetRR() { + a.aliasLock.Lock() + defer a.aliasLock.Unlock() + a.lb.RemoveAll() + for _, v := range a.RoutingConfigs { + a.lb.Add(v.FunctionVersionURN, int(v.Weight*weightRatio)) + } +} + +func (a *AliasElement) getFuncVersionURNByRule(params map[string]string) string { + a.aliasLock.RLock() + defer a.aliasLock.RUnlock() + if len(params) == 0 { + log.GetLogger().Warnf("params is empty, use default func version") + return a.FunctionVersionURN + } + if len(a.RoutingRules.Rules) == 0 { + log.GetLogger().Warnf("rule len is 0, use default func version") + return a.FunctionVersionURN + } + + matchRules, err := parseRules(a.RoutingRules) + if err != nil { + log.GetLogger().Warnf("parse rule error, use default func version: %s", err.Error()) + return a.FunctionVersionURN + } + + // To obtain the final matching result by matching each rule and considering the "AND" or "OR"relationship of the rules + matched := matchRule(params, matchRules, a.RoutingRules.RuleLogic) + // got to default version if not matched + if matched { + return a.RoutingRules.GrayVersion + } + return a.FunctionVersionURN +} + +// GetFuncVersionURNWithParams - +func (a *AliasElement) GetFuncVersionURNWithParams(params map[string]string) string { + if a.RoutingType == routingTypeRule { + return a.getFuncVersionURNByRule(params) + } + // default to go weight + return a.getFuncVersionURN() +} + +func aliasUpdate(destAlias, srcAlias *AliasElement) { + destAlias.AliasURN = srcAlias.AliasURN + destAlias.FunctionURN = srcAlias.FunctionURN + destAlias.FunctionVersionURN = srcAlias.FunctionVersionURN + destAlias.Name = srcAlias.Name + destAlias.FunctionVersion = srcAlias.FunctionVersion + destAlias.RevisionID = srcAlias.RevisionID + destAlias.Description = srcAlias.Description + destAlias.RoutingConfigs = srcAlias.RoutingConfigs + destAlias.RoutingRules = srcAlias.RoutingRules + destAlias.RoutingType = srcAlias.RoutingType +} + +func ifAliasRoutingChanged(originAlias, srcAlias *AliasElement) map[string]int { + changedURNMap := make(map[string]int) + if originAlias.RoutingType == srcAlias.RoutingType { + if originAlias.RoutingType == routingTypeRule { + if !reflect.DeepEqual(originAlias.RoutingRules, srcAlias.RoutingRules) { + changedURNMap[originAlias.RoutingRules.GrayVersion] = UpdateAllURN + } + if originAlias.FunctionVersionURN != srcAlias.FunctionVersionURN { + changedURNMap[originAlias.FunctionVersionURN] = UpdateMainURN + } + return changedURNMap + } + return isAliasWeightTypeChange(originAlias, srcAlias) + } + // routingTypeWeight change to routingTypeRule + if srcAlias.RoutingType == routingTypeRule { + return map[string]int{"": UpdateAllURN} + } + // routingTypeRule change to routingTypeWeight + for _, config := range srcAlias.RoutingConfigs { + if config.Weight <= 0 { + return map[string]int{"": UpdateAllURN} + } + } + if len(srcAlias.RoutingConfigs) == 0 { + return map[string]int{"": UpdateAllURN} + } + return map[string]int{"": NoneUpdate} +} + +// AliasKey contains the elements of an alias key +type AliasKey struct { + ProductID string + AliasSign string + BusinessSign string + BusinessID string + TenantSign string + TenantID string + FunctionSign string + FunctionID string + AliasName string +} + +// ParseFrom parses elements from an alias key +func (a *AliasKey) ParseFrom(aliasKeyStr string) error { + elements := strings.Split(aliasKeyStr, AliasKeySeparator) + urnLen := len(elements) + if urnLen != aliasKeyLength { + return fmt.Errorf("failed to parse an alias key %s, incorrect length", aliasKeyStr) + } + a.ProductID = elements[ProductIDIndex] + a.AliasSign = elements[AliasSignIndex] + a.BusinessSign = elements[BusinessSignIndex] + a.BusinessID = elements[BusinessIDIndex] + a.TenantSign = elements[TenantSignIndex] + a.TenantID = elements[TenantIDIndex] + a.FunctionSign = elements[FunctionSignIndex] + a.FunctionID = elements[FunctionIDIndex] + a.AliasName = elements[AliasNameIndex] + return nil +} + +// FetchInfoFromAliasKey collects alias information from an alias key +func FetchInfoFromAliasKey(aliasKeyStr string) *AliasKey { + var aliasKey AliasKey + if err := aliasKey.ParseFrom(aliasKeyStr); err != nil { + log.GetLogger().Errorf("error while parsing an URN: %s", err.Error()) + return &AliasKey{} + } + return &aliasKey +} + +// BuildURNFromAliasKey builds a URN from a alias key +func BuildURNFromAliasKey(aliasKeyStr string) string { + aliasKey := FetchInfoFromAliasKey(aliasKeyStr) + productURN := &urnutils.FunctionURN{ + ProductID: urnutils.DefaultURNProductID, + RegionID: urnutils.DefaultURNRegion, + BusinessID: aliasKey.BusinessID, + TenantID: aliasKey.TenantID, + TypeSign: urnutils.DefaultURNFuncSign, + FuncName: aliasKey.FunctionID, + FuncVersion: aliasKey.AliasName, + } + return productURN.String() +} + +func parseRules(routingRules routingRules) ([]Expression, error) { + rules := routingRules.Rules + var expressions []Expression + const expressionSize = 3 + for _, value := range rules { + partition := strings.Split(value, ":") + if len(partition) != expressionSize { + return nil, fmt.Errorf("rules (%s) fields size not equal %v", value, expressionSize) + } + expression := Expression{ + leftVal: partition[0], + operator: partition[1], + rightVal: partition[2], + } + expressions = append(expressions, expression) + } + return expressions, nil +} + +func matchRule(params map[string]string, expressions []Expression, ruleLogic string) bool { + var matchResultList []bool + + for _, exp := range expressions { + matchResultList = append(matchResultList, exp.Execute(params)) + } + if len(matchResultList) > 0 { + return isMatch(matchResultList, ruleLogic) + } + return false +} + +func isMatch(matchResultList []bool, ruleLogic string) bool { + matchResult := matchResultList[0] + if len(matchResultList) > 1 { + switch ruleLogic { + case "or": + for _, value := range matchResultList { + matchResult = matchResult || value + } + case "and": + for _, value := range matchResultList { + matchResult = matchResult && value + } + default: + log.GetLogger().Warnf("unknow rulelogic: %s, return false", ruleLogic) + return false + } + } + return matchResult +} + +// MarshalTenantAliasList marshal alias map to list with specific tenant id +func MarshalTenantAliasList(tenantID string) ([]byte, error) { + var aliasList []*AliasElement + GetAliases().AliasMap.Range(func(key, value interface{}) bool { + aliasElement, _ := value.(*AliasElement) + if urnutils.CheckAliasUrnTenant(tenantID, aliasElement.AliasURN) { + aliasList = append(aliasList, aliasElement) + return true + } + return true + }) + aliasData, err := json.Marshal(aliasList) + if err != nil { + return nil, errors.New("marshal alias list error") + } + return aliasData, nil +} + +// ProcessDelete - +func ProcessDelete(event *etcd3.Event) string { + aliasURN := BuildURNFromAliasKey(event.Key) + GetAliases().RemoveAlias(aliasURN) + return aliasURN +} + +// ProcessUpdate - +func ProcessUpdate(event *etcd3.Event) (string, error) { + alias := &AliasElement{} + err := json.Unmarshal(event.Value, alias) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal alias event, err: %s", err.Error()) + return "", err + } + GetAliases().AddAlias(alias) + return alias.AliasURN, nil +} + +// ResolveAliasedFunctionNameToURN - {functionName}:{alias|version} 解析别名路由 +func ResolveAliasedFunctionNameToURN(functionNameWithAlias string, tenantID string, params map[string]string) string { + splits := strings.Split(functionNameWithAlias, ":") + if len(splits) > 2 || len(splits) == 0 { // {functionName}:{alias|version} + return "" + } + + if len(splits) == 1 { + return urnutils.BuildURNOrAliasURNTemp(defaultBusinessID, tenantID, + urnutils.BuildStandardFunctionName(functionNameWithAlias), defaultVersion) + } + + functionName := urnutils.BuildStandardFunctionName(splits[0]) + versionOrAlias := splits[1] + _, err := strconv.Atoi(versionOrAlias) + if err != nil && versionOrAlias != defaultVersion { + aliasUrn := urnutils.BuildURNOrAliasURNTemp(defaultBusinessID, tenantID, functionName, versionOrAlias) + return GetAliases().GetFuncVersionURNWithParams(aliasUrn, params) + } + return urnutils.BuildURNOrAliasURNTemp(defaultBusinessID, tenantID, functionName, versionOrAlias) +} diff --git a/yuanrong/pkg/common/faas_common/aliasroute/alias_test.go b/yuanrong/pkg/common/faas_common/aliasroute/alias_test.go new file mode 100644 index 0000000..e850409 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/aliasroute/alias_test.go @@ -0,0 +1,515 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package aliasroute alias routing +package aliasroute + +import ( + "fmt" + "sync" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" +) + +const ( + aliasURN = "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasv1" +) + +// TestCase init +func GetFakeAliasEle() *AliasElement { + fakeAliasEle := &AliasElement{ + AliasURN: aliasURN, + FunctionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld", + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + Name: "myaliasv1", + FunctionVersion: "$latest", + RevisionID: "20210617023315921", + Description: "", + RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + Weight: 60, + }, + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:v1", + Weight: 40, + }, + }, + } + return fakeAliasEle +} + +func GetFakeRuleAliasEle() *AliasElement { + fakeAliasEle := &AliasElement{ + AliasURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasrulev1", + FunctionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld", + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + Name: "myaliasrulev1", + FunctionVersion: "$latest", + RevisionID: "20210617023315921", + Description: "", + RoutingType: "rule", + RoutingRules: routingRules{ + RuleLogic: "and", + Rules: []string{"userType:=:VIP", "age:<=:20", "devType:in:P40,P50,MATE40"}, + GrayVersion: "sn:cn:yrk:172120022620195843:function:0@default@test_func:3", + }, + } + return fakeAliasEle +} + +func GetFakeWeightAliasEle() *AliasElement { + fakeAliasEle := &AliasElement{ + AliasURN: "sn:cn:yrk:12345678901234561234567890123456:function:0@default@aliasfunc:myaliasrulev1", + FunctionURN: "sn:cn:yrk:12345678901234561234567890123456:function:0@default@aliasfunc", + FunctionVersionURN: "sn:cn:yrk:c53626012ba84727b938ca8bf03108ef:function:0@default@aliasfunc:latest", + Name: "myaliasrulev1", + FunctionVersion: "$latest", + RevisionID: "20210617023315921", + Description: "", + RoutingType: "weigh", + RoutingConfigs: []*routingConfig{{ + FunctionVersionURN: "sn:cn:yrk:c53626012ba84727b938ca8bf03108ef:function:0@default@aliasfunc:latest", + Weight: 80, + }, { + FunctionVersionURN: "sn:cn:yrk:c53626012ba84727b938ca8bf03108ef:function:0@default@aliasfunc:1", + Weight: 0, + }}, + } + return fakeAliasEle +} +func ClearAliasRoute() { + aliases = &Aliases{ + AliasMap: &sync.Map{}, + } +} + +func TestOptAlias(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + convey.Convey("AddAlias success", t, func() { + fakeAliasEle := GetFakeAliasEle() + aliases.AddAlias(fakeAliasEle) + ele, ok := aliases.AliasMap.Load(fakeAliasEle.AliasURN) + convey.So(ok, convey.ShouldBeTrue) + convey.So(ele, convey.ShouldNotBeNil) + }) + convey.Convey("update Alias success", t, func() { + fakeAliasEle := GetFakeAliasEle() + fakeAliasEle.RoutingConfigs = []*routingConfig{ + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest", + Weight: 50, + }, + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:v1", + Weight: 50, + }, + } + aliases.AddAlias(fakeAliasEle) + ele, ok := aliases.AliasMap.Load(fakeAliasEle.AliasURN) + convey.So(ok, convey.ShouldBeTrue) + convey.So(ele.(*AliasElement).RoutingConfigs[0].Weight, convey.ShouldEqual, 50) + convey.So(ele.(*AliasElement).RoutingConfigs[1].Weight, convey.ShouldEqual, 50) + }) + convey.Convey("remove Alias success", t, func() { + fakeAliasEle := GetFakeAliasEle() + aliases.AddAlias(fakeAliasEle) + aliases.RemoveAlias(fakeAliasEle.AliasURN) + ele, ok := aliases.AliasMap.Load(fakeAliasEle.AliasURN) + convey.So(ok, convey.ShouldBeFalse) + convey.So(ele, convey.ShouldBeNil) + }) +} + +func TestGetFuncURNFromAlias(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + convey.Convey("alias does not exist", t, func() { + urn := aliases.GetFuncURNFromAlias(aliasURN) + convey.So(urn, convey.ShouldEqual, aliasURN) + }) + + convey.Convey("alias get error", t, func() { + aliases.AliasMap.Store(aliasURN, "456") + urn := aliases.GetFuncURNFromAlias(aliasURN) + aliases.AliasMap.Delete(aliasURN) + convey.So(urn, convey.ShouldEqual, "") + }) + convey.Convey("alias get error", t, func() { + aliases.AddAlias(GetFakeAliasEle()) + urn := aliases.GetFuncURNFromAlias(aliasURN) + convey.So(urn, convey.ShouldNotEqual, aliasURN) + convey.So(urn, convey.ShouldNotEqual, "") + convey.So(urn, convey.ShouldNotContainSubstring, "myaliasv1") + }) + +} + +func TestFetchInfoFromAliasKey(t *testing.T) { + path := "/sn/aliases/business/yrk/tenant/12345678901234561234567890123456/function/helloworld/myalias" + aliasKey := FetchInfoFromAliasKey(path) + + assert.Equal(t, aliasKey.FunctionID, "helloworld") + assert.Equal(t, aliasKey.AliasName, "myalias") + + path = "/sn/aliases/business/yrk/tenant/12345678901234561234567890123456/function/helloworld" + aliasKey = FetchInfoFromAliasKey(path) + assert.Empty(t, aliasKey) +} + +func TestBuildURNFromAliasKey(t *testing.T) { + path := "/sn/aliases/business/yrk/tenant/12345678901234561234567890123456/function/helloworld/myalias" + urn := BuildURNFromAliasKey(path) + assert.Contains(t, urn, "myalias") +} + +func TestGetFuncVersionURNWithParamsMatch(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + fakeAliasEle := GetFakeRuleAliasEle() + aliases.AddAlias(fakeAliasEle) + params := map[string]string{} + params["userType"] = "VIP" + params["age"] = "10" + params["devType"] = "P40" + + aliasUrn := "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasrulev1" + wantFuncVer := "sn:cn:yrk:172120022620195843:function:0@default@test_func:3" + got := GetAliases().GetFuncVersionURNWithParams(aliasUrn, params) + assert.Equal(t, wantFuncVer, got) +} + +func TestGetFuncVersionURNWithParamsNotMatch(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + fakeAliasEle := GetFakeRuleAliasEle() + aliases.AddAlias(fakeAliasEle) + params := map[string]string{} + params["userType"] = "VIP" + params["age"] = "50" + params["devType"] = "P40" + + aliasUrn := "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasrulev1" + wantFuncVer := "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:$latest" + got := GetAliases().GetFuncVersionURNWithParams(aliasUrn, params) + assert.Equal(t, wantFuncVer, got) +} + +func TestMarshalTenantAliasList(t *testing.T) { + ClearAliasRoute() + defer ClearAliasRoute() + fakeAliasEle := GetFakeRuleAliasEle() + aliases.AddAlias(fakeAliasEle) + params := map[string]string{} + params["userType"] = "VIP" + params["age"] = "10" + params["devType"] = "P40" + + type args struct { + tenantID string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"case1", args{tenantID: "12345678901234561234567890123456"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := MarshalTenantAliasList(tt.args.tenantID) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalTenantAliasList() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestCheckUrnWithParamsMatchRules(t *testing.T) { + convey.Convey("CheckAliasRoutingChange", t, func() { + aliases.AddAlias(GetFakeRuleAliasEle()) + + aliasRuleURN := "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasrulev1" + urnWithParam_old := "sn:cn:yrk:172120022620195843:function:0@default@test_func:latest" + convey.So(aliases.CheckAliasRoutingChange(aliasRuleURN, urnWithParam_old, make(map[string]string)), + convey.ShouldEqual, true) + + convey.So(aliases.CheckAliasRoutingChange(aliasRuleURN, urnWithParam_old, make(map[string]string)), + convey.ShouldEqual, true) + + aliases.AddAlias(GetFakeWeightAliasEle()) + aliasWeight := "sn:cn:yrk:12345678901234561234567890123456:function:0@default@aliasfunc:myaliasrulev1" + aliasURN_old := "sn:cn:yrk:c53626012ba84727b938ca8bf03108ef:function:0@default@aliasfunc:1" + convey.So(aliases.CheckAliasRoutingChange(aliasWeight, aliasURN_old, make(map[string]string)), + convey.ShouldEqual, true) + + convey.So(aliases.CheckAliasRoutingChange(aliasWeight, "old alias urn needed update session", + make(map[string]string)), convey.ShouldEqual, true) + }) +} + +func TestAliasWeightLoadBalancer(t *testing.T) { + convey.Convey("AliasWeightLoadBalancer", t, func() { + fakeAliasEle := &AliasElement{ + AliasURN: aliasURN, + FunctionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld", + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:1", + Name: "myaliasv1", + FunctionVersion: "1", + RevisionID: "20210617023315921", + Description: "", + RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:2", + Weight: 80, + }, + { + FunctionVersionURN: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:1", + Weight: 20, + }, + }, + } + ClearAliasRoute() + aliases.AddAlias(fakeAliasEle) + + aliasElementIf, _ := aliases.AliasMap.Load(fakeAliasEle.AliasURN) + aliasElement := aliasElementIf.(*AliasElement) + urnMap1 := []string{} + urnMap2 := make([]string, 50) + for i := 0; i < 50; i++ { + urn := aliasElement.getFuncVersionURN() + urnMap1 = append(urnMap1, urn) + } + var count int + for index, urn := range urnMap1 { + if urn == "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:2" { + count++ + urnMap2[index] = urn + } + } + convey.So(count, convey.ShouldEqual, 40) + + for i := 0; i < 50; i++ { + if urnMap2[i] != "" { + newUrn := aliasElement.getFuncVersionURN() + if newUrn != urnMap2[i] { + fmt.Printf("index:%d oldUrn:%s, newUrn:%s \n", i, urnMap2[i], newUrn) + } + urnMap2[i] = newUrn + } + } + count = 0 + for _, urn := range urnMap2 { + if urn == "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:2" { + count++ + } + } + convey.So(count, convey.ShouldEqual, 32) + }) +} + +func Test_ifAliasRoutingChanged(t *testing.T) { + convey.Convey("ifAliasRoutingChanged", t, func() { + convey.Convey("same type weight UpdateAllURN", func() { + origin := &AliasElement{RoutingType: "weight"} + newAlias := &AliasElement{RoutingType: "weight"} + mapEvent := ifAliasRoutingChanged(origin, newAlias) + convey.So(mapEvent[""], convey.ShouldEqual, UpdateAllURN) + }) + + convey.Convey("same type weight UpdateWeightGreyURN UpdateMainURN", func() { + origin := &AliasElement{RoutingType: "weight", RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "function/latest", + Weight: 100, + }, + }, + FunctionVersion: "0", + } + newAlias := &AliasElement{RoutingType: "weight", RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "function/1", + Weight: 80, + }, + { + FunctionVersionURN: "function/2", + Weight: 20, + }, + { + FunctionVersionURN: "function/3", + Weight: 0, + }, + }, + FunctionVersion: "1", + } + mapEvent := ifAliasRoutingChanged(origin, newAlias) + convey.So(mapEvent["function/3"], convey.ShouldEqual, UpdateWeightGreyURN) + convey.So(mapEvent["function/latest"], convey.ShouldEqual, UpdateWeightGreyURN) + convey.So(mapEvent["0"], convey.ShouldEqual, UpdateMainURN) + }) + + convey.Convey("same type rule", func() { + origin := &AliasElement{RoutingType: routingTypeRule, RoutingRules: routingRules{ + RuleLogic: "and", + Rules: nil, + GrayVersion: "0", + }, + FunctionVersionURN: "function/0", + } + newAlias := &AliasElement{RoutingType: routingTypeRule, RoutingRules: routingRules{ + RuleLogic: "or", + Rules: nil, + GrayVersion: "1", + }, + FunctionVersionURN: "function/1", + } + mapEvent := ifAliasRoutingChanged(origin, newAlias) + convey.So(mapEvent[origin.RoutingRules.GrayVersion], convey.ShouldEqual, UpdateAllURN) + convey.So(mapEvent[origin.FunctionVersionURN], convey.ShouldEqual, UpdateMainURN) + }) + + convey.Convey("different type rule weight", func() { + newAlias := &AliasElement{RoutingType: routingTypeRule, RoutingRules: routingRules{ + RuleLogic: "and", + Rules: nil, + GrayVersion: "0", + }, + FunctionVersionURN: "function/0", + } + origin := &AliasElement{RoutingType: "weight", RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "function/latest", + Weight: 100, + }, + }, + FunctionVersion: "0", + FunctionVersionURN: "function/1", + } + mapEvent := ifAliasRoutingChanged(origin, newAlias) + convey.So(mapEvent[""], convey.ShouldEqual, UpdateAllURN) + }) + + convey.Convey("different type weight rule", func() { + origin := &AliasElement{RoutingType: routingTypeRule, RoutingRules: routingRules{ + RuleLogic: "and", + Rules: nil, + GrayVersion: "0", + }, + FunctionVersionURN: "function/0", + } + newAlias := &AliasElement{RoutingType: "weight", RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "function/latest", + Weight: 0, + }, + }, + FunctionVersion: "0", + FunctionVersionURN: "function/1", + } + mapEvent := ifAliasRoutingChanged(origin, newAlias) + convey.So(mapEvent[""], convey.ShouldEqual, UpdateAllURN) + + newAlias1 := &AliasElement{RoutingType: "weight", RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "function/latest", + Weight: 0, + }, + }, + FunctionVersion: "0", + FunctionVersionURN: "function/1", + } + mapEvent1 := ifAliasRoutingChanged(origin, newAlias1) + convey.So(mapEvent1[""], convey.ShouldEqual, UpdateAllURN) + + newAlias2 := &AliasElement{RoutingType: "weight", RoutingConfigs: []*routingConfig{ + { + FunctionVersionURN: "function/latest", + Weight: 100, + }, + }, + FunctionVersion: "0", + FunctionVersionURN: "function/1", + } + mapEvent2 := ifAliasRoutingChanged(origin, newAlias2) + convey.So(mapEvent2[""], convey.ShouldEqual, NoneUpdate) + }) + }) +} + +func TestResolveAliasedFunctionNameToURN(t *testing.T) { + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(GetAliases().GetFuncVersionURNWithParams, func(aliasUrn string, params map[string]string) string { + return "resolved_" + aliasUrn + }) + + testCases := []struct { + name string + functionNameWithAlias string + tenantID string + params map[string]string + expectedURN string + }{ + { + name: "Simple function name without alias", + functionNameWithAlias: "myFunction", + tenantID: "tenant1", + params: nil, + expectedURN: "sn:cn:yrk:tenant1:function:0@default@myFunction:latest", + }, + { + name: "Function name with version number", + functionNameWithAlias: "myFunction:2", + tenantID: "tenant1", + params: nil, + expectedURN: "sn:cn:yrk:tenant1:function:0@default@myFunction:2", + }, + { + name: "Function name with alias", + functionNameWithAlias: "myFunction:prod", + tenantID: "tenant1", + params: map[string]string{"key": "value"}, + expectedURN: "sn:cn:yrk:tenant1:function:0@default@myFunction:prod", + }, + { + name: "Invalid function name (too many splits)", + functionNameWithAlias: "myFunction:prod:extra", + tenantID: "tenant1", + params: nil, + expectedURN: "", + }, + { + name: "Empty function name", + functionNameWithAlias: "", + tenantID: "tenant1", + params: nil, + expectedURN: "sn:cn:yrk:tenant1:function:0@default@:latest", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := ResolveAliasedFunctionNameToURN(tc.functionNameWithAlias, tc.tenantID, tc.params) + assert.Equal(t, tc.expectedURN, result, "URN resolution should match expected output") + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/aliasroute/event.go b/yuanrong/pkg/common/faas_common/aliasroute/event.go new file mode 100644 index 0000000..e853d20 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/aliasroute/event.go @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package aliasroute event +package aliasroute + +const ( + // NoneUpdate - + NoneUpdate = iota + // UpdateAllURN event + UpdateAllURN + // UpdateMainURN alias change its main functionURN event + UpdateMainURN + // UpdateWeightGreyURN alias[type weight] change its grey functionURN event + UpdateWeightGreyURN + // Delete event + Delete +) + +// AliasEvent - +type AliasEvent struct { + Type int + AliasURN string + FunctionVersionURN string +} diff --git a/yuanrong/pkg/common/faas_common/aliasroute/expression.go b/yuanrong/pkg/common/faas_common/aliasroute/expression.go new file mode 100644 index 0000000..c30b956 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/aliasroute/expression.go @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package aliasroute alias routing in busclient +package aliasroute + +import ( + "strconv" + "strings" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + expressionSize = 3 +) + +// Expression rule expression struct +type Expression struct { + leftVal string + operator string + rightVal string +} + +func compareIntegerStrings(a, b string) (int, error) { + numA, err := strconv.Atoi(a) + if err != nil { + return 0, err + } + + numB, err := strconv.Atoi(b) + if err != nil { + return 0, err + } + + if numA < numB { + return -1, nil + } else if numA > numB { + return 1, nil + } else { + return 0, nil + } +} + +// Execute the rule expression +func (exp *Expression) Execute(params map[string]string) bool { + log.GetLogger().Debugf("params %v, exp.leftVal %v,exp.rightVal %v", params, exp.leftVal, exp.rightVal) + val, exist := params[exp.leftVal] + if !exist { + log.GetLogger().Warnf("cannot find val for %s in params", exp.leftVal) + return false + } + + switch exp.operator { + case "=": + return strings.TrimSpace(val) == strings.TrimSpace(exp.rightVal) + case "!=": + return strings.TrimSpace(val) != strings.TrimSpace(exp.rightVal) + case ">": + ret, err := compareIntegerStrings(val, exp.rightVal) + return err == nil && ret == 1 + case "<": + ret, err := compareIntegerStrings(val, exp.rightVal) + return err == nil && ret == -1 + case ">=": + ret, err := compareIntegerStrings(val, exp.rightVal) + return err == nil && (ret == 1 || ret == 0) + case "<=": + ret, err := compareIntegerStrings(val, exp.rightVal) + return err == nil && (ret == -1 || ret == 0) + case "in": + return matchStr(val, exp.rightVal) + default: + log.GetLogger().Warnf("unknown operator(%s), return false", val, exp.operator) + return false + } +} + +func matchStr(str string, targetStr string) bool { + tars := strings.Split(targetStr, ",") + for _, tar := range tars { + // The rvalue of the 'in' operator ignores "" + if tar != "" && strings.TrimSpace(str) == strings.TrimSpace(tar) { + return true + } + } + return false +} diff --git a/yuanrong/pkg/common/faas_common/aliasroute/expression_test.go b/yuanrong/pkg/common/faas_common/aliasroute/expression_test.go new file mode 100644 index 0000000..df57ccb --- /dev/null +++ b/yuanrong/pkg/common/faas_common/aliasroute/expression_test.go @@ -0,0 +1,172 @@ +package aliasroute + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +type ExpressionTestSuite struct { + alias AliasElement +} + +func (suite *ExpressionTestSuite) SetupTest() { + +} + +func (suite *ExpressionTestSuite) TearDownTest() { + +} + +func (suite *ExpressionTestSuite) TestEquel() { + +} + +func genExpression(str string) (Expression, error) { + partition := strings.Split(str, ":") + if len(partition) != expressionSize { + return Expression{}, fmt.Errorf("express(#{str}) string format is error") + } + return Expression{ + leftVal: partition[0], + operator: partition[1], + rightVal: partition[2], + }, nil +} + +func ExecuteExp(t *testing.T, expStr string, params map[string]string) bool { + exp, err := genExpression(expStr) + if err != nil { + t.Error("gen expression fail: ", expStr) + return false + } + return exp.Execute(params) +} + +func TestExpEq(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + + got := ExecuteExp(t, "id:=:123", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:=:444", params) + assert.False(t, got) +} + +func TestExpNotEq(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + + got := ExecuteExp(t, "id:!=:200", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:!=:123", params) + assert.False(t, got) +} + +func TestExpLt(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "id:<:200", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:<:100", params) + assert.False(t, got) + + got = ExecuteExp(t, "type:<:100", params) + assert.False(t, got) +} + +func TestExpLtEq(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "id:<=:200", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:<=:100", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:<=:123", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:<=:100", params) + assert.False(t, got) +} + +func TestExpGt(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "id:>:200", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:>:100", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:>:100", params) + assert.False(t, got) +} + +func TestExpGtEq(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "id:>=:200", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:>=:100", params) + assert.True(t, got) + + got = ExecuteExp(t, "id:>=:123", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:>=:1", params) + assert.False(t, got) +} + +func TestExpIn(t *testing.T) { + params := map[string]string{} + params["type"] = "p40" + + got := ExecuteExp(t, "type:in:p40,mate40", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:in:mate40, p40", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:in:mate40, p40 , p30", params) + assert.True(t, got) + + got = ExecuteExp(t, "type:in:mate40,p30", params) + assert.False(t, got) + + got = ExecuteExp(t, "type:in:", params) + assert.False(t, got) +} + +func TestExpExcept(t *testing.T) { + params := map[string]string{} + params["id"] = "123" + params["type"] = "p40" + + got := ExecuteExp(t, "age:<:30", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:<:", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:<:abc", params) + assert.False(t, got) + + got = ExecuteExp(t, "id:||:123", params) + assert.False(t, got) +} diff --git a/yuanrong/pkg/common/faas_common/autogc/algorithm.go b/yuanrong/pkg/common/faas_common/autogc/algorithm.go new file mode 100644 index 0000000..abfa6f0 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/autogc/algorithm.go @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package autogc + +// Algorithm is algorithm for adjusting GOGC +type Algorithm interface { + Init(totalMemory, threshold uint64) + NextGOGC(currentMemory uint64, preGOGC int) int +} + +const ( + defaultMaxGOGC = 500 + defaultMinGCStep = 50 * MB + percent = 100 +) + +func min(x, y int) int { + if x < y { + return x + } + return y +} + +// DefaultAlg defines default algorithm to adjust GOGC +// when current memory <= threshold, it will adjust GOGC to match threshold, but not above defaultMaxGOGC (500) +// when current memory > threshold, it will adjust GOGC so that GC will trigger every defaultMinGCStep (50MB) heap alloc +type DefaultAlg struct { + total uint64 + threshold uint64 + maxGOGC int +} + +// Init initializes alg with total memory and memory threshold +func (da *DefaultAlg) Init(total, threshold uint64) { + da.total = total + da.threshold = threshold + da.maxGOGC = defaultMaxGOGC +} + +// NextGOGC calculates appropriated GOGC with current memory and previous GOGC +func (da *DefaultAlg) NextGOGC(currentMemory uint64, preGOGC int) int { + if da.threshold >= currentMemory+defaultMinGCStep { + return min(da.maxGOGC, int(percent*(float64(da.threshold)/float64(currentMemory)-1.0))) + } + return int(percent * defaultMinGCStep / currentMemory) +} diff --git a/yuanrong/pkg/common/faas_common/autogc/algorithm_test.go b/yuanrong/pkg/common/faas_common/autogc/algorithm_test.go new file mode 100644 index 0000000..efe5b25 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/autogc/algorithm_test.go @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package autogc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultAlg(t *testing.T) { + assert.Equal(t, 4*GB, 4294967296) + + alg := DefaultAlg{} + alg.Init(4*GB, 3200*MB) + + tests := []struct { + current uint64 + excepted int + }{ + { + current: 40 * MB, + excepted: defaultMaxGOGC, + }, + { + current: 3200 * MB, + excepted: 1, + }, + { + current: 3201 * MB, + excepted: 1, + }, + { + current: 3100 * MB, + excepted: 3, + }, + { + current: 2000 * MB, + excepted: 60, + }, + } + + for _, test := range tests { + assert.Equal(t, test.excepted, alg.NextGOGC(test.current, 0)) + } +} diff --git a/yuanrong/pkg/common/faas_common/autogc/autogc.go b/yuanrong/pkg/common/faas_common/autogc/autogc.go new file mode 100644 index 0000000..a16b8cb --- /dev/null +++ b/yuanrong/pkg/common/faas_common/autogc/autogc.go @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package autogc adjusts GOGC automatically inspired by +// https://eng.uber.com/how-we-saved-70k-cores-across-30-mission-critical-services/ +package autogc + +import ( + "os" + "runtime" + "runtime/debug" + "strconv" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +var ( + gcChannel = make(chan struct{}, 1) + gcAlg Algorithm + previousGOGC = 100 +) + +const ( + defaultMemoryThreshold = 80 +) + +// InitAutoGOGC starts to adjust GOGC automatically +func InitAutoGOGC() { + currentThreshold, err := strconv.Atoi(os.Getenv("AUTO_GC_MEMORY_THRESHOLD")) + if err != nil { + currentThreshold = defaultMemoryThreshold + log.GetLogger().Warnf("failed to get AUTO_GC_MEMORY_THRESHOLD, use default threshold, %s", err.Error()) + } else if currentThreshold <= 0 || currentThreshold > percent { + currentThreshold = defaultMemoryThreshold + } + log.GetLogger().Infof("current auto gc memory threshold: %d", currentThreshold) + limit, err := parseCGroupMemoryLimit() + if err != nil { + log.GetLogger().Errorf("failed to read cgroup memory limit, err: %s", err.Error()) + return + } + log.GetLogger().Infof("cgroup memory limit is %d, memory %d", limit, uint64(currentThreshold)*limit/percent) + + gcAlg = &DefaultAlg{} + if percent == 0 { + return + } + gcAlg.Init(limit, uint64(currentThreshold)*limit/percent) + + newCycleRefObj() + + go runAutoGOGC() +} + +func runAutoGOGC() { + file, err := os.Open(memPath) + if err != nil { + log.GetLogger().Errorf("failed to open statm file") + return + } + defer file.Close() + buffer := make([]byte, KB) + for range gcChannel { + rss, err := parseRSS(file, buffer) + if err != nil { + log.GetLogger().Errorf("failed to parse RSS, err: %s", err.Error()) + return + } + previousGOGC = debug.SetGCPercent(gcAlg.NextGOGC(rss, previousGOGC)) + } +} + +type finalizer struct { + ref *finalizerRef +} + +type finalizerRef struct { + parent *finalizer +} + +func finalizerHandler(f *finalizerRef) { + select { + case gcChannel <- struct{}{}: + default: + } + runtime.SetFinalizer(f, finalizerHandler) +} + +func newCycleRefObj() *finalizer { + f := &finalizer{} + f.ref = &finalizerRef{parent: f} + runtime.SetFinalizer(f.ref, finalizerHandler) + f.ref = nil + return f +} diff --git a/yuanrong/pkg/common/faas_common/autogc/autogc_test.go b/yuanrong/pkg/common/faas_common/autogc/autogc_test.go new file mode 100644 index 0000000..808b676 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/autogc/autogc_test.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package autogc + +import ( + "os" + "runtime" + "runtime/debug" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/utils" +) + +func TestInitAutoGOGC(t *testing.T) { + InitAutoGOGC() + runtime.GC() + assert.Equal(t, 100, previousGOGC) +} + +func TestInitAutoGOGC2(t *testing.T) { + patches := utils.InitPatchSlice() + patches.Append(utils.PatchSlice{ + gomonkey.ApplyFunc(debug.SetGCPercent, + func(percent int) int { + return 100 + })}) + defer patches.ResetAll() + os.Setenv("AUTO_GC_MEMORY_THRESHOLD", "120") + InitAutoGOGC() + runtime.GC() + assert.Equal(t, 100, previousGOGC) +} diff --git a/yuanrong/pkg/common/faas_common/autogc/util.go b/yuanrong/pkg/common/faas_common/autogc/util.go new file mode 100644 index 0000000..cc100a9 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/autogc/util.go @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package autogc + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "strconv" + "strings" +) + +const ( + cgroupMemLimitPath = "/sys/fs/cgroup/memory/memory.limit_in_bytes" + rssValueFieldIndex = 1 + base = 10 + bitSize = 64 +) + +// constants of memory unit +const ( + B = 1 << (10 * iota) + KB + MB + GB +) + +var ( + pageSize = uint64(os.Getpagesize()) + memPath = fmt.Sprintf("/proc/%d/statm", os.Getpid()) +) + +func parseCGroupMemoryLimit() (uint64, error) { + v, err := ioutil.ReadFile(cgroupMemLimitPath) + if err != nil { + return 0, err + } + return strconv.ParseUint(strings.TrimSpace(string(v)), base, bitSize) +} + +func parseRSS(f io.ReadSeeker, buffer []byte) (uint64, error) { + _, err := f.Seek(0, io.SeekStart) + if err != nil { + return 0, err + } + _, err = f.Read(buffer) + if err != nil && err != io.EOF { + return 0, err + } + fields := strings.Split(string(buffer), " ") + if len(fields) < (rssValueFieldIndex + 1) { + return 0, errors.New("invalid statm fields") + } + rss, err := strconv.ParseUint(fields[rssValueFieldIndex], base, bitSize) + if err != nil { + return 0, err + } + return rss * pageSize, nil +} diff --git a/yuanrong/pkg/common/faas_common/autogc/util_test.go b/yuanrong/pkg/common/faas_common/autogc/util_test.go new file mode 100644 index 0000000..d8d52b7 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/autogc/util_test.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package autogc + +import ( + "bytes" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseRSS(t *testing.T) { + r := bytes.NewReader([]byte("367597 12113 4058 3810 0 47257 0\n")) + buffer := make([]byte, KB) + + for i := 0; i < 10; i++ { + rss, err := parseRSS(r, buffer) + assert.Nil(t, err, "parseRSS should return no error") + assert.Equal(t, uint64(12113*os.Getpagesize()), rss) + } + + r = bytes.NewReader([]byte("123")) + _, err := parseRSS(r, make([]byte, KB)) + assert.Error(t, err, "parseRSS should failed") + + r = bytes.NewReader([]byte("123 abcde 132")) + _, err = parseRSS(r, make([]byte, KB)) + assert.Error(t, err, "parseRSS should failed") +} diff --git a/yuanrong/pkg/common/faas_common/config/config.go b/yuanrong/pkg/common/faas_common/config/config.go new file mode 100644 index 0000000..b7269b0 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/config/config.go @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config - +package config + +// TLSConfig certificate configuration +type TLSConfig struct { + CaContent string `json:"caContent"` + KeyContent string `json:"keyContent"` + CertContent string `json:"certContent"` +} diff --git a/yuanrong/pkg/common/faas_common/constant/app.go b/yuanrong/pkg/common/faas_common/constant/app.go new file mode 100644 index 0000000..152dfe9 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/constant/app.go @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constant - +package constant + +// used for app-job +const ( + // UserMetadataKey key used for the app createOpts + UserMetadataKey = "USER_PROVIDED_METADATA" + // EntryPointKey entrypoint for starting app + EntryPointKey = "ENTRYPOINT" + // AppFN - + FunctionNameApp = "app" + // AppFuncId - + AppFuncId = "12345678901234561234567890123456/0-system-faasExecutorPosixCustom/$latest" + // AppType type for invoking create-app + AppType = "SUBMISSION" + // AppStatusPending - + AppStatusPending = "PENDING" + // AppStatusRunning - + AppStatusRunning = "RUNNING" + // AppStatusSucceeded - + AppStatusSucceeded = "SUCCEEDED" + // AppStatusFailed - + AppStatusFailed = "FAILED" + // AppStatusStopped - + AppStatusStopped = "STOPPED" + + // AppInvokeTimeout 30min + AppInvokeTimeout = 1800 +) + +// AppInfo - Ray job JobDetails +type AppInfo struct { + Key string `json:"key"` + // Enum: "SUBMISSION" "DRIVER" + Type string `json:"type"` + Entrypoint string `json:"entrypoint"` + SubmissionID string `json:"submission_id"` + DriverInfo DriverInfo `json:"driver_info" valid:",optional"` + // Status Enum: "PENDING" "RUNNING" "STOPPED" "SUCCEEDED" "FAILED" + Status string `json:"status" valid:",optional"` + StartTime string `json:"start_time" valid:",optional"` + EndTime string `json:"end_time" valid:",optional"` + Metadata map[string]string `json:"metadata" valid:",optional"` + RuntimeEnv map[string]interface{} `json:"runtime_env" valid:",optional"` + DriverAgentHttpAddress string `json:"driver_agent_http_address" valid:",optional"` + DriverNodeID string `json:"driver_node_id" valid:",optional"` + DriverExitCode int32 `json:"driver_exit_code" valid:",optional"` + ErrorType string `json:"error_type" valid:",optional"` +} + +// DriverInfo - +type DriverInfo struct { + ID string `json:"id"` + NodeIPAddress string `json:"node_ip_address"` + PID string `json:"pid"` +} diff --git a/yuanrong/pkg/common/faas_common/constant/constant.go b/yuanrong/pkg/common/faas_common/constant/constant.go new file mode 100644 index 0000000..cb0f733 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/constant/constant.go @@ -0,0 +1,524 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constant - +package constant + +import "time" + +const ( + // LibruntimeHeaderSize is the header length of libruntime package + LibruntimeHeaderSize = 16 +) + +const ( + // BusinessTypeFG is the business type of FunctionGraph + BusinessTypeFG = iota + // BusinessTypeWiseCloud is the business type of WiseCloud + BusinessTypeWiseCloud +) + +const ( + // BackendTypeKernel - + BackendTypeKernel = iota + // BackendTypeFG - + BackendTypeFG +) + +const ( + // DeployModeContainer - + DeployModeContainer = "Container" + // DeployModeProcesses - + DeployModeProcesses = "Processes" +) + +const ( + // HeaderRequestID - + HeaderRequestID = "X-Request-Id" + // HeaderTraceID - + HeaderTraceID = "X-Trace-Id" + // HeaderTraceParent - + HeaderTraceParent = "Traceparent" +) + +const ( + // KernelResourceNotEnoughErrCode is the error code of kernel resource not enough + KernelResourceNotEnoughErrCode = 1002 + // KernelInnerSystemErrCode is the error code of kernel inner system error + KernelInnerSystemErrCode = 3003 + // KernelRequestErrBetweenRuntimeAndBus is the error code of bus communicate with runtime + KernelRequestErrBetweenRuntimeAndBus = 3001 + // KernelUserCodeLoadErrCode is the error code if kernel user code load error + KernelUserCodeLoadErrCode = 2001 + // KernelUserFunctionExceptionErrCode is the error code of kernel when user function exception + KernelUserFunctionExceptionErrCode = 2002 + // KernelCreateLimitErrCode is the error code of kernel when create limited + KernelCreateLimitErrCode = 1012 + // KernelWriteEtcdCircuitErrCode is the error code of kernel when write etcd failed or circuit + KernelWriteEtcdCircuitErrCode = 3005 + // KernelDataSystemUnavailable is the error code of kernel when data system is unavailable + KernelDataSystemUnavailable = 3015 + // KernelNPUFAULTErrCode is the error code of kernel when user exit with npu card is fault + KernelNPUFAULTErrCode = 3016 +) + +const ( + // InsReqSuccessCode is the return code when instance request succeeds + InsReqSuccessCode = 6030 + // InsReqSuccessMessage is the return message when instance request succeeds + InsReqSuccessMessage = "instance request successfully" + // UnsupportedOperationErrorCode is the return code when operation is not supported + UnsupportedOperationErrorCode = 6031 + // UnsupportedOperationErrorMessage is the return message when operation is not supported + UnsupportedOperationErrorMessage = "operation not supported" + // FuncNotExistErrorCode is the return code when function does not exist + FuncNotExistErrorCode = 6032 + // FuncNotExistErrorMessage is the return message when function does not exist + FuncNotExistErrorMessage = "function not exist" + // InsNotExistErrorCode is the return code when instance does not exist + InsNotExistErrorCode = 6033 + // InsNotExistErrorMessage is the return message when instance does not exist + InsNotExistErrorMessage = "instance not exist" + // InsAcquireFailedErrorCode is the return code when acquire instance fails + InsAcquireFailedErrorCode = 6034 + // InsAcquireLeaseExistErrorCode - is the return code when acquire repeated lease + InsAcquireLeaseExistErrorCode = 6035 + // InsAcquireFailedErrorMessage is the return message when acquire instance fails + InsAcquireFailedErrorMessage = "failed to acquire instance" + // LeaseExpireOrDeletedErrorCode is the return code when lease expires or be deleted + LeaseExpireOrDeletedErrorCode = 6036 + // LeaseExpireOrDeletedErrorMessage is the return message when lease expires or be deleted + LeaseExpireOrDeletedErrorMessage = "lease expires or deleted" + // AcquireLeaseTrafficLimitErrorCode - + AcquireLeaseTrafficLimitErrorCode = 6037 + // AcquireLeaseTrafficLimitErrorMessage is reach max limit of acquiring lease concurrently + AcquireLeaseTrafficLimitErrorMessage = "reach max limit of acquiring lease concurrently" + // LeaseErrorInstanceIsAbnormalMessage - lease op failed, instance is abnormal + LeaseErrorInstanceIsAbnormalMessage = "lease op failed, instance is abnormal" + // InsAcquireTimeOutErrorCode is the return code when acquire instance timout + InsAcquireTimeOutErrorCode = 6038 + // AcquireLeaseVPCConflictErrorCode The called function instance has a VPC conflict + AcquireLeaseVPCConflictErrorCode = 6039 + // InstancesConfigEtcdPrefix - + InstancesConfigEtcdPrefix = "/instances" + // InstancePathPrefix is the etcd path where the instance info will be placed + InstancePathPrefix = "/sn/instance" + // ModuleSchedulerPrefix is the etcd path where the module scheduler info will be placed + ModuleSchedulerPrefix = "/sn/faas-scheduler/instances" + // SchedulerRolloutPrefix - + SchedulerRolloutPrefix = "/sn/faas-scheduler/rollout" + // RolloutConfigPrefix - + RolloutConfigPrefix = "/sn/faas-scheduler/rolloutConfig" + // HTTPTriggerPrefix - + HTTPTriggerPrefix = "/sn/triggers/triggerType/HTTP/business/" + // FunctionPrefix - + FunctionPrefix = "/sn/functions" + // AliasPrefix - + AliasPrefix = "/sn/aliases" + // LeasePrefix - + LeasePrefix = "/sn/lease" + // FunctionAvailClusterPrefix Used to identify whether the called function vpc conflicts with the cluster network + FunctionAvailClusterPrefix = "/sn/function/available/clusters/" + // FrontendInstancePrefix frontend instance information recorded in meta etcd + FrontendInstancePrefix = "/sn/frontend/instances" + // TenantQuotaPrefix define the key prefix of etcd for tenant metadata + TenantQuotaPrefix = "/sn/quota/cluster" + + // ETCDEventKeySeparator is the separator of ETCD event key + ETCDEventKeySeparator = "/" + + // DefaultMaxRequestBodySize frontend maximum request body size + DefaultMaxRequestBodySize = 100 * 1024 * 1024 + + // DefaultMapSize default map size + DefaultMapSize = 3 + // DefaultHostAliasesSliceSize default host aliases slice size + DefaultHostAliasesSliceSize = 4 + // MinCustomResourcesSize is min custom resource size of invoke + MinCustomResourcesSize = 0 + + // SchedulerExclusivityKey is the key for tenant exclusivity scheduler + SchedulerExclusivityKey = "exclusivity" + // SchedulerRecoverTime - + SchedulerRecoverTime = 30 * time.Second + // DefaultServerWriteTimeOut 1300s + DefaultServerWriteTimeOut = 1300 * time.Second + // SchedulerKeyTypeFunction - + SchedulerKeyTypeFunction = "function" + // SchedulerKeyTypeModule - + SchedulerKeyTypeModule = "module" + // StaticInstanceApplier mark the instance is created by static function + StaticInstanceApplier = "static_function" +) + +const ( + // KeySeparator is the separator in an ETCD key + KeySeparator = "/" + // ValidEtcdKeyLenForInstance is the valid len of an instance ETCD key + ValidEtcdKeyLenForInstance = 14 + // SysFunctionTenantID is the tenantID of a system function + SysFunctionTenantID = "0" + // FaasFrontendMark is a part of the function name of a faasfrontend system function + FaasFrontendMark = "system-faasfrontend" + // FaasSchedulerMark is a part of the function name of a faasscheduler system function + FaasSchedulerMark = "system-faasscheduler" + // FunctionsIndexForInstance is the functions index of an valid instance ETCD key + FunctionsIndexForInstance = 2 + // TenantIndexForInstance is the tenant index of an valid instance ETCD key + TenantIndexForInstance = 5 + // TenantIDIndexForInstance is the tenantID index of an valid instance ETCD key + TenantIDIndexForInstance = 6 + // FunctionIndexForInstance is the functon index of an valid instance ETCD key + FunctionIndexForInstance = 7 + // FunctionNameIndexForInstance is the functon name index of an valid instance ETCD key + FunctionNameIndexForInstance = 8 + // InstanceIDIndexForInstance is the instanceID index of an valid instance ETCD key + InstanceIDIndexForInstance = 13 + // FaasSchedulerName is function name of a faasscheduler system function + FaasSchedulerName = "0-system-faasscheduler" +) + +// InstanceStatus is stauts of instance_status object +type InstanceStatus int + +const ( + // KernelInstanceStatusExited instance is exited + KernelInstanceStatusExited InstanceStatus = -1 + // KernelInstanceStatusNew instance is not created + KernelInstanceStatusNew InstanceStatus = 0 + // KernelInstanceStatusScheduling instance is scheduling + KernelInstanceStatusScheduling InstanceStatus = 1 + // KernelInstanceStatusCreating instance is creating + KernelInstanceStatusCreating InstanceStatus = 2 + // KernelInstanceStatusRunning instance is running + KernelInstanceStatusRunning InstanceStatus = 3 + // KernelInstanceStatusFailed instance is failed + KernelInstanceStatusFailed InstanceStatus = 4 + // KernelInstanceStatusExiting instance is exiting + KernelInstanceStatusExiting InstanceStatus = 5 + // KernelInstanceStatusFatal instance abnormal exits + KernelInstanceStatusFatal InstanceStatus = 6 + // KernelInstanceStatusScheduleFailed instance is schedule failed + KernelInstanceStatusScheduleFailed InstanceStatus = 7 + // KernelInstanceStatusEvicting instance is evicting + KernelInstanceStatusEvicting InstanceStatus = 9 + // KernelInstanceStatusEvicted instance is evicted + KernelInstanceStatusEvicted InstanceStatus = 10 + // KernelInstanceStatusSubHealth instance is sub health + KernelInstanceStatusSubHealth InstanceStatus = 11 +) + +// InstanceStatusType is EXIT_TYPE of instance_status object +type InstanceStatusType int + +const ( + // KernelInstanceStatusTypeNoneExit - + KernelInstanceStatusTypeNoneExit InstanceStatusType = 0 + // KernelInstanceStatusTypeReturn - + KernelInstanceStatusTypeReturn InstanceStatusType = 1 + // KernelInstanceStatusTypeExceptionInfo - + KernelInstanceStatusTypeExceptionInfo InstanceStatusType = 2 + // KernelInstanceStatusTypeOomInfo - + KernelInstanceStatusTypeOomInfo InstanceStatusType = 3 + // KernelInstanceStatusTypeStandardInfo - + KernelInstanceStatusTypeStandardInfo InstanceStatusType = 4 + // KernelInstanceStatusTypeUnknownError - + KernelInstanceStatusTypeUnknownError InstanceStatusType = 5 + // KernelInstanceStatusTypeUserKillInfo - + KernelInstanceStatusTypeUserKillInfo InstanceStatusType = 6 +) + +const ( + // RuntimeTypeCpp - + RuntimeTypeCpp = "cpp" + // RuntimeTypeCppBin - + RuntimeTypeCppBin = "cppbin" + // RuntimeTypeJava - + RuntimeTypeJava = "java" + // RuntimeTypeNodejs - + RuntimeTypeNodejs = "nodejs" + // RuntimeTypePython - + RuntimeTypePython = "python" + // RuntimeTypeCustom - + RuntimeTypeCustom = "custom" + // RuntimeTypeFusion - + RuntimeTypeFusion = "fusion" + // RuntimeTypeHTTP - + RuntimeTypeHTTP = "http" +) + +const ( + // ExtendedCallHandler used as kernel metadata extendedMetaData.extended_handler.handler field + ExtendedCallHandler = "handler" + // ExtendedInitHandler used as kernel metadata extendedMetaData.extended_handler.initializer field + ExtendedInitHandler = "initializer" + // CallHandler - + CallHandler = "call" + // InitHandler - + InitHandler = "init" + // CheckPointHandler - + CheckPointHandler = "checkpoint" + // RecoverHandler - + RecoverHandler = "recover" + // ShutdownHandler - + ShutdownHandler = "shutdown" + // SignalHandler - + SignalHandler = "signal" +) + +const ( + // PythonCallExecutor - + PythonCallExecutor = "faas_executor.faasCallHandler" + // PythonInitExecutor - + PythonInitExecutor = "faas_executor.faasInitHandler" + // PythonCheckPointExecutor - + PythonCheckPointExecutor = "faas_executor.faasCheckPointHandler" + // PythonRecoverExecutor - + PythonRecoverExecutor = "faas_executor.faasRecoverHandler" + // PythonShutDownExecutor - + PythonShutDownExecutor = "faas_executor.faasShutDownHandler" + // PythonSignalExecutor - + PythonSignalExecutor = "faas_executor.faasSignalHandler" +) + +const ( + // MaxTraceIDLength is the max length of traceID + MaxTraceIDLength = 128 +) + +const ( + // DefaultListenIP - + DefaultListenIP = "127.0.0.1" + // BusProxyHTTPPort - + BusProxyHTTPPort = "22423" +) + +const ( + // TraceIDRuntimeCallCtx Key value of the traceID parameter in the context input parameter of CallHandler + TraceIDRuntimeCallCtx = "traceID" +) + +const ( + // DefaultURNVersion is the default version of a URN + DefaultURNVersion = "latest" + // DefaultNameSpace is the default namespace + DefaultNameSpace = "default" +) + +const ( + // ClusterNameEnvKey defines env key for cluster name + ClusterNameEnvKey = "CLUSTER_NAME" + // PodIPEnvKey define pod ip env key + PodIPEnvKey = "POD_IP" + // HostNameEnvKey defines the hostname env key + HostNameEnvKey = "HOSTNAME" + // HostIPEnvKey defines the host ip env key + HostIPEnvKey = "HOST_IP" + // ResourceCPUName - + ResourceCPUName = "CPU" + // ResourceMemoryName - + ResourceMemoryName = "Memory" + // ResourceEphemeralStorage - + ResourceEphemeralStorage = "ephemeral-storage" + // CustomContainerRuntimeType is the runtime type for http function + CustomContainerRuntimeType = "custom image" + // CustomImageExtraTimeout is the timeout to offset non-pool start of custom image + CustomImageExtraTimeout = 300 + // PosixCustomRuntimeType is the runtime type for posix custom + PosixCustomRuntimeType = "posix-custom-runtime" + // CommonExtraTimeout - + CommonExtraTimeout = 2 + // TrafficRedundantRate limit redundancy rate for traffic limitation + TrafficRedundantRate = 1.1 + // SystemExtraTimeout - + SystemExtraTimeout = 5 + // KernelScheduleTimeout is the timeout set in kernel to avoid instance schedule timeout + KernelScheduleTimeout = 5 + // ModuleScheduler - + ModuleScheduler = "ModuleScheduler" + // AffinityPoolIDKey - + AffinityPoolIDKey = "AFFINITY_POOL_ID" + // UnUseAntiOtherLabelsKey - + UnUseAntiOtherLabelsKey = "unUseAntiOtherLabels" + + // BusinessTypeServe - + BusinessTypeServe = "serve" + // URLSeparator is the separator of http url + URLSeparator = "/" + // ApplicationIndex - + ApplicationIndex = 0 + + // EnableStream 用于识别黄区测试桩环境frontend是否开启流式场景监听本地数据系统状态 + EnableStream = "ENABLE_STREAM" +) + +const ( + // TrueStr - + TrueStr = "true" +) + +const ( + // HeaderInvokeURN - + HeaderInvokeURN = "X-Tag-VersionUrn" + // HeaderStateKey - + HeaderStateKey = "X-State-Key" + // HeaderNodeLabel is node label + HeaderNodeLabel = "X-Node-Label" + // HeaderCPUSize is cpu size specified by invoke + HeaderCPUSize = "X-Instance-Cpu" + // HeaderMemorySize is cpu memory specified by invoke + HeaderMemorySize = "X-Instance-Memory" + // HeaderCustomResource is customResource specified by invoke + HeaderCustomResource = "X-Instance-CustomResource" + // HeaderCustomResourceNew is customResource specified by invoke + HeaderCustomResourceNew = "X-Instance-Custom-Resource" + // HeaderContentType - + HeaderContentType = "Content-Type" + // HeaderContentLength - + HeaderContentLength = "Content-Length" + // HeaderBillingDuration - + HeaderBillingDuration = "X-Billing-Duration" + // HeaderInnerCode - + HeaderInnerCode = "X-Inner-Code" + // HeaderInvokeSummary - + HeaderInvokeSummary = "X-Invoke-Summary" + // HeaderLogResult - + HeaderLogResult = "X-Log-Result" + // HeaderLogType - + HeaderLogType = "X-Log-Type" + // DefaultLogFlag is the default flag for log + DefaultLogFlag = "None" + // HeaderAuthTimestamp is the timestamp for authorization + HeaderAuthTimestamp = "X-Timestamp-Auth" + // HeaderAuthorization is authorization + HeaderAuthorization = "Authorization" + // HeaderInvokeAlias indicates alias of current invocation + HeaderInvokeAlias = "x-invoke-alias" + // HeaderRetryFlag - + HeaderRetryFlag = "X-Retry-Flag" + // HeaderInstanceID - + HeaderInstanceID = "X-Instance-Id" + // HeaderInstanceIP - + HeaderInstanceIP = "X-Instance-Ip" + // HeaderWorkerCost - + HeaderWorkerCost = "X-Worker-Cost" + // HeaderCallInstance - + HeaderCallInstance = "X-Call-Instance" + // HeaderCallNode - + HeaderCallNode = "X-Call-Node" + + // HeaderEventSourceID - + HeaderEventSourceID = "X-Event-Source-Id" + // HeaderCallType is the request type + HeaderCallType = "X-Call-Type" + // HeaderForceDeploy is Force Deploy + HeaderForceDeploy = "X-Force-Deploy" + // HeaderStreamAPIGEvent - + HeaderStreamAPIGEvent = "X-Stream-Apig-Event" + // HeaderRequestStreamName - + HeaderRequestStreamName = "X-Request-Stream-Name" + // HeaderResponseStreamName - + HeaderResponseStreamName = "X-Response-Stream-Name" + // HeaderFrontendResponseStreamName - + HeaderFrontendResponseStreamName = "X-Frontend-Response-Stream-Name" + // HeaderRemoteClientId - + HeaderRemoteClientId = "X-Remote-Client-Id" +) + +const ( + // NewLease for add a lease of client + NewLease = "NewLease" + // KeepAlive for keep client alive + KeepAlive = "KeepAlive" + // DelLease for del a lease of client + DelLease = "DelLease" +) +const ( + // MetaFuncKey key used to match functions within ETCD + MetaFuncKey = "/sn/functions/business/yrk/tenant/%s/function/%s/version/%s" + // SilentFuncKey key used to match silent functions within ETCD + SilentFuncKey = "/silent/sn/functions/business/yrk/tenant/%s/function/%s/version/%s" +) + +const ( + // RuntimeInstanceName is instance name specified by user + RuntimeInstanceName = "instanceName" + // InstanceCreateEvent key of instance create event + InstanceCreateEvent = "instanceCreateEvent" + // InstanceRequirementResourcesKey key of FunctionSystemClient.Invoke args[1] + InstanceRequirementResourcesKey = "resourcesData" + // InstanceRequirementInsIDKey key of FunctionSystemClient.Invoke args[1] + InstanceRequirementInsIDKey = "designateInstanceID" + // InstanceCallerPodName name of Instance Caller.Invoke args[1] + InstanceCallerPodName = "instanceCallerPodName" + // InstanceTrafficLimited - name of instance traffic limit key args[1] + InstanceTrafficLimited = "instanceTrafficLimited" + // InstanceRequirementPoolLabel - key of poolLabel + InstanceRequirementPoolLabel = "poolLabel" + // InstanceSessionConfig is the key of instance session config in instance acquiring + InstanceSessionConfig = "instanceSessionConfig" + // InstanceRequirementInvokeLabel - name of instance label args[1] + InstanceRequirementInvokeLabel = "instanceInvokeLabel" +) + +const ( + // HeaderTenantID - + HeaderTenantID = "X-Tenant-Id" + // HeaderFunctionName - + HeaderFunctionName = "X-Function-Name" + // HeaderDataSystemPayloadInfo - + HeaderDataSystemPayloadInfo = "X-Data-System-Payload-Info" + // HeaderClientID - + HeaderClientID = "X-Client-Id" + // HeaderTargetServiceID - + HeaderTargetServiceID = "X-Target-Service-Id" +) + +const ( + // PipInstallPrefix - + PipInstallPrefix = "pip3.9 install" + // WorkingDirType - + WorkingDirType = "working_dir" + // PipCheckSuffix - + PipCheckSuffix = "pip3.9 check" +) + +const ( + // KillSignalVal - + KillSignalVal = 1 + // StopAppSignalVal used for stop-app + StopAppSignalVal = 7 + + // KillSignalAliasUpdate is signal for alias update + KillSignalAliasUpdate = 64 + // KillSignalFaaSSchedulerUpdate is signal for faasscheduler update + KillSignalFaaSSchedulerUpdate = 72 +) + +const ( + // InstanceNameNote notes instance name + InstanceNameNote = "INSTANCE_NAME_NOTE" + // FunctionKeyNote - is used to describe the function + FunctionKeyNote = "FUNCTION_KEY_NOTE" + // ResourceSpecNote - is used to describe the resource + ResourceSpecNote = "RESOURCE_SPEC_NOTE" + // SchedulerIDNote - is used to decribe the schedulerID + SchedulerIDNote = "SCHEDULER_ID_NOTE" + // InstanceTypeNote - is used to decribe the instance type: "scaled", "reserved", "state" + InstanceTypeNote = "INSTANCE_TYPE_NOTE" + // InstanceLabelNode - + InstanceLabelNode = "INSTANCE_LABEL_NOTE" +) diff --git a/yuanrong/pkg/common/faas_common/constant/delegate.go b/yuanrong/pkg/common/faas_common/constant/delegate.go new file mode 100644 index 0000000..86057ff --- /dev/null +++ b/yuanrong/pkg/common/faas_common/constant/delegate.go @@ -0,0 +1,96 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package constant + +const ( + // DelegateVolumeMountKey is the key for DELEGATE_VOLUME_MOUNTS in CreateOption + DelegateVolumeMountKey = "DELEGATE_VOLUME_MOUNTS" + // DelegateInitVolumeMountKey is the key for DELEGATE_INIT_VOLUME_MOUNTS in CreateOption + DelegateInitVolumeMountKey = "DELEGATE_INIT_VOLUME_MOUNTS" + // DelegateAgentVolumeMountKey is the key for DELEGATE_AGENT_VOLUME_MOUNTS i + DelegateAgentVolumeMountKey = "DELEGATE_AGENT_VOLUME_MOUNTS" + // DelegateVolumesKey is the key for DELEGATE_VOLUMES in CreateOption + DelegateVolumesKey = "DELEGATE_VOLUMES" + // DelegateHostAliases is the key for DELEGATE_HOST_ALIASES in CreateOption + DelegateHostAliases = "DELEGATE_HOST_ALIASES" + // DelegateDownloadKey is the key for DelegateDownload in CreateOption + DelegateDownloadKey = "DELEGATE_DOWNLOAD" + // DelegateBootstrapKey is the key for DelegateStart in CreateOption + DelegateBootstrapKey = "DELEGATE_BOOTSTRAP" + // DelegateLayerDownloadKey is the key for DelegateLayerDownload in CreateOption + DelegateLayerDownloadKey = "DELEGATE_LAYER_DOWNLOAD" + // DelegateMountKey is the key for DELEGATE_MOUNT in CreateOption + DelegateMountKey = "DELEGATE_MOUNT" + // DelegateEncryptKey is the key for DELEGATE_ENCRYPT in CreateOption + DelegateEncryptKey = "DELEGATE_ENCRYPT" + // DelegateContainerKey is the key for DELEGATE_CONTAINER in CreateOption + DelegateContainerKey = "DELEGATE_CONTAINER" + // DelegateContainerSideCars is the key for DELEGATE_SIDECARS in CreateOption + DelegateContainerSideCars = "DELEGATE_SIDECARS" + // DelegateInitContainers is the key for DELEGATE_INIT_CONTAINERS in CreateOption + DelegateInitContainers = "DELEGATE_INIT_CONTAINERS" + // DelegatePodAnnotations is used to transfer pod annotations to the kernel during instance creation + DelegatePodAnnotations = "DELEGATE_POD_ANNOTATIONS" + // DelegatePodLabels is used to transfer pod labels to the kernel during instance creation + DelegatePodLabels = "DELEGATE_POD_LABELS" + // DelegatePodInitLabels - + DelegatePodInitLabels = "DELEGATE_POD_INIT_LABELS" + // DelegatePodSeccompProfile is key for DELEGATE_POD_SECCOMP_PROFILE in CreateOption + DelegatePodSeccompProfile = "DELEGATE_POD_SECCOMP_PROFILE" + // DelegateInitVolumeMounts is key for DELEGATE_INIT_VOLUME_MOUNTS in CreateOption + DelegateInitVolumeMounts = "DELEGATE_INIT_VOLUME_MOUNTS" + // DelegateNuwaRuntimeInfo is key for DELEGATE_NUWA_RUNTIME_INFO in CreateOption + DelegateNuwaRuntimeInfo = "DELEGATE_NUWA_RUNTIME_INFO" + // DelegateInitEnv is key for DelegateInitEnv in CreateOption + DelegateInitEnv = "DELEGATE_INIT_ENV" + // EnvDelegateEncrypt - + EnvDelegateEncrypt = "DELEGATE_ENCRYPT" + // DelegateTolerations is the key for DELEGATE_TOLERATIONS in CreateOption + DelegateTolerations = "DELEGATE_TOLERATIONS" + + // DelegateRuntimeManagerTag the key of runtime-manager's image tag + DelegateRuntimeManagerTag = "DELEGATE_RUNTIME_MANAGER" + // DelegateNodeAffinity is the key for DELEGATE_NODE_AFFINITY in CreateOption + DelegateNodeAffinity = "DELEGATE_NODE_AFFINITY" + + // DelegateNodeAffinityPolicy - + DelegateNodeAffinityPolicy = "DELEGATE_NODE_AFFINITY_POLICY" + // DelegateAffinity - + DelegateAffinity = "DELEGATE_AFFINITY" + // DelegateNodeAffinityPolicyCoverage - + DelegateNodeAffinityPolicyCoverage = "coverage" + // DelegateNodeAffinityPolicyAggregation - + DelegateNodeAffinityPolicyAggregation = "aggregation" + + // InstanceLifeCycle - + InstanceLifeCycle = "lifecycle" + // InstanceLifeCycleDetached - + InstanceLifeCycleDetached = "detached" + + // DelegateDirectoryInfo is the path that will be monitored its disk usage + DelegateDirectoryInfo = "DELEGATE_DIRECTORY_INFO" + // DelegateDirectoryQuota is the quota of the path + DelegateDirectoryQuota = "DELEGATE_DIRECTORY_QUOTA" + // PostStartExec - + PostStartExec = "POST_START_EXEC" + // DelegateEnvVar - + DelegateEnvVar = "DELEGATE_ENV_VAR" + // BusinessTypeTypeNote - is used to decribe the instance business type: "Serve", "FaaS", "Actor" + BusinessTypeTypeNote = "BUSINESS_TYPE_NOTE" + // FaasInvokeTimeout is function exec timeout + FaasInvokeTimeout = "INVOKE_TIMEOUT" +) diff --git a/yuanrong/pkg/common/faas_common/constant/functiongraph.go b/yuanrong/pkg/common/faas_common/constant/functiongraph.go new file mode 100644 index 0000000..03c7ed0 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/constant/functiongraph.go @@ -0,0 +1,156 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constant - +package constant + +const ( + // BusinessTypeWebSocket websocket business type + BusinessTypeWebSocket = "WEBSOCKET" + // BusinessTypeCAE cae business type + BusinessTypeCAE = "CAE" + // BusinessTypeFaaS FaaS business type + BusinessTypeFaaS = "FaaS" + // BusinessType business type key + BusinessType = "BUSINESS_TYPE" + // LanguageJava8 language java8 + LanguageJava8 = "java8" +) + +const ( + // ZoneKey zone key + ZoneKey = "KUBERNETES_IO_AVAILABLEZONE" + // ZoneNameLen define zone length + ZoneNameLen = 255 + // DefaultAZ default az + DefaultAZ = "defaultaz" + + // PodNamespaceEnvKey define pod namespace env key + PodNamespaceEnvKey = "POD_NAMESPACE" + + // FunctionLoadTimeoutEnvKey load Function timeout time + FunctionLoadTimeoutEnvKey = "LOAD_FUNCTION_TIMEOUT" + + // ResourceLimitsMemory Memory limit, in bytes + ResourceLimitsMemory = "MEMORY_LIMIT_BYTES" + + // FuncBranchEnvKey is branch env key + FuncBranchEnvKey = "FUNC_BRANCH" + + // DataSystemBranchEnvKey is branch env key + DataSystemBranchEnvKey = "DATASYSTEM_CAPABILITY" + + // HTTPort busproxy httpserver listen port + HTTPort = "22423" + // TCPort busproxy tcpserver listen port + TCPort = "32568" + // BusWorkerServerTCPort bus worker server listen port + BusWorkerServerTCPort = "32569" + // BusRuntimeServerPort bus listen port for + BusRuntimeServerPort = "32570" + // DefaultCachePort indicates the default port of a cache-manager server + DefaultCachePort = "9993" + + // PlatformTenantID is tenant ID of platform function + PlatformTenantID = "0" + + // RuntimeLogOptTail - + RuntimeLogOptTail = "Tail" + // RuntimeLayerDirName - + RuntimeLayerDirName = "layer" + // RuntimeFuncDirName - + RuntimeFuncDirName = "func" + + // FunctionTaskAppID - + FunctionTaskAppID = "function-task" + + // BackpressureCode indicate that frontend should choose another proxy/worker and retry + BackpressureCode = 211429 + // HeaderBackpressure indicate that proxy can backpressure this request + HeaderBackpressure = "X-Backpressure" + // HeaderBackpressureNums Backpressure numbers counter + HeaderBackpressureNums = "X-Backpressure-Nums" + // MonitorFileName monitor file name + MonitorFileName = "monitor-disk" + + // DefaultFuncLogIndex default function log's index + DefaultFuncLogIndex = -2 + + // IsClusterUpgrading indicate that the cluster is in upgrading phase + IsClusterUpgrading = "FAAS_CLUSTER_IS_UPGRADING" +) + +const ( + // WorkerManagerApplier mark the instance is created by minInstance + WorkerManagerApplier = "worker-manager" + // ASBResApplier mark the instance is created by ASBRes + ASBResApplier = "ASBRes" + // FunctionTaskApplier mark the instance is created by minInstance + FunctionTaskApplier = "functiontask" + // PredictionApplier mark the instance is created by smart warmer predict + PredictionApplier = "prediction" + // FaasSchedulerApplier the instance is created by faas scheduler + FaasSchedulerApplier = "faas-scheduler" + // PoolInfoPrefix pool info prefix in redis + PoolInfoPrefix = "ClusterState_Pool" + // PoolInfoSep pool info separator in redis + PoolInfoSep = "_" + // ClusterIDKey cluster id key in system env + ClusterIDKey = "CLUSTER_ID" + // DefaultRecordingInterval default pool info recording interval, unit is second + DefaultRecordingInterval = 5 + // DefaultRecordExpiredTime default pool info record expired time, unit is second + DefaultRecordExpiredTime = 900 +) + +const ( + // FunctionAccessor - defines the microservice component name. + FunctionAccessor = "FunctionAccessor" + // FunctionTask - + FunctionTask = "FunctionTask" + // InstanceManager - + InstanceManager = "FunctionInstanceManager" + // StateManager - + StateManager = "StateManager" + // CacheManager - + CacheManager = "CacheManager" + // CacheServiceName indicates the header of the cache service + CacheServiceName = "cache-manager" + // FaaSScheduler - + FaaSScheduler = "faas-scheduler" + // SnapshotManager - + SnapshotManager = "SnapshotManager" + // Autoscaler define the alarm type + Autoscaler = "Autoscaler" +) + +// header constant key for FG +const ( + // FGHeaderRequestID - + FGHeaderRequestID = "X-Request-Id" + // FGHeaderAccessKey - + FGHeaderAccessKey = "X-Access-Key" + // FGHeaderSecretKey - + FGHeaderSecretKey = "X-Secret-Key" + // FGHeaderSecurityAccessKey - + FGHeaderSecurityAccessKey = "X-Security-Access-Key" + // FGHeaderSecuritySecretKey - + FGHeaderSecuritySecretKey = "X-Security-Secret-Key" + // FGHeaderAuthToken - + FGHeaderAuthToken = "X-Auth-Token" + // FGHeaderSecurityToken - + FGHeaderSecurityToken = "X-Security-Token" +) diff --git a/yuanrong/pkg/common/faas_common/constant/wisecloud.go b/yuanrong/pkg/common/faas_common/constant/wisecloud.go new file mode 100644 index 0000000..4630d45 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/constant/wisecloud.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constant - +package constant + +// header constant key for caas +const ( + CaaSHeaderTraceID = "X-Caas-Trace-Id" + CaaSHeaderRequestID = "X-Cff-Request-Id" +) + +// CaaS alarm +const ( + // WiseCloudSite site + WiseCloudSite = "WISECLOUD_SITE" + // TenantID WiseCloud tenantID + TenantID = "WISECLOUD_TENANTID" + // ApplicationID WiseCloud applicationId + ApplicationID = "WISECLOUD_APPLICATIONID" + // ServiceID WiseCloud serviceId + ServiceID = "WISECLOUD_SERVICEID" + // ClusterName define cluster env key + ClusterName = "CLUSTER_NAME" + // PodNameEnvKey define pod name env key + PodNameEnvKey = "POD_NAME" + // CloudMapId define in wiseCloud about cloudMap id + CloudMapId = "X_WISECLOUD_CLOUDMAP_ID" +) diff --git a/yuanrong/pkg/common/faas_common/crypto/cryptoapi_mock.go b/yuanrong/pkg/common/faas_common/crypto/cryptoapi_mock.go new file mode 100644 index 0000000..173bfb1 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/crypto/cryptoapi_mock.go @@ -0,0 +1,42 @@ +//go:build !scc + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crypto for auth +package crypto + +// SccConfig - +type SccConfig struct { + Enable bool `json:"enable" valid:"optional"` + SecretName string `json:"secretName" valid:"optional"` + Algorithm string `json:"algorithm" valid:"optional"` +} + +// SCCInitialized - +func SCCInitialized() bool { + return false +} + +// InitializeSCC - +func InitializeSCC(config SccConfig) error { + return nil +} + +// SCCDecrypt - +func SCCDecrypt(cipher []byte) (string, error) { + return "", nil +} diff --git a/yuanrong/pkg/common/faas_common/crypto/scc_crypto.go b/yuanrong/pkg/common/faas_common/crypto/scc_crypto.go new file mode 100644 index 0000000..87687a7 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/crypto/scc_crypto.go @@ -0,0 +1,166 @@ +//go:build scc + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package crypto for auth +package crypto + +import ( + "cryptoapi" + "encoding/json" + "fmt" + "path" + "sync" + + corev1 "k8s.io/api/core/v1" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +var ( + sccInitialized bool = false + m sync.RWMutex +) + +const ( + // Aes256Cbc - + Aes256Cbc = "AES256_CBC" + // Aes128Gcm - + Aes128Gcm = "AES128_GCM" + // Aes256Gcm - + Aes256Gcm = "AES256_GCM" + // Sm4Cbc - + Sm4Cbc = "SM4_CBC" + // Sm4Ctr - + Sm4Ctr = "SM4_CTR" + + sccConfigDefaultPath = "/home/snuser/secret/scc" +) + +// SccConfig - +type SccConfig struct { + Enable bool `json:"enable" valid:"optional"` + SecretName string `json:"secretName" valid:"optional"` + Algorithm string `json:"algorithm" valid:"optional"` + SccConfigPath string `json:"sccConfigPath" valid:"optional"` +} + +// SCCInitialized - +func SCCInitialized() bool { + m.RLock() + defer m.RUnlock() + return sccInitialized +} + +// GetSCCAlgorithm - +func GetSCCAlgorithm(algorithm string) int { + switch algorithm { + case Aes256Cbc: + return cryptoapi.ALG_AES256_CBC + case Aes128Gcm: + return cryptoapi.ALG_AES128_GCM + case Aes256Gcm: + return cryptoapi.ALG_AES256_GCM + case Sm4Cbc: + return cryptoapi.ALG_SM4_CBC + case Sm4Ctr: + return cryptoapi.ALG_SM4_CTR + default: + return cryptoapi.ALG_AES256_GCM + } +} + +// InitializeSCC - +func InitializeSCC(config SccConfig) error { + m.Lock() + defer m.Unlock() + + if !config.Enable { + return nil + } + options := cryptoapi.NewSccOptions() + sccConfigPath := config.SccConfigPath + if sccConfigPath == "" { + sccConfigPath = sccConfigDefaultPath + } + options.PrimaryKeyFile = path.Join(sccConfigPath, "primary.ks") + options.StandbyKeyFile = path.Join(sccConfigPath, "standby.ks") + options.LogPath = "/tmp/log/" + options.LogFile = "scc" + options.DefaultAlgorithm = GetSCCAlgorithm(config.Algorithm) + options.RandomDevice = "/dev/random" + options.EnableChangeFilePermission = 1 + cryptoapi.Finalize() + err := cryptoapi.InitializeWithConfig(options) + if err != nil { + log.GetLogger().Errorf("Initialize SCC Error = [%s]", err.Error()) + return err + } + sccInitialized = true + return nil +} + +// FinalizeSCC - +func FinalizeSCC() { + m.Lock() + defer m.Unlock() + sccInitialized = false + cryptoapi.Finalize() +} + +// SCCEncrypt - +func SCCEncrypt(plainInput string) ([]byte, error) { + cipher, err := cryptoapi.Encrypt(plainInput) + if err != nil { + log.GetLogger().Errorf("SCC Encrypt Error = [%s]", err.Error()) + return nil, err + } + + return []byte(cipher), nil +} + +// SCCDecrypt - +func SCCDecrypt(cipher []byte) (string, error) { + plain, err := cryptoapi.Decrypt(string(cipher)) + if err != nil { + log.GetLogger().Errorf("SCC Decrypt Error = [%s]", err.Error()) + return "", err + } + + return plain, nil +} + +// GenerateSCCVolumesAndMounts - +func GenerateSCCVolumesAndMounts(secretName string, builder *utils.VolumeBuilder) (string, string, error) { + if builder == nil { + return "", "", fmt.Errorf("volume builder is nil") + } + builder.AddVolume(corev1.Volume{Name: "scc-ks", + VolumeSource: corev1.VolumeSource{Secret: &corev1.SecretVolumeSource{SecretName: secretName}}}) + builder.AddVolumeMount(utils.ContainerRuntimeManager, + corev1.VolumeMount{Name: "scc-ks", MountPath: "/home/snuser/resource/scc"}) + volumesData, err := json.Marshal(builder.Volumes) + if err != nil { + return "", "", err + } + volumesMountData, err := json.Marshal(builder.Mounts[utils.ContainerRuntimeManager]) + if err != nil { + return "", "", err + } + return string(volumesData), string(volumesMountData), nil +} diff --git a/yuanrong/pkg/common/faas_common/crypto/scc_crypto_test.go b/yuanrong/pkg/common/faas_common/crypto/scc_crypto_test.go new file mode 100644 index 0000000..f271f25 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/crypto/scc_crypto_test.go @@ -0,0 +1,85 @@ +//go:build scc + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This test file can also be used as a tool to create, encrypt and decrypt our secrets and cipher texts +package crypto + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSCCEncryptDecryptInitialized(t *testing.T) { + var c = SccConfig{ + Enable: true, + Algorithm: "AES256_GCM", + } + ret := InitializeSCC(c) + assert.Nil(t, ret) + input := "text to encrypt" + encrypted, err := SCCEncrypt(input) + fmt.Printf("encrypted : %s\n", string(encrypted)) + assert.Nil(t, err) + decrypt, err := SCCDecrypt(encrypted) + fmt.Printf("decrypt : %s\n", decrypt) + assert.Nil(t, err) + assert.Equal(t, input, decrypt) + assert.NotEqual(t, encrypted, input) + FinalizeSCC() +} + +func TestSCCEncryptDecryptNotInitialized(t *testing.T) { + var c = SccConfig{ + Enable: false, + Algorithm: "AES256_GCM", + } + ret := InitializeSCC(c) + assert.Nil(t, ret) + input := "text to encrypt" + encrypted, _ := SCCEncrypt(input) + fmt.Printf("encrypted : %s\n", string(encrypted)) + decrypt, _ := SCCDecrypt(encrypted) + fmt.Printf("decrypt : %s\n", decrypt) + FinalizeSCC() +} + +func TestSCCEncryptDecryptAlgorithms(t *testing.T) { + var c = SccConfig{ + Enable: true, + Algorithm: "AES256_GCM", + } + + algorithms := []string{"AES256_CBC", "AES128_GCM", "AES256_GCM", "SM4_CBC", "SM4_CTR", "DEFAULT"} + for _, algo := range algorithms { + FinalizeSCC() + c.Algorithm = algo + ret := InitializeSCC(c) + assert.Nil(t, ret) + input := "text to encrypt" + encrypted, err := SCCEncrypt(input) + fmt.Printf("encrypted : %s\n", string(encrypted)) + assert.Nil(t, err) + decrypt, err := SCCDecrypt(encrypted) + fmt.Printf("decrypt : %s\n", decrypt) + assert.Nil(t, err) + assert.Equal(t, input, decrypt) + assert.NotEqual(t, encrypted, input) + } +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/cache.go b/yuanrong/pkg/common/faas_common/etcd3/cache.go new file mode 100644 index 0000000..3aa7239 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/cache.go @@ -0,0 +1,408 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 - +package etcd3 + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "os" + "sort" + "strconv" + "strings" + "time" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +const ( + cacheMetaFilePrefix = "etcdCacheMeta_" + cacheDataFilePrefix = "etcdCacheData_" + backupFileSuffix = "_backup" + cacheDataSplitNum = 3 +) + +var ( + // ErrInvalidCacheMeta - + ErrInvalidCacheMeta = errors.New("invalid cache meta") + // ErrCacheDataNotExist - + ErrCacheDataNotExist = errors.New("cache data not exist") + // ErrCacheDataMD5Mismatch - + ErrCacheDataMD5Mismatch = errors.New("cache data md5 mismatch") + cacheDataSeparator = "|" + cacheDataLineFeed = []byte("\n") +) + +// ETCDCacheMeta - +type ETCDCacheMeta struct { + Revision int64 `json:"revision"` + CacheMD5 string `json:"cacheMD5"` +} + +func (ew *EtcdWatcher) setCacheFilePath() { + cacheMetaFileName := fmt.Sprintf("%s%s", cacheMetaFilePrefix, strings.ReplaceAll(ew.key, "/", "#")) + cacheDataFileName := fmt.Sprintf("%s%s", cacheDataFilePrefix, strings.ReplaceAll(ew.key, "/", "#")) + ew.cacheConfig.MetaFilePath = fmt.Sprintf("%s/%s", ew.cacheConfig.PersistPath, cacheMetaFileName) + ew.cacheConfig.DataFilePath = fmt.Sprintf("%s/%s", ew.cacheConfig.PersistPath, cacheDataFileName) + ew.cacheConfig.BackupFilePath = fmt.Sprintf("%s%s", ew.cacheConfig.DataFilePath, backupFileSuffix) +} + +func (ew *EtcdWatcher) processETCDCache() { + log.GetLogger().Infof("start processing ETCD cache") + ew.setCacheFilePath() + persistInterval := ew.cacheConfig.FlushInterval + ticker := time.NewTicker(time.Minute * time.Duration(persistInterval)) + defer ticker.Stop() + // only record event with latest revision which is easier for flushCacheFile + eventCache := make(map[string]*Event, ew.cacheConfig.FlushThreshold) + for { + select { + case <-ticker.C: + log.GetLogger().Infof("ticker triggers, flushing cache now") + if err := ew.flushCacheToFile(eventCache); err == nil { + eventCache = make(map[string]*Event, ew.cacheConfig.FlushThreshold) + } + case event := <-ew.CacheChan: + log.GetLogger().Infof("threshold triggers, flushing cache now") + preEvent, exist := eventCache[event.Key] + if !exist || (exist && preEvent.Rev < event.Rev) { + eventCache[event.Key] = event + } + if len(eventCache) > ew.cacheConfig.FlushThreshold { + if err := ew.flushCacheToFile(eventCache); err == nil { + eventCache = make(map[string]*Event, ew.cacheConfig.FlushThreshold) + } + } + case <-ew.configCh: + log.GetLogger().Infof("cache config changed, new config is %+v", ew.cacheConfig) + if !ew.cacheConfig.EnableCache { + log.GetLogger().Warnf("etcd cache disabled, stop processing cache") + return + } + if ew.cacheConfig.FlushInterval != persistInterval { + persistInterval = ew.cacheConfig.FlushInterval + ticker.Reset(time.Minute * time.Duration(persistInterval)) + } + case <-ew.stopCh: + log.GetLogger().Warnf("etcd watcher stopped, stop processing cache") + return + } + } +} + +func (ew *EtcdWatcher) getCacheMeta() *ETCDCacheMeta { + var cacheMeta *ETCDCacheMeta + _, statErr := os.Stat(ew.cacheConfig.MetaFilePath) + if os.IsNotExist(statErr) { + return &ETCDCacheMeta{} + } + if statErr == nil { + cacheMetaData, err := os.ReadFile(ew.cacheConfig.MetaFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read cache meta file %s error %s", ew.cacheConfig.MetaFilePath, + err.Error()) + return nil + } + cacheMeta = &ETCDCacheMeta{} + if err = json.Unmarshal(cacheMetaData, cacheMeta); err != nil { + log.GetLogger().Errorf("failed to unmarshal cache meta file %s error %s", ew.cacheConfig.MetaFilePath, + err.Error()) + return nil + } + return cacheMeta + } + return nil +} + +func (ew *EtcdWatcher) cleanCacheFile(cleanMeta, cleanData, cleanBackup bool) { + if cleanMeta { + if err := os.Remove(ew.cacheConfig.MetaFilePath); err != nil { + log.GetLogger().Errorf("failed to remove cache meta file %s error %s", ew.cacheConfig.MetaFilePath, + err.Error()) + } + } + if cleanData { + if err := os.Remove(ew.cacheConfig.DataFilePath); err != nil { + log.GetLogger().Errorf("failed to remove cache data file %s error %s", ew.cacheConfig.DataFilePath, + err.Error()) + } + } + if cleanBackup { + if err := os.Remove(ew.cacheConfig.BackupFilePath); err != nil { + log.GetLogger().Errorf("failed to remove cache backup file %s error %s", ew.cacheConfig.BackupFilePath, + err.Error()) + } + } +} + +// processDataBackup turns dataFile to backFile +func (ew *EtcdWatcher) processDataBackup(cacheMeta *ETCDCacheMeta) error { + _, statDataFileErr := os.Stat(ew.cacheConfig.DataFilePath) + _, statBackupFileErr := os.Stat(ew.cacheConfig.BackupFilePath) + // need to handle backupFife if either dataFile or backupFile exists + if statDataFileErr == nil || statBackupFileErr == nil { + // if backupFile doesn't exist, it's the normal case, rename dataFile to backupFile if it exists. + // if backupFile exists, it's the fault case where flush is interrupted, remove dataFile if it exists. + if statDataFileErr == nil && os.IsNotExist(statBackupFileErr) { + if err := os.Rename(ew.cacheConfig.DataFilePath, ew.cacheConfig.BackupFilePath); err != nil { + log.GetLogger().Errorf("failed to rename cache file to %s error %s", ew.cacheConfig.BackupFilePath, + err.Error()) + ew.cleanCacheFile(true, true, true) + return err + } + } else if statDataFileErr == nil && statBackupFileErr == nil { + if err := os.Remove(ew.cacheConfig.DataFilePath); err != nil { + log.GetLogger().Errorf("failed to remove dirty cache file %s error %s", ew.cacheConfig.DataFilePath, + err.Error()) + return err + } + } + if utils.CalcFileMD5(ew.cacheConfig.BackupFilePath) != cacheMeta.CacheMD5 { + log.GetLogger().Errorf("md5 mismatch for cache backup file %s", ew.cacheConfig.BackupFilePath) + ew.cleanCacheFile(true, true, true) + return ErrCacheDataMD5Mismatch + } + } + return nil +} + +// flushCacheToFile will modify the given eventCache during processing +func (ew *EtcdWatcher) flushCacheToFile(eventCache map[string]*Event) error { + ew.Lock() + if ew.cacheFlushing { + ew.Unlock() + return nil + } + ew.cacheFlushing = true + defer func() { + ew.Lock() + ew.cacheFlushing = false + ew.Unlock() + }() + ew.Unlock() + cacheMeta := ew.getCacheMeta() + if cacheMeta == nil { + ew.cleanCacheFile(true, true, true) + return ErrInvalidCacheMeta + } + // backup dataFile if it exists, will generate new dataFile from backupFile and eventCache + if err := ew.processDataBackup(cacheMeta); err != nil { + return err + } + var scanner *bufio.Scanner + _, statBackupFileErr := os.Stat(ew.cacheConfig.BackupFilePath) + if statBackupFileErr == nil { + backupFile, err := os.OpenFile(ew.cacheConfig.BackupFilePath, os.O_RDONLY, 0600) + if err != nil { + log.GetLogger().Errorf("failed to open cache backup file %s error %s", ew.cacheConfig.BackupFilePath, + err.Error()) + return err + } + defer func() { + if err := backupFile.Close(); err != nil { + log.GetLogger().Errorf("failed to close backup file %s error %s", ew.cacheConfig.BackupFilePath, + err.Error()) + } + if err := os.Remove(ew.cacheConfig.BackupFilePath); err != nil { + log.GetLogger().Errorf("failed to remove backup file %s error %s", ew.cacheConfig.BackupFilePath, + err.Error()) + } + }() + scanner = bufio.NewScanner(backupFile) + } + dataFile, err := os.OpenFile(ew.cacheConfig.DataFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC|os.O_SYNC, 0600) + if err != nil { + log.GetLogger().Errorf("failed to open cache file %s error %s", ew.cacheConfig.DataFilePath, err.Error()) + return err + } + eventList := generateSortedCacheList(eventCache) + offset := int64(0) + for scanner != nil && scanner.Scan() { + line := scanner.Text() + items := strings.SplitN(line, cacheDataSeparator, cacheDataSplitNum) + if len(items) != cacheDataSplitNum { + log.GetLogger().Warnf("skip invalid data %s in cache file %s", line, ew.cacheConfig.BackupFilePath) + continue + } + scanKey, scanValue := items[0], []byte(items[2]) + scanRevision, err := strconv.ParseInt(items[1], 10, 64) + if err != nil { + log.GetLogger().Errorf("invalid revision format of %s in line %s cache file %s", items[1], line, + ew.cacheConfig.BackupFilePath) + continue + } + if scanRevision > cacheMeta.Revision { + cacheMeta.Revision = scanRevision + } + index := -1 + for i, event := range eventList { + if event.Rev > cacheMeta.Revision { + cacheMeta.Revision = event.Rev + } + // eventList keeps keys in lexicographical order which is also the order we set in cache data file, this + // loop only handles eventKey <= scanKey scenario which contains two types of keys : 1. eventKey which goes + // before scanKey with PUT type 2. eventKey equals to scanKey which will update or delete scanKey if it has + // a newer revision + if event.Key < scanKey { + if event.Type == PUT { + offset = flushEventToFile(dataFile, offset, []byte(event.Key), event.Value, event.Rev) + } + index = i + continue + } + if event.Key == scanKey { + // should not update or delete if event revision is older than cacheMeta + if event.Rev > scanRevision && event.Type == PUT { + scanValue = event.Value + scanRevision = event.Rev + } else if event.Rev > scanRevision && event.Type == DELETE { + scanValue = nil + } + index = i + } + // here eventKey >= scanKey no need to go further + break + } + if index != -1 { + eventList = eventList[index+1:] + } + if scanValue != nil { + offset = flushEventToFile(dataFile, offset, []byte(scanKey), scanValue, scanRevision) + } + } + for _, event := range eventList { + if event.Rev > cacheMeta.Revision { + cacheMeta.Revision = event.Rev + } + if event.Type == PUT { + offset = flushEventToFile(dataFile, offset, []byte(event.Key), event.Value, event.Rev) + } + } + if err = dataFile.Close(); err != nil { + log.GetLogger().Errorf("failed to close cache data file %s error %s", ew.cacheConfig.DataFilePath, + err.Error()) + } + if offset == 0 { + log.GetLogger().Errorf("failed to write data file %s", ew.cacheConfig.DataFilePath) + ew.cleanCacheFile(false, true, false) + return errors.New("failed to write data file") + } + cacheMeta.CacheMD5 = utils.CalcFileMD5(ew.cacheConfig.DataFilePath) + cacheMetaData, err := json.Marshal(cacheMeta) + if err != nil { + log.GetLogger().Errorf("failed to marshal cache meta error %s", err.Error()) + return err + } + if err = os.WriteFile(ew.cacheConfig.MetaFilePath, cacheMetaData, 0600); err != nil { + log.GetLogger().Errorf("failed to write cache meta file %s error %s", ew.cacheConfig.MetaFilePath, + err.Error()) + ew.cleanCacheFile(false, true, false) + return err + } + log.GetLogger().Infof("succeed to flush cache") + return nil +} + +func (ew *EtcdWatcher) restoreCacheFromFile() error { + ew.setCacheFilePath() + _, statBackupFileErr := os.Stat(ew.cacheConfig.BackupFilePath) + if statBackupFileErr == nil { + // backupFile exists, it's the fault scenario, flushCacheToFile with nil to restore dataFile from backupFile + ew.flushCacheToFile(nil) + } + _, statDataFileErr := os.Stat(ew.cacheConfig.DataFilePath) + if os.IsNotExist(statDataFileErr) { + return ErrCacheDataNotExist + } + cacheMeta := ew.getCacheMeta() + if cacheMeta == nil { + ew.cleanCacheFile(true, true, true) + return ErrInvalidCacheMeta + } + if utils.CalcFileMD5(ew.cacheConfig.DataFilePath) != cacheMeta.CacheMD5 { + log.GetLogger().Errorf("md5 mismatch for cache data file %s", ew.cacheConfig.DataFilePath) + ew.cleanCacheFile(true, true, true) + return ErrCacheDataMD5Mismatch + } + dataFile, err := os.OpenFile(ew.cacheConfig.DataFilePath, os.O_RDONLY, 0600) + if err != nil { + log.GetLogger().Errorf("failed to open cache backup file %s error %s", ew.cacheConfig.DataFilePath, + err.Error()) + ew.cleanCacheFile(true, true, true) + return err + } + scanner := bufio.NewScanner(dataFile) + for scanner.Scan() { + line := scanner.Text() + items := strings.SplitN(line, cacheDataSeparator, cacheDataSplitNum) + if len(items) != cacheDataSplitNum { + log.GetLogger().Warnf("skip invalid data %s in cache file %s", line, ew.cacheConfig.DataFilePath) + continue + } + scanKey, scanValue := items[0], []byte(items[2]) + scanRevision, err := strconv.ParseInt(items[1], 10, 64) + if err != nil { + log.GetLogger().Errorf("invalid revision format of %s in line %s file %s", items[1], line, + ew.cacheConfig.DataFilePath) + continue + } + ew.sendEvent(&Event{ + Type: PUT, + Key: scanKey, + Value: scanValue, + Rev: scanRevision, + }) + } + if err = dataFile.Close(); err != nil { + log.GetLogger().Errorf("failed to close cache backup file %s error %s", ew.cacheConfig.DataFilePath, + err.Error()) + } + ew.initialRev = cacheMeta.Revision + log.GetLogger().Infof("succeed to restore etcd cache to revision %d", cacheMeta.Revision) + return nil +} + +func flushEventToFile(f *os.File, offset int64, key, value []byte, revision int64) int64 { + buffer := new(bytes.Buffer) + buffer.Write(key) + buffer.Write([]byte(cacheDataSeparator)) + buffer.Write([]byte(strconv.FormatInt(revision, 10))) + buffer.Write([]byte(cacheDataSeparator)) + buffer.Write(value) + buffer.Write(cacheDataLineFeed) + _, err := f.WriteAt(buffer.Bytes(), offset) + if err != nil { + log.GetLogger().Errorf("failed to write content to cache file error %s", err.Error()) + return offset + } + return offset + int64(buffer.Len()) +} + +func generateSortedCacheList(cache map[string]*Event) []*Event { + cacheList := make([]*Event, 0, len(cache)) + for _, v := range cache { + cacheList = append(cacheList, v) + } + sort.Slice(cacheList, func(i, j int) bool { + return cacheList[i].Key < cacheList[j].Key + }) + return cacheList +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/cache_test.go b/yuanrong/pkg/common/faas_common/etcd3/cache_test.go new file mode 100644 index 0000000..570cdec --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/cache_test.go @@ -0,0 +1,362 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 - +package etcd3 + +import ( + "encoding/json" + "errors" + "os" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "yuanrong/pkg/common/faas_common/utils" +) + +func TestProcessETCDCache(t *testing.T) { + hackTicker := time.NewTicker(50 * time.Millisecond) + resetDuration := time.Duration(0) + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&time.Ticker{}), "Reset", func(_ *time.Ticker, d time.Duration) { + resetDuration = d + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + convey.Convey("ticker case", t, func() { + patch := gomonkey.ApplyFunc(time.NewTicker, func(d time.Duration) *time.Ticker { + hackTicker.Reset(50 * time.Millisecond) + return hackTicker + }) + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + stopCh := make(chan struct{}, 1) + ew := newEtcdWatcher() + ew.stopCh = stopCh + go ew.processETCDCache() + time.Sleep(500 * time.Millisecond) + ew.CacheChan <- &Event{ + Rev: 100, + Type: PUT, + Key: "/sn/function/123/hello/latest", + Value: []byte(`{"name":"hello","version":"latest"}`), + } + time.Sleep(500 * time.Millisecond) + stopCh <- struct{}{} + time.Sleep(500 * time.Millisecond) + data, err := os.ReadFile("etcdCacheData_#sn#function") + convey.So(err, convey.ShouldBeNil) + convey.So(string(data), convey.ShouldEqual, "/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n") + data, err = os.ReadFile("etcdCacheMeta_#sn#function") + convey.So(err, convey.ShouldBeNil) + convey.So(string(data), convey.ShouldEqual, "{\"revision\":100,\"cacheMD5\":\"4fca8f1c736ca30135ed16538f4aebfc\"}") + patch.Reset() + }) + convey.Convey("threshold case", t, func() { + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + stopCh := make(chan struct{}, 1) + ew := newEtcdWatcher() + ew.stopCh = stopCh + ew.cacheConfig.FlushThreshold = 0 + go ew.processETCDCache() + time.Sleep(500 * time.Millisecond) + ew.CacheChan <- &Event{ + Rev: 100, + Type: PUT, + Key: "/sn/function/123/hello/latest", + Value: []byte(`{"name":"hello","version":"latest"}`), + } + time.Sleep(500 * time.Millisecond) + stopCh <- struct{}{} + time.Sleep(500 * time.Millisecond) + data, err := os.ReadFile("etcdCacheData_#sn#function") + convey.So(err, convey.ShouldBeNil) + convey.So(string(data), convey.ShouldEqual, "/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n") + data, err = os.ReadFile("etcdCacheMeta_#sn#function") + convey.So(err, convey.ShouldBeNil) + convey.So(string(data), convey.ShouldEqual, "{\"revision\":100,\"cacheMD5\":\"4fca8f1c736ca30135ed16538f4aebfc\"}") + }) + convey.Convey("config update case", t, func() { + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + ew := newEtcdWatcher() + go ew.processETCDCache() + time.Sleep(500 * time.Millisecond) + ew.cacheConfig.FlushInterval = 20 + ew.configCh <- struct{}{} + time.Sleep(500 * time.Millisecond) + convey.So(resetDuration, convey.ShouldEqual, 20*time.Minute) + ew.cacheConfig.EnableCache = false + ew.configCh <- struct{}{} + }) + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") +} + +func TestFlushCacheToFile(t *testing.T) { + stopCh := make(chan struct{}, 1) + ew := &EtcdWatcher{ + key: "/sn/function", + CacheChan: make(chan *Event, 10), + configCh: make(chan struct{}, 1), + stopCh: stopCh, + cacheConfig: EtcdCacheConfig{ + EnableCache: true, + PersistPath: "./", + FlushInterval: 10, + FlushThreshold: 10, + }, + } + ew.setCacheFilePath() + convey.Convey("no dataFile and no backupFile", t, func() { + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + eventBuffer := map[string]*Event{ + "/sn/function/123/hello/latest": &Event{ + Rev: 100, + Key: "/sn/function/123/hello/latest", + Value: []byte(`{"name":"hello","version":"latest"}`), + }, + } + err := ew.flushCacheToFile(eventBuffer) + convey.So(err, convey.ShouldBeNil) + _, stateMetaFileErr := os.Stat("./etcdCacheMeta_#sn#function") + _, stateDataFileErr := os.Stat("./etcdCacheData_#sn#function") + _, stateBackupFileErr := os.Stat("./etcdCacheData_#sn#function_backup") + convey.So(stateMetaFileErr, convey.ShouldBeNil) + convey.So(stateDataFileErr, convey.ShouldBeNil) + convey.So(os.IsNotExist(stateBackupFileErr), convey.ShouldEqual, true) + }) + convey.Convey("no backupFile", t, func() { + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + // dataFile exists and no metaFile + os.WriteFile("./etcdCacheData_#sn#function", []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + err := ew.flushCacheToFile(nil) + convey.So(err, convey.ShouldNotBeNil) + _, errStatMeta := os.Stat("./etcdCacheMeta_#sn#function") + _, errStatData := os.Stat("./etcdCacheData_#sn#function") + _, errStatBackup := os.Stat("./etcdCacheData_#sn#function_backup") + convey.So(os.IsNotExist(errStatMeta), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatData), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatBackup), convey.ShouldEqual, true) + // dataFile exists and metaFile exists + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`{"revision":101,"cacheMD5":"726eb6f3140438ac1cbe334777e1a272"}`), 0600) + os.WriteFile("./etcdCacheData_#sn#function", []byte("/sn/function/123/goodbye/latest|101|{\"name\":\"goodbye\",\"version\":\"latest\"}\n/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n/sn/function/123/invalid/latest|xxx|{\"name\":\"invalid\",\"version\":\"latest\"}\nThisIsInvalidData\n"), 0600) + eventBuffer := map[string]*Event{ + "/sn/function/123/goodbye/latest": &Event{ + Rev: 102, + Type: PUT, + Key: "/sn/function/123/goodbye/latest", + Value: []byte(`{"name":"goodbye","version":"v1"}`), + }, + "/sn/function/123/hello/latest": &Event{ + Rev: 103, + Type: DELETE, + Key: "/sn/function/123/hello/latest", + Value: []byte(`{"name":"hello","version":"latest"}`), + }, + "/sn/function/123/echo/latest": &Event{ + Rev: 104, + Type: PUT, + Key: "/sn/function/123/echo/latest", + Value: []byte(`{"name":"echo","version":"latest"}`), + }, + } + err = ew.flushCacheToFile(eventBuffer) + convey.So(err, convey.ShouldBeNil) + data, err := os.ReadFile("./etcdCacheMeta_#sn#function") + convey.So(err, convey.ShouldBeNil) + meta := &ETCDCacheMeta{} + err = json.Unmarshal(data, meta) + convey.So(err, convey.ShouldBeNil) + convey.So(meta.Revision, convey.ShouldEqual, 104) + convey.So(meta.CacheMD5, convey.ShouldEqual, "006731eddc832c067f9814b64ae12833") + convey.So(utils.CalcFileMD5("./etcdCacheData_#sn#function"), convey.ShouldEqual, "006731eddc832c067f9814b64ae12833") + _, stateBackupFileErr := os.Stat("./etcdCacheData_#sn#function_backup") + convey.So(os.IsNotExist(stateBackupFileErr), convey.ShouldEqual, true) + }) + convey.Convey("backupFile exists", t, func() { + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + // backupFile mismatch with metaFile + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`{"revision":100,"cacheMD5":"4f4449c598ec58854d7104c4a64e979f"}`), 0600) + os.WriteFile("./etcdCacheData_#sn#function_backup", []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + err := ew.flushCacheToFile(nil) + convey.So(err, convey.ShouldNotBeNil) + _, errStatMeta := os.Stat("./etcdCacheMeta_#sn#function") + _, errStatData := os.Stat("./etcdCacheData_#sn#function") + _, errStatBackup := os.Stat("./etcdCacheData_#sn#function_backup") + convey.So(os.IsNotExist(errStatMeta), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatData), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatBackup), convey.ShouldEqual, true) + // backupFile exists and dataFile exists + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`{"revision":100,"cacheMD5":"4fca8f1c736ca30135ed16538f4aebfc"}`), 0600) + os.WriteFile("./etcdCacheData_#sn#function", []byte("/sn/function/123/goodbye/latest|101|{\"name\":\"goodbye\",\"version\":\"latest\"}\n"), 0600) + os.WriteFile("./etcdCacheData_#sn#function_backup", []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + err = ew.flushCacheToFile(nil) + convey.So(err, convey.ShouldBeNil) + data, err := os.ReadFile("./etcdCacheMeta_#sn#function") + convey.So(err, convey.ShouldBeNil) + meta := &ETCDCacheMeta{} + err = json.Unmarshal(data, meta) + convey.So(err, convey.ShouldBeNil) + convey.So(meta.Revision, convey.ShouldEqual, 100) + convey.So(meta.CacheMD5, convey.ShouldEqual, "4fca8f1c736ca30135ed16538f4aebfc") + convey.So(utils.CalcFileMD5("./etcdCacheData_#sn#function"), convey.ShouldEqual, "4fca8f1c736ca30135ed16538f4aebfc") + _, stateBackupFileErr := os.Stat("./etcdCacheData_#sn#function_backup") + convey.So(os.IsNotExist(stateBackupFileErr), convey.ShouldEqual, true) + }) + convey.Convey("file close fail", t, func() { + patch1 := gomonkey.ApplyMethod(reflect.TypeOf(&os.File{}), "Close", func(f *os.File) error { + return errors.New("some error") + }) + fileHack, _ := os.OpenFile("./xxx", os.O_WRONLY|os.O_CREATE|os.O_TRUNC|os.O_SYNC, 0600) + patch2 := gomonkey.ApplyFunc(os.OpenFile, func(name string, flag int, perm os.FileMode) (*os.File, error) { + return fileHack, nil + }) + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`{"revision":100,"cacheMD5":"4fca8f1c736ca30135ed16538f4aebfc"}`), 0600) + os.WriteFile("./etcdCacheData_#sn#function", []byte("/sn/function/123/goodbye/latest|101|{\"name\":\"goodbye\",\"version\":\"latest\"}\n"), 0600) + os.WriteFile("./etcdCacheData_#sn#function_backup", []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + err := ew.flushCacheToFile(nil) + convey.So(err, convey.ShouldNotBeNil) + os.Remove("etcdCacheMeta_#sn#function") + err = ew.flushCacheToFile(nil) + convey.So(err, convey.ShouldNotBeNil) + patch1.Reset() + patch2.Reset() + fileHack.Close() + os.Remove("./xxx") + }) + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") +} + +func TestRestoreCacheFromFile(t *testing.T) { + stopCh := make(chan struct{}, 1) + ew := &EtcdWatcher{ + key: "/sn/function", + ResultChan: make(chan *Event, 10), + CacheChan: make(chan *Event, 10), + configCh: make(chan struct{}, 1), + stopCh: stopCh, + cacheConfig: EtcdCacheConfig{ + EnableCache: true, + PersistPath: "./", + FlushInterval: 10, + FlushThreshold: 10, + }, + } + convey.Convey("no dataFile", t, func() { + err := ew.restoreCacheFromFile() + convey.So(err, convey.ShouldNotBeNil) + convey.So(len(ew.ResultChan), convey.ShouldEqual, 0) + }) + convey.Convey("backupFile exists", t, func() { + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + // invalid metaFile + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`this is a invalid json`), 0600) + os.WriteFile("./etcdCacheData_#sn#function_backup", []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + err := ew.restoreCacheFromFile() + convey.So(err, convey.ShouldNotBeNil) + convey.So(len(ew.ResultChan), convey.ShouldEqual, 0) + _, errStatMeta := os.Stat("./etcdCacheMeta_#sn#function") + _, errStatData := os.Stat("./etcdCacheData_#sn#function") + _, errStatBackup := os.Stat("./etcdCacheData_#sn#function_backup") + convey.So(os.IsNotExist(errStatMeta), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatData), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatBackup), convey.ShouldEqual, true) + // backupFile mismatches with metaFile + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`{"revision":100,"cacheMD5":"4f4449c598ec58854d7104c4a64e979f"}`), 0600) + os.WriteFile("./etcdCacheData_#sn#function_backup", []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + err = ew.restoreCacheFromFile() + convey.So(err, convey.ShouldNotBeNil) + convey.So(len(ew.ResultChan), convey.ShouldEqual, 0) + _, errStatMeta = os.Stat("./etcdCacheMeta_#sn#function") + _, errStatData = os.Stat("./etcdCacheData_#sn#function") + _, errStatBackup = os.Stat("./etcdCacheData_#sn#function_backup") + convey.So(os.IsNotExist(errStatMeta), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatData), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatBackup), convey.ShouldEqual, true) + }) + convey.Convey("dataFile exist and no backupFile", t, func() { + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + // dataFile mismatches with metaFile + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`{"revision":100,"cacheMD5":"4f4449c598ec58854d7104c4a64e979f"}`), 0600) + os.WriteFile("./etcdCacheData_#sn#function", []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + err := ew.restoreCacheFromFile() + convey.So(err, convey.ShouldNotBeNil) + convey.So(len(ew.ResultChan), convey.ShouldEqual, 0) + _, errStatMeta := os.Stat("./etcdCacheMeta_#sn#function") + _, errStatData := os.Stat("./etcdCacheData_#sn#function") + _, errStatBackup := os.Stat("./etcdCacheData_#sn#function_backup") + convey.So(os.IsNotExist(errStatMeta), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatData), convey.ShouldEqual, true) + convey.So(os.IsNotExist(errStatBackup), convey.ShouldEqual, true) + // dataFile matches with metaFile + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`{"revision":100,"cacheMD5":"03d9ff29f229e0123e427a1c84ad5afb"}`), 0600) + os.WriteFile("./etcdCacheData_#sn#function", []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\nThisIsInvalidData\n"), 0600) + err = ew.restoreCacheFromFile() + convey.So(err, convey.ShouldBeNil) + convey.So(len(ew.ResultChan), convey.ShouldEqual, 1) + event := <-ew.ResultChan + convey.So(event, convey.ShouldResemble, &Event{ + Rev: 100, + Key: "/sn/function/123/hello/latest", + Value: []byte(`{"name":"hello","version":"latest"}`), + }) + }) + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") +} + +func newEtcdWatcher() *EtcdWatcher { + return &EtcdWatcher{ + key: "/sn/function", + CacheChan: make(chan *Event, 10), + configCh: make(chan struct{}, 1), + cacheConfig: EtcdCacheConfig{ + EnableCache: true, + PersistPath: "./", + FlushInterval: 10, + FlushThreshold: 10, + }, + } +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/client.go b/yuanrong/pkg/common/faas_common/etcd3/client.go new file mode 100644 index 0000000..2ee6165 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/client.go @@ -0,0 +1,485 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 client +package etcd3 + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "sync" + "time" + + "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" +) + +var ( + routerEtcdClient *EtcdClient + metaEtcdClient *EtcdClient + caeMetaEtcdClient *EtcdClient + dataSystemEtcdClient *EtcdClient +) + +const ( + // Router router etcd type + Router = "route" + + // Meta meta etcd type + Meta = "meta" + + // CAEMeta cae meta etcd type + CAEMeta = "CAEMeta" + + // DataSystem cae meta etcd type + DataSystem = "DataSystem" + + defaultEtcdLostContactTime = 5 * time.Minute +) + +var ( + errInitRouterEtcd = errors.New("failed to init router etcd client") + errInitMetadataEtcd = errors.New("failed to init metadata etcd client") + errInitCAEMetadataEtcd = errors.New("failed to init CAE metadata etcd client") + errInitDataSystemEtcd = errors.New("failed to init dataSystem etcd client") +) + +// GetRouterEtcdClient - +func GetRouterEtcdClient() *EtcdClient { + return routerEtcdClient +} + +// GetMetaEtcdClient - +func GetMetaEtcdClient() *EtcdClient { + return metaEtcdClient +} + +// GetCAEMetaEtcdClient - +func GetCAEMetaEtcdClient() *EtcdClient { + return caeMetaEtcdClient +} + +// GetDataSystemEtcdClient - +func GetDataSystemEtcdClient() *EtcdClient { + return dataSystemEtcdClient +} + +// GetEtcdStatusLostContact - +func (e *EtcdClient) GetEtcdStatusLostContact() bool { + return e.etcdStatusAfterLostContact +} + +// GetEtcdStatusNow - +func (e *EtcdClient) GetEtcdStatusNow() bool { + return e.etcdStatusNow +} + +// GetEtcdType - +func (e *EtcdClient) GetEtcdType() string { + return e.etcdType +} + +// InitParam - +func InitParam() *EtcdInitParam { + return new(EtcdInitParam) +} + +// WithRouteEtcdConfig - +func (e *EtcdInitParam) WithRouteEtcdConfig(config EtcdConfig) *EtcdInitParam { + e.routeEtcdConfig = &config + return e +} + +// WithMetaEtcdConfig - +func (e *EtcdInitParam) WithMetaEtcdConfig(config EtcdConfig) *EtcdInitParam { + e.metaEtcdConfig = &config + return e +} + +// WithCAEMetaEtcdConfig - +func (e *EtcdInitParam) WithCAEMetaEtcdConfig(config EtcdConfig) *EtcdInitParam { + e.CAEMetaEtcdConfig = &config + return e +} + +// WithDataSystemEtcdConfig - +func (e *EtcdInitParam) WithDataSystemEtcdConfig(config EtcdConfig) *EtcdInitParam { + e.DataSystemEtcdConfig = &config + return e +} + +// WithStopCh - +func (e *EtcdInitParam) WithStopCh(ch <-chan struct{}) *EtcdInitParam { + e.stopCh = ch + return e +} + +// WithAlarmSwitch - +func (e *EtcdInitParam) WithAlarmSwitch(enableAlarm bool) *EtcdInitParam { + e.enableAlarm = enableAlarm + return e +} + +// InitRouterEtcdClient - +func InitRouterEtcdClient(etcdConfig EtcdConfig, alarmConfig alarm.Config, stopCh <-chan struct{}) error { + if err := InitParam(). + WithRouteEtcdConfig(etcdConfig). + WithStopCh(stopCh). + WithAlarmSwitch(alarmConfig.EnableAlarm). + InitClient(); err != nil { + return err + } + if routerClient := GetRouterEtcdClient(); routerClient != nil { + if err := routerClient.EtcdHeatBeat(); err != nil { + errInfo := fmt.Sprintf("failed to check etcd client conn, err: %s", err.Error()) + log.GetLogger().Errorf(errInfo) + routerClient.reportOrClearAlarm(alarm.GenerateAlarmLog, errInfo, alarm.Level2) + time.Sleep(DurationContextTimeout) + return err + } + } + return nil +} + +// InitMetaEtcdClient - +func InitMetaEtcdClient(etcdConfig EtcdConfig, alarmConfig alarm.Config, stopCh <-chan struct{}) error { + if err := InitParam(). + WithMetaEtcdConfig(etcdConfig). + WithAlarmSwitch(alarmConfig.EnableAlarm). + WithStopCh(stopCh). + InitClient(); err != nil { + return err + } + if metaClient := GetMetaEtcdClient(); metaClient != nil { + if err := metaClient.EtcdHeatBeat(); err != nil { + errInfo := fmt.Sprintf("failed to check etcd client conn, err: %s", err.Error()) + log.GetLogger().Errorf(errInfo) + metaClient.reportOrClearAlarm(alarm.GenerateAlarmLog, errInfo, alarm.Level2) + time.Sleep(DurationContextTimeout) + return err + } + } + return nil +} + +// InitCAEMetaEtcdClient - +func InitCAEMetaEtcdClient(etcdConfig EtcdConfig, alarmConfig alarm.Config, stopCh <-chan struct{}) error { + if err := InitParam(). + WithCAEMetaEtcdConfig(etcdConfig). + WithAlarmSwitch(alarmConfig.EnableAlarm). + WithStopCh(stopCh). + InitClient(); err != nil { + return err + } + if metaClient := GetCAEMetaEtcdClient(); metaClient != nil { + if err := metaClient.EtcdHeatBeat(); err != nil { + errInfo := fmt.Sprintf("failed to check etcd client conn, err: %s", err.Error()) + log.GetLogger().Errorf(errInfo) + metaClient.reportOrClearAlarm(alarm.GenerateAlarmLog, errInfo, alarm.Level2) + time.Sleep(DurationContextTimeout) + return err + } + } + return nil +} + +// InitDataSystemEtcdClient - +func InitDataSystemEtcdClient(etcdConfig EtcdConfig, alarmConfig alarm.Config, stopCh <-chan struct{}) error { + if err := InitParam(). + WithDataSystemEtcdConfig(etcdConfig). + WithAlarmSwitch(alarmConfig.EnableAlarm). + WithStopCh(stopCh). + InitClient(); err != nil { + return err + } + if etcdClient := GetDataSystemEtcdClient(); etcdClient != nil { + if err := etcdClient.EtcdHeatBeat(); err != nil { + errInfo := fmt.Sprintf("failed to check etcd client conn, err: %s", err.Error()) + log.GetLogger().Errorf(errInfo) + etcdClient.reportOrClearAlarm(alarm.GenerateAlarmLog, errInfo, alarm.Level2) + time.Sleep(DurationContextTimeout) + return err + } + } + return nil +} + +// InitClient initialize etcdClient based on initialization parameters. +func (e *EtcdInitParam) InitClient() error { + if e.routeEtcdConfig != nil && e.initRouteEtcdClient() != nil { + return errInitRouterEtcd + } + if e.metaEtcdConfig != nil && e.initMetadataEtcdClient() != nil { + return errInitMetadataEtcd + } + if e.CAEMetaEtcdConfig != nil && e.initCAEMetadataEtcdClient() != nil { + return errInitCAEMetadataEtcd + } + if e.DataSystemEtcdConfig != nil && e.initDataSystemEtcdClient() != nil { + return errInitDataSystemEtcd + } + return nil +} + +func (e *EtcdInitParam) initRouteEtcdClient() error { + if routerEtcdClient != nil { + return nil + } + var err error + if routerEtcdClient, err = newClient(e.routeEtcdConfig, e.stopCh, e.enableAlarm, Router); err != nil { + log.GetLogger().Errorf("failed to new router etcd client with error: %s", err.Error()) + return err + } + return nil +} + +func (e *EtcdInitParam) initMetadataEtcdClient() error { + if metaEtcdClient != nil { + return nil + } + var err error + log.GetLogger().Infof("new meta etcd client") + if metaEtcdClient, err = newClient(e.metaEtcdConfig, e.stopCh, e.enableAlarm, Meta); err != nil { + log.GetLogger().Errorf("failed to new metadata etcd client with error: %s", err.Error()) + return err + } + return nil +} + +func (e *EtcdInitParam) initCAEMetadataEtcdClient() error { + if caeMetaEtcdClient != nil { + return nil + } + var err error + log.GetLogger().Infof("new CAE meta etcd client") + if caeMetaEtcdClient, err = newClient(e.CAEMetaEtcdConfig, e.stopCh, e.enableAlarm, CAEMeta); err != nil { + log.GetLogger().Errorf("failed to new CAE metadata etcd client with error: %s", err.Error()) + return err + } + return nil +} + +func (e *EtcdInitParam) initDataSystemEtcdClient() error { + if dataSystemEtcdClient != nil { + return nil + } + var err error + log.GetLogger().Infof("new DataSystem etcd client") + if dataSystemEtcdClient, err = newClient(e.DataSystemEtcdConfig, e.stopCh, e.enableAlarm, DataSystem); err != nil { + log.GetLogger().Errorf("failed to new DataSystem etcd client with error: %s", err.Error()) + return err + } + return nil +} + +func newClient(config *EtcdConfig, stopCh <-chan struct{}, enableAlarm bool, + etcdType string) (*EtcdClient, error) { + if stopCh == nil { + return nil, errors.New("etcd stopCh should not be nil") + } + client, err := buildClient(config) + if err != nil { + log.GetLogger().Errorf("failed to new %s etcd client, %s", etcdType, err.Error()) + return nil, err + } + client.stopCh = stopCh + client.config = config + client.etcdType = etcdType + client.isAlarmEnable = enableAlarm + client.etcdStatusAfterLostContact = true + client.etcdStatusNow = true + + go client.keepConnAlive() + return client, nil +} + +func buildClient(config *EtcdConfig) (*EtcdClient, error) { + cfg, err := GetEtcdAuthType(*config).GetEtcdConfig() + if err != nil { + log.GetLogger().Errorf("failed to create shared etcd client error %s", err.Error()) + return nil, err + } + cfg.DialTimeout = etcdDialTimeout + cfg.DialKeepAliveTime = etcdKeepaliveTime + cfg.DialKeepAliveTimeout = etcdKeepaliveTimeout + cfg.Endpoints = config.Servers + etcdClient, err := clientv3.New(*cfg) + if err != nil { + log.GetLogger().Errorf("failed to create shared etcd client error %s", err.Error()) + return nil, err + } + return &EtcdClient{ + Client: etcdClient, + clientExitCh: make(chan struct{}), + cond: sync.NewCond(&sync.Mutex{}), + }, nil +} + +func (e *EtcdClient) keepConnAlive() { + timer := time.NewTimer(keepConnAliveTTL) + for { + select { + case <-timer.C: + e.checkConnState() + timer.Reset(keepConnAliveTTL) + case _, ok := <-e.stopCh: + if !ok { + log.GetLogger().Warnf("stop channel is closed and quits keep %s etcd conn alive task", e.etcdType) + } + e.cond.Broadcast() + timer.Stop() + return + } + } +} + +// EtcdHeatBeat - +func (e *EtcdClient) EtcdHeatBeat() error { + ctx, cancel := context.WithTimeout(context.Background(), keepConnAliveTTL) + defer cancel() + _, err := e.Client.Get(ctx, "alive", clientv3.WithKeysOnly()) + return err +} + +func (e *EtcdClient) checkConnState() { + e.rwMutex.RLock() + err := e.EtcdHeatBeat() + e.rwMutex.RUnlock() + + if err != nil { + if e.etcdTimer == nil { + e.abnormalContinuouslyTimes++ + e.etcdTimer = time.AfterFunc(defaultEtcdLostContactTime, func() { + e.etcdStatusAfterLostContact = false + errInfo := fmt.Sprintf("etcd %s lost contact over %v, etcdStatusAfterLostContact is %v", + e.etcdType, defaultEtcdLostContactTime, e.etcdStatusAfterLostContact) + e.reportOrClearAlarm(alarm.GenerateAlarmLog, errInfo, alarm.Level3) + log.GetLogger().Warnf(errInfo) + }) + } + e.etcdStatusNow = false + e.exitOnce.Do(func() { + close(e.clientExitCh) + }) + errInfo := fmt.Sprintf("failed to check etcd client conn, err: %s", err.Error()) + log.GetLogger().Errorf(errInfo) + e.reportOrClearAlarm(alarm.GenerateAlarmLog, errInfo, alarm.Level2) + if err = e.restart(); err != nil { + log.GetLogger().Errorf("failed to restart etcd client, %s", err.Error()) + } + return + } + if e.etcdStatusAfterLostContact == false { + e.reportOrClearAlarm(alarm.ClearAlarmLog, "Clear critical alarm, "+ + "The connection to etcd has been restored", alarm.Level3) + } + if e.abnormalContinuouslyTimes > 0 { + e.reportOrClearAlarm(alarm.ClearAlarmLog, "Clear major alarm, "+ + "The connection to etcd has been restored", alarm.Level2) + e.abnormalContinuouslyTimes = 0 + } + + if e.etcdTimer != nil { + e.etcdTimer.Stop() + e.etcdTimer = nil + e.etcdStatusAfterLostContact = true + log.GetLogger().Infof("reconnect to %s etcd", e.etcdType) + } + if !e.etcdStatusNow { + e.clientExitCh = make(chan struct{}) + e.exitOnce = sync.Once{} + e.cond.Broadcast() + } + e.etcdStatusNow = true +} + +func (e *EtcdClient) reportOrClearAlarm(opType string, detail string, alarmLevel string) { + if e.isAlarmEnable { + alarmDetail := &alarm.Detail{ + SourceTag: os.Getenv(constant.PodNameEnvKey) + "|" + os.Getenv(constant.PodIPEnvKey) + + "|" + os.Getenv(constant.ClusterName) + "|MetadataEtcdConnection", + OpType: opType, + Details: detail, + StartTimestamp: 0, + EndTimestamp: 0, + } + if alarmDetail.OpType == alarm.GenerateAlarmLog { + alarmDetail.StartTimestamp = int(time.Now().Unix()) + } else { + alarmDetail.EndTimestamp = int(time.Now().Unix()) + } + alarmInfo := &alarm.LogAlarmInfo{ + AlarmID: alarm.MetadataEtcdConnection00001, + AlarmName: "MetadataEtcdConnection", + AlarmLevel: alarmLevel, + } + if e.etcdType == Router { + alarmDetail.SourceTag = os.Getenv(constant.PodNameEnvKey) + "|" + os.Getenv(constant.PodIPEnvKey) + + "|" + os.Getenv(constant.ClusterName) + "|RouterEtcdConnection" + alarmInfo.AlarmID = alarm.RouterEtcdConnection00001 + alarmInfo.AlarmName = "RouterEtcdConnection" + } + if e.etcdType == CAEMeta { + alarmDetail.SourceTag = os.Getenv(constant.PodNameEnvKey) + "|" + os.Getenv(constant.PodIPEnvKey) + + "|" + os.Getenv(constant.ClusterName) + "|CAEMetadataEtcdConnection" + alarmInfo.AlarmID = alarm.RouterEtcdConnection00001 + alarmInfo.AlarmName = "CAEMetadataEtcdConnection" + } + alarm.ReportOrClearAlarm(alarmInfo, alarmDetail) + } +} + +func (e *EtcdClient) restart() error { + log.GetLogger().Infof("start to rebuild %s etcd client", e.etcdType) + recreatedClient, err := buildClient(e.config) + if err != nil { + log.GetLogger().Errorf("failed to recreate %s etcd client, %s", e.etcdType, err.Error()) + return err + } + e.rwMutex.Lock() + e.stop() + e.Client = recreatedClient.Client + e.rwMutex.Unlock() + return nil +} + +func (e *EtcdClient) stop() { + if err := e.Client.Close(); err != nil { + log.GetLogger().Errorf("failed to close %s etcd client, %s", e.etcdType, err.Error()) + } +} + +// AttachAZPrefix - +func (e *EtcdClient) AttachAZPrefix(key string) string { + if e.config != nil && len(e.config.AZPrefix) != 0 { + return fmt.Sprintf("/%s%s", e.config.AZPrefix, key) + } + return key +} + +// DetachAZPrefix - +func (e *EtcdClient) DetachAZPrefix(key string) string { + if e.config != nil && len(e.config.AZPrefix) != 0 { + return strings.TrimPrefix(key, fmt.Sprintf("/%s", e.config.AZPrefix)) + } + return key +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/client_test.go b/yuanrong/pkg/common/faas_common/etcd3/client_test.go new file mode 100644 index 0000000..86151d0 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/client_test.go @@ -0,0 +1,363 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 client +package etcd3 + +import ( + "context" + "errors" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/alarm" +) + +type KvMock struct { +} + +func (k *KvMock) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + //TODO implement me + panic("implement me") +} + +func (k *KvMock) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, nil +} + +func (k *KvMock) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + //TODO implement me + panic("implement me") +} + +func (k *KvMock) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, error) { + //TODO implement me + panic("implement me") +} + +func (k *KvMock) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + //TODO implement me + panic("implement me") +} + +func (k *KvMock) Txn(ctx context.Context) clientv3.Txn { + //TODO implement me + panic("implement me") +} + +func TestInitEtcdClientOK(t *testing.T) { + stopCh := make(chan struct{}) + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(buildClient, func(config *EtcdConfig) (*EtcdClient, error) { + return &EtcdClient{clientExitCh: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{})}, nil + }), + } + defer func() { + close(stopCh) + for _, patch := range patches { + patch.Reset() + } + }() + + convey.Convey("new RouteClient", t, func() { + err := InitParam(). + WithRouteEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + convey.So(err, convey.ShouldBeNil) + convey.So(GetRouterEtcdClient(), convey.ShouldNotBeNil) + }) + + convey.Convey("new MetadataClient", t, func() { + err := InitParam(). + WithMetaEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + convey.So(err, convey.ShouldBeNil) + convey.So(GetMetaEtcdClient(), convey.ShouldNotBeNil) + }) + + convey.Convey("new CAEMetadataClient", t, func() { + err := InitParam(). + WithCAEMetaEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + convey.So(err, convey.ShouldBeNil) + convey.So(GetCAEMetaEtcdClient(), convey.ShouldNotBeNil) + }) + + convey.Convey("new DataSystemEtcdClient", t, func() { + err := InitParam(). + WithDataSystemEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + convey.So(err, convey.ShouldBeNil) + convey.So(GetDataSystemEtcdClient(), convey.ShouldNotBeNil) + }) +} + +func TestInitEtcdClientFail(t *testing.T) { + var stopCh chan struct{} + routerEtcdClient = nil + metaEtcdClient = nil + caeMetaEtcdClient = nil + convey.Convey("new RouteClient", t, func() { + err := InitParam(). + WithRouteEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("new MetadataClient", t, func() { + err := InitParam(). + WithMetaEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("new RouteClient", t, func() { + stopCh = make(chan struct{}) + err := InitParam(). + WithRouteEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("new CAE MetadataClient", t, func() { + err := InitParam(). + WithCAEMetaEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + convey.So(err, convey.ShouldNotBeNil) + }) + close(stopCh) +} + +func TestInitEtcdClientKeepAliveOK(t *testing.T) { + stopCh := make(chan struct{}) + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(clientv3.New, func(cfg clientv3.Config) (*clientv3.Client, error) { + return client, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(kv), "Get", func(_ *KvMock, ctx context.Context, key string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, nil + }), + gomonkey.ApplyGlobalVar(&keepConnAliveTTL, time.Duration(100)*time.Millisecond), + } + defer func() { + close(stopCh) + for _, patch := range patches { + patch.Reset() + } + }() + + convey.Convey("etcd client alive", t, func() { + err := InitParam(). + WithRouteEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + time.Sleep(200 * time.Millisecond) + convey.So(err, convey.ShouldBeNil) + convey.So(GetRouterEtcdClient().GetEtcdStatusNow(), convey.ShouldEqual, true) + }) +} + +func TestInitEtcdClientKeepAliveReconnect(t *testing.T) { + routerEtcdClient = nil + metaEtcdClient = nil + stopCh := make(chan struct{}) + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(clientv3.New, func(cfg clientv3.Config) (*clientv3.Client, error) { + return client, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(kv), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, errors.New("lost connection") + }), + gomonkey.ApplyMethod(reflect.TypeOf(client), "Close", func(_ *clientv3.Client) error { + return nil + }), + gomonkey.ApplyGlobalVar(&keepConnAliveTTL, time.Duration(100)*time.Millisecond), + } + defer func() { + close(stopCh) + for _, patch := range patches { + patch.Reset() + } + }() + + convey.Convey("lost etcd client and reconnect", t, func() { + err := InitParam(). + WithRouteEtcdConfig(EtcdConfig{}). + WithStopCh(stopCh).InitClient() + time.Sleep(200 * time.Millisecond) + convey.So(err, convey.ShouldBeNil) + convey.So(GetRouterEtcdClient().GetEtcdStatusNow(), convey.ShouldEqual, false) + + patches = append(patches, gomonkey.ApplyMethod(reflect.TypeOf(kv), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, nil + })) + time.Sleep(200 * time.Millisecond) + convey.So(GetRouterEtcdClient().GetEtcdStatusNow(), convey.ShouldEqual, true) + convey.So(GetRouterEtcdClient().GetEtcdStatusLostContact(), convey.ShouldEqual, true) + }) +} + +func TestInitMetaEtcdClient(t *testing.T) { + convey.Convey("InitMetaEtcdClient", t, func() { + convey.Convey("failed to init", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdInitParam{}), "InitClient", + func(e *EtcdInitParam) error { + return errors.New("failed to init") + }).Reset() + stop := make(chan struct{}) + err := InitMetaEtcdClient(EtcdConfig{}, alarm.Config{}, stop) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("failed to heat beat", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdInitParam{}), "InitClient", + func(e *EtcdInitParam) error { + metaEtcdClient = &EtcdClient{} + return nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(metaEtcdClient), "EtcdHeatBeat", func(e *EtcdClient) error { + return errors.New("failed to heart beat") + }).Reset() + stop := make(chan struct{}) + err := InitMetaEtcdClient(EtcdConfig{}, alarm.Config{}, stop) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestInitCAEMetaEtcdClient(t *testing.T) { + convey.Convey("InitCAEMetaEtcdClient", t, func() { + convey.Convey("failed to init", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdInitParam{}), "InitClient", + func(e *EtcdInitParam) error { + return errors.New("failed to init") + }).Reset() + stop := make(chan struct{}) + err := InitMetaEtcdClient(EtcdConfig{}, alarm.Config{}, stop) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("failed to heat beat", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdInitParam{}), "InitClient", + func(e *EtcdInitParam) error { + caeMetaEtcdClient = &EtcdClient{} + return nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(caeMetaEtcdClient), "EtcdHeatBeat", func(e *EtcdClient) error { + return errors.New("failed to heart beat") + }).Reset() + stop := make(chan struct{}) + err := InitCAEMetaEtcdClient(EtcdConfig{}, alarm.Config{}, stop) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestInitDataSystemEtcdClient(t *testing.T) { + convey.Convey("InitDataSystemEtcdClient", t, func() { + convey.Convey("failed to init", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdInitParam{}), "InitClient", + func(e *EtcdInitParam) error { + return errors.New("failed to init") + }).Reset() + stop := make(chan struct{}) + err := InitDataSystemEtcdClient(EtcdConfig{}, alarm.Config{}, stop) + convey.So(err.Error(), convey.ShouldContainSubstring, "failed to init") + }) + convey.Convey("failed to heat beat", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdInitParam{}), "InitClient", + func(e *EtcdInitParam) error { + dataSystemEtcdClient = &EtcdClient{} + return nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(dataSystemEtcdClient), "EtcdHeatBeat", func(e *EtcdClient) error { + return errors.New("failed to heart beat") + }).Reset() + stop := make(chan struct{}) + err := InitDataSystemEtcdClient(EtcdConfig{}, alarm.Config{}, stop) + convey.So(err.Error(), convey.ShouldContainSubstring, "failed to heart beat") + }) + }) +} + +func TestInitRouterEtcdClient(t *testing.T) { + convey.Convey("InitMetaEtcdClient", t, func() { + convey.Convey("failed to init", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdInitParam{}), "InitClient", + func(e *EtcdInitParam) error { + return errors.New("failed to init") + }).Reset() + stop := make(chan struct{}) + err := InitRouterEtcdClient(EtcdConfig{}, alarm.Config{}, stop) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("failed to heat beat", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdInitParam{}), "InitClient", + func(e *EtcdInitParam) error { + routerEtcdClient = &EtcdClient{} + return nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(routerEtcdClient), "EtcdHeatBeat", func(e *EtcdClient) error { + return errors.New("failed to heart beat") + }).Reset() + stop := make(chan struct{}) + err := InitRouterEtcdClient(EtcdConfig{}, alarm.Config{}, stop) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func Test_reportOrClearAlarm(t *testing.T) { + convey.Convey("reportOrClearAlarm", t, func() { + convey.Convey("no test assertion", func() { + e := EtcdClient{isAlarmEnable: true, etcdType: Router} + e.reportOrClearAlarm(alarm.GenerateAlarmLog, "告警", "INFO") + }) + }) +} + +func Test_AZPrefixProcess(t *testing.T) { + convey.Convey("test AZPrefix", t, func() { + convey.Convey("AttachAZPrefix", func() { + e := EtcdClient{config: &EtcdConfig{ + AZPrefix: "az1", + }} + key := e.AttachAZPrefix("/sn/instance/xxx") + convey.So(key, convey.ShouldEqual, "/az1/sn/instance/xxx") + e.config.AZPrefix = "" + key = e.AttachAZPrefix("/sn/instance/xxx") + convey.So(key, convey.ShouldEqual, "/sn/instance/xxx") + }) + convey.Convey("DetachAZPrefix", func() { + e := EtcdClient{config: &EtcdConfig{ + AZPrefix: "az1", + }} + key := e.DetachAZPrefix("/az1/sn/instance/xxx") + convey.So(key, convey.ShouldEqual, "/sn/instance/xxx") + e.config.AZPrefix = "" + key = e.DetachAZPrefix("/sn/instance/xxx") + convey.So(key, convey.ShouldEqual, "/sn/instance/xxx") + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/config.go b/yuanrong/pkg/common/faas_common/etcd3/config.go new file mode 100644 index 0000000..1a1f397 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/config.go @@ -0,0 +1,278 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "os" + + "go.etcd.io/etcd/client/v3" + + commonCrypto "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/sts" + "yuanrong/pkg/common/faas_common/sts/cert" + commontls "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/common/faas_common/utils" +) + +// EtcdAuth etcd authentication interface +type EtcdAuth interface { + GetEtcdConfig() (*clientv3.Config, error) +} + +type noAuth struct { +} + +type tlsAuth struct { + caFile string + certFile string + keyFile string + user string + password string +} + +type pwdAuth struct { + user string + password string +} + +type clientTLSAuth struct { + cerfile []byte + keyfile []byte + cafile []byte + passphrasefile []byte +} + +// GetEtcdAuthType etcd authentication type +func GetEtcdAuthType(etcdConfig EtcdConfig) EtcdAuth { + if etcdConfig.AuthType == "TLS" { + return &clientTLSAuth{ + cerfile: []byte(etcdConfig.CertFile), + keyfile: []byte(etcdConfig.KeyFile), + cafile: []byte(etcdConfig.CaFile), + passphrasefile: []byte(etcdConfig.PassphraseFile), + } + } + if etcdConfig.SslEnable { + if os.Getenv(sts.EnvSTSEnable) == "true" { + return &tlsAuth{} + } + return &tlsAuth{ + certFile: etcdConfig.CertFile, + keyFile: etcdConfig.KeyFile, + caFile: etcdConfig.CaFile, + user: etcdConfig.User, + password: etcdConfig.Password, + } + } + if etcdConfig.Password == "" { + return &noAuth{} + } + if len(etcdConfig.User) != 0 || len(etcdConfig.Password) != 0 { + return &pwdAuth{ + user: etcdConfig.User, + password: etcdConfig.Password, + } + } + return &noAuth{} +} + +func (n *noAuth) GetEtcdConfig() (*clientv3.Config, error) { + return &clientv3.Config{}, nil +} + +func (t *tlsAuth) GetEtcdConfig() (*clientv3.Config, error) { + if os.Getenv(sts.EnvSTSEnable) == "true" { + return BuildStsCfg() + } + pool, err := commontls.GetX509CACertPool(t.caFile) + if err != nil { + log.GetLogger().Errorf("failed to getX509CACertPool: %s", err.Error()) + return nil, err + } + + var certs []tls.Certificate + if certs, err = commontls.LoadServerTLSCertificate(t.certFile, t.keyFile, "", "LOCAL", false); err != nil { + log.GetLogger().Errorf("failed to loadServerTLSCertificate: %s", err.Error()) + return nil, err + } + + clientAuthMode := tls.NoClientCert + cfg := &clientv3.Config{ + TLS: &tls.Config{ + RootCAs: pool, + Certificates: certs, + ClientAuth: clientAuthMode, + }, + } + if len(t.user) != 0 && len(t.password) != 0 { + pwd, err := localauth.Decrypt(t.password) + if err != nil { + log.GetLogger().Errorf("failed to decrypt etcd config with error %s", err) + return nil, err + } + cfg.Username = t.user + cfg.Password = string(pwd) + utils.ClearStringMemory(t.password) + } + return cfg, nil +} + +func (p *pwdAuth) GetEtcdConfig() (*clientv3.Config, error) { + if len(p.user) == 0 || len(p.password) == 0 { + return nil, errors.New("etcd user or password is empty") + } + pwd, err := localauth.Decrypt(p.password) + if err != nil { + log.GetLogger().Errorf("failed to decrypt etcd config with error %s", err) + return nil, err + } + cfg := &clientv3.Config{ + Username: p.user, + Password: string(pwd), + } + utils.ClearStringMemory(p.password) + return cfg, nil +} + +func (c *clientTLSAuth) getPassphrase() ([]byte, error) { + // check whether the passphrasefile file exists. If the file exists, the client key is encrypted using a password. + // If the file does not exist, the client key is not encrypted and can be directly read. + var keyPwd []byte + var err error + if _, err = os.Stat(string(c.passphrasefile)); err == nil { + keyPwd, err = ioutil.ReadFile(string(c.passphrasefile)) + if err != nil { + log.GetLogger().Errorf("failed to read passphrasefile, err: %s", err.Error()) + return nil, err + } + if crypto.SCCInitialized() { + pwd, err := crypto.SCCDecrypt(keyPwd) + if err != nil { + log.GetLogger().Errorf("failed to decrypt passphrasefile, err: %s", err.Error()) + return nil, err + } + keyPwd = []byte(pwd) + } + } + + return keyPwd, nil +} + +func (c *clientTLSAuth) getTLSConfig(encryptedKeyPEM []byte, keyPwd []byte, + certPEM []byte, caCertPEM []byte) (*tls.Config, error) { + // Decode will find the next PEM formatted block (certificate, private key etc) in the input. + // It returns that block and the remainder of the input. + // If no PEM data is found, keyBlock is nil and the whole of the input is returned in rest. + // When keyBlock is nil, an error is reported. + // You do not need to pay attention to the content of the second return value. + keyBlock, _ := pem.Decode(encryptedKeyPEM) + if keyBlock == nil { + log.GetLogger().Errorf("failed to decode key PEM block") + return nil, fmt.Errorf("failed to decode key PEM block") + } + keyDER, err := commonCrypto.DecryptPEMBlock(keyBlock, keyPwd) + if err != nil { + log.GetLogger().Errorf("failed to decrypt key: err: %s", err.Error()) + return nil, err + } + + key, err := x509.ParsePKCS1PrivateKey(keyDER) + if err != nil { + log.GetLogger().Errorf("failed to parse private key: err: %s", err.Error()) + return nil, err + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(caCertPEM) { + log.GetLogger().Errorf("failed to append CA certificate") + return nil, fmt.Errorf("failed to append CA certificate") + } + + clientCert, err := tls.X509KeyPair(certPEM, pem.EncodeToMemory( + &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})) + if err != nil { + log.GetLogger().Errorf("failed to create client certificate: %s", err.Error()) + return nil, err + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: certPool, + } + return tlsConfig, nil +} + +func (c *clientTLSAuth) GetEtcdConfig() (*clientv3.Config, error) { + keyPwd, err := c.getPassphrase() + if err != nil { + return nil, err + } + + certPEM, err := ioutil.ReadFile(string(c.cerfile)) + if err != nil { + log.GetLogger().Errorf("failed to read cert file: %s", err.Error()) + return nil, err + } + + caCertPEM, err := ioutil.ReadFile(string(c.cafile)) + if err != nil { + log.GetLogger().Errorf("failed to read ca file: %s", err.Error()) + return nil, err + } + + encryptedKeyPEM, err := ioutil.ReadFile(string(c.keyfile)) + if err != nil { + log.GetLogger().Errorf("failed to read key file: %s", err.Error()) + return nil, err + } + tlsConfig, err := c.getTLSConfig(encryptedKeyPEM, keyPwd, certPEM, caCertPEM) + if err != nil { + return nil, err + } + return &clientv3.Config{ + TLS: tlsConfig, + }, nil +} + +// BuildStsCfg - Construct tlsConfig from sts p12 +func BuildStsCfg() (*clientv3.Config, error) { + caCertsPool, tlsCert, err := cert.LoadCerts() + if err != nil { + log.GetLogger().Errorf("failed to get X509CACertPool and TLSCertificate: %s", err.Error()) + return nil, err + } + + clientAuthMode := tls.NoClientCert + tlsConfig := &clientv3.Config{ + TLS: &tls.Config{ + RootCAs: caCertsPool, + Certificates: []tls.Certificate{*tlsCert}, + ClientAuth: clientAuthMode, + }, + } + return tlsConfig, nil +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/config_test.go b/yuanrong/pkg/common/faas_common/etcd3/config_test.go new file mode 100644 index 0000000..e518f79 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/config_test.go @@ -0,0 +1,469 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "os" + "strings" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + clientv3 "go.etcd.io/etcd/client/v3" + + commonCrypto "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/sts/cert" + commontls "yuanrong/pkg/common/faas_common/tls" + mockUtils "yuanrong/pkg/common/faas_common/utils" +) + +func TestGetEtcdAuthType(t *testing.T) { + convey.Convey("tlsAuth", t, func() { + etcdConfig := EtcdConfig{ + SslEnable: true, + } + etcdAuth := GetEtcdAuthType(etcdConfig) + convey.So(etcdAuth, convey.ShouldResemble, &tlsAuth{ + certFile: etcdConfig.CertFile, + keyFile: etcdConfig.KeyFile, + caFile: etcdConfig.CaFile, + }) + }) + convey.Convey("tlsAuth", t, func() { + etcdConfig := EtcdConfig{ + SslEnable: false, + Password: "", + } + etcdAuth := GetEtcdAuthType(etcdConfig) + convey.So(etcdAuth, convey.ShouldResemble, &noAuth{}) + }) + convey.Convey("tlsAuth", t, func() { + etcdConfig := EtcdConfig{ + SslEnable: false, + Password: "p123", + } + etcdAuth := GetEtcdAuthType(etcdConfig) + convey.So(etcdAuth, convey.ShouldResemble, &pwdAuth{ + user: etcdConfig.User, + password: etcdConfig.Password, + }) + }) + convey.Convey("clientTLSAuth", t, func() { + etcdConfig := EtcdConfig{ + AuthType: "TLS", + CaFile: "CaFile", + CertFile: "CertFile", + KeyFile: "KeyFile", + PassphraseFile: "PassphraseFile", + } + etcdAuth := GetEtcdAuthType(etcdConfig) + convey.So(etcdAuth, convey.ShouldResemble, &clientTLSAuth{ + cerfile: []byte("CertFile"), + keyfile: []byte("KeyFile"), + cafile: []byte("CaFile"), + passphrasefile: []byte("PassphraseFile"), + }) + }) +} + +func TestGetEtcdConfig(t *testing.T) { + defer gomonkey.ApplyFunc(localauth.Decrypt, func(src string) ([]byte, error) { + return []byte(strings.Clone(src)), nil + }).Reset() + convey.Convey("noAuth", t, func() { + noAuth := &noAuth{} + cfg, err := noAuth.GetEtcdConfig() + convey.So(cfg, convey.ShouldResemble, &clientv3.Config{}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("tlsAuth", t, func() { + defer gomonkey.ApplyFunc(tls.X509KeyPair, func(certPEMBlock []byte, keyPEMBlock []byte) (tls.Certificate, error) { + return tls.Certificate{}, nil + }).Reset() + defer gomonkey.ApplyFunc(ioutil.ReadFile, func(filename string) ([]byte, error) { + return []byte{}, nil + }).Reset() + defer gomonkey.ApplyFunc(commontls.LoadServerTLSCertificate, func(certFile, keyFile, passPhase, decryptTool string, + isHTTPS bool) ([]tls.Certificate, error) { + return nil, nil + }).Reset() + tlsAuth := &tlsAuth{ + user: "root", + password: string([]byte("123")), + } + cfg, err := tlsAuth.GetEtcdConfig() + convey.So(err, convey.ShouldBeNil) + convey.So(cfg, convey.ShouldNotBeNil) + }) + convey.Convey("tlsAuth error", t, func() { + defer gomonkey.ApplyFunc(localauth.Decrypt, func(src string) ([]byte, error) { + return nil, errors.New("some error") + }).Reset() + tlsAuth := &tlsAuth{} + _, err := tlsAuth.GetEtcdConfig() + convey.So(err, convey.ShouldNotBeNil) + tlsAuth.user, tlsAuth.password = "root", "123" + _, err = tlsAuth.GetEtcdConfig() + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("pwdAuth", t, func() { + pwdAuth := &pwdAuth{ + user: "root", + password: string([]byte("123")), + } + cfg, err := pwdAuth.GetEtcdConfig() + convey.So(cfg.Password, convey.ShouldEqual, "123") + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("pwdAuth error", t, func() { + defer gomonkey.ApplyFunc(localauth.Decrypt, func(src string) ([]byte, error) { + return nil, errors.New("some error") + }).Reset() + pwdAuth := &pwdAuth{ + password: string([]byte("123")), + } + _, err := pwdAuth.GetEtcdConfig() + convey.So(err, convey.ShouldNotBeNil) + pwdAuth.user = "root" + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestGetPassphrase(t *testing.T) { + patches := []*Patches{ + ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, nil + }), + ApplyFunc(os.ReadFile, func(string) ([]byte, error) { + return []byte("dummyPassphrase"), nil + }), + ApplyFunc(crypto.SCCInitialized, func() bool { + return false + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + + c := &clientTLSAuth{ + passphrasefile: []byte("path/to/passphrasefile"), + } + + passphrase, err := c.getPassphrase() + assert.Nil(t, err) + assert.Equal(t, []byte("dummyPassphrase"), passphrase) +} + +func TestBuildStsCfg(t *testing.T) { + tests := []struct { + name string + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1", false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cert.LoadCerts, func() (*x509.CertPool, *tls.Certificate, + error) { + return &x509.CertPool{}, &tls.Certificate{}, nil + })}) + return patches + }}, + {"case2 LoadCerts error", true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cert.LoadCerts, func() (*x509.CertPool, *tls.Certificate, + error) { + return &x509.CertPool{}, &tls.Certificate{}, errors.New("error") + })}) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + _, err := BuildStsCfg() + if (err != nil) != tt.wantErr { + t.Errorf("BuildStsCfg() error = %v, wantErr %v", err, tt.wantErr) + return + } + patches.ResetAll() + }) + } +} + +func TestGetTLSConfig(t *testing.T) { + // Create test data + testKey, _ := rsa.GenerateKey(rand.Reader, 2048) + keyDER := x509.MarshalPKCS1PrivateKey(testKey) + encryptedPEM := pem.EncodeToMemory(&pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: keyDER, + }) + certPEM := []byte("test cert") + caCertPEM := []byte("test CA cert") + keyPwd := []byte("test password") + + tests := []struct { + name string + mockSetups func() []*gomonkey.Patches + expectedError bool + errorContains string + }{ + { + name: "decrypt PEM block failure", + mockSetups: func() []*gomonkey.Patches { + var patches []*gomonkey.Patches + + p1 := gomonkey.ApplyFunc(pem.Decode, func(data []byte) (*pem.Block, []byte) { + return &pem.Block{}, nil + }) + patches = append(patches, p1) + + p2 := gomonkey.ApplyFunc(commonCrypto.DecryptPEMBlock, func(block *pem.Block, password []byte) ([]byte, error) { + return nil, fmt.Errorf("decrypt error") + }) + patches = append(patches, p2) + + return patches + }, + expectedError: true, + errorContains: "decrypt error", + }, + { + name: "parse private key failure", + mockSetups: func() []*gomonkey.Patches { + var patches []*gomonkey.Patches + + p1 := gomonkey.ApplyFunc(pem.Decode, func(data []byte) (*pem.Block, []byte) { + return &pem.Block{}, nil + }) + patches = append(patches, p1) + + p2 := gomonkey.ApplyFunc(commonCrypto.DecryptPEMBlock, func(block *pem.Block, password []byte) ([]byte, error) { + return []byte("invalid key"), nil + }) + patches = append(patches, p2) + + p3 := gomonkey.ApplyFunc(x509.ParsePKCS1PrivateKey, func(der []byte) (*rsa.PrivateKey, error) { + return nil, fmt.Errorf("parse error") + }) + patches = append(patches, p3) + + return patches + }, + expectedError: true, + errorContains: "parse error", + }, + { + name: "append CA cert failure", + mockSetups: func() []*gomonkey.Patches { + var patches []*gomonkey.Patches + + p1 := gomonkey.ApplyFunc(pem.Decode, func(data []byte) (*pem.Block, []byte) { + return &pem.Block{}, nil + }) + patches = append(patches, p1) + + p2 := gomonkey.ApplyFunc(commonCrypto.DecryptPEMBlock, func(block *pem.Block, password []byte) ([]byte, error) { + return keyDER, nil + }) + patches = append(patches, p2) + + p3 := gomonkey.ApplyFunc(x509.ParsePKCS1PrivateKey, func(der []byte) (*rsa.PrivateKey, error) { + return testKey, nil + }) + patches = append(patches, p3) + + p4 := gomonkey.ApplyMethod((*x509.CertPool)(nil), "AppendCertsFromPEM", + func(_ *x509.CertPool, _ []byte) bool { + return false + }) + patches = append(patches, p4) + + return patches + }, + expectedError: true, + errorContains: "failed to append CA certificate", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.mockSetups() + defer func() { + for _, p := range patches { + p.Reset() + } + }() + + c := &clientTLSAuth{} + config, err := c.getTLSConfig(encryptedPEM, keyPwd, certPEM, caCertPEM) + + if tt.expectedError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + assert.Nil(t, config) + } else { + assert.NoError(t, err) + assert.NotNil(t, config) + } + }) + } +} + +func mockTls(keyDER []byte, testKey *rsa.PrivateKey) []*Patches { + var patches []*gomonkey.Patches + + // Mock pem.Decode + p1 := gomonkey.ApplyFunc(pem.Decode, func(data []byte) (*pem.Block, []byte) { + return &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: keyDER, + }, nil + }) + patches = append(patches, p1) + + // Mock commonCrypto.DecryptPEMBlock + p2 := gomonkey.ApplyFunc(commonCrypto.DecryptPEMBlock, func(block *pem.Block, password []byte) ([]byte, error) { + return keyDER, nil + }) + patches = append(patches, p2) + + // Mock x509.ParsePKCS1PrivateKey + p3 := gomonkey.ApplyFunc(x509.ParsePKCS1PrivateKey, func(der []byte) (*rsa.PrivateKey, error) { + return testKey, nil + }) + patches = append(patches, p3) + + // Mock certPool.AppendCertsFromPEM + p4 := gomonkey.ApplyMethod((*x509.CertPool)(nil), "AppendCertsFromPEM", + func(_ *x509.CertPool, _ []byte) bool { + return true + }) + patches = append(patches, p4) + + // Mock tls.X509KeyPair + p5 := gomonkey.ApplyFunc(tls.X509KeyPair, func(certPEM, keyPEM []byte) (tls.Certificate, error) { + return tls.Certificate{}, nil + }) + patches = append(patches, p5) + + return patches +} + +func TestTlsAuthGetEtcdConfig(t *testing.T) { + patches := []*Patches{ + ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, nil + }), + ApplyFunc(os.ReadFile, func(string) ([]byte, error) { + return []byte("dummyPassphrase"), nil + }), + ApplyFunc(crypto.SCCInitialized, func() bool { + return false + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + tests := []struct { + name string + mockSetups func() []*gomonkey.Patches + expectedError bool + errorContains string + }{ + { + name: "successful case", + mockSetups: func() []*gomonkey.Patches { + var patches []*gomonkey.Patches + + // Mock ioutil.ReadFile + p2 := gomonkey.ApplyFunc(ioutil.ReadFile, func(filename string) ([]byte, error) { + return []byte("test data"), nil + }) + patches = append(patches, p2) + + testKey, _ := rsa.GenerateKey(rand.Reader, 2048) + keyDER := x509.MarshalPKCS1PrivateKey(testKey) + tlsMocks := mockTls(keyDER, testKey) + patches = append(patches, tlsMocks...) + + return patches + }, + expectedError: false, + }, + { + name: "getPassphrase failure", + mockSetups: func() []*gomonkey.Patches { + p1 := gomonkey.ApplyFunc(pem.Decode, func(data []byte) (*pem.Block, []byte) { + return nil, nil + }) + return []*gomonkey.Patches{p1} + }, + expectedError: true, + errorContains: "failed to decode key PEM block", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.mockSetups() + defer func() { + for _, p := range patches { + p.Reset() + } + }() + + c := &clientTLSAuth{ + cerfile: []byte("cert.pem"), + cafile: []byte("ca.pem"), + keyfile: []byte("key.pem"), + } + config, err := c.GetEtcdConfig() + + if tt.expectedError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + assert.Nil(t, config) + } else { + assert.NoError(t, err) + assert.NotNil(t, config) + } + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/event.go b/yuanrong/pkg/common/faas_common/etcd3/event.go new file mode 100644 index 0000000..d5e5ab7 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/event.go @@ -0,0 +1,106 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 event +package etcd3 + +import ( + "go.etcd.io/etcd/api/v3/mvccpb" + "go.etcd.io/etcd/client/v3" +) + +const ( + // PUT event + PUT = iota + // DELETE event + DELETE + // HISTORYDELETE event + HISTORYDELETE + // HISTORYUPDATE event + HISTORYUPDATE + // ERROR unexpected event + ERROR + // SYNCED synced event + SYNCED +) + +// Event of databases +type Event struct { + Type int + Key string + Value []byte + PrevValue []byte + Rev int64 + ETCDType string +} + +// only type can be used +// notice watcher, ready to watch etcd kv. +func syncedEvent() *Event { + return &Event{ + Type: SYNCED, + Key: "", + Value: nil, + PrevValue: nil, + Rev: 0, + ETCDType: "", + } +} + +// parseKV converts a KeyValue retrieved from an initial sync() listing to a synthetic isCreated event. +func parseKV(kv *mvccpb.KeyValue, etcdType string) *Event { + return &Event{ + Type: PUT, + Key: string(kv.Key), + Value: kv.Value, + PrevValue: nil, + Rev: kv.ModRevision, + ETCDType: etcdType, + } +} + +func parseEvent(e *clientv3.Event, etcdType string) *Event { + eType := PUT + if e.Type == clientv3.EventTypeDelete { + eType = DELETE + } + ret := &Event{ + Type: eType, + Key: string(e.Kv.Key), + Value: e.Kv.Value, + Rev: e.Kv.ModRevision, + ETCDType: etcdType, + } + if e.PrevKv != nil { + ret.PrevValue = e.PrevKv.Value + } + return ret +} + +func parseHistoryEvent(e *clientv3.Event, etcdType string) *Event { + event := parseEvent(e, etcdType) + if event.Type == DELETE { + event.Type = HISTORYDELETE + } + if event.Type == PUT { + event.Type = HISTORYUPDATE + } + return event +} + +func parseErr(err error, source string) *Event { + return &Event{Type: ERROR, Value: []byte(err.Error()), ETCDType: source} +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/event_test.go b/yuanrong/pkg/common/faas_common/etcd3/event_test.go new file mode 100644 index 0000000..fd2ae79 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/event_test.go @@ -0,0 +1,73 @@ +package etcd3 + +import ( + "errors" + "reflect" + "testing" + + "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/api/v3/mvccpb" + "go.etcd.io/etcd/client/v3" +) + +func Test_syncedEvent(t *testing.T) { + convey.Convey("syncedEvent", t, func() { + event := syncedEvent() + convey.So(event.Type, convey.ShouldEqual, SYNCED) + }) +} + +func Test_parseKV(t *testing.T) { + convey.Convey("parseKV", t, func() { + kv := parseKV(&mvccpb.KeyValue{Key: []byte("key1"), Value: []byte("value1")}, Router) + convey.So(kv.Key, convey.ShouldEqual, "key1") + convey.So(string(kv.Value), convey.ShouldEqual, "value1") + convey.So(kv.Type, convey.ShouldEqual, PUT) + }) +} + +func Test_parseEvent(t *testing.T) { + convey.Convey("parseEvent", t, func() { + event := parseEvent(&clientv3.Event{ + Type: DELETE, + Kv: &mvccpb.KeyValue{Key: []byte("key1"), Value: []byte("value1")}, + PrevKv: &mvccpb.KeyValue{Key: []byte("key2"), Value: []byte("value2")}, + }, Router) + convey.So(event.Type, convey.ShouldEqual, DELETE) + convey.So(event.Key, convey.ShouldEqual, "key1") + convey.So(string(event.Value), convey.ShouldEqual, "value1") + convey.So(string(event.PrevValue), convey.ShouldEqual, "value2") + }) +} + +func Test_parseErr(t *testing.T) { + convey.Convey("parseErr", t, func() { + err := parseErr(errors.New("parseErr"), Router) + convey.So(err.Type, convey.ShouldEqual, ERROR) + convey.So(string(err.Value), convey.ShouldEqual, "parseErr") + }) +} + +func Test_parseHistoryEvent(t *testing.T) { + type args struct { + e *clientv3.Event + } + tests := []struct { + name string + args args + wantType int + }{ + {"case1", args{e: &clientv3.Event{ + Type: DELETE, + Kv: &mvccpb.KeyValue{Key: []byte("key1"), Value: []byte("value1")}, + PrevKv: &mvccpb.KeyValue{Key: []byte("key2"), Value: []byte("value2")}, + }}, HISTORYDELETE}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseHistoryEvent(tt.args.e, Router); !reflect.DeepEqual(got.Type, tt.wantType) { + t.Errorf("parseHistoryEvent() = %v, want %v", got.Type, tt.wantType) + } + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/instance_register.go b/yuanrong/pkg/common/faas_common/etcd3/instance_register.go new file mode 100644 index 0000000..d164e29 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/instance_register.go @@ -0,0 +1,147 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 - +package etcd3 + +import ( + "context" + "errors" + "time" + + "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + instanceEtcdKeyTTL = 30 + defaultRefreshInterval = 15 * time.Second +) + +var ( + refreshInterval = defaultRefreshInterval +) + +// EtcdRegister - register to specified ETCD +type EtcdRegister struct { + EtcdClient *EtcdClient + InstanceKey string + Value string + leaseID clientv3.LeaseID + StopCh <-chan struct{} +} + +// Register - register instance to meta etcd or router etcd +func (r *EtcdRegister) Register() error { + if r.EtcdClient != GetMetaEtcdClient() && r.EtcdClient != GetRouterEtcdClient() { + log.GetLogger().Errorf("etcdClient is not meta or route etcd") + return errors.New("etcdClient is not meta or route etcd") + } + var err error + err = r.putInstanceInfoToEtcd() + if err != nil { + log.GetLogger().Errorf("failed to register instance to %s etcd when start, error:%s", + r.EtcdClient.GetEtcdType(), err.Error()) + return err + } + go r.startRefreshLeaseJob() + return nil +} + +func (r *EtcdRegister) startRefreshLeaseJob() { + if r.StopCh == nil { + log.GetLogger().Errorf("StopCh is nil, lease in %s etcd will not be refreshed", + r.EtcdClient.GetEtcdType()) + return + } + refreshTicker := time.NewTicker(refreshInterval) + defer refreshTicker.Stop() + for { + select { + case <-refreshTicker.C: + r.refreshLease() + case <-r.StopCh: + log.GetLogger().Warnf("stopping refresh lease job") + refreshTicker.Stop() + r.stopLease() + return + } + } +} + +func (r *EtcdRegister) stopLease() { + revokeCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + err := r.EtcdClient.Revoke(revokeCtx, r.leaseID) + if err != nil { + log.GetLogger().Warnf("revoke lease in %s etcd failed, err:%s", + r.EtcdClient.GetEtcdType(), err.Error()) + } + ctx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + err = r.EtcdClient.Delete(ctx, r.InstanceKey) + if err != nil { + log.GetLogger().Errorf("delete key: %s,from %s etcd failed, err:%s", + r.InstanceKey, r.EtcdClient.GetEtcdType(), err.Error()) + } +} + +func (r *EtcdRegister) refreshLease() { + if !r.isKeyExist() { + if err := r.putInstanceInfoToEtcd(); err != nil { + return + } + } + keepAliveOnceCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + err := r.EtcdClient.KeepAliveOnce(keepAliveOnceCtx, r.leaseID) + if err != nil { + log.GetLogger().Errorf("unable to refresh lease in %s etcd:%s", + r.EtcdClient.GetEtcdType(), err.Error()) + } +} + +func (r *EtcdRegister) isKeyExist() bool { + ctx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + resp, err := r.EtcdClient.GetResponse(ctx, r.InstanceKey, + clientv3.WithKeysOnly(), clientv3.WithSerializable()) + if err != nil { + log.GetLogger().Errorf("failed to get new key:%s from %s etcd, err:%s", + r.InstanceKey, r.EtcdClient.GetEtcdType(), err.Error()) + return false + } + return len(resp.Kvs) > 0 +} + +func (r *EtcdRegister) putInstanceInfoToEtcd() error { + grantCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + id, err := r.EtcdClient.Grant(grantCtx, instanceEtcdKeyTTL) + if err != nil { + log.GetLogger().Errorf("failed to grant instance lease in %s etcd: %s", r.EtcdClient.GetEtcdType(), + err.Error()) + return err + } + + ctx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + err = r.EtcdClient.Put(ctx, r.InstanceKey, r.Value, clientv3.WithLease(id)) + if err != nil { + log.GetLogger().Errorf("unable to put new key:%s to %s etcd, err:%s", + r.InstanceKey, r.EtcdClient.GetEtcdType(), err.Error()) + return err + } + r.leaseID = id + log.GetLogger().Infof("register instance key:%s, value:%s to %s etcd successfully!", + r.InstanceKey, r.Value, r.EtcdClient.GetEtcdType()) + return nil +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/instance_register_test.go b/yuanrong/pkg/common/faas_common/etcd3/instance_register_test.go new file mode 100644 index 0000000..ec89d4b --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/instance_register_test.go @@ -0,0 +1,267 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 - +package etcd3 + +import ( + "context" + "errors" + "fmt" + "os" + "reflect" + "sync/atomic" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/api/v3/mvccpb" + "go.etcd.io/etcd/client/v3" +) + +type mockLease struct { +} + +func (m mockLease) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) { + return &clientv3.LeaseGrantResponse{ID: 1}, nil +} + +func (m mockLease) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) { + return nil, nil +} + +func (m mockLease) TimeToLive(ctx context.Context, id clientv3.LeaseID, opts ...clientv3.LeaseOption) (*clientv3.LeaseTimeToLiveResponse, error) { + return nil, nil +} + +func (m mockLease) Leases(ctx context.Context) (*clientv3.LeaseLeasesResponse, error) { + return nil, nil +} + +func (m mockLease) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) { + return nil, nil +} + +func (m mockLease) KeepAliveOnce(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseKeepAliveResponse, error) { + return nil, nil +} + +func (m mockLease) Close() error { + panic("implement me") +} + +type mockKV struct { + put uint32 + get uint32 + delete uint32 + do uint32 +} + +func (fk *mockKV) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + atomic.AddUint32(&fk.put, 1) + return &clientv3.PutResponse{}, nil +} + +func (fk *mockKV) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + atomic.AddUint32(&fk.get, 1) + return &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{{}}}, nil +} + +func (fk *mockKV) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + atomic.AddUint32(&fk.delete, 1) + return &clientv3.DeleteResponse{}, nil +} + +func (mockKV) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, error) { + return &clientv3.CompactResponse{}, nil +} + +func (fk *mockKV) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + atomic.AddUint32(&fk.do, 1) + return clientv3.OpResponse{}, nil +} + +func (mockKV) Txn(ctx context.Context) clientv3.Txn { + return nil +} + +func TestRegisterInstance_PutInstanceToEtcd(t *testing.T) { + convey.Convey("test put instance info", t, func() { + patch := gomonkey.ApplyFunc(GetMetaEtcdClient, func() *EtcdClient { + return &EtcdClient{Client: &clientv3.Client{KV: &mockKV{}, Lease: &mockLease{}}} + }) + + defer func() { + patch.Reset() + }() + register := &EtcdRegister{ + EtcdClient: GetMetaEtcdClient(), + InstanceKey: "/sn/frontend/instances/CLUSTER_ID/HOST_IP/POD_NAME", + Value: "active", + } + convey.Convey("lease id not exist", func() { + var keyInput string + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "Put", func(_ *EtcdClient, + ctxInfo EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + keyInput = key + return nil + }).Reset() + defer gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return key + }).Reset() + err := register.putInstanceInfoToEtcd() + convey.So(err, convey.ShouldBeNil) + convey.So(keyInput, convey.ShouldEqual, "/sn/frontend/instances/CLUSTER_ID/HOST_IP/POD_NAME") + }) + + convey.Convey("Grant error", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "Grant", func(_ *EtcdClient, + ctxInfo EtcdCtxInfo, ttl int64) (clientv3.LeaseID, error) { + return 111, fmt.Errorf("grant failed") + }).Reset() + defer gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return key + }).Reset() + err := register.putInstanceInfoToEtcd() + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("put error", func() { + var keyInput string + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "Put", func(_ *EtcdClient, + ctxInfo EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return fmt.Errorf("put failed") + }).Reset() + defer gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return key + }).Reset() + err := register.putInstanceInfoToEtcd() + convey.So(err, convey.ShouldNotBeNil) + convey.So(keyInput, convey.ShouldBeBlank) + convey.So(err.Error(), convey.ShouldEqual, "put failed") + }) + }) +} + +func Test_registerInstance(t *testing.T) { + etcdClient := &EtcdClient{Client: &clientv3.Client{Lease: &mockLease{}}} + patch := gomonkey.ApplyFunc(GetMetaEtcdClient, func() *EtcdClient { + return etcdClient + }) + patch.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "Put", func(_ *EtcdClient, + ctxInfo EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return errors.New("put etcd error") + }) + defer func() { + patch.Reset() + }() + + register := &EtcdRegister{ + EtcdClient: GetMetaEtcdClient(), + InstanceKey: "/sn/frontend/instances/CLUSTER_ID/HOST_IP/POD_NAME", + Value: "active", + } + err := register.Register() + assert.NotNil(t, err) +} + +func Test_isKeyExist(t *testing.T) { + convey.Convey("Test isKeyExist", t, func() { + patch := gomonkey.ApplyFunc(GetMetaEtcdClient, func() *EtcdClient { + return &EtcdClient{ + Client: &clientv3.Client{}, + } + }) + defer patch.Reset() + register := &EtcdRegister{ + EtcdClient: GetMetaEtcdClient(), + InstanceKey: "/sn/frontend/instances/CLUSTER_ID/HOST_IP/POD_NAME", + Value: "active", + } + convey.Convey("get etcd key return empty", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "GetResponse", func(_ *EtcdClient, + ctxInfo EtcdCtxInfo, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{}}, nil + }) + defer patch.Reset() + existed := register.isKeyExist() + convey.So(existed, convey.ShouldBeFalse) + }) + + convey.Convey("succeed to get etcd key", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "GetResponse", func(_ *EtcdClient, + ctxInfo EtcdCtxInfo, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{{Key: []byte("key")}}}, nil + }) + defer patch.Reset() + existed := register.isKeyExist() + convey.So(existed, convey.ShouldBeTrue) + }) + + convey.Convey("failed to get etcd key", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "GetResponse", func(_ *EtcdClient, + ctxInfo EtcdCtxInfo, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, errors.New("failed") + }) + defer patch.Reset() + existed := register.isKeyExist() + convey.So(existed, convey.ShouldBeFalse) + }) + }) +} + +func Test_startRefreshLeaseJob(t *testing.T) { + kv := &mockKV{} + patches := gomonkey.NewPatches() + patches.ApplyFunc((*EtcdClient).Put, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return nil + }) + patches.ApplyFunc((*clientv3.Client).Ctx, func(_ *clientv3.Client) context.Context { return context.TODO() }) + patches.ApplyFunc((*clientv3.Client).Close, func(_ *clientv3.Client) error { return nil }) + patches.ApplyFunc(clientv3.NewKV, func(c *clientv3.Client) clientv3.KV { + return kv + }) + patches.ApplyFunc(GetMetaEtcdClient, func() *EtcdClient { + return &EtcdClient{Client: &clientv3.Client{KV: kv, Lease: &mockLease{}}} + }) + defer func() { + patches.Reset() + }() + refreshInterval = 1 * time.Millisecond + + register := &EtcdRegister{ + EtcdClient: GetMetaEtcdClient(), + InstanceKey: "/sn/frontend/instances/CLUSTER_ID/HOST_IP/POD_NAME", + Value: "active", + } + + // stop chan is nil, will not trigger refresh + go register.startRefreshLeaseJob() + time.Sleep(10 * time.Millisecond) + assert.Equal(t, uint32(0), atomic.LoadUint32(&kv.get)) + assert.Equal(t, uint32(0), atomic.LoadUint32(&kv.put)) + assert.Equal(t, uint32(0), atomic.LoadUint32(&kv.do)) + + stopCh := make(chan struct{}) + register.StopCh = stopCh + go register.startRefreshLeaseJob() + time.Sleep(100 * time.Millisecond) + assert.NotEqual(t, uint32(0), atomic.LoadUint32(&kv.get)) + close(stopCh) + time.Sleep(1 * time.Second) +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/lease.go b/yuanrong/pkg/common/faas_common/etcd3/lease.go new file mode 100644 index 0000000..bac561d --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/lease.go @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "go.etcd.io/etcd/client/v3" +) + +// Grant - +func (e *EtcdClient) Grant(ctxInfo EtcdCtxInfo, ttl int64) (clientv3.LeaseID, error) { + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + e.rwMutex.RLock() + resp, err := e.Client.Grant(ctx, ttl) + e.rwMutex.RUnlock() + cancel() + if err != nil { + return 0, err + } + return resp.ID, nil +} + +// KeepAliveOnce - +func (e *EtcdClient) KeepAliveOnce(ctxInfo EtcdCtxInfo, leaseID clientv3.LeaseID) error { + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + e.rwMutex.RLock() + _, err := e.Client.KeepAliveOnce(ctx, leaseID) + e.rwMutex.RUnlock() + cancel() + return err +} + +// Revoke - +func (e *EtcdClient) Revoke(ctxInfo EtcdCtxInfo, leaseID clientv3.LeaseID) error { + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + e.rwMutex.RLock() + _, err := e.Client.Revoke(ctx, leaseID) + e.rwMutex.RUnlock() + cancel() + return err +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/lease_test.go b/yuanrong/pkg/common/faas_common/etcd3/lease_test.go new file mode 100644 index 0000000..cb948aa --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/lease_test.go @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "context" + "testing" + "time" + + "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/utils" +) + +// TestEtcdClient_Grant - +func TestEtcdClient_Grant(t *testing.T) { + convey.Convey("test: grant", t, func() { + client := &EtcdClient{ + Client: &clientv3.Client{ + Lease: utils.FakeEtcdLease{}, + }, + } + id, err := client.Grant(CreateEtcdCtxInfoWithTimeout(context.Background(), 100*time.Millisecond), 10) + convey.So(id, convey.ShouldEqual, 1) + convey.So(err, convey.ShouldBeNil) + }) +} + +// TestEtcdClient_KeepAliveOnce - +func TestEtcdClient_KeepAliveOnce(t *testing.T) { + convey.Convey("test: keepAliveOnce", t, func() { + client := &EtcdClient{ + Client: &clientv3.Client{ + Lease: utils.FakeEtcdLease{}, + }, + } + err := client.KeepAliveOnce(CreateEtcdCtxInfoWithTimeout(context.Background(), 100*time.Millisecond), 1) + convey.So(err, convey.ShouldBeNil) + }) +} + +// TestEtcdClient_KeepAliveOnce - +func TestEtcdClient_Revoke(t *testing.T) { + convey.Convey("test: revoke", t, func() { + client := &EtcdClient{ + Client: &clientv3.Client{ + Lease: utils.FakeEtcdLease{}, + }, + } + err := client.Revoke(CreateEtcdCtxInfoWithTimeout(context.Background(), 100*time.Millisecond), 1) + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/lock.go b/yuanrong/pkg/common/faas_common/etcd3/lock.go new file mode 100644 index 0000000..dcd1adc --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/lock.go @@ -0,0 +1,282 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +const ( + defaultRequestTimeout = 30 * time.Second + refreshAheadTime = 1 * time.Second + lockedKeyHoldIndex = 1 +) + +var ( + // ErrEtcdResponseInvalid - + ErrEtcdResponseInvalid = errors.New("etcd response is invalid") + // ErrNoKeyCanBeFound - + ErrNoKeyCanBeFound = errors.New("no etcd key can be found") + // ErrNoKeyCanBeLocked - + ErrNoKeyCanBeLocked = errors.New("no etcd key can be locked") + lockFailCountLimit = 10 +) + +// EtcdLocker - +type EtcdLocker struct { + EtcdClient *EtcdClient + acquiredLock *concurrency.Mutex + LockedKey string + holderKey string + LeaseTTL int + leaseID clientv3.LeaseID + locked atomic.Uint32 + LockCallback func(locker *EtcdLocker) error + UnlockCallback func(locker *EtcdLocker) error + FailCallback func() + unlockCh chan struct{} + StopCh <-chan struct{} +} + +// GetLockedKey - +func (l *EtcdLocker) GetLockedKey() string { + return l.LockedKey +} + +// TryLockWithPrefix will get all identities(instanceID) distributed from control plane and try to lock one +func (l *EtcdLocker) TryLockWithPrefix(prefix string, filter func(k, v []byte) bool) error { + resp, err := l.EtcdClient.Get(CreateEtcdCtxInfoWithTimeout(context.TODO(), defaultRequestTimeout), prefix, + clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("failed to get prefix %s from etcd error %s", prefix, err.Error()) + return err + } + if len(resp.Kvs) == 0 { + log.GetLogger().Warnf("no etcd key is found for prefix %s", prefix) + return ErrNoKeyCanBeLocked + } + var ( + locked bool + tryLockErr error + ) + for _, kv := range resp.Kvs { + if filter(kv.Key, kv.Value) { + tryLockErr = ErrNoKeyCanBeLocked + continue + } + tryLockErr = l.TryLock(string(kv.Key)) + if tryLockErr == nil { + locked = true + break + } + } + if !locked { + if tryLockErr != nil { + return tryLockErr + } else { + return ErrNoKeyCanBeLocked + } + } + return nil +} + +// TryLock - +func (l *EtcdLocker) TryLock(key string) error { + if err := l.tryLock(key); err != nil { + return err + } + go l.lockKeeperLoop() + return nil +} + +func (l *EtcdLocker) tryLock(key string) error { + grtCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + if l.leaseID == clientv3.NoLease { + leaseID, err := l.EtcdClient.Grant(grtCtx, int64(l.LeaseTTL)) + if err != nil { + log.GetLogger().Errorf("failed to grant lease for key in %s etcd error %s", l.EtcdClient.GetEtcdType(), + err.Error()) + return err + } + l.leaseID = leaseID + } + l.holderKey = fmt.Sprintf("%s/%x", key, l.leaseID) + log.GetLogger().Infof("generate holderKey %s", l.holderKey) + var lockErr error + defer func() { + if lockErr != nil { + log.GetLogger().Errorf("failed to lock key %s, delete holder key %s", key, l.holderKey) + rvkCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + if err := l.EtcdClient.Revoke(rvkCtx, l.leaseID); err != nil { + log.GetLogger().Errorf("failed to revoke lease %d error %d", l.leaseID, err.Error()) + } + l.leaseID = clientv3.NoLease + delCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + if err := l.EtcdClient.Delete(delCtx, l.holderKey); err != nil { + log.GetLogger().Errorf("failed to delete holder key %s error %d", l.holderKey, err.Error()) + } + } + }() + cmp := clientv3.Compare(clientv3.LeaseValue(l.holderKey), "=", clientv3.NoLease) + put := clientv3.OpPut(l.holderKey, "", clientv3.WithLease(l.leaseID)) + get := clientv3.OpGet(l.holderKey) + // key is already been put, we want to get the minimum holder key so use WithLimit(2) + getKeyHolder := clientv3.OpGet(key, []clientv3.OpOption{clientv3.WithPrefix(), clientv3.WithSort( + clientv3.SortByCreateRevision, clientv3.SortAscend), clientv3.WithLimit(lockedKeyHoldIndex + 1)}...) + txnCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + var resp *clientv3.TxnResponse + resp, lockErr = l.EtcdClient.Client.Txn(txnCtx.Ctx).If(cmp).Then(put, getKeyHolder).Else(get, getKeyHolder).Commit() + if lockErr != nil { + log.GetLogger().Errorf("failed to lock key %s, transaction error %s", key, lockErr.Error()) + return lockErr + } + if len(resp.Responses) != lockedKeyHoldIndex+1 { + log.GetLogger().Errorf("failed to lock key %s, transaction response size %s is invalid", key, + len(resp.Responses)) + lockErr = ErrEtcdResponseInvalid + return lockErr + } + var myRevision int64 + if resp.Succeeded { + myRevision = resp.Header.Revision + } else { + if len(resp.Responses[0].GetResponseRange().Kvs) == 0 { + log.GetLogger().Errorf("failed to lock key %s, transaction response[0] kvs size is 0", key) + lockErr = ErrEtcdResponseInvalid + return lockErr + } + myRevision = resp.Responses[0].GetResponseRange().Kvs[0].CreateRevision + } + log.GetLogger().Infof("get holderKey %s my revision %d", l.holderKey, myRevision) + // resp.Responses[1] contains info got from getKeyHolder, ideally looks like [originKey, holderKey] after sorting, + // because originKey is put by control plane and has lower revision than any holderKey attached with a lease + holderKvs := resp.Responses[1].GetResponseRange().Kvs + // holderKvs[0] is not the originKey means originKey is deleted + if len(holderKvs) == 0 || string(holderKvs[0].Key) != key { + log.GetLogger().Warnf("failed to find key %s, key may be deleted", l.holderKey) + lockErr = ErrNoKeyCanBeFound + return lockErr + } + // holderKvs[1] has different revision from myRevision means other one has locked this key before me + if len(holderKvs) > 1 && holderKvs[1].CreateRevision != myRevision { + log.GetLogger().Warnf("failed to lock key %s, key already locked, holder revision %d", l.holderKey, + holderKvs[1].CreateRevision) + lockErr = ErrNoKeyCanBeLocked + return lockErr + } + l.LockedKey = key + l.unlockCh = make(chan struct{}) + if l.LockCallback != nil { + if lockErr = l.LockCallback(l); lockErr != nil { + log.GetLogger().Warnf("failed to process lock callback of %s error %s", key, lockErr.Error()) + return lockErr + } + } + log.GetLogger().Infof("succeed to lock key %s", key) + return nil +} + +// Unlock - +func (l *EtcdLocker) Unlock() error { + delCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + err := l.EtcdClient.Delete(delCtx, l.holderKey) + if err != nil { + log.GetLogger().Errorf("failed to unlock key %s , delete holder %s error %s", l.LockedKey, l.holderKey, + err.Error()) + } + if l.UnlockCallback != nil { + if err = l.UnlockCallback(l); err != nil { + log.GetLogger().Errorf("failed to process unlock callback of %s error %s", l.LockedKey, err.Error()) + } + } + l.LockedKey = "" + l.holderKey = "" + l.leaseID = clientv3.NoLease + utils.SafeCloseChannel(l.unlockCh) + return err +} + +func (l *EtcdLocker) lockKeeperLoop() { + leaseTicker := time.NewTicker(time.Duration(l.LeaseTTL)*time.Second - refreshAheadTime) + defer leaseTicker.Stop() + failCount := 0 + for { + select { + case _, ok := <-l.unlockCh: + if !ok { + log.GetLogger().Warnf("unlock channel triggers for etcd lock of key %s", l.LockedKey) + } + return + case _, ok := <-l.StopCh: + if !ok { + log.GetLogger().Warnf("stop channel triggers for etcd lock of key %s", l.LockedKey) + } + l.Unlock() + return + case <-leaseTicker.C: + if l.leaseID == clientv3.NoLease { + // wait for multiple leaseTTL time to make sure lease is expired at server side + time.Sleep(time.Duration(l.LeaseTTL) * time.Second) + if err := l.tryLock(l.LockedKey); err == ErrNoKeyCanBeFound || err == ErrNoKeyCanBeLocked { + log.GetLogger().Errorf("cannot keep lock key %s, lock fail count %d error %s", l.LockedKey, + failCount, err) + if failCount >= lockFailCountLimit { + l.FailCallback() + return + } + failCount++ + } else { + failCount = 0 + } + } else { + getCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + resp, err := l.EtcdClient.Get(getCtx, l.LockedKey) + if err != nil { + log.GetLogger().Errorf("unable to get locked key %s in %s etcd error %s", l.LockedKey, + l.EtcdClient.GetEtcdType(), err.Error()) + l.leaseID = clientv3.NoLease + continue + } + if len(resp.Kvs) == 0 { + log.GetLogger().Warnf("locked key %s is deleted in %s etcd unlock now", l.LockedKey, + l.EtcdClient.GetEtcdType()) + l.Unlock() + l.FailCallback() + return + } + keepAliveOnceCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + err = l.EtcdClient.KeepAliveOnce(keepAliveOnceCtx, l.leaseID) + if err != nil { + log.GetLogger().Errorf("unable to refresh lease in %s etcd error %s", l.EtcdClient.GetEtcdType(), + err.Error()) + l.leaseID = clientv3.NoLease + } + } + } + } +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/lock_test.go b/yuanrong/pkg/common/faas_common/etcd3/lock_test.go new file mode 100644 index 0000000..df2ea05 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/lock_test.go @@ -0,0 +1,341 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 implements crud and watch operations based etcd clientv3 +package etcd3 + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/api/v3/etcdserverpb" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" +) + +type fakeTxn struct { + response *clientv3.TxnResponse + err error +} + +func (t *fakeTxn) If(cs ...clientv3.Cmp) clientv3.Txn { + return t +} + +func (t *fakeTxn) Then(ops ...clientv3.Op) clientv3.Txn { + return t +} + +func (t *fakeTxn) Else(ops ...clientv3.Op) clientv3.Txn { + return t +} + +func (t *fakeTxn) Commit() (*clientv3.TxnResponse, error) { + return t.response, t.err +} + +type fakeKv struct { + txn *fakeTxn +} + +func (k *fakeKv) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + return nil, nil +} + +func (k *fakeKv) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, nil +} + +func (k *fakeKv) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + return nil, nil +} + +func (k *fakeKv) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, error) { + return nil, nil +} + +func (k *fakeKv) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + return clientv3.OpResponse{}, nil +} + +func (k *fakeKv) Txn(ctx context.Context) clientv3.Txn { + return k.txn +} + +func buildTxnResponse(success bool, revision int64, kvs1, kvs2 []*mvccpb.KeyValue) *clientv3.TxnResponse { + responses := []*etcdserverpb.ResponseOp{} + if kvs1 != nil { + responses = append(responses, &etcdserverpb.ResponseOp{ + Response: &etcdserverpb.ResponseOp_ResponseRange{ + ResponseRange: &etcdserverpb.RangeResponse{ + Kvs: kvs1, + }, + }, + }) + } + if kvs2 != nil { + responses = append(responses, &etcdserverpb.ResponseOp{ + Response: &etcdserverpb.ResponseOp_ResponseRange{ + ResponseRange: &etcdserverpb.RangeResponse{ + Kvs: kvs2, + }, + }, + }) + } + return &clientv3.TxnResponse{ + Succeeded: success, + Header: &etcdserverpb.ResponseHeader{Revision: revision}, + Responses: responses, + } +} + +func TestTryLock(t *testing.T) { + convey.Convey("test TryLock", t, func() { + ft := &fakeTxn{} + stopCh := make(chan struct{}) + lock := &EtcdLocker{EtcdClient: &EtcdClient{Client: &clientv3.Client{KV: &fakeKv{txn: ft}}}, LeaseTTL: 10, + StopCh: stopCh} + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*EtcdClient).Grant, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, ttl int64) (clientv3.LeaseID, + error) { + return 123, nil + }), + gomonkey.ApplyFunc((*EtcdClient).Revoke, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, leaseID clientv3.LeaseID) error { + return nil + }), + gomonkey.ApplyFunc((*EtcdClient).Delete, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) error { + return nil + }), + } + defer func() { + close(stopCh) + time.Sleep(100 * time.Millisecond) + for _, p := range patches { + p.Reset() + } + }() + convey.Convey("got and locked", func() { + patch1 := gomonkey.ApplyFunc((*EtcdClient).Get, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{{Key: []byte("/test/key1")}}}, nil + }) + defer patch1.Reset() + ft.response = buildTxnResponse(true, 123, []*mvccpb.KeyValue{}, []*mvccpb.KeyValue{ + { + Key: []byte("/test/key1"), + CreateRevision: 100, + }, + }) + ft.err = nil + err := lock.TryLockWithPrefix("/test", func(k, v []byte) bool { return false }) + convey.So(err, convey.ShouldBeNil) + key := lock.GetLockedKey() + convey.So(key, convey.ShouldEqual, "/test/key1") + }) + convey.Convey("got error", func() { + patch1 := gomonkey.ApplyFunc((*EtcdClient).Get, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, errors.New("some error") + }) + defer patch1.Reset() + err := lock.TryLockWithPrefix("/test", func(k, v []byte) bool { return false }) + convey.So(err.Error(), convey.ShouldEqual, "some error") + }) + convey.Convey("lock key lost", func() { + patch1 := gomonkey.ApplyFunc((*EtcdClient).Get, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{{Key: []byte("/test/key1")}}}, nil + }) + defer patch1.Reset() + ft.response = buildTxnResponse(true, 123, []*mvccpb.KeyValue{}, nil) + ft.err = nil + err := lock.TryLockWithPrefix("/test", func(k, v []byte) bool { return false }) + convey.So(err, convey.ShouldNotBeNil) + ft.response = buildTxnResponse(true, 123, []*mvccpb.KeyValue{}, []*mvccpb.KeyValue{ + { + Key: []byte("/test/key1/123"), + CreateRevision: 100, + }, + }) + ft.err = nil + err = lock.TryLockWithPrefix("/test", func(k, v []byte) bool { return false }) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("lock key locked by others", func() { + patch1 := gomonkey.ApplyFunc((*EtcdClient).Get, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{{Key: []byte("/test/key1")}}}, nil + }) + defer patch1.Reset() + ft.response = buildTxnResponse(true, 123, []*mvccpb.KeyValue{}, []*mvccpb.KeyValue{ + { + Key: []byte("/test/key1"), + CreateRevision: 100, + }, + { + Key: []byte("/test/key1/xxx"), + CreateRevision: 101, + }, + }) + ft.err = nil + err := lock.TryLockWithPrefix("/test", func(k, v []byte) bool { return false }) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("lock callback", func() { + patch1 := gomonkey.ApplyFunc((*EtcdClient).Get, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{{Key: []byte("/test/key1")}}}, nil + }) + defer patch1.Reset() + ft.response = buildTxnResponse(true, 123, []*mvccpb.KeyValue{}, []*mvccpb.KeyValue{ + { + Key: []byte("/test/key1"), + CreateRevision: 100, + }, + }) + ft.err = nil + lock.LockCallback = func(l *EtcdLocker) error { return errors.New("some error") } + err := lock.TryLockWithPrefix("/test", func(k, v []byte) bool { return false }) + convey.So(err.Error(), convey.ShouldEqual, "some error") + }) + }) +} + +func TestUnlock(t *testing.T) { + stopCh := make(chan struct{}) + lock := &EtcdLocker{EtcdClient: &EtcdClient{}, StopCh: stopCh} + convey.Convey("test Unlock", t, func() { + convey.Convey("unlock ok", func() { + patch := gomonkey.ApplyFunc((*EtcdClient).Delete, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) error { + return nil + }) + defer patch.Reset() + err := lock.Unlock() + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("unlock error", func() { + patch := gomonkey.ApplyFunc((*EtcdClient).Delete, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) error { + return errors.New("some error") + }) + defer patch.Reset() + err := lock.Unlock() + convey.So(err.Error(), convey.ShouldEqual, "some error") + }) + convey.Convey("unlock callback", func() { + patch := gomonkey.ApplyFunc((*EtcdClient).Delete, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) error { + return nil + }) + defer patch.Reset() + lock.UnlockCallback = func(l *EtcdLocker) error { return errors.New("some error") } + err := lock.Unlock() + convey.So(err.Error(), convey.ShouldEqual, "some error") + }) + }) +} + +func TestLockKeeperLoop(t *testing.T) { + convey.Convey("test lockKeeperLoop", t, func() { + stopCh := make(chan struct{}) + lock := &EtcdLocker{EtcdClient: &EtcdClient{}, LeaseTTL: 0, StopCh: stopCh} + getResp := &clientv3.GetResponse{} + getErr := error(nil) + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*EtcdClient).Get, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return getResp, getErr + }), + gomonkey.ApplyFunc((*EtcdLocker).Unlock, func(_ *EtcdLocker) error { + return nil + }), + gomonkey.ApplyGlobalVar(&lockFailCountLimit, 0), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + convey.Convey("ticker case 1", func() { + ticker := time.NewTicker(100 * time.Millisecond) + patch1 := gomonkey.ApplyFunc(time.NewTicker, func(d time.Duration) *time.Ticker { + return ticker + }) + defer patch1.Reset() + patch2 := gomonkey.ApplyFunc((*EtcdLocker).tryLock, func(_ *EtcdLocker, key string) error { + return ErrNoKeyCanBeFound + }) + defer patch2.Reset() + getErr = errors.New("get key error") + lock.leaseID = 123 + called := false + lock.FailCallback = func() { called = true } + lock.lockKeeperLoop() + convey.So(called, convey.ShouldBeTrue) + }) + convey.Convey("ticker case 2", func() { + ticker := time.NewTicker(100 * time.Millisecond) + patch1 := gomonkey.ApplyFunc(time.NewTicker, func(d time.Duration) *time.Ticker { + return ticker + }) + defer patch1.Reset() + patch2 := gomonkey.ApplyFunc((*EtcdClient).KeepAliveOnce, func(_ *EtcdClient, ctxInfo EtcdCtxInfo, + leaseID clientv3.LeaseID) error { + return errors.New("context deadline exceeded") + }) + defer patch2.Reset() + patch3 := gomonkey.ApplyFunc((*EtcdLocker).tryLock, func(_ *EtcdLocker, key string) error { + return ErrNoKeyCanBeFound + }) + defer patch3.Reset() + getErr = nil + getResp.Kvs = []*mvccpb.KeyValue{{}} + lock.leaseID = 123 + called := false + lock.FailCallback = func() { called = true } + lock.lockKeeperLoop() + convey.So(called, convey.ShouldBeTrue) + }) + convey.Convey("other case", func() { + called := false + patch1 := gomonkey.ApplyFunc((*EtcdLocker).Unlock, func(_ *EtcdLocker) error { + called = true + return nil + }) + defer patch1.Reset() + ticker := time.NewTicker(100 * time.Millisecond) + patch2 := gomonkey.ApplyFunc(time.NewTicker, func(d time.Duration) *time.Ticker { + return ticker + }) + defer patch2.Reset() + unlockCh := make(chan struct{}) + lock.unlockCh = unlockCh + close(unlockCh) + lock.lockKeeperLoop() + unlockCh = make(chan struct{}) + lock.unlockCh = unlockCh + close(stopCh) + lock.lockKeeperLoop() + convey.So(called, convey.ShouldBeTrue) + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/type.go b/yuanrong/pkg/common/faas_common/etcd3/type.go new file mode 100644 index 0000000..b54b67c --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/type.go @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 type +package etcd3 + +import ( + "sync" + "time" + + "go.etcd.io/etcd/client/v3" +) + +// EtcdWatcherFilter defines watch filter of etcd +type EtcdWatcherFilter func(*Event) bool + +// EtcdWatcherHandler defines watch handler of etcd +type EtcdWatcherHandler func(*Event) + +// EtcdClient wrapper etcd client +type EtcdClient struct { + Client *clientv3.Client + config *EtcdConfig + etcdTimer *time.Timer + rwMutex sync.RWMutex + cond *sync.Cond + // notify goroutine keepConnAlive exit + stopCh <-chan struct{} + clientExitCh chan struct{} + exitOnce sync.Once + etcdType string + // router etcd status lost contact after defaultEtcdLostContactTime, true is healthy, false is unhealthy + etcdStatusAfterLostContact bool + etcdStatusNow bool + isAlarmEnable bool + abnormalContinuouslyTimes int +} + +// EtcdWatcher - +type EtcdWatcher struct { + filter EtcdWatcherFilter + handler EtcdWatcherHandler + cacheConfig EtcdCacheConfig + watcher *EtcdClient + ResultChan chan *Event + CacheChan chan *Event + resultChanWG *sync.WaitGroup + configCh chan struct{} + stopCh <-chan struct{} + key string + etcdType string + initialRev int64 + historyRev int64 + cacheFlushing bool + sync.Mutex +} + +// EtcdInitParam - +type EtcdInitParam struct { + metaEtcdConfig *EtcdConfig + routeEtcdConfig *EtcdConfig + CAEMetaEtcdConfig *EtcdConfig + DataSystemEtcdConfig *EtcdConfig + stopCh <-chan struct{} + enableAlarm bool +} + +// EtcdConfig the info to get function instance +type EtcdConfig struct { + Servers []string `json:"servers" valid:"optional"` + AZPrefix string `json:"azPrefix" valid:"optional"` + User string `json:"user" valid:"optional"` + Password string `json:"password" valid:"optional"` + SslEnable bool `json:"sslEnable" valid:"optional"` + AuthType string `json:"authType" valid:"optional"` + UseSecret bool `json:"useSecret" valid:"optional"` + SecretName string `json:"secretName" valid:"optional"` + LimitRate int `json:"limitRate,omitempty" valid:"optional"` + LimitBurst int `json:"limitBurst,omitempty" valid:"optional"` + LimitTimeout int `json:"limitTimeout,omitempty" valid:"optional"` + CaFile string `json:"cafile,omitempty" valid:",optional"` + CertFile string `json:"certfile,omitempty" valid:",optional"` + KeyFile string `json:"keyfile,omitempty" valid:",optional"` + PassphraseFile string `json:"passphraseFile,omitempty" valid:",optional"` +} + +// EtcdCacheConfig - +type EtcdCacheConfig struct { + EnableCache bool `json:"enableCache"` + PersistPath string `json:"persistPath"` + FlushInterval int `json:"flushInterval"` + FlushThreshold int `json:"flushThreshold"` + MetaFilePath string + DataFilePath string + BackupFilePath string +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/utils.go b/yuanrong/pkg/common/faas_common/etcd3/utils.go new file mode 100644 index 0000000..fc55e67 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/utils.go @@ -0,0 +1,170 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 - +package etcd3 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + "sync" + "time" + + "go.etcd.io/etcd/client/v3" + "k8s.io/api/core/v1" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +const ( + // etcdDialTimeout is the timeout for establishing a connection. + etcdDialTimeout = 20 * time.Second + + // etcdKeepaliveTime is the time after which client pings the server to see if + etcdKeepaliveTime = 30 * time.Second + + // etcdKeepaliveTimeout is the time that the client waits for a response for the + etcdKeepaliveTimeout = 10 * time.Second + + etcdClientCerts = "etcd-client-certs" + + etcdCertsMountPath = "/home/snuser/resource/etcd" + + etcdCaFile = "/home/snuser/resource/etcd/ca.crt" + + etcdCertFile = "/home/snuser/resource/etcd/client.crt" + + etcdKeyFile = "/home/snuser/resource/etcd/client.key" + + etcdPassphraseFile = "/home/snuser/resource/etcd/passphrase" +) + +const ( + retrySleepTime = 100 * time.Millisecond + maxRetryTime = 3 +) + +var ( + etcdClientMap sync.Map +) + +// GetEtcdConfigKey generates key for etcd config +func GetEtcdConfigKey(etcdConfig *EtcdConfig) string { + sort.Strings(etcdConfig.Servers) + return strings.Join(etcdConfig.Servers, "#") +} + +func createETCDClient(config *EtcdConfig) (*clientv3.Client, error) { + cfg, err := GetEtcdAuthType(*config).GetEtcdConfig() + if err != nil { + log.GetLogger().Errorf("failed to create shared etcd client error %s", err.Error()) + return nil, err + } + cfg.DialTimeout = etcdDialTimeout + cfg.DialKeepAliveTime = etcdKeepaliveTime + cfg.DialKeepAliveTimeout = etcdKeepaliveTimeout + cfg.Endpoints = config.Servers + etcdClient, err := clientv3.New(*cfg) + if err != nil { + log.GetLogger().Errorf("failed to create shared etcd client error %s", err.Error()) + return nil, err + } + return etcdClient, nil +} + +// GetSharedEtcdClient returns a shared etcd client +func GetSharedEtcdClient(etcdConfig *EtcdConfig) (*clientv3.Client, error) { + etcdConfigKey := GetEtcdConfigKey(etcdConfig) + obj, exist := etcdClientMap.Load(etcdConfigKey) + var err error + if !exist { + if obj, err = createETCDClient(etcdConfig); err != nil { + return nil, err + } + } + etcdClient, ok := obj.(*clientv3.Client) + if !ok { + return nil, errors.New("etcd client type error") + } + etcdClientMap.Store(etcdConfigKey, etcdClient) + return etcdClient, nil +} + +// GetValueFromEtcdWithRetry query value from etcd and retry only in case of timeout +func GetValueFromEtcdWithRetry(key string, etcdClient *EtcdClient) ([]byte, error) { + if etcdClient.GetEtcdStatusLostContact() == false || etcdClient.Client == nil { + return nil, errors.New("etcd connection loss") + } + var ( + values []string + err error + ) + for i := 1; i <= maxRetryTime; i++ { + defaultEtcdCtx := CreateEtcdCtxInfoWithTimeout(context.Background(), DurationContextTimeout) + values, err = etcdClient.GetValues(defaultEtcdCtx, key) + if err == nil { + break + } + if err != context.DeadlineExceeded { + return nil, err + } + log.GetLogger().Errorf("get value from etcd with key %s timeout, try time %d", key, i) + time.Sleep(retrySleepTime) + } + + if len(values) == 0 { + log.GetLogger().Errorf("failed to get value from etcd, key: %s", key) + return nil, fmt.Errorf("the value got from etcd is empty") + } + + return []byte(values[0]), err +} + +// GenerateETCDClientCertsVolumesAndMounts - +func GenerateETCDClientCertsVolumesAndMounts(secretName string, builder *utils.VolumeBuilder) (string, string, error) { + if builder == nil { + return "", "", fmt.Errorf("etcd volume builder is nil") + } + builder.AddVolume(v1.Volume{Name: etcdClientCerts, + VolumeSource: v1.VolumeSource{Secret: &v1.SecretVolumeSource{SecretName: secretName}}}) + builder.AddVolumeMount(utils.ContainerRuntimeManager, + v1.VolumeMount{Name: etcdClientCerts, MountPath: etcdCertsMountPath}) + volumesData, err := json.Marshal(builder.Volumes) + if err != nil { + return "", "", err + } + volumesMountData, err := json.Marshal(builder.Mounts[utils.ContainerRuntimeManager]) + if err != nil { + return "", "", err + } + return string(volumesData), string(volumesMountData), nil +} + +// SetETCDTLSConfig - +func SetETCDTLSConfig(etcdConfig *EtcdConfig) { + if etcdConfig == nil { + return + } + etcdConfig.CaFile = etcdCaFile + etcdConfig.CertFile = etcdCertFile + etcdConfig.KeyFile = etcdKeyFile + etcdConfig.PassphraseFile = etcdPassphraseFile +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/utils_test.go b/yuanrong/pkg/common/faas_common/etcd3/utils_test.go new file mode 100644 index 0000000..d41fedd --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/utils_test.go @@ -0,0 +1,155 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 - +package etcd3 + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/client/v3" + "k8s.io/api/core/v1" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/utils" +) + +func TestGetSharedEtcdClient(t *testing.T) { + etcdConfig123 := &EtcdConfig{ + Servers: []string{"1", "2", "3"}, + } + convey.Convey("get client failed", t, func() { + defer gomonkey.ApplyFunc(clientv3.New, func(cfg clientv3.Config) (*clientv3.Client, error) { + return nil, errors.New("some error") + }).Reset() + _, err := GetSharedEtcdClient(etcdConfig123) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("get client success", t, func() { + defer gomonkey.ApplyFunc(clientv3.New, func(cfg clientv3.Config) (*clientv3.Client, error) { + return &clientv3.Client{}, nil + }).Reset() + _, err := GetSharedEtcdClient(etcdConfig123) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("load client", t, func() { + _, err := GetSharedEtcdClient(etcdConfig123) + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestGetValueFromEtcdWithRetry(t *testing.T) { + funcKey := "123/testFunc/1" + tenantID, funcName, funcVersion := utils.ParseFuncKey(funcKey) + silentEtcdKey := fmt.Sprintf(constant.SilentFuncKey, tenantID, funcName, funcVersion) + convey.Convey("Test GetValueFromEtcdWithRetry", t, func() { + convey.Convey("etcd connection loss", func() { + etcdClient := &EtcdClient{} + _, err := GetValueFromEtcdWithRetry(silentEtcdKey, etcdClient) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("get values error", func() { + etcdClient := &EtcdClient{ + etcdStatusAfterLostContact: true, + Client: &clientv3.Client{}, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "GetValues", + func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, opts ...clientv3.OpOption) ([]string, error) { + return nil, errors.New("error") + }).Reset() + _, err := GetValueFromEtcdWithRetry(silentEtcdKey, etcdClient) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("value got from etcd is empty", func() { + etcdClient := &EtcdClient{ + etcdStatusAfterLostContact: true, + Client: &clientv3.Client{}, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "GetValues", + func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, opts ...clientv3.OpOption) ([]string, error) { + return []string{}, nil + }).Reset() + _, err := GetValueFromEtcdWithRetry(silentEtcdKey, etcdClient) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("fetch success", func() { + etcdClient := &EtcdClient{ + etcdStatusAfterLostContact: true, + Client: &clientv3.Client{}, + } + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdClient{}), "GetValues", + func(_ *EtcdClient, ctxInfo EtcdCtxInfo, key string, opts ...clientv3.OpOption) ([]string, error) { + return []string{"silent func"}, nil + }).Reset() + value, err := GetValueFromEtcdWithRetry(silentEtcdKey, etcdClient) + convey.So(err, convey.ShouldBeNil) + convey.So(string(value), convey.ShouldEqual, "silent func") + }) + }) +} + +func TestGenerateETCDClientCertsVolumesAndMounts(t *testing.T) { + t.Run("builder is nil", func(t *testing.T) { + volumesData, volumesMountData, err := GenerateETCDClientCertsVolumesAndMounts("test-secret", nil) + assert.Empty(t, volumesData) + assert.Empty(t, volumesMountData) + assert.EqualError(t, err, "etcd volume builder is nil") + }) + + t.Run("normal case", func(t *testing.T) { + builder := utils.NewVolumeBuilder() + + secretName := "test-secret" + volumesData, volumesMountData, err := GenerateETCDClientCertsVolumesAndMounts(secretName, builder) + assert.NoError(t, err) + + var volumes []v1.Volume + err = json.Unmarshal([]byte(volumesData), &volumes) + assert.NoError(t, err) + assert.Len(t, volumes, 1) + assert.Equal(t, etcdClientCerts, volumes[0].Name) + assert.Equal(t, secretName, volumes[0].VolumeSource.Secret.SecretName) + + var volumeMounts []v1.VolumeMount + err = json.Unmarshal([]byte(volumesMountData), &volumeMounts) + assert.NoError(t, err) + assert.Len(t, volumeMounts, 1) + assert.Equal(t, etcdClientCerts, volumeMounts[0].Name) + assert.Equal(t, etcdCertsMountPath, volumeMounts[0].MountPath) + }) +} + +func TestSetETCDTLSConfig(t *testing.T) { + t.Run("etcdConfig", func(t *testing.T) { + SetETCDTLSConfig(nil) + + etcdConfig := &EtcdConfig{} + + SetETCDTLSConfig(etcdConfig) + + assert.Equal(t, etcdCaFile, etcdConfig.CaFile, "CaFile should be set correctly") + assert.Equal(t, etcdCertFile, etcdConfig.CertFile, "CertFile should be set correctly") + assert.Equal(t, etcdKeyFile, etcdConfig.KeyFile, "KeyFile should be set correctly") + assert.Equal(t, etcdPassphraseFile, etcdConfig.PassphraseFile, "PassphraseFile should be set correctly") + }) +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/watcher.go b/yuanrong/pkg/common/faas_common/etcd3/watcher.go new file mode 100644 index 0000000..4a6a57e --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/watcher.go @@ -0,0 +1,357 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcd3 - +package etcd3 + +import ( + "context" + "fmt" + "sync" + "time" + + "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + defaultEventChanSize = 1000 + // DurationContextTimeout default context duration timeout + DurationContextTimeout = 5 * time.Second +) + +var ( + // keepConnAliveTTL - + keepConnAliveTTL = 10 * time.Second +) + +// EtcdCtxInfo etcd context info +type EtcdCtxInfo struct { + Ctx context.Context + Cancel context.CancelFunc +} + +// Watcher defines watcher of registry +type Watcher interface { + StartWatch() + StartList() + EtcdHistory(revision int64) +} + +// EtcdClientInterface is the interface of ETCD client +type EtcdClientInterface interface { + GetResponse(ctxInfo EtcdCtxInfo, etcdKey string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) + Put(ctxInfo EtcdCtxInfo, etcdKey string, value string, opts ...clientv3.OpOption) error + Delete(ctxInfo EtcdCtxInfo, etcdKey string, opts ...clientv3.OpOption) error +} + +// NewEtcdWatcher create a EtcdWatcher object +func NewEtcdWatcher(prefix string, filter EtcdWatcherFilter, handler EtcdWatcherHandler, stopCh <-chan struct{}, + etcdClient *EtcdClient) *EtcdWatcher { + ew := &EtcdWatcher{ + watcher: etcdClient, + ResultChan: make(chan *Event, defaultEventChanSize), + CacheChan: make(chan *Event, defaultEventChanSize), + filter: filter, + handler: handler, + key: etcdClient.AttachAZPrefix(prefix), + resultChanWG: &sync.WaitGroup{}, + configCh: make(chan struct{}, 1), + stopCh: stopCh, + } + if etcdClient != nil { + ew.etcdType = etcdClient.GetEtcdType() + } + ew.resultChanWG.Add(1) + go ew.processEventLoop() + return ew +} + +// etcdList get current events in etcd and handle these events +func (ew *EtcdWatcher) etcdList(handler func(*clientv3.GetResponse)) error { + opts := []clientv3.OpOption{clientv3.WithPrefix()} + response, err := ew.watcher.Client.KV.Get(context.TODO(), ew.key, opts...) + if err != nil { + log.GetLogger().Errorf("failed to get value from etcd, key: %s, err: %s", ew.key, err.Error()) + return err + } + ew.initialRev = response.Header.Revision + handler(response) + return nil +} + +// EtcdHistory find if delete event happened while recovering +func (ew *EtcdWatcher) EtcdHistory(revision int64) { + if revision == 0 || revision >= ew.initialRev { + return + } + log.GetLogger().Debugf("start to find key %s history event", ew.key) + watchOption := []clientv3.OpOption{clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision), + clientv3.WithProgressNotify()} + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + watchChan := clientv3.NewWatcher(ew.watcher.Client).Watch(ctx, ew.key, watchOption...) + if watchChan == nil { + log.GetLogger().Errorf("failed to watch %s, watch channel is empty", ew.key) + return + } + events, ok := <-watchChan + if !ok { + log.GetLogger().Warnf("the channel received the result may be closed") + return + } + for _, event := range events.Events { + ew.sendEvent(parseHistoryEvent(event, ew.etcdType)) + } +} + +// StartWatch start watch etcd event +func (ew *EtcdWatcher) StartWatch() { + go ew.recoverWatch() + if !ew.watcher.etcdStatusNow { + log.GetLogger().Warnf("no connection with etcd.") + return + } + go ew.run() +} + +// recoverWatch recover watch etcd event when etcd reconnected +func (ew *EtcdWatcher) recoverWatch() { +loop: + for { + if ew.watcher.cond == nil { + log.GetLogger().Warnf("etcd client condition lock is not initialized") + return + } + ew.watcher.cond.L.Lock() + ew.watcher.cond.Wait() + ew.watcher.cond.L.Unlock() + select { + case <-ew.stopCh: + break loop + default: + } + go ew.run() + } + ew.resultChanWG.Wait() + close(ew.ResultChan) +} + +func (ew *EtcdWatcher) run() { + log.GetLogger().Infof("start to watch etcd prefix %s", ew.key) + if ew.watcher.Client == nil { + log.GetLogger().Errorf("failed to watch %s, etcd client is nil", ew.key) + return + } + if ew.cacheConfig.EnableCache { + go ew.processETCDCache() + } + ew.StartList() + watchChan, cancel, err := createWatchChan(ew) + defer cancel() + if err != nil || watchChan == nil { + return + } + for { + select { + case events, ok := <-watchChan: + if !ok { + cancel() + log.GetLogger().Warnf("the channel received the result may be closed") + watchChan, cancel, err = createWatchChan(ew) + if err != nil { + return + } + continue + } + if events.Err() != nil { + log.GetLogger().Errorf("etcd receive err events, err:%s", events.Err().Error()) + } + if ew.historyRev > 0 && ew.historyRev < ew.initialRev { + ew.EtcdHistory(ew.historyRev) + } + for _, event := range events.Events { + e := parseEvent(event, ew.etcdType) + ew.initialRev = e.Rev + ew.historyRev = ew.initialRev + ew.sendEvent(e) + } + case <-ew.stopCh: + log.GetLogger().Infof("stop watching etcd prefix %s", ew.key) + return + case <-ew.watcher.clientExitCh: + log.GetLogger().Errorf("lost %s etcd client", ew.watcher.etcdType) + return + } + } +} + +func createWatchChan(ew *EtcdWatcher) (clientv3.WatchChan, context.CancelFunc, error) { + watchOption := []clientv3.OpOption{clientv3.WithPrefix(), clientv3.WithPrevKV(), + clientv3.WithRev(ew.initialRev), clientv3.WithProgressNotify()} + ctx, cancelFunc := context.WithCancel(context.Background()) + if err := ew.etcdList(func(_ *clientv3.GetResponse) {}); err != nil { + log.GetLogger().Errorf("failed to etcdList, err: %s", err.Error()) + return nil, cancelFunc, err + } + watchChan := clientv3.NewWatcher(ew.watcher.Client).Watch(ctx, ew.key, watchOption...) + if watchChan == nil { + log.GetLogger().Errorf("failed to watch %s, watch channel is empty", ew.key) + return nil, cancelFunc, fmt.Errorf("failed to watch %s, watch channel is empty", ew.key) + } + return watchChan, cancelFunc, nil +} + +// StartList performs a ETCD List and send corresponding events, revision will be set after list +func (ew *EtcdWatcher) StartList() { + if ew.initialRev == 0 { + var restoreErr error + if ew.cacheConfig.EnableCache { + restoreErr = ew.restoreCacheFromFile() + } + if !ew.cacheConfig.EnableCache || restoreErr != nil { + if err := ew.etcdList(func(response *clientv3.GetResponse) { + for _, event := range response.Kvs { + ew.sendEvent(parseKV(event, ew.etcdType)) + } + }); err != nil { + log.GetLogger().Errorf("failed to sync with latest state, error: %s", err.Error()) + } + } + // notice watcher, ready to watch etcd kv + ew.sendEvent(syncedEvent()) + ew.historyRev = ew.initialRev + } +} + +// processEventLoop receive etcd event and process +func (ew *EtcdWatcher) processEventLoop() { + defer ew.resultChanWG.Done() + for { + select { + case event, ok := <-ew.ResultChan: + if !ok { + log.GetLogger().Warnf("event channel is closed, stop processing event") + return + } + if event.Type == SYNCED || !ew.filter(event) { + ew.handler(event) + } + case <-ew.stopCh: + log.GetLogger().Warnf("stop processing etcd event loop") + return + } + } +} + +func (ew *EtcdWatcher) sendEvent(e *Event) { + if len(ew.ResultChan) == defaultEventChanSize { + log.GetLogger().Warnf("Fast watcher, slow processing. Number of buffered events: %d."+ + "Probably caused by slow decoding, user not receiving fast, or other processing logic", + defaultEventChanSize) + } + if ew.watcher != nil { + e.Key = ew.watcher.DetachAZPrefix(e.Key) + } + select { + case ew.ResultChan <- e: + case <-ew.stopCh: + log.GetLogger().Warnf("etcd watcher chan closed") + } + if ew.cacheConfig.EnableCache && (e.Type == PUT || e.Type == DELETE) { + select { + case ew.CacheChan <- e: + case <-ew.stopCh: + log.GetLogger().Warnf("etcd watcher chan closed") + } + } +} + +// GetResponse get etcd value and return pointer of GetResponse struct +func (e *EtcdClient) GetResponse(ctxInfo EtcdCtxInfo, etcdKey string, + opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + etcdKey = e.AttachAZPrefix(etcdKey) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + + kv := clientv3.NewKV(e.Client) + getResp, err := kv.Get(ctx, etcdKey, opts...) + + return getResp, err +} + +// Put put context key and value +func (e *EtcdClient) Put(ctxInfo EtcdCtxInfo, etcdKey string, value string, opts ...clientv3.OpOption) error { + etcdKey = e.AttachAZPrefix(etcdKey) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + + kv := clientv3.NewKV(e.Client) + _, err := kv.Put(ctx, etcdKey, value, opts...) + return err +} + +// Delete delete key +func (e *EtcdClient) Delete(ctxInfo EtcdCtxInfo, etcdKey string, opts ...clientv3.OpOption) error { + etcdKey = e.AttachAZPrefix(etcdKey) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + kv := clientv3.NewKV(e.Client) + _, err := kv.Delete(ctx, etcdKey, opts...) + return err +} + +// Get gets from etcd +func (e *EtcdClient) Get(ctxInfo EtcdCtxInfo, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + key = e.AttachAZPrefix(key) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + kv := clientv3.NewKV(e.Client) + response, err := kv.Get(ctx, key, opts...) + if err != nil { + return nil, err + } + return response, nil +} + +// GetValues return list of object for key +func (e *EtcdClient) GetValues(ctxInfo EtcdCtxInfo, key string, opts ...clientv3.OpOption) ([]string, error) { + key = e.AttachAZPrefix(key) + ctx, cancel := ctxInfo.Ctx, ctxInfo.Cancel + defer cancel() + + kv := clientv3.NewKV(e.Client) + response, err := kv.Get(ctx, key, opts...) + if err != nil { + return nil, err + } + values := make([]string, len(response.Kvs)) + + for index, v := range response.Kvs { + values[index] = string(v.Value) + } + return values, err +} + +// CreateEtcdCtxInfoWithTimeout create a context with timeout, default timeout is DurationContextTimeout +func CreateEtcdCtxInfoWithTimeout(ctx context.Context, duration time.Duration) EtcdCtxInfo { + ctx, cancel := context.WithTimeout(ctx, duration) + return EtcdCtxInfo{ + Ctx: ctx, + Cancel: cancel, + } +} diff --git a/yuanrong/pkg/common/faas_common/etcd3/watcher_test.go b/yuanrong/pkg/common/faas_common/etcd3/watcher_test.go new file mode 100644 index 0000000..5938359 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/etcd3/watcher_test.go @@ -0,0 +1,456 @@ +package etcd3 + +import ( + "context" + "errors" + "fmt" + "os" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/api/v3/etcdserverpb" + "go.etcd.io/etcd/api/v3/mvccpb" + "go.etcd.io/etcd/client/v3" + + mockUtils "yuanrong/pkg/common/faas_common/utils" +) + +var watchChan clientv3.WatchChan +var resultCh chan *Event + +type EtcdWatcherMock struct { +} + +func (e EtcdWatcherMock) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan { + watchChan = make(chan clientv3.WatchResponse, 1) + return watchChan +} + +func (e EtcdWatcherMock) RequestProgress(ctx context.Context) error { + //TODO implement me + panic("implement me") +} + +func (e EtcdWatcherMock) Close() error { + //TODO implement me + panic("implement me") +} + +func TestNewEtcdWatcher(t *testing.T) { + prefix := "" + filter := func(event *Event) bool { return true } + handler := func(event *Event) {} + stopCh := make(chan struct{}) + + convey.Convey("Test NewEtcdWatcher", t, func() { + + convey.Convey("Test NewEtcdWatcher for success", func() { + etcdClient := GetRouterEtcdClient() + watcher := NewEtcdWatcher(prefix, filter, handler, stopCh, etcdClient) + convey.So(watcher, convey.ShouldNotBeNil) + }) + }) +} + +func TestEtcdList(t *testing.T) { + convey.Convey("StartList", t, func() { + stopCh := make(chan struct{}) + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &EtcdClient{Client: client, clientExitCh: make(chan struct{}), etcdStatusNow: true, cond: sync.NewCond(&sync.Mutex{})} + resultCh = make(chan *Event, 2) + watcher := NewEtcdWatcher("/xxx", etcdFilter, etcdHandler, stopCh, etcdClient) + defer gomonkey.ApplyMethod(reflect.TypeOf(kv), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + c := &clientv3.GetResponse{ + Header: &etcdserverpb.ResponseHeader{}, + Kvs: []*mvccpb.KeyValue{{Key: []byte("/xxx1"), Value: []byte("value1")}}, + } + c.Header.Revision = 1 + return c, nil + }).Reset() + watcher.StartList() + event := <-resultCh + convey.So(event.Type, convey.ShouldEqual, PUT) + convey.So(event.Key, convey.ShouldEqual, "/xxx1") + convey.So(string(event.Value), convey.ShouldEqual, "value1") + event = <-resultCh + convey.So(event.Type, convey.ShouldEqual, SYNCED) + close(stopCh) + }) +} + +func etcdFilter(event *Event) bool { + return false +} + +func etcdHandler(event *Event) { + resultCh <- event +} + +func erFail(t *testing.T) { + convey.Convey("failed to watch", t, func() { + convey.Convey("no connection with etcd", func() { + stopCh := make(chan struct{}) + etcdClient := &EtcdClient{clientExitCh: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{})} + watcher := NewEtcdWatcher("/xxx", etcdFilter, etcdHandler, stopCh, etcdClient) + watcher.StartWatch() + watcher.watcher.cond.Broadcast() + close(stopCh) + }) + convey.Convey("recover watcher", func() { + exitCh := make(chan struct{}, 1) + stopCh := make(chan struct{}, 1) + etcdClient := &EtcdClient{clientExitCh: exitCh, cond: sync.NewCond(&sync.Mutex{})} + etcdClient.etcdStatusNow = true + e := &EtcdWatcher{watcher: etcdClient, resultChanWG: &sync.WaitGroup{}, stopCh: stopCh, ResultChan: make(chan *Event)} + e.resultChanWG.Add(1) + go e.StartWatch() + exitCh <- struct{}{} + etcdClient.etcdStatusNow = true + time.Sleep(1 * time.Second) + close(stopCh) + e.watcher.cond.Broadcast() + e.resultChanWG.Done() + _, ok := <-e.ResultChan + convey.So(ok, convey.ShouldEqual, false) + }) + }) +} + +func TestEtcdWatcher(t *testing.T) { + stopCh := make(chan struct{}) + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &EtcdClient{Client: client, clientExitCh: make(chan struct{}), etcdStatusNow: true, cond: sync.NewCond(&sync.Mutex{})} + resultCh = make(chan *Event, 1) + watcher := NewEtcdWatcher("/xxx", etcdFilter, etcdHandler, stopCh, etcdClient) + watcher.initialRev = 1 + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(kv), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + c := &clientv3.GetResponse{ + Header: &etcdserverpb.ResponseHeader{}, + Kvs: []*mvccpb.KeyValue{}, + } + c.Header.Revision = 1 + return c, nil + }), + gomonkey.ApplyFunc(clientv3.NewWatcher, func(c *clientv3.Client) clientv3.Watcher { + return &EtcdWatcherMock{} + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + convey.Convey("watch etcd", t, func() { + go watcher.StartWatch() + e := &Event{ + Type: PUT, + Key: "/xxx", + Value: []byte("test"), + PrevValue: nil, + Rev: 0, + } + time.Sleep(500 * time.Millisecond) + watcher.sendEvent(e) + close(stopCh) + event := <-resultCh + convey.So(event, convey.ShouldEqual, e) + }) +} + +type fakeKV struct { + cache map[string]string +} + +func (f *fakeKV) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + f.cache[key] = val + return nil, nil +} + +func (f *fakeKV) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + delete(f.cache, key) + return nil, nil +} + +func (f *fakeKV) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + return clientv3.OpResponse{}, nil +} + +func (f *fakeKV) Txn(ctx context.Context) clientv3.Txn { + return nil +} + +func (f *fakeKV) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, error) { + return nil, nil +} + +func (f *fakeKV) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + if _, ok := f.cache[key]; !ok { + return nil, fmt.Errorf("Doesn't exist") + } + return &clientv3.GetResponse{Count: 1, Kvs: []*mvccpb.KeyValue{ + &mvccpb.KeyValue{Value: []byte(f.cache[key])}, + }}, nil +} + +func TestOptEtcd(t *testing.T) { + ew := &EtcdWatcher{ + watcher: &EtcdClient{ + Client: &clientv3.Client{}, + cond: sync.NewCond(&sync.Mutex{}), + }, + } + fakeKv := &fakeKV{cache: map[string]string{}} + defer gomonkey.ApplyFunc(clientv3.NewKV, func(c *clientv3.Client) clientv3.KV { + return fakeKv + }).Reset() + etcdCtx := EtcdCtxInfo{ + Cancel: func() {}, + } + key1 := "etcdKey" + val1 := "etcdValue" + key2 := "etcdKey2" + val2 := "etcdValue2" + + convey.Convey("Put", t, func() { + err := ew.watcher.Put(etcdCtx, key1, val1) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("GetResponse", t, func() { + resp, err := ew.watcher.GetResponse(etcdCtx, key1) + convey.So(err, convey.ShouldBeNil) + convey.So(resp, convey.ShouldResemble, &clientv3.GetResponse{Count: 1, Kvs: []*mvccpb.KeyValue{ + &mvccpb.KeyValue{Value: []byte(val1)}, + }}) + }) + convey.Convey("Delete", t, func() { + err := ew.watcher.Delete(etcdCtx, key1) + convey.So(err, convey.ShouldBeNil) + + resp, err := ew.watcher.GetValues(etcdCtx, key1) + convey.So(err.Error(), convey.ShouldEqual, "Doesn't exist") + convey.So(resp, convey.ShouldBeNil) + }) + convey.Convey("GetValues", t, func() { + err := ew.watcher.Put(etcdCtx, key2, val2) + convey.So(err, convey.ShouldBeNil) + resp, err := ew.watcher.GetValues(etcdCtx, key2) + convey.So(err, convey.ShouldBeNil) + convey.So(resp, convey.ShouldResemble, []string{val2}) + }) +} + +func TestCreateEtcdCtxInfoWithTimeout(t *testing.T) { + convey.Convey("CreateEtcdCtxInfoWithTimeout", t, func() { + ctxInfoWithTimeout := CreateEtcdCtxInfoWithTimeout(context.TODO(), time.Second) + convey.So(ctxInfoWithTimeout, convey.ShouldNotBeNil) + }) +} + +func Test_run(t *testing.T) { + convey.Convey("run", t, func() { + defer gomonkey.ApplyFunc(clientv3.NewWatcher, func(c *clientv3.Client) clientv3.Watcher { + return &EtcdWatcherMock{} + }).Reset() + convey.Convey("the channel received the result may be closed", func() { + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &EtcdClient{Client: client, clientExitCh: make(chan struct{}), etcdStatusNow: true, cond: sync.NewCond(&sync.Mutex{})} + receiveCh := make(chan *Event, 1) + stopCh := make(chan struct{}) + e := &EtcdWatcher{watcher: etcdClient, ResultChan: receiveCh, stopCh: stopCh} + e.initialRev = 1 + watchCh := make(chan clientv3.WatchResponse, 1) + callCount := 0 + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdWatcherMock{}), "Watch", + func(e *EtcdWatcherMock, ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan { + callCount++ + if callCount == 1 { + return watchCh + } + return make(chan clientv3.WatchResponse, 1) + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(e.watcher), "GetEtcdStatusNow", func(e *EtcdClient) bool { + return false + }).Reset() + closeCount := 0 + defer gomonkey.ApplyMethod(reflect.TypeOf(kv), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + closeCount++ + return &clientv3.GetResponse{Header: &etcdserverpb.ResponseHeader{Revision: int64(closeCount)}}, nil + }).Reset() + go e.run() + time.Sleep(100 * time.Millisecond) + close(watchCh) + time.Sleep(100 * time.Millisecond) + convey.So(callCount, convey.ShouldEqual, 2) + close(stopCh) + }) + convey.Convey("sendEvent", func() { + eventCh := make(chan clientv3.WatchResponse, 1) + defer gomonkey.ApplyMethod(reflect.TypeOf(&EtcdWatcherMock{}), "Watch", + func(e *EtcdWatcherMock, ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan { + return eventCh + }).Reset() + stopCh := make(chan struct{}) + receiveCh := make(chan *Event, defaultEventChanSize) + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &EtcdClient{Client: client, clientExitCh: make(chan struct{}), etcdStatusNow: true, cond: sync.NewCond(&sync.Mutex{})} + e := &EtcdWatcher{watcher: etcdClient, stopCh: stopCh, ResultChan: receiveCh} + e.initialRev = 1 + closeCount := 0 + defer gomonkey.ApplyMethod(reflect.TypeOf(kv), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + closeCount++ + return &clientv3.GetResponse{Header: &etcdserverpb.ResponseHeader{Revision: int64(closeCount)}}, nil + }).Reset() + go e.run() + eventCh <- clientv3.WatchResponse{Events: []*clientv3.Event{{Kv: &mvccpb.KeyValue{Key: []byte("key1"), Value: []byte("value1")}}}} + event := <-receiveCh + convey.So(event.Key, convey.ShouldEqual, "key1") + convey.So(string(event.Value), convey.ShouldEqual, "value1") + close(stopCh) + }) + convey.Convey("enable cache", func() { + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") + defer gomonkey.ApplyMethod(reflect.TypeOf(&KvMock{}), "Get", func(_ *KvMock, ctx context.Context, + key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, errors.New("some error") + }).Reset() + stopCh := make(chan struct{}) + receiveCh := make(chan *Event, defaultEventChanSize) + cacheCh := make(chan *Event, defaultEventChanSize) + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &EtcdClient{Client: client, clientExitCh: make(chan struct{}), etcdStatusNow: true} + e := &EtcdWatcher{watcher: etcdClient, stopCh: stopCh, ResultChan: receiveCh, CacheChan: cacheCh, + key: "/sn/function", cacheConfig: EtcdCacheConfig{ + EnableCache: true, + PersistPath: "./", + FlushInterval: 10, + }} + os.WriteFile("./etcdCacheMeta_#sn#function", []byte(`{"revision":101,"cacheMD5":"5642747b723c9497e2b7324b49fb0513"}`), 0600) + os.WriteFile("./etcdCacheData_#sn#function", []byte("/sn/function/123/goodbye/latest|101|{\"name\":\"goodbye\",\"version\":\"latest\"}\n/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + go e.run() + time.Sleep(500 * time.Millisecond) + convey.So(len(e.ResultChan), convey.ShouldEqual, 3) + event1 := <-e.ResultChan + event2 := <-e.ResultChan + convey.So(event1, convey.ShouldResemble, &Event{ + Rev: 101, + Type: PUT, + Key: "/sn/function/123/goodbye/latest", + Value: []byte(`{"name":"goodbye","version":"latest"}`), + }) + convey.So(event2, convey.ShouldResemble, &Event{ + Rev: 100, + Type: PUT, + Key: "/sn/function/123/hello/latest", + Value: []byte(`{"name":"hello","version":"latest"}`), + }) + close(stopCh) + }) + }) + os.Remove("etcdCacheMeta_#sn#function") + os.Remove("etcdCacheData_#sn#function") + os.Remove("etcdCacheData_#sn#function_backup") +} + +func TestEtcdWatcher_EtcdHistory(t *testing.T) { + type fields struct { + filter EtcdWatcherFilter + handler EtcdWatcherHandler + watcher *EtcdClient + ResultChan chan *Event + resultChanWG *sync.WaitGroup + stopCh <-chan struct{} + key string + initialRev int64 + } + type args struct { + revision int64 + } + tests := []struct { + name string + fields fields + args args + patchesFunc mockUtils.PatchesFunc + }{ + {"case1", fields{watcher: &EtcdClient{cond: sync.NewCond(&sync.Mutex{})}}, args{revision: -1}, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(clientv3.NewWatcher, func(c *clientv3.Client) clientv3.Watcher { + return &EtcdWatcherMock{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&EtcdWatcherMock{}), "Watch", + func(e *EtcdWatcherMock, ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan { + ch := make(chan clientv3.WatchResponse, 1) + go close(ch) + return ch + })}) + return patches + }}, + {"case2 watch chan nil", fields{watcher: &EtcdClient{cond: sync.NewCond(&sync.Mutex{})}}, args{revision: -1}, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(clientv3.NewWatcher, func(c *clientv3.Client) clientv3.Watcher { + return &EtcdWatcherMock{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&EtcdWatcherMock{}), "Watch", + func(e *EtcdWatcherMock, ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan { + return nil + })}) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + ew := &EtcdWatcher{ + filter: tt.fields.filter, + handler: tt.fields.handler, + watcher: tt.fields.watcher, + ResultChan: tt.fields.ResultChan, + resultChanWG: tt.fields.resultChanWG, + stopCh: tt.fields.stopCh, + key: tt.fields.key, + initialRev: tt.fields.initialRev, + } + ew.EtcdHistory(tt.args.revision) + patches.ResetAll() + }) + } +} + +func TestEtcdClient_Get(t *testing.T) { + e := &EtcdClient{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctxInfo := EtcdCtxInfo{Ctx: ctx, Cancel: cancel} + + key := "test-key" + response := &clientv3.GetResponse{} + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFuncReturn(clientv3.NewKV, &clientv3.Client{}) + patches.ApplyMethodReturn(&clientv3.Client{}, "Get", response, nil) + + got, err := e.Get(ctxInfo, key) + + assert.NoError(t, err) + assert.Equal(t, response, got) +} diff --git a/yuanrong/pkg/common/faas_common/instance/util.go b/yuanrong/pkg/common/faas_common/instance/util.go new file mode 100644 index 0000000..7cb786b --- /dev/null +++ b/yuanrong/pkg/common/faas_common/instance/util.go @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instance +package instance + +import ( + "encoding/json" + "fmt" + "strings" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" +) + +const ( + keySeparator = "/" + + instanceIDValueIndex = 13 + validEtcdKeyLenForInstance = 14 +) + +// GetInstanceIDFromEtcdKey gets instance id from etcd key of instance +func GetInstanceIDFromEtcdKey(etcdKey string) string { + items := strings.Split(etcdKey, keySeparator) + if len(items) != validEtcdKeyLenForInstance { + return "" + } + return fmt.Sprintf("%s", items[instanceIDValueIndex]) +} + +// GetInsSpecFromEtcdValue gets InstanceSpecification from etcd value of instance +func GetInsSpecFromEtcdValue(etcdKey string, etcdValue []byte) *types.InstanceSpecification { + insSpec := &types.InstanceSpecification{} + if len(etcdValue) != 0 { + err := json.Unmarshal(etcdValue, insSpec) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal etcd value to instance specification %s", err.Error()) + return nil + } + } else { + log.GetLogger().Warnf("etcd value is empty when get instance specification from key %s", etcdKey) + } + return insSpec +} diff --git a/yuanrong/pkg/common/faas_common/instance/util_test.go b/yuanrong/pkg/common/faas_common/instance/util_test.go new file mode 100644 index 0000000..2cad1e2 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/instance/util_test.go @@ -0,0 +1,68 @@ +package instance + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + commonTypes "yuanrong/pkg/common/faas_common/types" +) + +func TestGetInstanceIDFromEtcdKey(t *testing.T) { + etcdKey := "/sn/instance/business/yrk/tenant/123/function/faasscheduler/version/$latest/defaultaz/requestID/abc" + instanceID := GetInstanceIDFromEtcdKey(etcdKey) + assert.Equal(t, "abc", instanceID) + + instanceIDNil := GetInstanceIDFromEtcdKey("") + assert.Equal(t, "", instanceIDNil) +} + +func TestGetInsSpecFromEtcdValue(t *testing.T) { + etcdValue := []byte("{\"instanceID\":\"51f71580-3a07-4000-8000-004b56e7f471\",\"requestID\":\"7fb31" + + "b50-7c5a-11ed-a991-fa163e3523c8\",\"runtimeID\":\"runtime-e06fe343-0000-4000-8000-00bbad15e23" + + "8\",\"runtimeAddress\":\"10.244.162.129:33333\",\"functionAgentID\":\"function_agent_10.244.16" + + "2.129-33333\",\"functionProxyID\":\"dggphis35893-8490\",\"function\":\"12345678901234561234567" + + "890123456/0-system-hello/$latest\",\"resources\":{\"resources\":{\"Memory\":{\"name\":\"Memor" + + "y\",\"scalar\":{\"value\":500}},\"CPU\":{\"name\":\"CPU\",\"scalar\":{\"value\":500}}}},\"sched" + + "uleOption\":{\"affinity\":{\"instanceAffinity\":{}}},\"instanceStatus\":{\"code\":3,\"msg\":\"i" + + "nstance is running\"}}") + insSpecTrans := GetInsSpecFromEtcdValue("", etcdValue) + insSpecExpected := &commonTypes.InstanceSpecification{ + InstanceID: "51f71580-3a07-4000-8000-004b56e7f471", + RequestID: "7fb31b50-7c5a-11ed-a991-fa163e3523c8", + RuntimeID: "runtime-e06fe343-0000-4000-8000-00bbad15e238", + RuntimeAddress: "10.244.162.129:33333", + FunctionAgentID: "function_agent_10.244.162.129-33333", + FunctionProxyID: "dggphis35893-8490", + Function: "12345678901234561234567890123456/0-system-hello/$latest", + RestartPolicy: "", + Resources: commonTypes.Resources{ + Resources: map[string]commonTypes.Resource{ + "CPU": commonTypes.Resource{ + Name: "CPU", + Scalar: commonTypes.ValueScalar{Value: 500}, + }, + "Memory": commonTypes.Resource{ + Name: "Memory", + Scalar: commonTypes.ValueScalar{Value: 500}, + }, + }, + }, + ActualUse: commonTypes.Resources{}, + ScheduleOption: commonTypes.ScheduleOption{ + Affinity: commonTypes.Affinity{ + InstanceAffinity: commonTypes.InstanceAffinity{}, + }, + }, + CreateOptions: nil, + Labels: nil, + StartTime: "", + InstanceStatus: commonTypes.InstanceStatus{Code: 3, Msg: "instance is running"}, + JobID: "", + SchedulerChain: nil, + } + assert.Equal(t, insSpecExpected, insSpecTrans) + + insSpecNil := GetInsSpecFromEtcdValue("", []byte("")) + assert.Equal(t, true, insSpecNil != nil) +} diff --git a/yuanrong/pkg/common/faas_common/instanceconfig/util.go b/yuanrong/pkg/common/faas_common/instanceconfig/util.go new file mode 100644 index 0000000..e2d3f51 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/instanceconfig/util.go @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instanceconfig - +package instanceconfig + +import ( + "encoding/json" + "fmt" + "strings" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/types" + wisecloudtypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" +) + +const ( + keySeparator = "/" + insConfigTenantValueIndex = 7 + insConfigFuncNameValueIndex = 9 + insConfigVersionValueIndex = 11 + validEtcdKeyLenForInsConfig = 12 + insConfigLabelValueIndex = 13 + validEtcdKeyLenForInsWithLabelConf = 14 + + insConfigKeyIndex = 1 + insConfigClusterKeyIndex = 4 + insConfigClusterValueIndex = 5 + insConfigTenantKeyIndex = 6 + insConfigFunctionKeyIndex = 8 + insConfigLabelKeyIndex = 12 + + functionClusterKeyIdx = 5 + + // InsConfigEtcdPrefix - 函数实例配置项元数据key前缀 + InsConfigEtcdPrefix = "/instances" +) + +// GetLabelFromInstanceConfigEtcdKey - +func GetLabelFromInstanceConfigEtcdKey(etcdKey string) string { + items := strings.Split(etcdKey, keySeparator) + if len(items) != validEtcdKeyLenForInsWithLabelConf { + return "" + } + return items[insConfigLabelValueIndex] +} + +// ParseInstanceConfigFromEtcdEvent - +func ParseInstanceConfigFromEtcdEvent(etcdKey string, etcdValue []byte) (*Configuration, error) { + items := strings.Split(etcdKey, keySeparator) + if len(items) != validEtcdKeyLenForInsConfig && len(items) != validEtcdKeyLenForInsWithLabelConf { + return nil, fmt.Errorf("etcdKey format error") + } + + funcKey := fmt.Sprintf("%s/%s/%s", items[insConfigTenantValueIndex], items[insConfigFuncNameValueIndex], + items[insConfigVersionValueIndex]) + + label := "" + if len(items) == validEtcdKeyLenForInsWithLabelConf { + label = items[insConfigLabelValueIndex] + } + + if len(etcdValue) == 0 { + return nil, fmt.Errorf("etcdValue is empty") + } + insConfig := &Configuration{} + err := json.Unmarshal(etcdValue, insConfig) + if err != nil { + return nil, fmt.Errorf("unmarshal etcdValue failed, err: %s", err.Error()) + } + + insConfig.FuncKey = funcKey + insConfig.InstanceLabel = label + return insConfig, nil +} + +// Configuration - +type Configuration struct { + FuncKey string + InstanceLabel string + InstanceMetaData types.InstanceMetaData `json:"instanceMetaData" valid:",optional"` + NuwaRuntimeInfo wisecloudtypes.NuwaRuntimeInfo `json:"nuwaRuntimeInfo" valid:",optional"` +} + +// DeepCopy return a Configuration Copy +func (i *Configuration) DeepCopy() *Configuration { + return &(*i) +} + +// GetWatcherFilter - +func GetWatcherFilter(clusterId string) func(event *etcd3.Event) bool { + return func(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForInsConfig && len(items) != validEtcdKeyLenForInsWithLabelConf { + return true + } + if items[insConfigKeyIndex] != "instances" || items[insConfigClusterKeyIndex] != "cluster" || + items[insConfigTenantKeyIndex] != "tenant" || items[insConfigFunctionKeyIndex] != "function" { + return true + } + if len(items) == validEtcdKeyLenForInsWithLabelConf && items[insConfigLabelKeyIndex] != "label" { + return true + } + if clusterId != items[insConfigClusterValueIndex] { + return true + } + return false + } +} diff --git a/yuanrong/pkg/common/faas_common/instanceconfig/util_test.go b/yuanrong/pkg/common/faas_common/instanceconfig/util_test.go new file mode 100644 index 0000000..f8227e3 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/instanceconfig/util_test.go @@ -0,0 +1,68 @@ +package instanceconfig + +import ( + "encoding/json" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instance" + "yuanrong/pkg/common/faas_common/types" + wisecloudtypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" +) + +func TestParseInstanceConfigFromEtcdEvent(t *testing.T) { + convey.Convey("Test ParseInstanceConfigFromEtcdEvent", t, func() { + testConfig := &Configuration{ + InstanceMetaData: types.InstanceMetaData{PoolID: "test"}, + NuwaRuntimeInfo: wisecloudtypes.NuwaRuntimeInfo{WisecloudRuntimeId: "runtime1"}, + } + testConfigData, _ := json.Marshal(testConfig) + + convey.Convey("should parse config with label successfully", func() { + etcdKey := "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456/function/0@test111@yrfunc111/version/latest/label/aaa" + config, err := ParseInstanceConfigFromEtcdEvent(etcdKey, testConfigData) + convey.So(err, convey.ShouldBeNil) + convey.So(config.FuncKey, convey.ShouldEqual, "12345678901234561234567890123456/0@test111@yrfunc111/latest") + convey.So(config.InstanceLabel, convey.ShouldEqual, "aaa") + }) + + convey.Convey("should return error for invalid key format", func() { + key := "/invalid/key" + _, err := ParseInstanceConfigFromEtcdEvent(key, testConfigData) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestGetWatcherFilter(t *testing.T) { + convey.Convey("Test GetWatcherFilter", t, func() { + filter := GetWatcherFilter("cluster1") + + convey.Convey("should filter matching cluster key", func() { + event := &etcd3.Event{ + Key: "/instances/business/yrk/cluster/cluster1/tenant/t1/function/f1/version/v1", + } + convey.So(filter(event), convey.ShouldBeFalse) + }) + + convey.Convey("should not filter invalid key structure", func() { + event := &etcd3.Event{ + Key: "/invalid/key", + } + convey.So(filter(event), convey.ShouldBeTrue) + }) + }) +} + +func TestGetLabelFromInstanceConfigEtcdKey(t *testing.T) { + etcdKey := "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456/function/0@test111@yrfunc111/version/latest" + label := GetLabelFromInstanceConfigEtcdKey(etcdKey) + assert.Equal(t, "", label) + + etcdKey = "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456/function/0@test111@yrfunc111/version/latest/label/aaa" + label = instance.GetInstanceIDFromEtcdKey(etcdKey) + assert.Equal(t, "aaa", label) +} diff --git a/yuanrong/pkg/common/faas_common/k8sclient/tools.go b/yuanrong/pkg/common/faas_common/k8sclient/tools.go new file mode 100644 index 0000000..e674ef6 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/k8sclient/tools.go @@ -0,0 +1,311 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package k8sclient include some k8s Client operation +package k8sclient + +import ( + "context" + "fmt" + "reflect" + "sync" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// KubeClient - +type KubeClient struct { + Client kubernetes.Interface +} + +var ( + // KubeClientSet - + KubeClientSet *KubeClient + kubeClientOnce sync.Once + + // dynamicClient - + dynamicClient dynamic.Interface + dynamicClientOnce sync.Once +) + +// GetkubeClient is used to obtain a K8S Client +func GetkubeClient() *KubeClient { + kubeClientOnce.Do(func() { + // create Kubernetes config + config, err := rest.InClusterConfig() + if err != nil { + log.GetLogger().Errorf("Failed to create Kubernetes config: %v", err) + return + } + + // create Kubernetes Client + client, err := kubernetes.NewForConfig(config) + if err != nil { + log.GetLogger().Errorf("Failed to create Kubernetes Client: %v", err) + return + } + + KubeClientSet = &KubeClient{ + Client: client, + } + }) + return KubeClientSet +} + +// NewDynamicClient New Dynamic Client +func NewDynamicClient() dynamic.Interface { + dynamicClientOnce.Do(func() { + // create Kubernetes config + config, err := rest.InClusterConfig() + if err != nil { + log.GetLogger().Errorf("Failed to create Kubernetes config: %s", err.Error()) + return + } + // create dynamic client + dynamicClient, err = dynamic.NewForConfig(config) + if err != nil { + log.GetLogger().Errorf("failed to create dynamic client: %s", err.Error()) + return + } + }) + return dynamicClient +} + +// DeleteK8sService - +func (kc *KubeClient) DeleteK8sService(namespace string, serviceName string) error { + if kc == nil { + return fmt.Errorf("kubeclient is nil") + } + err := kc.Client.CoreV1().Services(namespace).Delete(context.TODO(), serviceName, metav1.DeleteOptions{}) + if err != nil { + if errors.IsNotFound(err) { + log.GetLogger().Infof("Service %s in namespace %s not found", serviceName, namespace) + return nil + } + return err + } + log.GetLogger().Infof("Service %s in namespace %s deleted", serviceName, namespace) + return nil +} + +// CreateK8sService - +func (kc *KubeClient) CreateK8sService(service *v1.Service) error { + if kc == nil { + return fmt.Errorf("kubeclient is nil") + } + + // delete service + if err := kc.DeleteK8sService(service.Namespace, service.Name); err != nil { + return err + } + // create Service + result, err := kc.Client.CoreV1().Services(service.Namespace).Create(context.TODO(), service, metav1.CreateOptions{}) + if err != nil { + return fmt.Errorf("failed to create Service: %s", err.Error()) + } + + log.GetLogger().Infof("created Service %q with IP %q", result.GetObjectMeta().GetName(), result.Spec.ClusterIP) + return nil +} + +// CreateK8sConfigMap - +func (kc *KubeClient) CreateK8sConfigMap(configMap *v1.ConfigMap) error { + if kc == nil { + return fmt.Errorf("kubeclient is nil") + } + + // delete configMap + if err := kc.DeleteK8sConfigMap(configMap.Namespace, configMap.Name); err != nil { + return err + } + // create configMap + result, err := kc.Client.CoreV1().ConfigMaps(configMap.Namespace).Create(context.TODO(), + configMap, metav1.CreateOptions{}) + if err != nil { + return fmt.Errorf("failed to create ConfigMap: %s", err.Error()) + } + + log.GetLogger().Infof("created ConfigMap: %s", result.GetObjectMeta().GetName()) + return nil +} + +// DeleteK8sConfigMap - +func (kc *KubeClient) DeleteK8sConfigMap(namespace string, configMapName string) error { + if kc == nil { + return fmt.Errorf("kubeclient is nil") + } + + // delete configMap + err := kc.Client.CoreV1().ConfigMaps(namespace).Delete(context.TODO(), configMapName, metav1.DeleteOptions{}) + if err != nil { + if errors.IsNotFound(err) { + log.GetLogger().Infof("configMap %s in namespace %s not found", configMapName, namespace) + return nil + } + return err + } + log.GetLogger().Infof("configMap %s in namespace %s deleted", configMapName, namespace) + return nil +} + +// UpdateK8sConfigMap - +func (kc *KubeClient) UpdateK8sConfigMap(configMap *v1.ConfigMap) error { + if kc == nil { + return fmt.Errorf("kubeclient is nil") + } + _, err := kc.Client.CoreV1().ConfigMaps(configMap.Namespace).Get(context.TODO(), configMap.Name, metav1.GetOptions{}) + if err != nil { + if errors.IsNotFound(err) { + log.GetLogger().Infof("configMap %s in namespace %s not found", configMap.Name, configMap.Namespace) + } + return err + } + _, err = kc.Client.CoreV1().ConfigMaps(configMap.Namespace).Update(context.TODO(), configMap, metav1.UpdateOptions{}) + if err != nil { + log.GetLogger().Errorf("update configmap failed, error is %s", err.Error()) + return err + } + log.GetLogger().Infof("configMap %s in namespace %s updated", configMap.Name, configMap.Namespace) + return nil +} + +// GetK8sConfigMap - +func (kc *KubeClient) GetK8sConfigMap(namespace string, configMapName string) (*v1.ConfigMap, error) { + if kc == nil { + return nil, fmt.Errorf("kubeclient is nil") + } + configmap, err := kc.Client.CoreV1().ConfigMaps(namespace).Get(context.TODO(), configMapName, metav1.GetOptions{}) + if err != nil { + if errors.IsNotFound(err) { + log.GetLogger().Infof("configMap %s in namespace %s not found", configMapName, namespace) + } + return nil, err + } + log.GetLogger().Infof("Get configMap %s in namespace %s updated", configMapName, namespace) + return configmap, nil +} + +// GetK8sSecret - +func (kc *KubeClient) GetK8sSecret(namespace string, secretName string) (*v1.Secret, error) { + if kc == nil { + return nil, fmt.Errorf("kubeclient is nil") + } + ctx := context.TODO() + secret, err := kc.Client.CoreV1().Secrets(namespace).Get(ctx, secretName, metav1.GetOptions{}) + if err != nil { + if errors.IsNotFound(err) { + log.GetLogger().Infof("secret %s not found", secretName) + return nil, err + } + log.GetLogger().Errorf("secret %s get failed, err is: %s", secretName, err) + return nil, err + } + log.GetLogger().Errorf("secret %s already exists, no need create.", secretName) + return secret, nil +} + +// CreateK8sSecret - +func (kc *KubeClient) CreateK8sSecret(namespace string, s *v1.Secret) (*v1.Secret, error) { + if kc == nil { + return nil, fmt.Errorf("kubeclient is nil") + } + ctx := context.TODO() + secret, err := kc.Client.CoreV1().Secrets(namespace).Create(ctx, s, metav1.CreateOptions{}) + if err != nil { + log.GetLogger().Errorf("k8s failed to create secret: %s, secretName: %s", err.Error(), s.Name) + return nil, err + } + log.GetLogger().Infof("secret %s in namespace %s created", secret.Name, namespace) + + return secret, nil +} + +// UpdateK8sSecret - +func (kc *KubeClient) UpdateK8sSecret(namespace string, s *v1.Secret) (*v1.Secret, error) { + if kc == nil { + return nil, fmt.Errorf("kubeclient is nil") + } + ctx := context.TODO() + secret, err := kc.Client.CoreV1().Secrets(namespace).Update(ctx, s, metav1.UpdateOptions{}) + if err != nil { + log.GetLogger().Errorf("k8s failed to update secret: %s, secretName: %s", err.Error(), s.Name) + return nil, err + } + log.GetLogger().Infof("secret %s in namespace %s updated", secret.Name, namespace) + + return secret, nil +} + +// DeleteK8sSecret - +func (kc *KubeClient) DeleteK8sSecret(namespace string, secretName string) error { + if kc == nil { + return fmt.Errorf("kubeclient is nil") + } + ctx := context.TODO() + err := kc.Client.CoreV1().Secrets(namespace).Delete(ctx, secretName, metav1.DeleteOptions{}) + + if err != nil { + if errors.IsNotFound(err) { + log.GetLogger().Infof("secret %s in namespace %s not found", secretName, namespace) + return nil + } + log.GetLogger().Errorf("k8s failed to delete secret: %s, secretName: %s", err.Error(), secretName) + return err + } + log.GetLogger().Infof("secret %s in namespace %s deleted successfully", secretName, namespace) + + return nil +} + +// CreateOrUpdateConfigMap - +func (kc *KubeClient) CreateOrUpdateConfigMap(c *v1.ConfigMap) error { + if kc == nil { + return fmt.Errorf("kubeclient is nil") + } + ctx := context.TODO() + oldConfig, getErr := kc.Client.CoreV1().ConfigMaps(c.Namespace).Get(ctx, c.Name, metav1.GetOptions{}) + if getErr != nil && errors.IsNotFound(getErr) { + log.GetLogger().Infof("Creating a new Configmap, Configmap.Name: %s", c.Name) + _, createErr := kc.Client.CoreV1().ConfigMaps(c.Namespace).Create(ctx, c, metav1.CreateOptions{}) + if createErr != nil { + log.GetLogger().Errorf("k8s failed to create configmap: %s, traceID: %s", + createErr.Error(), "TraceID") + return createErr + } + return nil + } + if getErr != nil { + log.GetLogger().Errorf("failed to get configmap: %s, err:%v", c.Name, getErr.Error()) + return getErr + } + + if !reflect.DeepEqual(oldConfig, c) { + _, updateErr := kc.Client.CoreV1().ConfigMaps(c.Namespace).Update(ctx, c, metav1.UpdateOptions{}) + if updateErr != nil { + log.GetLogger().Errorf("k8s failed to update configmap: %s, traceID: %s", updateErr.Error(), + "TraceID") + return updateErr + } + } + return nil +} diff --git a/yuanrong/pkg/common/faas_common/k8sclient/tools_test.go b/yuanrong/pkg/common/faas_common/k8sclient/tools_test.go new file mode 100644 index 0000000..c5516e4 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/k8sclient/tools_test.go @@ -0,0 +1,524 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package k8sclient include some k8s client operation +package k8sclient + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/rest" + k8testing "k8s.io/client-go/testing" +) + +var inClusterConfigFunc = rest.InClusterConfig + +type mockK8sClient struct { + createConfigError error + createClientError error + expectedConfigCalled bool +} + +func TestGetkubeClient(t *testing.T) { + defer gomonkey.ApplyFunc(rest.InClusterConfig, func() (*rest.Config, error) { + return &rest.Config{}, nil + }).Reset() + convey.Convey("get client success", t, func() { + defer gomonkey.ApplyFunc(kubernetes.NewForConfig, func(c *rest.Config) (*kubernetes.Clientset, error) { + return &kubernetes.Clientset{}, nil + }).Reset() + client := GetkubeClient() + convey.So(client, convey.ShouldNotBeNil) + }) + KubeClientSet = nil + kubeClientOnce = sync.Once{} + convey.Convey("get client error", t, func() { + defer gomonkey.ApplyFunc(kubernetes.NewForConfig, func(c *rest.Config) (*kubernetes.Clientset, error) { + return nil, fmt.Errorf("get client error") + }).Reset() + client := GetkubeClient() + convey.So(client, convey.ShouldBeNil) + }) + kubeClientOnce = sync.Once{} + convey.Convey("get cfg error", t, func() { + defer gomonkey.ApplyFunc(rest.InClusterConfig, func() (*rest.Config, error) { + return nil, fmt.Errorf("get cfg error") + }).Reset() + client := GetkubeClient() + convey.So(client, convey.ShouldBeNil) + }) +} + +func TestDeleteK8sService(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + namespace := "default" + serviceName := "frontend" + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceName, + Namespace: namespace, + }, + Spec: v1.ServiceSpec{ + Selector: map[string]string{ + "app": "frontend", + }, + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: "TCP", + Port: 8888, + TargetPort: intstr.IntOrString{ + Type: intstr.Int, + IntVal: 32104, + }, + NodePort: 31222, + }, + }, + Type: v1.ServiceTypeNodePort, + }, + } + client.PrependReactor("delete", "services", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + deleteAction := action.(k8testing.DeleteAction) + if deleteAction.GetName() == service.Name && deleteAction.GetNamespace() == service.Namespace { + return true, service, nil + } + return true, nil, fmt.Errorf("Not found") + }) + + convey.Convey("delete service success", t, func() { + err := KubeClientSet.DeleteK8sService(namespace, serviceName) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("delete service not found", t, func() { + err := KubeClientSet.DeleteK8sService(namespace, "error service name") + convey.So(err.Error(), convey.ShouldContainSubstring, "Not found") + }) + convey.Convey("delete service not found", t, func() { + client.PrependReactor("delete", "services", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + deleteAction := action.(k8testing.DeleteAction) + if deleteAction.GetName() == service.Name && deleteAction.GetNamespace() == service.Namespace { + return true, service, fmt.Errorf("delete error") + } + return false, nil, nil + }) + + err := KubeClientSet.DeleteK8sService(namespace, serviceName) + convey.So(err.Error(), convey.ShouldContainSubstring, "delete error") + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + err := KubeClientSet.DeleteK8sService(namespace, serviceName) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestCreateK8sService(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "frontend", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Selector: map[string]string{ + "app": "frontend", + }, + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: "TCP", + Port: 8888, + TargetPort: intstr.IntOrString{ + Type: intstr.Int, + IntVal: 32104, + }, + NodePort: 31222, + }, + }, + Type: v1.ServiceTypeNodePort, + }, + } + + convey.Convey("create service success", t, func() { + err := KubeClientSet.CreateK8sService(service) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("create service error", t, func() { + client.PrependReactor("create", "services", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + createAction := action.(k8testing.CreateAction) + if createAction.GetObject().(*v1.Service).Name == service.Name && createAction.GetNamespace() == service.Namespace { + return true, service, fmt.Errorf("failed to create service") + } + return false, nil, nil + }) + err := KubeClientSet.CreateK8sService(service) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + err := KubeClientSet.CreateK8sService(service) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestCreateK8sConfigMap(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + configmap := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test_configmap", + Namespace: "default", + }, + Data: map[string]string{"key": "value"}, + } + + convey.Convey("create configmap success", t, func() { + err := KubeClientSet.CreateK8sConfigMap(configmap) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("create configmap error", t, func() { + client.PrependReactor("create", "configmaps", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + createAction := action.(k8testing.CreateAction) + if createAction.GetObject().(*v1.ConfigMap).Name == configmap.Name && createAction.GetNamespace() == configmap.Namespace { + return true, configmap, fmt.Errorf("failed to create configmap") + } + return false, nil, nil + }) + err := KubeClientSet.CreateK8sConfigMap(configmap) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + err := KubeClientSet.CreateK8sConfigMap(configmap) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestDeleteK8sConfigMap(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + namespace := "default" + configmapName := "test_configmap" + configmap := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: configmapName, + Namespace: namespace, + }, + Data: map[string]string{"key": "value"}, + } + + convey.Convey("delete configmap success", t, func() { + err := KubeClientSet.DeleteK8sConfigMap(namespace, configmapName) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("delete configmap error", t, func() { + client.PrependReactor("delete", "configmaps", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + deleteAction := action.(k8testing.DeleteAction) + if deleteAction.GetName() == configmap.Name && deleteAction.GetNamespace() == configmap.Namespace { + return true, configmap, fmt.Errorf("failed to delete service") + } + return false, nil, nil + }) + err := KubeClientSet.CreateK8sConfigMap(configmap) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + err := KubeClientSet.DeleteK8sConfigMap(namespace, configmapName) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestUpdateK8sConfigMap(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + namespace := "default" + configmapName := "test_configmap" + configmap := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: configmapName, + Namespace: namespace, + }, + Data: map[string]string{"key": "value"}, + } + _, err := client.CoreV1().ConfigMaps(namespace).Create(context.TODO(), configmap, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create configmap: %v", err) + } + convey.Convey("update configmap success", t, func() { + err := KubeClientSet.UpdateK8sConfigMap(configmap) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("update configmap error", t, func() { + client.PrependReactor("update", "configmaps", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + updateAction := action.(k8testing.UpdateAction) + if updateAction.GetObject().(*v1.ConfigMap).Name == configmap.Name && updateAction.GetNamespace() == configmap.Namespace { + return true, configmap, fmt.Errorf("failed to update configmap") + } + return false, nil, nil + }) + err := KubeClientSet.UpdateK8sConfigMap(configmap) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + err := KubeClientSet.UpdateK8sConfigMap(configmap) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestGetK8sConfigMap(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + namespace := "default" + configmapName := "test_configmap" + configmap := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: configmapName, + Namespace: namespace, + }, + Data: map[string]string{"key": "value"}, + } + _, err := client.CoreV1().ConfigMaps(namespace).Create(context.TODO(), configmap, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create configmap: %v", err) + } + convey.Convey("get configmap success", t, func() { + _, err := KubeClientSet.GetK8sConfigMap(namespace, namespace) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + err := KubeClientSet.UpdateK8sConfigMap(configmap) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestKubeClient_GetK8sSecret(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + namespace := "default" + secretName := "test_secret" + secret := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: namespace, + }, + Data: map[string][]byte{"key": []byte("value")}, + } + _, err := client.CoreV1().Secrets(namespace).Create(context.TODO(), secret, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create secret: %v", err) + } + convey.Convey("get secret success", t, func() { + _, err := KubeClientSet.GetK8sSecret(namespace, namespace) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + _, err := KubeClientSet.GetK8sSecret(namespace, secretName) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestKubeClient_CreateK8sSecret(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + secret := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test_secret", + Namespace: "default", + }, + Data: map[string][]byte{"key": []byte("value")}, + } + + convey.Convey("create secret success", t, func() { + _, err := KubeClientSet.CreateK8sSecret("default", secret) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("create secret error", t, func() { + client.PrependReactor("create", "secrets", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + createAction := action.(k8testing.CreateAction) + if createAction.GetObject().(*v1.Secret).Name == secret.Name && createAction.GetNamespace() == secret.Namespace { + return true, secret, fmt.Errorf("failed to create secret") + } + return false, nil, nil + }) + _, err := KubeClientSet.CreateK8sSecret("default", secret) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + _, err := KubeClientSet.CreateK8sSecret("default", secret) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestKubeClient_UpdateK8sSecret(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + namespace := "default" + secretName := "test_secret" + secret := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: namespace, + }, + Data: map[string][]byte{"key": []byte("value")}, + } + _, err := client.CoreV1().Secrets(namespace).Create(context.TODO(), secret, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create secret: %v", err) + } + convey.Convey("update secret success", t, func() { + _, err := KubeClientSet.UpdateK8sSecret(namespace, secret) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("update secret error", t, func() { + client.PrependReactor("update", "secrets", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + updateAction := action.(k8testing.UpdateAction) + if updateAction.GetObject().(*v1.Secret).Name == secret.Name && updateAction.GetNamespace() == secret.Namespace { + return true, secret, fmt.Errorf("failed to update secret") + } + return false, nil, nil + }) + _, err := KubeClientSet.UpdateK8sSecret(namespace, secret) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + _, err := KubeClientSet.UpdateK8sSecret(namespace, secret) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestKubeClient_DeleteK8sSecret(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + namespace := "default" + secretName := "test_secret" + secret := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: namespace, + }, + Data: map[string][]byte{"key": []byte("value")}, + } + + convey.Convey("delete secret success", t, func() { + err := KubeClientSet.DeleteK8sSecret(namespace, secretName) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("delete secret error", t, func() { + client.PrependReactor("delete", "secrets", func(action k8testing.Action) (handled bool, ret runtime.Object, err error) { + deleteAction := action.(k8testing.DeleteAction) + if deleteAction.GetName() == secret.Name && deleteAction.GetNamespace() == secret.Namespace { + return true, secret, fmt.Errorf("failed to delete service") + } + return false, nil, nil + }) + err := KubeClientSet.DeleteK8sSecret(namespace, secretName) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + err := KubeClientSet.DeleteK8sSecret(namespace, secretName) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestKubeClient_CreateOrUpdateConfigMap(t *testing.T) { + client := fake.NewSimpleClientset() + KubeClientSet = &KubeClient{client} + configmap := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test_configmap", + Namespace: "default", + }, + Data: map[string]string{"key": "value"}, + } + configmap2 := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test_configmap", + Namespace: "default", + }, + Data: map[string]string{"key": "value1"}, + } + _, err := client.CoreV1().ConfigMaps("default").Create(context.TODO(), configmap2, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create configmap: %v", err) + } + + convey.Convey("create configmap success", t, func() { + err := KubeClientSet.CreateOrUpdateConfigMap(configmap) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("KubeClientSet is nil", t, func() { + KubeClientSet = nil + err := KubeClientSet.CreateOrUpdateConfigMap(configmap) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func (m *mockK8sClient) InClusterConfig() (*rest.Config, error) { + m.expectedConfigCalled = true + return &rest.Config{}, m.createConfigError +} + +func (m *mockK8sClient) NewForConfig(_ *rest.Config) (dynamic.Interface, error) { + return nil, m.createClientError +} + +func TestNewDynamicClient_Success(t *testing.T) { + dynamicClient = nil + dynamicClientOnce = sync.Once{} + mock := &mockK8sClient{} + oldInClusterConfig := inClusterConfigFunc + inClusterConfigFunc = mock.InClusterConfig + defer func() { inClusterConfigFunc = oldInClusterConfig }() + client := NewDynamicClient() + if client == nil { + t.Fatal("Expected non-nil client, got nil") + } +} + +func TestNewDynamicClient_Singleton(t *testing.T) { + dynamicClient = nil + dynamicClientOnce = sync.Once{} + mock := &mockK8sClient{} + oldInClusterConfig := inClusterConfigFunc + inClusterConfigFunc = mock.InClusterConfig + defer func() { inClusterConfigFunc = oldInClusterConfig }() + client1 := NewDynamicClient() + client2 := NewDynamicClient() + if client1 != client2 { + t.Error("Expected singleton instance, got different clients") + } +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/connection/connection.go b/yuanrong/pkg/common/faas_common/kernelrpc/connection/connection.go new file mode 100644 index 0000000..2064c1c --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/connection/connection.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package connection - +package connection + +import ( + "time" + + "yuanrong/pkg/common/faas_common/grpc/pb" // production: package api + "yuanrong/pkg/common/faas_common/grpc/pb/runtime" +) + +// SendOption - +type SendOption struct { + Timeout time.Duration +} + +// SendCallback - +type SendCallback func(message *runtime.NotifyRequest) + +// Connection defines basic grpc connection +type Connection interface { + Send(message *api.StreamingMessage, option SendOption, callback SendCallback) (*api.StreamingMessage, error) + Recv() (*api.StreamingMessage, error) + Close() + CheckClose() chan struct{} +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/connection/stream_connection.go b/yuanrong/pkg/common/faas_common/kernelrpc/connection/stream_connection.go new file mode 100644 index 0000000..e305259 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/connection/stream_connection.go @@ -0,0 +1,450 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package connection - +package connection + +import ( + "errors" + "reflect" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/grpc/pb" // production: package api + "yuanrong/pkg/common/faas_common/grpc/pb/runtime" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/uuid" +) + +var ( + defaultChannelSize = 300000 +) + +var ( + // ErrStreamConnectionBroken is the error of stream connection broken + ErrStreamConnectionBroken = errors.New("stream connection is broken") + // ErrStreamConnectionClosed is the error of stream connection closed + ErrStreamConnectionClosed = errors.New("stream connection is closed") + // ErrRequestIDAlreadyExist is the error of requestID already exist + ErrRequestIDAlreadyExist = errors.New("requestID already exist") +) + +// StreamParams - +type StreamParams struct { + PeerAddr string + SendReqConcurrentNum int + SendRspConcurrentNum int + RecvConcurrentNum int +} + +// Stream is the common interface of stream for both server and client +type Stream interface { + Send(*api.StreamingMessage) error + Recv() (*api.StreamingMessage, error) +} + +// RepairStreamFunc repairs stream +type RepairStreamFunc func() Stream + +// HealthCheckFunc checks stream health +type HealthCheckFunc func() bool + +type sendAckPack struct { + rsp *api.StreamingMessage + err error +} + +type sendCbPack struct { + t time.Time + cb SendCallback +} + +// StreamConnection is an implementation of Connection with stream +type StreamConnection struct { + stream Stream + sendAckRecord map[string]chan sendAckPack + sendCbRecord map[string]sendCbPack + peerAddr string + closed bool + repairing bool + repairFunc RepairStreamFunc + healthFunc HealthCheckFunc + sendReqCh chan *api.StreamingMessage + sendRspCh chan *api.StreamingMessage + recvCh chan *api.StreamingMessage + repairCh chan struct{} + closeCh chan struct{} + *sync.RWMutex + *sync.Cond +} + +// CreateStreamConnection creates a StreamConnection +func CreateStreamConnection(stream Stream, params StreamParams, healthFunc HealthCheckFunc, + repairFunc RepairStreamFunc) Connection { + calibrateParams(¶ms) + mutex := new(sync.RWMutex) + sc := &StreamConnection{ + stream: stream, + sendAckRecord: make(map[string]chan sendAckPack, constant.DefaultMapSize), + sendCbRecord: make(map[string]sendCbPack, constant.DefaultMapSize), + peerAddr: params.PeerAddr, + healthFunc: healthFunc, + repairFunc: repairFunc, + sendReqCh: make(chan *api.StreamingMessage, defaultChannelSize), + sendRspCh: make(chan *api.StreamingMessage, defaultChannelSize), + recvCh: make(chan *api.StreamingMessage, defaultChannelSize), + repairCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + RWMutex: mutex, + Cond: sync.NewCond(mutex), + } + startLoopProcess(func() { sc.sendLoop(sc.sendReqCh) }, params.SendReqConcurrentNum) + startLoopProcess(func() { sc.sendLoop(sc.sendRspCh) }, params.SendRspConcurrentNum) + startLoopProcess(sc.recvLoop, params.RecvConcurrentNum) + if repairFunc != nil { + startLoopProcess(sc.repairLoop, 1) + } + return sc +} + +// Send sends stream message +func (sc *StreamConnection) Send(message *api.StreamingMessage, option SendOption, callback SendCallback) ( + *api.StreamingMessage, error) { + select { + case <-sc.closeCh: + return nil, ErrStreamConnectionClosed + default: + } + if sc.healthFunc != nil && !sc.healthFunc() { + return nil, ErrStreamConnectionBroken + } + if len(message.MessageID) == 0 { + message.MessageID = uuid.New().String() + } + sc.Lock() + ackCh := make(chan sendAckPack, 1) + sc.sendAckRecord[message.MessageID] = ackCh + // message with requestID is an async message which needs a callback + requestID := getRequestID(message.GetBody()) + if len(requestID) != 0 && callback != nil { + if _, exist := sc.sendCbRecord[requestID]; exist { + sc.Unlock() + return nil, ErrRequestIDAlreadyExist + } + sc.sendCbRecord[requestID] = sendCbPack{t: time.Now(), cb: callback} + } + sc.Unlock() + defer func() { + sc.Lock() + delete(sc.sendAckRecord, message.MessageID) + sc.Unlock() + }() + select { + case sc.sendReqCh <- message: + default: + log.GetLogger().Warnf("send channel reach limit %d for connection of %s", defaultChannelSize, sc.peerAddr) + sc.Lock() + delete(sc.sendCbRecord, requestID) + sc.Unlock() + return nil, errors.New("stream send is blocked") + } + timer := time.NewTimer(option.Timeout) + select { + case <-timer.C: + // send failed, no need to record callback + sc.Lock() + delete(sc.sendCbRecord, requestID) + sc.Unlock() + return nil, errors.New("send timeout") + case ackPack, ok := <-ackCh: + // consider to add retry here + if !ok { + return nil, errors.New("send response channel closed") + } + return ackPack.rsp, ackPack.err + } +} + +// Recv receives stream message +func (sc *StreamConnection) Recv() (*api.StreamingMessage, error) { + select { + case <-sc.closeCh: + return nil, ErrStreamConnectionClosed + case msg, ok := <-sc.recvCh: + if !ok { + return nil, errors.New("recv channel is closed") + } + return msg, nil + } +} + +// Close closes stream +func (sc *StreamConnection) Close() { + sc.Lock() + if sc.closed { + sc.Unlock() + return + } + sc.closed = true + sc.Unlock() + close(sc.closeCh) +} + +// CheckClose checks if stream is closed +func (sc *StreamConnection) CheckClose() chan struct{} { + return sc.closeCh +} + +func (sc *StreamConnection) sendLoop(sendCh chan *api.StreamingMessage) { + for { + select { + case <-sc.closeCh: + log.GetLogger().Debugf("stop send loop for connection of %s", sc.peerAddr) + return + case msg, ok := <-sendCh: + if !ok { + log.GetLogger().Warnf("close stream, send channel closed for connection of %s", sc.peerAddr) + return + } + if !sc.waitForStreamFix() { + log.GetLogger().Warnf("cannot fix stream, stop send loop for connection of %s", sc.peerAddr) + return + } + err := sc.stream.Send(msg) + sc.RLock() + ackCh, exist := sc.sendAckRecord[msg.GetMessageID()] + sc.RUnlock() + if err != nil { + if exist && ackCh != nil { + ackCh <- sendAckPack{ + rsp: nil, + err: err, + } + } else { + log.GetLogger().Warnf("response channel for sending message %s doesn't exist for connection %s", + msg.MessageID, sc.peerAddr) + } + sc.repairStream() + continue + } + if !expectResponse(msg) { + if exist && ackCh != nil { + ackCh <- sendAckPack{ + rsp: nil, + err: nil, + } + } else { + log.GetLogger().Warnf("response channel for sending message %s doesn't exist for connection %s", + msg.MessageID, sc.peerAddr) + } + } + } + } +} + +func (sc *StreamConnection) recvLoop() { + for { + select { + case <-sc.closeCh: + log.GetLogger().Debugf("close stream, stop recv loop for connection of %s", sc.peerAddr) + return + default: + if !sc.waitForStreamFix() { + log.GetLogger().Warnf("cannot fix stream, stop recv loop for connection of %s", sc.peerAddr) + return + } + msg, err := sc.stream.Recv() + if err != nil { + log.GetLogger().Errorf("receive error %s for connection of %s", err.Error(), sc.peerAddr) + sc.repairStream() + continue + } + switch msg.GetBody().(type) { + case *api.StreamingMessage_CreateRsp, *api.StreamingMessage_InvokeRsp, *api.StreamingMessage_ExitRsp, + *api.StreamingMessage_SaveRsp, *api.StreamingMessage_LoadRsp, *api.StreamingMessage_KillRsp, + *api.StreamingMessage_NotifyRsp: + sc.Lock() + askCh, exist := sc.sendAckRecord[msg.GetMessageID()] + if exist { + delete(sc.sendAckRecord, msg.GetMessageID()) + } else { + log.GetLogger().Warnf("receive unexpected response messageID %s for connection %s", + msg.GetMessageID(), sc.peerAddr) + } + sc.Unlock() + if exist { + askCh <- sendAckPack{ + rsp: msg, + err: nil, + } + continue + } + + case *api.StreamingMessage_CallReq, *api.StreamingMessage_CheckpointReq, *api.StreamingMessage_RecoverReq, + *api.StreamingMessage_ShutdownReq, *api.StreamingMessage_SignalReq, *api.StreamingMessage_InvokeReq: + // StreamingMessage_InvokeReq is used in simplified server mode + select { + case sc.recvCh <- msg: + default: + log.GetLogger().Warnf("receive channel reaches limit %d for connection %s", defaultChannelSize, + sc.peerAddr) + } + case *api.StreamingMessage_NotifyReq: + notifyReq := msg.GetNotifyReq() + requestID := notifyReq.GetRequestID() + sc.Lock() + cbPack, exist := sc.sendCbRecord[requestID] + if exist { + delete(sc.sendCbRecord, requestID) + } else { + log.GetLogger().Warnf("receive unexpected notify requestID %s for connection %s", requestID, + sc.peerAddr) + } + sc.Unlock() + if exist { + go cbPack.cb(notifyReq) + } + select { + case sc.sendRspCh <- &api.StreamingMessage{ + MessageID: msg.GetMessageID(), + Body: &api.StreamingMessage_NotifyRsp{ + NotifyRsp: &runtime.NotifyResponse{}, + }, + }: + default: + log.GetLogger().Warnf("sendRsp channel reaches limit %d for connection %s", defaultChannelSize, + sc.peerAddr) + } + case *api.StreamingMessage_HeartbeatReq: + select { + case sc.sendRspCh <- &api.StreamingMessage{ + MessageID: msg.GetMessageID(), + Body: &api.StreamingMessage_HeartbeatRsp{ + HeartbeatRsp: &runtime.HeartbeatResponse{}, + }, + }: + default: + log.GetLogger().Warnf("sendRsp channel reaches limit %d for connection %s", defaultChannelSize, + sc.peerAddr) + } + default: + log.GetLogger().Warnf("receive unknown type message %s", reflect.TypeOf(msg.GetBody()).String()) + } + + } + } +} + +func (sc *StreamConnection) repairLoop() { + for { + select { + case <-sc.closeCh: + log.GetLogger().Debugf("stop recv loop for connection of %s", sc.peerAddr) + return + case _, ok := <-sc.repairCh: + if !ok { + log.GetLogger().Warnf("repair channel closed for connection of %s", sc.peerAddr) + return + } + stream := sc.repairFunc() + if stream == nil { + log.GetLogger().Warnf("failed to fix stream during fix loop") + continue + } + sc.Lock() + sc.repairing = false + sc.stream = stream + sc.Unlock() + sc.Broadcast() + } + } +} + +func (sc *StreamConnection) waitForStreamFix() bool { + sc.L.Lock() + if sc.repairing || (sc.healthFunc != nil && !sc.healthFunc()) { + sc.Wait() + } + if sc.closed { + sc.L.Unlock() + return false + } + sc.L.Unlock() + return true +} + +func (sc *StreamConnection) repairStream() { + sc.Lock() + // close stream if there is no way to fix it + if sc.repairFunc == nil { + sc.Unlock() + sc.Close() + return + } + if sc.repairing { + sc.Unlock() + return + } + sc.repairing = true + sc.Unlock() + select { + case sc.repairCh <- struct{}{}: + default: + } +} + +func calibrateParams(params *StreamParams) { + if params.SendReqConcurrentNum < 1 { + params.SendReqConcurrentNum = 1 + } + if params.SendRspConcurrentNum < 1 { + params.SendRspConcurrentNum = 1 + } + if params.RecvConcurrentNum < 1 { + params.RecvConcurrentNum = 1 + } +} + +func startLoopProcess(loop func(), num int) { + for i := 0; i < num; i++ { + go loop() + } +} + +// only async request contains requestID (create and invoke) +func getRequestID(req interface{}) string { + switch req.(type) { + case *api.StreamingMessage_CreateReq: + return req.(*api.StreamingMessage_CreateReq).CreateReq.GetRequestID() + case *api.StreamingMessage_InvokeReq: + return req.(*api.StreamingMessage_InvokeReq).InvokeReq.GetRequestID() + default: + return "" + } +} + +// server may send some message which doesn't expect response +func expectResponse(msg *api.StreamingMessage) bool { + switch msg.GetBody().(type) { + case *api.StreamingMessage_CreateReq, *api.StreamingMessage_InvokeReq, *api.StreamingMessage_ExitReq, + *api.StreamingMessage_SaveReq, *api.StreamingMessage_LoadReq, *api.StreamingMessage_KillReq, + *api.StreamingMessage_NotifyReq: + return true + default: + return false + } +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/connection/stream_connection_test.go b/yuanrong/pkg/common/faas_common/kernelrpc/connection/stream_connection_test.go new file mode 100644 index 0000000..d254348 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/connection/stream_connection_test.go @@ -0,0 +1,410 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package connection - +package connection + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + + "github.com/smartystreets/goconvey/convey" + api "yuanrong/pkg/common/faas_common/grpc/pb" + "yuanrong/pkg/common/faas_common/grpc/pb/core" + "yuanrong/pkg/common/faas_common/grpc/pb/runtime" +) + +type fakeStream struct { + sendDelay time.Duration + sendErrCh chan error + recvErrCh chan error + sendCh chan *api.StreamingMessage + recvCh chan *api.StreamingMessage +} + +func createFakeStream(sendDelay time.Duration) *fakeStream { + return &fakeStream{ + sendDelay: sendDelay, + sendErrCh: make(chan error, 1), + recvErrCh: make(chan error, 1), + sendCh: make(chan *api.StreamingMessage, 1), + recvCh: make(chan *api.StreamingMessage, 1), + } +} + +func (f *fakeStream) Send(msg *api.StreamingMessage) error { + if f.sendDelay != 0 { + <-time.After(f.sendDelay) + } + select { + case err := <-f.sendErrCh: + return err + default: + f.sendCh <- msg + return nil + } +} + +func (f *fakeStream) Recv() (*api.StreamingMessage, error) { + select { + case err := <-f.recvErrCh: + return nil, err + case msg := <-f.recvCh: + return msg, nil + } +} + +func TestStreamSend(t *testing.T) { + convey.Convey("test steam send", t, func() { + healthReturn := true + healthFunc := func() bool { + return healthReturn + } + repairCount := 0 + repairFunc := func() Stream { + repairCount++ + return &fakeStream{} + } + var callbackRes *runtime.NotifyRequest + callbackFunc := func(message *runtime.NotifyRequest) { + callbackRes = message + } + convey.Convey("unhealthy stream", func() { + healthReturn = false + repairCount = 0 + stream := createFakeStream(0) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + _, err := sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 200 * time.Millisecond}, nil) + convey.So(err, convey.ShouldEqual, ErrStreamConnectionBroken) + }) + convey.Convey("requestID already exist", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(1 * time.Second) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + go sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 200 * time.Millisecond}, callbackFunc) + time.Sleep(100 * time.Millisecond) + _, err := sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 200 * time.Millisecond}, callbackFunc) + convey.So(err, convey.ShouldEqual, ErrRequestIDAlreadyExist) + }) + convey.Convey("stream send blocked", func() { + healthReturn = true + repairCount = 0 + patch := gomonkey.ApplyGlobalVar(&defaultChannelSize, 1) + stream := createFakeStream(1 * time.Second) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + sc.(*StreamConnection).sendReqCh <- &api.StreamingMessage{} + sc.(*StreamConnection).sendReqCh <- &api.StreamingMessage{} + _, err := sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-789", + }, + }, + }, SendOption{Timeout: 200 * time.Millisecond}, nil) + convey.So(err.Error(), convey.ShouldEqual, "stream send is blocked") + patch.Reset() + }) + convey.Convey("stream send error and repair", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + var newStream *fakeStream + repairFunc = func() Stream { + repairCount++ + newStream = createFakeStream(0) + return newStream + } + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + stream.sendErrCh <- io.EOF + _, err := sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 200 * time.Millisecond}, nil) + time.Sleep(100 * time.Millisecond) + convey.So(err, convey.ShouldEqual, io.EOF) + convey.So(repairCount, convey.ShouldEqual, 1) + stream.recvErrCh <- io.EOF + time.Sleep(100 * time.Millisecond) + go func() { + time.Sleep(100 * time.Millisecond) + msg := <-newStream.sendCh + newStream.recvCh <- &api.StreamingMessage{ + MessageID: msg.GetMessageID(), + Body: &api.StreamingMessage_InvokeRsp{ + InvokeRsp: &core.InvokeResponse{}, + }, + } + }() + _, err = sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 200 * time.Millisecond}, nil) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("stream send error and no repair", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, nil) + stream.sendErrCh <- io.EOF + _, err := sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 200 * time.Millisecond}, nil) + convey.So(err, convey.ShouldEqual, io.EOF) + convey.So(repairCount, convey.ShouldEqual, 0) + _, err = sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 200 * time.Millisecond}, nil) + convey.So(err, convey.ShouldEqual, ErrStreamConnectionClosed) + convey.So(repairCount, convey.ShouldEqual, 0) + }) + convey.Convey("stream send timeout", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + _, err := sc.Send(&api.StreamingMessage{ + MessageID: "msgID-123", + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 100 * time.Millisecond}, nil) + convey.So(err.Error(), convey.ShouldEqual, "send timeout") + convey.So(repairCount, convey.ShouldEqual, 0) + }) + convey.Convey("stream send expect no response", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + _, err := sc.Send(&api.StreamingMessage{ + MessageID: "msgID-123", + Body: &api.StreamingMessage_InvokeRsp{ + InvokeRsp: &core.InvokeResponse{}, + }, + }, SendOption{Timeout: 100 * time.Millisecond}, nil) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("stream send expect response", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + go func() { + time.Sleep(100 * time.Millisecond) + msg := <-stream.sendCh + stream.recvCh <- &api.StreamingMessage{ + MessageID: msg.GetMessageID(), + Body: &api.StreamingMessage_InvokeRsp{ + InvokeRsp: &core.InvokeResponse{}, + }, + } + stream.recvCh <- &api.StreamingMessage{ + Body: &api.StreamingMessage_NotifyReq{ + NotifyReq: &runtime.NotifyRequest{ + RequestID: "reqID-123", + }, + }, + } + }() + _, err := sc.Send(&api.StreamingMessage{ + Body: &api.StreamingMessage_InvokeReq{ + InvokeReq: &core.InvokeRequest{ + RequestID: "reqID-123", + }, + }, + }, SendOption{Timeout: 200 * time.Minute}, callbackFunc) + time.Sleep(100 * time.Millisecond) + convey.So(err, convey.ShouldBeNil) + convey.So(repairCount, convey.ShouldEqual, 0) + convey.So(callbackRes, convey.ShouldNotBeNil) + }) + }) +} + +func TestStreamRecv(t *testing.T) { + convey.Convey("test steam receive", t, func() { + healthReturn := true + healthFunc := func() bool { + return healthReturn + } + repairCount := 0 + repairFunc := func() Stream { + repairCount++ + return createFakeStream(0) + } + convey.Convey("stream receive error and no repair", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, nil) + stream.recvErrCh <- errors.New("some error") + time.Sleep(100 * time.Millisecond) + _, err := sc.Recv() + convey.So(err, convey.ShouldEqual, ErrStreamConnectionClosed) + }) + convey.Convey("stream receive error and repair", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + var newStream *fakeStream + repairFunc = func() Stream { + repairCount++ + newStream = createFakeStream(0) + return newStream + } + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + stream.recvErrCh <- io.EOF + time.Sleep(100 * time.Millisecond) + newStream.recvCh <- &api.StreamingMessage{ + Body: &api.StreamingMessage_CallReq{ + CallReq: &runtime.CallRequest{ + RequestID: "reqID-123", + }, + }, + } + time.Sleep(100 * time.Millisecond) + _, err := sc.Recv() + convey.So(err, convey.ShouldBeNil) + convey.So(repairCount, convey.ShouldEqual, 1) + }) + convey.Convey("receive call message", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + sc := CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + stream.recvCh <- &api.StreamingMessage{ + Body: &api.StreamingMessage_CallReq{ + CallReq: &runtime.CallRequest{ + RequestID: "reqID-123", + }, + }, + } + msg, err := sc.Recv() + convey.So(err, convey.ShouldBeNil) + callReq := msg.GetCallReq() + convey.So(callReq, convey.ShouldNotBeNil) + convey.So(callReq.GetRequestID(), convey.ShouldEqual, "reqID-123") + }) + convey.Convey("receive heartbeat message", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + stream.recvCh <- &api.StreamingMessage{ + Body: &api.StreamingMessage_HeartbeatReq{}, + } + msg := <-stream.sendCh + heartbeatRsp := msg.GetHeartbeatRsp() + convey.So(heartbeatRsp, convey.ShouldNotBeNil) + }) + convey.Convey("receive unexpected id", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + stream.recvCh <- &api.StreamingMessage{ + MessageID: "msgID-123", + Body: &api.StreamingMessage_NotifyRsp{ + NotifyRsp: &runtime.NotifyResponse{}, + }, + } + stream.recvCh <- &api.StreamingMessage{ + MessageID: "msgID-123", + Body: &api.StreamingMessage_NotifyReq{ + NotifyReq: &runtime.NotifyRequest{ + RequestID: "reqID-123", + }, + }, + } + time.Sleep(100 * time.Millisecond) + convey.So(len(stream.sendCh), convey.ShouldEqual, 1) + }) + convey.Convey("receive unsupported message", func() { + healthReturn = true + repairCount = 0 + stream := createFakeStream(0) + CreateStreamConnection(stream, StreamParams{}, healthFunc, repairFunc) + stream.recvCh <- &api.StreamingMessage{ + Body: &api.StreamingMessage_SignalRsp{}, + } + convey.So(len(stream.sendCh), convey.ShouldEqual, 0) + }) + }) +} + +func TestStreamClose(t *testing.T) { + convey.Convey("test steam close", t, func() { + repairFunc := func() Stream { + return createFakeStream(0) + } + stream := createFakeStream(0) + sc := CreateStreamConnection(stream, StreamParams{}, nil, repairFunc) + closeCh := sc.CheckClose() + sc.Close() + sc.Close() + convey.So(sc.(*StreamConnection).closed, convey.ShouldEqual, true) + select { + case <-closeCh: + default: + t.Errorf("closeCh is not closed") + } + _, err := sc.Send(nil, SendOption{}, nil) + convey.So(err, convey.ShouldEqual, ErrStreamConnectionClosed) + stream.recvErrCh <- io.EOF + _, err = sc.Recv() + convey.So(err, convey.ShouldEqual, ErrStreamConnectionClosed) + }) +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/basic_stream_client.go b/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/basic_stream_client.go new file mode 100644 index 0000000..0cfc894 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/basic_stream_client.go @@ -0,0 +1,298 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rpcclient - +package rpcclient + +import ( + "context" + "errors" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/metadata" + + rtapi "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/grpc/pb" // production: package api + "yuanrong/pkg/common/faas_common/grpc/pb/common" + "yuanrong/pkg/common/faas_common/grpc/pb/core" + "yuanrong/pkg/common/faas_common/grpc/pb/runtime" + "yuanrong/pkg/common/faas_common/kernelrpc/connection" + "yuanrong/pkg/common/faas_common/kernelrpc/utils" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" +) + +const ( + maxMsgSize = 1024 * 1024 * 10 + maxWindowSize = 1024 * 1024 * 10 + maxBufferSize = 1024 * 1024 * 10 + dialBaseDelay = 300 * time.Millisecond + dialMultiplier = 1.2 + dialJitter = 0.1 + runtimeDialMaxDelay = 100 * time.Second +) + +var ( + // ErrUnsupportedMethod - + ErrUnsupportedMethod = snerror.New(statuscode.InternalErrorCode, "unsupported method") + dialTimeout = 5 * time.Second + dialRetryTime = 10 + dialRetryInterval = 3 * time.Second + setupStreamRetryInterval = 3 * time.Second + streamMessagePool = sync.Pool{} + invokeRequestPool = sync.Pool{} +) + +// StreamClientParams - +type StreamClientParams struct { + SendReqConcurrentNum int + SendRspConcurrentNum int + RecvConcurrentNum int +} + +// BasicSteamClient is basic implementation of KernelClient which only sends POSIX calls as a runtime +type BasicSteamClient struct { + clientConn *grpc.ClientConn + streamConn connection.Connection + peerAddr string +} + +// CreateBasicStreamClient creates BasicSteamClient +func CreateBasicStreamClient(peerAddr string, params StreamClientParams) (KernelClient, error) { + conn, err := dialConnection(peerAddr) + if err != nil { + log.GetLogger().Errorf("failed to dial connection to %s error %s", peerAddr, err.Error()) + return nil, err + } + stream, err := createStream(conn, nil) + if err != nil { + log.GetLogger().Errorf("failed to create stream to %s error %s", peerAddr, err.Error()) + return nil, err + } + client := &BasicSteamClient{ + peerAddr: peerAddr, + clientConn: conn, + } + streamConn := connection.CreateStreamConnection(stream, + connection.StreamParams{ + PeerAddr: peerAddr, + SendReqConcurrentNum: params.SendReqConcurrentNum, + SendRspConcurrentNum: params.SendRspConcurrentNum, + RecvConcurrentNum: params.RecvConcurrentNum, + }, + client.checkClientConnHealth, client.repairStream) + client.streamConn = streamConn + return client, nil +} + +// Create - +func (k *BasicSteamClient) Create(funcKey string, args []*rtapi.Arg, createParams CreateParams, + callback KernelClientCallback) (string, snerror.SNError) { + return "", ErrUnsupportedMethod +} + +// Invoke - +func (k *BasicSteamClient) Invoke(funcKey string, instanceID string, args []*rtapi.Arg, invokeParams InvokeParams, + callback KernelClientCallback) (string, snerror.SNError) { + CalibrateTransportParams(&invokeParams.TransportParams) + message := acquireStreamMessageInvokeRequest() + defer releaseStreamMessageInvokeRequest(message) + invokeReq := message.GetInvokeReq() + invokeReq.Function = funcKey + invokeReq.Args = pb2Arg(args) + invokeReq.InstanceID = instanceID + if len(invokeParams.RequestID) != 0 { + invokeReq.RequestID = invokeParams.RequestID + } else { + invokeReq.RequestID = utils.GenTaskID() + } + if len(invokeParams.TraceID) != 0 { + invokeReq.TraceID = invokeParams.TraceID + } else { + invokeReq.TraceID = utils.GenTaskID() + } + sendOption := connection.SendOption{ + Timeout: invokeParams.Timeout, + } + sendCallback := func(notifyReq *runtime.NotifyRequest) { + var ( + notifyMsg []byte + notifyErr snerror.SNError + ) + if notifyReq.Code != common.ErrorCode_ERR_NONE { + notifyErr = snerror.New(int(notifyReq.Code), notifyReq.Message) + } else { + notifyMsg = []byte(notifyReq.Message) + } + callback(notifyMsg, notifyErr) + } + msg, err := k.streamConn.Send(message, sendOption, sendCallback) + if err != nil { + return "", snerror.New(statuscode.InternalErrorCode, err.Error()) + } + sendRsp, ok := msg.GetBody().(*api.StreamingMessage_InvokeRsp) + if !ok { + return "", snerror.New(statuscode.InternalErrorCode, "invoke response type error") + } + if sendRsp.InvokeRsp.Code != common.ErrorCode_ERR_NONE { + return "", snerror.New(int(sendRsp.InvokeRsp.Code), sendRsp.InvokeRsp.Message) + } + return sendRsp.InvokeRsp.Message, nil +} + +// SaveState - +func (k *BasicSteamClient) SaveState(state []byte) (string, snerror.SNError) { + return "", ErrUnsupportedMethod +} + +// LoadState - +func (k *BasicSteamClient) LoadState(checkpointID string) ([]byte, snerror.SNError) { + return nil, ErrUnsupportedMethod +} + +// Kill - +func (k *BasicSteamClient) Kill(instanceID string, signal int32, payload []byte) snerror.SNError { + return ErrUnsupportedMethod +} + +// Exit - +func (k *BasicSteamClient) Exit() { +} + +func (k *BasicSteamClient) checkClientConnHealth() bool { + return checkClientConnHealth(k.clientConn) +} + +func (k *BasicSteamClient) repairStream() connection.Stream { + if !k.checkClientConnHealth() { + conn, err := dialConnection(k.peerAddr) + if err != nil { + log.GetLogger().Errorf("failed to repair stream, dial connection to %s error %s", k.peerAddr, err.Error()) + return nil + } + k.clientConn = conn + } + stream, err := createStream(k.clientConn, nil) + if err != nil { + log.GetLogger().Errorf("failed to repair stream, create stream to %s error %s", k.peerAddr, err.Error()) + return nil + } + return stream +} + +func dialConnection(addr string) (*grpc.ClientConn, error) { + ctx, cancel := context.WithTimeout(context.TODO(), dialTimeout) + defer cancel() + dialFunc := func() (*grpc.ClientConn, error) { + return grpc.DialContext(ctx, addr, grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithInitialWindowSize(maxWindowSize), + grpc.WithInitialConnWindowSize(maxWindowSize), + grpc.WithWriteBufferSize(maxBufferSize), + grpc.WithReadBufferSize(maxBufferSize), + grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize), grpc.MaxCallRecvMsgSize(maxMsgSize)), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.Config{BaseDelay: dialBaseDelay, Multiplier: dialMultiplier, Jitter: dialJitter, + MaxDelay: runtimeDialMaxDelay}, MinConnectTimeout: dialBaseDelay, + })) + } + var ( + conn *grpc.ClientConn + err error + ) + for i := 0; i < dialRetryTime; i++ { + conn, err = dialFunc() + if err == nil { + return conn, err + } + log.GetLogger().Warnf("failed to dial connection to %s error %s", addr, err.Error()) + time.Sleep(time.Duration(i+1) * dialRetryInterval) + } + log.GetLogger().Errorf("failed to dial connection to %s after %d retries error %s", addr, dialRetryTime, + err.Error()) + return nil, err +} + +func createStream(conn *grpc.ClientConn, mdMap map[string]string) (api.RuntimeRPC_MessageStreamClient, error) { + if !checkClientConnHealth(conn) { + log.GetLogger().Errorf("grpc connection is nil, failed to create stream rpcclient") + return nil, errors.New("conn is unhealthy") + } + client := api.NewRuntimeRPCClient(conn) + md := metadata.New(mdMap) + var ( + stream api.RuntimeRPC_MessageStreamClient + err error + ) + var retryTimes int + for i := 0; i < dialRetryTime; i++ { + stream, err = client.MessageStream(metadata.NewOutgoingContext(context.Background(), md)) + if err == nil { + log.GetLogger().Infof("succeed to get stream from function proxy") + break + } + log.GetLogger().Errorf("failed to get stream from function proxy for %d times, err: %s", + retryTimes, err.Error()) + time.Sleep(setupStreamRetryInterval) + } + if err != nil { + log.GetLogger().Errorf("failed to create stream rpcclient to %s when setup message stream error %s", conn.Target(), + err.Error()) + return nil, err + } + return stream, nil +} + +func checkClientConnHealth(conn *grpc.ClientConn) bool { + if conn == nil { + return false + } + return conn.GetState() == connectivity.Idle || conn.GetState() == connectivity.Ready +} + +func acquireStreamMessageInvokeRequest() *api.StreamingMessage { + var ( + streamMsg *api.StreamingMessage + invokeReq *api.StreamingMessage_InvokeReq + ok bool + ) + streamMsg, ok = streamMessagePool.Get().(*api.StreamingMessage) + if !ok { + streamMsg = &api.StreamingMessage{} + } + invokeReq, ok = invokeRequestPool.Get().(*api.StreamingMessage_InvokeReq) + if !ok { + invokeReq = &api.StreamingMessage_InvokeReq{InvokeReq: &core.InvokeRequest{}} + } + streamMsg.Body = invokeReq + return streamMsg +} + +func releaseStreamMessageInvokeRequest(streamMsg *api.StreamingMessage) { + invokeReq, ok := streamMsg.GetBody().(*api.StreamingMessage_InvokeReq) + if !ok { + return + } + invokeReq.InvokeReq.Reset() + invokeRequestPool.Put(invokeReq) + streamMsg.Reset() + streamMessagePool.Put(streamMsg) +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/basic_stream_client_test.go b/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/basic_stream_client_test.go new file mode 100644 index 0000000..95c2297 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/basic_stream_client_test.go @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rpcclient + +import ( + "testing" + + "github.com/stretchr/testify/assert" + rtapi "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/snerror" +) + +func TestBasicSteamClient_Create(t *testing.T) { + client := &BasicSteamClient{} + + funcKey := "testFuncKey" + args := []*rtapi.Arg{} + createParams := CreateParams{} + callback := func(result []byte, err snerror.SNError) { + } + + result, err := client.Create(funcKey, args, createParams, callback) + + assert.Equal(t, "", result) + assert.Equal(t, ErrUnsupportedMethod, err) +} + +func TestBasicSteamClient_SaveState(t *testing.T) { + client := &BasicSteamClient{} + + state := []byte("test state") + + result, err := client.SaveState(state) + + assert.Equal(t, "", result, "Expected empty string as result") + assert.Equal(t, ErrUnsupportedMethod, err, "Expected ErrUnsupportedMethod error") +} + +func TestBasicSteamClient_LoadState(t *testing.T) { + client := &BasicSteamClient{} + + checkpointID := "testCheckpointID" + + result, err := client.LoadState(checkpointID) + + assert.Nil(t, result, "Expected nil as result") + assert.Equal(t, ErrUnsupportedMethod, err, "Expected ErrUnsupportedMethod error") +} + +func TestBasicSteamClient_Kill(t *testing.T) { + client := &BasicSteamClient{} + + instanceID := "testInstanceID" + signal := int32(9) + payload := []byte("test payload") + + err := client.Kill(instanceID, signal, payload) + + assert.Equal(t, ErrUnsupportedMethod, err, "Expected ErrUnsupportedMethod error") +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/client.go b/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/client.go new file mode 100644 index 0000000..32d0963 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/client.go @@ -0,0 +1,167 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rpcclient - +package rpcclient + +import ( + "time" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/grpc/pb/common" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" +) + +const ( + defaultTimeout = 900 * time.Second +) + +var ( + // ErrKernelClientTimeout - + ErrKernelClientTimeout = snerror.New(statuscode.InternalErrorCode, "kernel rpcclient timeout") +) + +// TransportParams - +type TransportParams struct { + Timeout time.Duration + RetryInterval time.Duration + RetryNumber int +} + +// AffinityType - +type AffinityType int32 + +// SchedulingOptions - +type SchedulingOptions struct { + Priority int32 + Resources map[string]float64 + Extension map[string]string + Affinity map[string]AffinityType + ScheduleAffinity []byte +} + +// CreateParams - +type CreateParams struct { + TransportParams + DesignatedInstanceID string + Label []string + CreateOption map[string]string + ScheduleOption SchedulingOptions +} + +// InvokeParams - +type InvokeParams struct { + TransportParams + InvokeOptions map[string]string + RequestID string + TraceID string +} + +// KernelClientCallback - +type KernelClientCallback = func(result []byte, err snerror.SNError) + +// KernelClientAsyncCreate - +type KernelClientAsyncCreate = func(function string, args []string, createParams CreateParams, + callback KernelClientCallback) (string, snerror.SNError) + +// KernelClientAsyncInvoke _ +type KernelClientAsyncInvoke = func(function string, instanceID string, args []string, invokeParams InvokeParams, + callback KernelClientCallback) snerror.SNError + +// KernelClient defines basic POSIX client methods, it's worth noting that Create and +// Invoke are original async calls while others are sync calls +type KernelClient interface { + Create(funcKey string, args []*api.Arg, createParams CreateParams, callback KernelClientCallback) (string, + snerror.SNError) + + Invoke(funcKey string, instanceID string, args []*api.Arg, invokeParams InvokeParams, + callback KernelClientCallback) (string, snerror.SNError) + + SaveState(state []byte) (string, snerror.SNError) + + LoadState(checkpointID string) ([]byte, snerror.SNError) + + Kill(instanceID string, signal int32, payload []byte) snerror.SNError + + Exit() +} + +// SyncInvoke will call invoke synchronously +func SyncInvoke(asyncInvoke KernelClientAsyncInvoke, funcKey string, instanceID string, args []string, + invokeParams InvokeParams) ([]byte, snerror.SNError) { + CalibrateTransportParams(&invokeParams.TransportParams) + var ( + resultData []byte + resultError snerror.SNError + ) + waitCh := make(chan struct{}, 1) + callback := func(result []byte, err snerror.SNError) { + resultData, resultError = result, err + waitCh <- struct{}{} + } + invokeErr := asyncInvoke(funcKey, instanceID, args, invokeParams, callback) + if invokeErr != nil { + return nil, invokeErr + } + timer := time.NewTimer(invokeParams.Timeout) + defer timer.Stop() + retryCount := 0 + for { + select { + case <-timer.C: + log.GetLogger().Errorf("sync invoke times out after %ds for function %s traceID %s", + invokeParams.Timeout.Seconds(), funcKey, invokeParams.TraceID) + return nil, ErrKernelClientTimeout + case <-waitCh: + if resultError == nil { + return resultData, nil + } + retryCount++ + if retryCount <= invokeParams.RetryNumber { + time.Sleep(invokeParams.RetryInterval) + log.GetLogger().Errorf("sync invoke reties count %d after %ds for function %s traceID %s", + retryCount, invokeParams.RetryInterval.Seconds(), funcKey, invokeParams.TraceID) + invokeErr = asyncInvoke(funcKey, instanceID, args, invokeParams, callback) + if invokeErr != nil { + return nil, invokeErr + } + continue + } + log.GetLogger().Errorf("sync invoke reties reach limit %d for function %s traceID %s", + invokeParams.RetryNumber, funcKey, invokeParams.TraceID) + return resultData, resultError + } + } +} + +func pb2Arg(args []*api.Arg) []*common.Arg { + length := len(args) + newArgs := make([]*common.Arg, 0, length) + for _, arg := range args { + newArgs = append(newArgs, &common.Arg{Type: common.Arg_ArgType(arg.Type), Value: arg.Data}) + } + return newArgs +} + +// CalibrateTransportParams calibrates transport params +func CalibrateTransportParams(params *TransportParams) { + if params.Timeout == 0 { + params.Timeout = defaultTimeout + } +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/client_test.go b/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/client_test.go new file mode 100644 index 0000000..80bd5cb --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/rpcclient/client_test.go @@ -0,0 +1,110 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rpcclient + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/snerror" +) + +func TestSyncInvoke(t *testing.T) { + funcKey := "testFuncKey" + instanceID := "testInstanceID" + args := []string{"arg1", "arg2"} + invokeParams := InvokeParams{ + TransportParams: TransportParams{ + Timeout: 500 * time.Millisecond, + RetryInterval: 100 * time.Millisecond, + RetryNumber: 2, + }, + InvokeOptions: map[string]string{"option1": "value1"}, + RequestID: "testRequestID", + TraceID: "testTraceID", + } + + t.Run("Success without retries", func(t *testing.T) { + asyncInvoke := func(funcKey, instanceID string, args []string, invokeParams InvokeParams, + callback func(result []byte, err snerror.SNError)) snerror.SNError { + go func() { + time.Sleep(50 * time.Millisecond) // 模拟异步调用的延迟 + callback([]byte("success"), nil) + }() + return nil + } + + resultData, resultError := SyncInvoke(asyncInvoke, funcKey, instanceID, args, invokeParams) + + assert.Nil(t, resultError) + assert.Equal(t, []byte("success"), resultData) + }) + + t.Run("Error then success after retries", func(t *testing.T) { + var callCount int + asyncInvoke := func(funcKey, instanceID string, args []string, invokeParams InvokeParams, + callback func(result []byte, err snerror.SNError)) snerror.SNError { + go func() { + time.Sleep(50 * time.Millisecond) + if callCount < 1 { + callCount++ + callback(nil, snerror.New(1, "temporary error")) + } else { + callback([]byte("recovered success"), nil) + } + }() + return nil + } + + resultData, resultError := SyncInvoke(asyncInvoke, funcKey, instanceID, args, invokeParams) + + assert.Nil(t, resultError) + assert.Equal(t, []byte("recovered success"), resultData) + }) + + t.Run("Error with retries exhausted", func(t *testing.T) { + asyncInvoke := func(funcKey, instanceID string, args []string, invokeParams InvokeParams, + callback func(result []byte, err snerror.SNError)) snerror.SNError { + go func() { + time.Sleep(50 * time.Millisecond) + callback(nil, snerror.New(1, "persistent error")) + }() + return nil + } + + resultData, resultError := SyncInvoke(asyncInvoke, funcKey, instanceID, args, invokeParams) + + assert.NotNil(t, resultError) + assert.Equal(t, "persistent error", resultError.Error()) + assert.Nil(t, resultData) + }) + + t.Run("Timeout without response", func(t *testing.T) { + asyncInvoke := func(funcKey, instanceID string, args []string, invokeParams InvokeParams, + callback func(result []byte, err snerror.SNError)) snerror.SNError { + return nil + } + + resultData, resultError := SyncInvoke(asyncInvoke, funcKey, instanceID, args, invokeParams) + + assert.NotNil(t, resultError) + assert.Equal(t, ErrKernelClientTimeout, resultError) + assert.Nil(t, resultData) + }) +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/server.go b/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/server.go new file mode 100644 index 0000000..44e3090 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/server.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rpcserver - +package rpcserver + +import ( + "time" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/grpc/pb/common" +) + +const ( + defaultSendTimeout = 5 * time.Second +) + +// KernelInvokeHandler - +type KernelInvokeHandler func(args []*api.Arg, traceID string) (string, error) + +// KernelServer defines basic POSIX server methods, currently only RegisterInvokeHandler is needed +type KernelServer interface { + RegisterInvokeHandler(handler KernelInvokeHandler) + Serve() error + Stop() +} + +func pb2Arg(args []*common.Arg) []*api.Arg { + length := len(args) + newArgs := make([]*api.Arg, 0, length) + for _, value := range args { + newArgs = append(newArgs, &api.Arg{Type: api.ArgType(value.Type), Data: value.Value}) + } + return newArgs +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/simplified_stream_server.go b/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/simplified_stream_server.go new file mode 100644 index 0000000..835724d --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/simplified_stream_server.go @@ -0,0 +1,217 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rpcserver - +package rpcserver + +import ( + "net" + "reflect" + "sync" + + "github.com/panjf2000/ants/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/peer" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/grpc/pb" + "yuanrong/pkg/common/faas_common/grpc/pb/common" + "yuanrong/pkg/common/faas_common/grpc/pb/core" + "yuanrong/pkg/common/faas_common/kernelrpc/connection" + "yuanrong/pkg/common/faas_common/logger/log" +) + +var ( + streamMessagePool = sync.Pool{} + invokeResponsePool = sync.Pool{} +) + +type requestPack struct { + msg *api.StreamingMessage + conn connection.Connection +} + +// SimplifiedStreamServer is a simplified stream server which can respond to invokeRequest +type SimplifiedStreamServer struct { + api.UnimplementedRuntimeRPCServer + grpcServer *grpc.Server + taskPool *ants.PoolWithFunc + streamConnMap map[string]connection.Connection + invokeHandler KernelInvokeHandler + listenAddr string + stopped bool + stopCh chan struct{} + sync.Mutex +} + +// CreateSimplifiedStreamServer creates SimplifiedStreamServer +func CreateSimplifiedStreamServer(listenAddr string, concurrentNum int) (KernelServer, error) { + server := &SimplifiedStreamServer{ + grpcServer: grpc.NewServer(), + listenAddr: listenAddr, + streamConnMap: make(map[string]connection.Connection, constant.DefaultMapSize), + stopCh: make(chan struct{}), + } + taskPool, err := ants.NewPoolWithFunc(concurrentNum, func(arg interface{}) { + reqPack, ok := arg.(requestPack) + if !ok { + return + } + server.handleRequest(reqPack.msg, reqPack.conn) + }) + if err != nil { + log.GetLogger().Errorf("failed to create task pool error %s", err.Error()) + return nil, err + } + server.taskPool = taskPool + api.RegisterRuntimeRPCServer(server.grpcServer, server) + return server, nil +} + +// MessageStream handles stream from grpc server +func (s *SimplifiedStreamServer) MessageStream(stream api.RuntimeRPC_MessageStreamServer) error { + peerObj, _ := peer.FromContext(stream.Context()) + peerAddr := peerObj.Addr.String() + streamConn := connection.CreateStreamConnection(stream, connection.StreamParams{PeerAddr: peerAddr}, nil, nil) + closeCh := streamConn.CheckClose() + s.Lock() + s.streamConnMap[peerAddr] = streamConn + s.Unlock() + log.GetLogger().Infof("create streamConn success,peer:%s", peerAddr) + defer func() { + s.Lock() + delete(s.streamConnMap, peerAddr) + s.Unlock() + }() + for { + select { + case <-s.stopCh: + log.GetLogger().Warnf("server stops, closing stream connection to %s", peerAddr) + streamConn.Close() + return nil + case <-closeCh: + log.GetLogger().Warnf("stream connection to %s is closed", peerAddr) + return nil + default: + msg, err := streamConn.Recv() + if err != nil { + log.GetLogger().Errorf("failed to receive stream message from %s error %s", peerAddr, err.Error()) + continue + } + if err = s.taskPool.Invoke(requestPack{msg: msg, conn: streamConn}); err != nil { + log.GetLogger().Errorf("failed to invoke task pool error %s", err.Error()) + } + } + } +} + +// Serve starts serving on listenAddr +func (s *SimplifiedStreamServer) Serve() error { + lis, err := net.Listen("tcp", s.listenAddr) + if err != nil { + log.GetLogger().Errorf("failed to listen to address %s error %s\n", s.listenAddr, err.Error()) + return err + } + if err = s.grpcServer.Serve(lis); err != nil { + log.GetLogger().Errorf("failed to serve on address %s error %s", s.listenAddr, err.Error()) + return err + } + log.GetLogger().Infof("stop serve on address %s", s.listenAddr) + return nil +} + +// Stop stops server +func (s *SimplifiedStreamServer) Stop() { + s.Lock() + if s.stopped { + s.Unlock() + return + } + s.stopped = true + s.Unlock() + s.grpcServer.GracefulStop() + s.Lock() + for _, stream := range s.streamConnMap { + stream.Close() + } + s.Unlock() +} + +func (s *SimplifiedStreamServer) handleRequest(msg *api.StreamingMessage, conn connection.Connection) { + switch msg.GetBody().(type) { + case *api.StreamingMessage_InvokeReq: + invokeReq := msg.GetInvokeReq() + message := acquireStreamMessageInvokeResponse() + message.MessageID = msg.GetMessageID() + InvokeRsp := message.GetInvokeRsp() + defer func() { + if _, err := conn.Send(message, connection.SendOption{Timeout: defaultSendTimeout}, nil); err != nil { + log.GetLogger().Errorf("failed to send invoke response error %s", err.Error()) + } + releaseStreamMessageInvokeResponse(message) + }() + if s.invokeHandler == nil { + log.GetLogger().Errorf("invoke handler is nil") + InvokeRsp.Code = common.ErrorCode_ERR_USER_FUNCTION_EXCEPTION + InvokeRsp.Message = "invoke handler is nil" + return + } + rsp, err := s.invokeHandler(pb2Arg(invokeReq.GetArgs()), invokeReq.TraceID) + if err != nil { + InvokeRsp.Code = common.ErrorCode_ERR_USER_FUNCTION_EXCEPTION + InvokeRsp.Message = err.Error() + } else { + InvokeRsp.Code = common.ErrorCode_ERR_NONE + InvokeRsp.Message = rsp + } + default: + log.GetLogger().Warnf("receive unknown type message %s", reflect.TypeOf(msg.GetBody()).String()) + } +} + +// RegisterInvokeHandler registers invokeHandler +func (s *SimplifiedStreamServer) RegisterInvokeHandler(handler KernelInvokeHandler) { + s.invokeHandler = handler +} + +func acquireStreamMessageInvokeResponse() *api.StreamingMessage { + var ( + streamMsg *api.StreamingMessage + invokeRsp *api.StreamingMessage_InvokeRsp + ok bool + ) + streamMsg, ok = streamMessagePool.Get().(*api.StreamingMessage) + if !ok { + streamMsg = &api.StreamingMessage{} + } + invokeRsp, ok = invokeResponsePool.Get().(*api.StreamingMessage_InvokeRsp) + if !ok { + invokeRsp = &api.StreamingMessage_InvokeRsp{InvokeRsp: &core.InvokeResponse{}} + } + streamMsg.Body = invokeRsp + return streamMsg +} + +func releaseStreamMessageInvokeResponse(streamMsg *api.StreamingMessage) { + invokeRsp, ok := streamMsg.GetBody().(*api.StreamingMessage_InvokeRsp) + if !ok { + return + } + invokeRsp.InvokeRsp.Reset() + invokeResponsePool.Put(invokeRsp) + streamMsg.Reset() + streamMessagePool.Put(streamMsg) +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/simplified_stream_server_test.go b/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/simplified_stream_server_test.go new file mode 100644 index 0000000..0f84070 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/rpcserver/simplified_stream_server_test.go @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rpcserver - +package rpcserver + +import ( + "errors" + "net" + "testing" + "time" + + gomonkey "github.com/agiledragon/gomonkey/v2" + ants "github.com/panjf2000/ants/v2" + "github.com/smartystreets/goconvey/convey" + "google.golang.org/grpc" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/grpc/pb/common" + "yuanrong/pkg/common/faas_common/kernelrpc/rpcclient" +) + +func TestCreateSimplifiedStreamServer(t *testing.T) { + convey.Convey("test CreateSimplifiedStreamServer", t, func() { + patch := gomonkey.ApplyFunc(ants.NewPoolWithFunc, func(size int, pf func(interface{}), options ...ants.Option) ( + *ants.PoolWithFunc, error) { + return nil, errors.New("some error") + }) + server, err := CreateSimplifiedStreamServer("0.0.0.0:0", 100) + convey.So(err, convey.ShouldNotBeNil) + convey.So(server, convey.ShouldBeNil) + patch.Reset() + server, err = CreateSimplifiedStreamServer("0.0.0.0:0", 100) + convey.So(err, convey.ShouldBeNil) + convey.So(server, convey.ShouldNotBeNil) + }) +} + +func TestSimplifiedStreamServerServeAndClose(t *testing.T) { + convey.Convey("test SimplifiedStreamServer serve", t, func() { + patch := gomonkey.ApplyFunc(net.Listen, func(network, address string) (net.Listener, error) { + return nil, errors.New("some error") + }) + server, _ := CreateSimplifiedStreamServer("0.0.0.0:0", 100) + err := server.Serve() + convey.So(err, convey.ShouldNotBeNil) + patch.Reset() + patch = gomonkey.ApplyFunc((*grpc.Server).Serve, func(_ *grpc.Server, lis net.Listener) error { + return errors.New("some error") + }) + server, _ = CreateSimplifiedStreamServer("0.0.0.0:0", 100) + err = server.Serve() + convey.So(err, convey.ShouldNotBeNil) + patch.Reset() + server, _ = CreateSimplifiedStreamServer("0.0.0.0:0", 100) + go func() { + time.Sleep(100 * time.Millisecond) + server.Stop() + server.Stop() + }() + err = server.Serve() + convey.So(err, convey.ShouldBeNil) + server.Stop() + }) +} + +func TestSimplifiedStreamServerHandleInvoke(t *testing.T) { + convey.Convey("test SimplifiedStreamServer handleInvoke", t, func() { + server, _ := CreateSimplifiedStreamServer("0.0.0.0:5678", 100) + go server.Serve() + client, _ := rpcclient.CreateBasicStreamClient("0.0.0.0:5678", rpcclient.StreamClientParams{}) + args := []*api.Arg{{ + Type: api.Value, + Data: []byte("123"), + }} + convey.Convey("invokeHandler is nil", func() { + _, err := client.Invoke("testFunc", "testIns", args, rpcclient.InvokeParams{}, nil) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Code(), convey.ShouldEqual, common.ErrorCode_ERR_USER_FUNCTION_EXCEPTION) + }) + convey.Convey("invokeHandler return error", func() { + invokeHandler := func(args []*api.Arg, traceID string) (string, error) { + return "", errors.New("some error") + } + server.RegisterInvokeHandler(invokeHandler) + msg, err := client.Invoke("testFunc", "testIns", args, rpcclient.InvokeParams{}, nil) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Code(), convey.ShouldEqual, common.ErrorCode_ERR_USER_FUNCTION_EXCEPTION) + convey.So(msg, convey.ShouldBeEmpty) + }) + convey.Convey("invokeHandler return ok", func() { + invokeHandler := func(args []*api.Arg, traceID string) (string, error) { + return "abc", nil + } + server.RegisterInvokeHandler(invokeHandler) + msg, err := client.Invoke("testFunc", "testIns", args, rpcclient.InvokeParams{}, nil) + convey.So(err, convey.ShouldBeNil) + convey.So(msg, convey.ShouldEqual, "abc") + }) + grpcServer, _ := server.(*SimplifiedStreamServer) + grpcServer.grpcServer.Stop() + }) +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/utils/utils.go b/yuanrong/pkg/common/faas_common/kernelrpc/utils/utils.go new file mode 100644 index 0000000..2f58fff --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/utils/utils.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "yuanrong/pkg/common/uuid" +) + +// GenTaskID for create a task id +func GenTaskID() string { + return "task-" + uuid.New().String() +} diff --git a/yuanrong/pkg/common/faas_common/kernelrpc/utils/utils_test.go b/yuanrong/pkg/common/faas_common/kernelrpc/utils/utils_test.go new file mode 100644 index 0000000..d41509f --- /dev/null +++ b/yuanrong/pkg/common/faas_common/kernelrpc/utils/utils_test.go @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenTaskID(t *testing.T) { + taskID := GenTaskID() + assert.Equal(t, true, strings.Contains(taskID, "task")) +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/hash.go b/yuanrong/pkg/common/faas_common/loadbalance/hash.go new file mode 100644 index 0000000..ffd8a14 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/hash.go @@ -0,0 +1,454 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides consistent hash alogrithm +package loadbalance + +import ( + "hash/crc32" + "sort" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + // MaxInstanceSize is the max instance size be stored in hash ring + MaxInstanceSize = 100 + defaultMapSize = 100 +) + +type uint32Slice []uint32 + +// Len returns the size +func (u uint32Slice) Len() int { + return len(u) +} + +// Swap will swap two elements +func (u uint32Slice) Swap(i, j int) { + if i < 0 || i >= len(u) || j < 0 || j >= len(u) { + return + } + u[i], u[j] = u[j], u[i] +} + +// Less returns true if i less than j +func (u uint32Slice) Less(i, j int) bool { + if i < 0 || i >= len(u) || j < 0 || j >= len(u) { + return false + } + return u[i] < u[j] +} + +type anchorInfo struct { + instanceHash uint32 + instanceKey string +} + +// CHGeneric is the generic consistent hash +type CHGeneric struct { + anchorPoint map[string]*anchorInfo + instanceMap map[uint32]string + hashPool uint32Slice + insMutex sync.RWMutex + anchorMutex sync.Mutex +} + +// NewCHGeneric creates generic consistent hash +func NewCHGeneric() *CHGeneric { + return &CHGeneric{ + hashPool: make([]uint32, 0, MaxInstanceSize), + instanceMap: make(map[uint32]string, defaultMapSize), + anchorPoint: make(map[string]*anchorInfo, defaultMapSize), + } +} + +// Next returns the next scheduled node of a function +func (c *CHGeneric) Next(name string, move bool) interface{} { + c.anchorMutex.Lock() + anchor, exist := c.anchorPoint[name] + if !exist { + anchor = c.addAnchorPoint(name) + c.anchorMutex.Unlock() + return anchor.instanceKey + } + if move { + c.moveAnchorPoint(name, anchor.instanceHash) + } + c.insMutex.RLock() + _, exist = c.instanceMap[anchor.instanceHash] + c.insMutex.RUnlock() + // check if node still exists, no maxReqCount limitation + if !exist { + c.moveAnchorPoint(name, anchor.instanceHash) + } + c.anchorMutex.Unlock() + return anchor.instanceKey +} + +// Previous - returns the previous scheduled node of a function +func (c *CHGeneric) Previous(name string, move bool) interface{} { + previous := c.getPreviousHashKey(getHashKeyCRC32([]byte(name))) + if move { + previous = c.getPreviousHashKey(previous) + } + c.insMutex.RLock() + _, exist := c.instanceMap[previous] + c.insMutex.RUnlock() + if !exist { + previous = c.getPreviousHashKey(previous) + } + return c.instanceMap[previous] +} + +// Add will add a node into hash ring +func (c *CHGeneric) Add(node interface{}, weight int) { + c.insMutex.Lock() + defer c.insMutex.Unlock() + name, ok := node.(string) + if !ok { + log.GetLogger().Errorf("unable to convert %T to string", node) + return + } + hashKey := getHashKeyCRC32([]byte(name)) + _, exist := c.instanceMap[hashKey] + if exist { + return + } + c.instanceMap[hashKey] = name + c.hashPool = append(c.hashPool, hashKey) + sort.Sort(c.hashPool) + log.GetLogger().Infof("add node %s, hashKey %d to hash ring, hashPool is %v", name, hashKey, c.hashPool) +} + +// Remove will remove a node from hash ring +func (c *CHGeneric) Remove(node interface{}) { + name, assertOK := node.(string) + if !assertOK { + log.GetLogger().Errorf("unable to convert %T to string", node) + return + } + hashKey := getHashKeyCRC32([]byte(name)) + c.insMutex.Lock() + delete(c.instanceMap, hashKey) + for i, hash := range c.hashPool { + if hash == hashKey { + copy(c.hashPool[i:], c.hashPool[i+1:]) + c.hashPool[len(c.hashPool)-1] = 0 + c.hashPool = c.hashPool[:len(c.hashPool)-1] + break + } + } + log.GetLogger().Infof("delete node %s from hash ring", name) + c.insMutex.Unlock() + +} + +// RemoveAll will remove all nodes from hash ring +func (c *CHGeneric) RemoveAll() { + c.insMutex.Lock() + c.hashPool = make([]uint32, 0, MaxInstanceSize) + c.instanceMap = make(map[uint32]string, defaultMapSize) + c.insMutex.Unlock() + return +} + +// Reset will clean all anchor infos +func (c *CHGeneric) Reset() { + c.anchorMutex.Lock() + c.anchorPoint = make(map[string]*anchorInfo, defaultMapSize) + c.anchorMutex.Unlock() + log.GetLogger().Infof("reset hash ring anchorPoint") + return +} + +// DeleteBalancer - +func (c *CHGeneric) DeleteBalancer(name string) { + c.anchorMutex.Lock() + defer c.anchorMutex.Unlock() + delete(c.anchorPoint, name) +} + +func (c *CHGeneric) addAnchorPoint(name string) *anchorInfo { + // need to be called in a thread safe context + hashKey := getHashKeyCRC32([]byte(name)) + c.insMutex.RLock() + instanceHash := c.getNextHashKey(hashKey) + c.insMutex.RUnlock() + newAnchor := &anchorInfo{ + instanceHash: instanceHash, + instanceKey: c.instanceMap[instanceHash], + } + c.anchorPoint[name] = newAnchor + log.GetLogger().Debugf("name %s hashKey %d", name, hashKey) + return newAnchor +} + +func (c *CHGeneric) moveAnchorPoint(name string, curHash uint32) { + c.insMutex.Lock() + instanceHash := c.getNextHashKey(curHash) + c.anchorPoint[name].instanceHash = instanceHash + c.anchorPoint[name].instanceKey = c.instanceMap[instanceHash] + c.insMutex.Unlock() +} + +func (c *CHGeneric) getNextHashKey(hashKey uint32) uint32 { + // need to be called with insMutex locked + if len(c.hashPool) == 0 { + return 0 + } + nextHashKey := c.hashPool[0] + for _, v := range c.hashPool { + if v > hashKey { + nextHashKey = v + break + } + } + return nextHashKey +} + +func (c *CHGeneric) getPreviousHashKey(hashKey uint32) uint32 { + // need to be called with insMutex locked + if len(c.hashPool) == 0 { + return 0 + } + hashLen := len(c.hashPool) + previousHashKey := c.hashPool[hashLen-1] + for i := hashLen - 1; i >= 0; i-- { + if c.hashPool[i] < hashKey { + previousHashKey = c.hashPool[i] + break + } + } + return previousHashKey +} + +func getHashKeyCRC32(key []byte) uint32 { + return crc32.ChecksumIEEE(key) +} + +// NewConcurrentCHGeneric return ConcurrentCHGeneric with given concurrency +func NewConcurrentCHGeneric(concurrency int) *ConcurrentCHGeneric { + return &ConcurrentCHGeneric{ + CHGeneric: NewCHGeneric(), + concurrency: concurrency, + counter: make(map[string]*concurrentCounter, constant.DefaultMapSize), + } +} + +type concurrentCounter struct { + count int + last time.Time +} + +// ConcurrentCHGeneric is concurrency balanced +type ConcurrentCHGeneric struct { + *CHGeneric + counter map[string]*concurrentCounter + countMutex sync.Mutex + concurrency int +} + +// Next returns the next scheduled node +func (c *ConcurrentCHGeneric) Next(name string, move bool) interface{} { + c.countMutex.Lock() + defer c.countMutex.Unlock() + l, ok := c.counter[name] + if !ok { + c.counter[name] = &concurrentCounter{ + last: time.Now(), + } + return c.CHGeneric.Next(name, move) + } + l.count++ + if l.count >= c.concurrency { + now := time.Now() + l.count = 0 + if now.Sub(l.last) < 1*time.Second { + move = true + } + l.last = now + } + return c.CHGeneric.Next(name, move) +} + +// Previous - returns the previous scheduled node of a function +func (c *ConcurrentCHGeneric) Previous(name string, move bool) interface{} { + return c.CHGeneric.Previous(name, move) +} + +// Add a node to hash ring +func (c *ConcurrentCHGeneric) Add(node interface{}, weight int) { + c.CHGeneric.Add(node, weight) +} + +// Remove a node from hash ring +func (c *ConcurrentCHGeneric) Remove(node interface{}) { + c.countMutex.Lock() + defer c.countMutex.Unlock() + c.CHGeneric.Remove(node) +} + +// RemoveAll remove all nodes from hash ring +func (c *ConcurrentCHGeneric) RemoveAll() { + c.countMutex.Lock() + defer c.countMutex.Unlock() + c.counter = make(map[string]*concurrentCounter, constant.DefaultMapSize) + c.CHGeneric.RemoveAll() +} + +// Reset clean all anchor infos and counters +func (c *ConcurrentCHGeneric) Reset() { + c.countMutex.Lock() + defer c.countMutex.Unlock() + c.counter = make(map[string]*concurrentCounter, constant.DefaultMapSize) + c.CHGeneric.Reset() +} + +// DeleteBalancer - +func (c *ConcurrentCHGeneric) DeleteBalancer(name string) { + c.countMutex.Lock() + delete(c.counter, name) + c.countMutex.Unlock() +} + +// NewLimiterCHGeneric return limiterCHGeneric with given concurrency +func NewLimiterCHGeneric(limiterTime time.Duration) *LimiterCHGeneric { + return &LimiterCHGeneric{ + CHGeneric: NewCHGeneric(), + limiterTime: limiterTime, + limiter: make(map[string]*concurrentLimiter, constant.DefaultMapSize), + } +} + +type concurrentLimiter struct { + head *limiterNode +} + +type limiterNode struct { + instanceKey interface{} + lastTime time.Time + next *limiterNode +} + +// LimiterCHGeneric is limiter balanced +type LimiterCHGeneric struct { + *CHGeneric + limiter map[string]*concurrentLimiter + nodeCount int + limiterMutex sync.Mutex + limiterTime time.Duration +} + +// Next returns the next scheduled node +func (c *LimiterCHGeneric) Next(name string, move bool) interface{} { + c.limiterMutex.Lock() + defer c.limiterMutex.Unlock() + if _, ok := c.limiter[name]; !ok { + c.limiter[name] = &concurrentLimiter{ + head: &limiterNode{}, + } + } + + moveFlag := move +label: + for exitFlag := 0; exitFlag <= c.nodeCount; exitFlag++ { + instanceKey := c.CHGeneric.Next(name, moveFlag) + h := c.limiter[name].head + n := h.next + for ; n != nil; n = n.next { + if n.instanceKey == instanceKey && !n.lastTime.IsZero() && time.Now().Sub(n.lastTime) < c.limiterTime { + moveFlag = true + continue label + } + if n.instanceKey == instanceKey && (n.lastTime.IsZero() || time.Now().Sub(n.lastTime) >= c.limiterTime) { + break + } + } + if n == nil { + h.next = &limiterNode{ + instanceKey: instanceKey, + next: h.next, + } + } + return instanceKey + } + return nil +} + +// Previous - returns the previous scheduled node of a function +func (c *LimiterCHGeneric) Previous(name string, move bool) interface{} { + return nil +} + +// Add a node to hash ring +func (c *LimiterCHGeneric) Add(node interface{}, weight int) { + c.limiterMutex.Lock() + defer c.limiterMutex.Unlock() + c.nodeCount++ + c.CHGeneric.Add(node, weight) +} + +// Remove a node from hash ring +func (c *LimiterCHGeneric) Remove(node interface{}) { + c.limiterMutex.Lock() + defer c.limiterMutex.Unlock() + c.nodeCount-- + c.CHGeneric.Remove(node) +} + +// RemoveAll remove all nodes from hash ring +func (c *LimiterCHGeneric) RemoveAll() { + c.limiterMutex.Lock() + defer c.limiterMutex.Unlock() + c.nodeCount = 0 + c.CHGeneric.RemoveAll() +} + +// Reset clean all anchor infos and counters +func (c *LimiterCHGeneric) Reset() { + c.limiterMutex.Lock() + defer c.limiterMutex.Unlock() + c.limiter = make(map[string]*concurrentLimiter, constant.DefaultMapSize) + c.CHGeneric.Reset() +} + +// DeleteBalancer - +func (c *LimiterCHGeneric) DeleteBalancer(name string) { + c.limiterMutex.Lock() + defer c.limiterMutex.Unlock() + c.CHGeneric.DeleteBalancer(name) + delete(c.limiter, name) +} + +// SetStain give the specified function, specify the node to set the stain +func (c *LimiterCHGeneric) SetStain(function string, node interface{}) { + c.limiterMutex.Lock() + defer c.limiterMutex.Unlock() + if _, ok := c.limiter[function]; !ok { + return + } + n := c.limiter[function].head + for ; n != nil; n = n.next { + if n.instanceKey == node { + n.lastTime = time.Now() + return + } + } +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/hash_test.go b/yuanrong/pkg/common/faas_common/loadbalance/hash_test.go new file mode 100644 index 0000000..c9ed8f2 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/hash_test.go @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides consistent hash algorithm +package loadbalance + +import ( + "github.com/smartystreets/goconvey/convey" + "testing" + "time" +) + +func TestConcurrentCHGeneric_Next(t *testing.T) { + convey.Convey("concurrentCHGeneric next", t, func() { + generic := NewConcurrentCHGeneric(10) + + generic.Add("node1", 0) + generic.Add("node2", 0) + next1 := generic.Next("function1", false) + next2 := generic.Next("function1", false) + convey.So(next1, convey.ShouldResemble, next2) + + generic = NewConcurrentCHGeneric(1) + generic.Add("node1", 0) + generic.Add("node2", 0) + next3 := generic.Next("function1", false) + next4 := generic.Next("function1", false) + convey.So(next3, convey.ShouldNotResemble, next4) + }) +} + +func TestCHGeneric_Previous(t *testing.T) { + convey.Convey("CHGeneric previous", t, func() { + generic := NewCHGeneric() + generic.Add("node1", 0) + generic.Add("node2", 0) + generic.Add("node3", 0) + + previous := generic.Previous("node2", false) + convey.So(previous, convey.ShouldEqual, "node1") + + previous = generic.Previous("node2", true) + convey.So(previous, convey.ShouldEqual, "node3") + }) +} + +func TestLimiterCHGeneric_DeleteBalancer(t *testing.T) { + convey.Convey("LimiterCHGeneric_DeleteBalancer", t, func() { + generic := NewLimiterCHGeneric(1 * time.Second) + generic.Add("node1", 0) + generic.Add("node2", 0) + generic.Add("node3", 0) + + next1 := generic.Next("function1", false) + convey.So(next1, convey.ShouldEqual, "node2") + next2 := generic.Next("function2", false) + convey.So(next2, convey.ShouldEqual, "node3") + + _, ok := generic.limiter["function1"] + _, exist := generic.anchorPoint["function1"] + convey.So(ok, convey.ShouldBeTrue) + convey.So(exist, convey.ShouldBeTrue) + + generic.DeleteBalancer("function1") + _, ok = generic.limiter["function1"] + _, exist = generic.anchorPoint["function1"] + convey.So(ok, convey.ShouldBeFalse) + convey.So(exist, convey.ShouldBeFalse) + }) +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/hashcache.go b/yuanrong/pkg/common/faas_common/loadbalance/hashcache.go new file mode 100644 index 0000000..66a0f9d --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/hashcache.go @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import "sync" + +type hashCache struct { + hashes sync.Map +} + +func createHashCache() *hashCache { + return &hashCache{ + hashes: sync.Map{}, + } +} + +func (cache *hashCache) getHash(key string) uint32 { + hashIf, ok := cache.hashes.Load(key) + if ok { + hash, ok := hashIf.(uint32) + if ok { + return hash + } + return 0 + } + hash := getHashKeyCRC32([]byte(key)) + cache.hashes.Store(key, hash) + return hash +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/loadbalance.go b/yuanrong/pkg/common/faas_common/loadbalance/loadbalance.go new file mode 100644 index 0000000..7b3365e --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/loadbalance.go @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides load balancing algorithm +package loadbalance + +import "time" + +const ( + // RoundRobinNginx represents type of Round Robin Nginx + RoundRobinNginx LBType = iota + // RoundRobinLVS represents type of Round Robin LVS + RoundRobinLVS + // ConsistentHashGeneric represents type of Generic Consistent Hash + ConsistentHashGeneric + // ConcurrentConsistentHashGeneric represents type of concurrent Consistent + ConcurrentConsistentHashGeneric +) + +// Request - +type Request struct { + Name string + TraceID string + Timestamp time.Time +} + +// LBType is the type of load loadbalance algorithm +type LBType int + +const defaultCHGenericConcurrency = 100 + +// LoadBalance is the interface of loadbalance algorithm +type LoadBalance interface { + Next(name string, move bool) interface{} // move parameter controls whether the hash loop moves + Previous(name string, move bool) interface{} + Add(node interface{}, weight int) + Remove(node interface{}) + RemoveAll() + Reset() + DeleteBalancer(name string) +} + +// LBFactory is the factory of loadbalance algorithm +func LBFactory(t LBType) LoadBalance { + switch t { + case RoundRobinNginx: + return &WNGINX{} + case ConsistentHashGeneric: + return NewCHGeneric() + case ConcurrentConsistentHashGeneric: + return NewConcurrentCHGeneric(defaultCHGenericConcurrency) + default: + return NewCHGeneric() + } +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/loadbalance_test.go b/yuanrong/pkg/common/faas_common/loadbalance/loadbalance_test.go new file mode 100644 index 0000000..17e7d87 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/loadbalance_test.go @@ -0,0 +1,238 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides consistent hash algorithm +package loadbalance + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type LBTestSuite struct { + suite.Suite + LoadBalance + lbType LBType + m sync.RWMutex + emptyNode interface{} +} + +func (lbs *LBTestSuite) SetupSuite() { + switch lbs.lbType { + case RoundRobinNginx, RoundRobinLVS: + lbs.emptyNode = nil + case ConsistentHashGeneric: + lbs.emptyNode = "" + default: + lbs.emptyNode = "" + } +} + +func (lbs *LBTestSuite) SetupTest() { + lbs.m = sync.RWMutex{} + lbs.LoadBalance = LBFactory(lbs.lbType) +} + +func (lbs *LBTestSuite) TearDownTest() { + lbs.LoadBalance = nil +} + +func (lbs *LBTestSuite) AddToLB(workerInstance interface{}, weight int) { + switch lbs.lbType { + case RoundRobinNginx, RoundRobinLVS: + lbs.m.Lock() + lbs.Add(workerInstance, weight) + lbs.Reset() + lbs.m.Unlock() + case ConsistentHashGeneric: + lbs.Add(workerInstance, 0) + default: + } +} + +func (lbs *LBTestSuite) DelFromLB(workerInstance interface{}) { + switch lbs.lbType { + case RoundRobinNginx, RoundRobinLVS: + lbs.m.Lock() + lbs.Remove(workerInstance) + lbs.Reset() + defer lbs.m.Unlock() + case ConsistentHashGeneric: + lbs.Remove(workerInstance) + default: + } +} + +func (lbs *LBTestSuite) TestAdd() { + lbs.AddToLB("new-node-01", 0) + lbs.AddToLB("new-node-01", 1) // test duplicate + lbs.AddToLB("new-node-02", 2) + lbs.AddToLB("new-node-03", 5) + lbs.AddToLB("", 6) + lbs.AddToLB(nil, 4) + next := lbs.Next("fn-urn-01", true) + assert.NotEqual(lbs.T(), lbs.emptyNode, next) + lbs.Reset() + next = lbs.Next("fn-urn-01", true) + assert.NotEqual(lbs.T(), lbs.emptyNode, next) +} + +func (lbs *LBTestSuite) TestNext() { + var wg sync.WaitGroup + next := lbs.Next("fn-urn-01", false) + assert.Equal(lbs.T(), lbs.emptyNode, next) + + lbs.AddToLB("new-node-01", 5) + next = lbs.Next("fn-urn-01", true) + assert.Equal(lbs.T(), "new-node-01", next) + + for i := 2; i < 5; i++ { + wg.Add(1) + go func(i int, wg *sync.WaitGroup) { + lbs.AddToLB("new-node-0"+strconv.Itoa(i), 5) + wg.Done() + }(i, &wg) + } + wg.Wait() + next = lbs.Next("fn-urn-01", true) + assert.NotEqual(lbs.T(), lbs.emptyNode, next) +} + +func (lbs *LBTestSuite) TestRemove() { + var wg sync.WaitGroup + for i := 1; i < 5; i++ { + wg.Add(1) + go func(i int, wg *sync.WaitGroup) { + lbs.AddToLB("new-node-0"+strconv.Itoa(i), 5) + wg.Done() + }(i, &wg) + } + wg.Wait() + for i := 1; i < 4; i++ { + wg.Add(1) + go func(i int, wg *sync.WaitGroup) { + lbs.DelFromLB("new-node-0" + strconv.Itoa(i)) + wg.Done() + }(i, &wg) + } + wg.Wait() + next := lbs.Next("fn-urn-01", true) + assert.Equal(lbs.T(), "new-node-04", next) +} + +func (lbs *LBTestSuite) TestRemoveAll() { + var wg sync.WaitGroup + for i := 1; i < 5; i++ { + wg.Add(1) + go func(i int, wg *sync.WaitGroup) { + lbs.Add("new-node-0"+strconv.Itoa(i), 5) + wg.Done() + }(i, &wg) + } + wg.Wait() + lbs.RemoveAll() + next := lbs.Next("fn-urn-01", true) + assert.Equal(lbs.T(), lbs.emptyNode, next) +} + +func TestLBTestSuite(t *testing.T) { + suite.Run(t, &LBTestSuite{lbType: ConsistentHashGeneric}) +} + +func TestConcurrentCHGeneric_Add(t *testing.T) { + con := NewConcurrentCHGeneric(2) + con.Add("n1", 0) + con.Add("n2", 0) + + next := con.Next("n1", false) + assert.Equal(t, "n2", next) + + con.Remove("n2") + con.RemoveAll() + con.Reset() + + next = con.Next("n1", false) + assert.Equal(t, "", next) +} + +func TestLimiterCHGeneric(t *testing.T) { + limiter := NewLimiterCHGeneric(5 * time.Second) + limiter.Add("n1", 0) + limiter.Add("n2", 0) + limiter.Add("n3", 0) + + next := limiter.Next("func1", false) + assert.Equal(t, "n1", next) + + limiter.SetStain("func1", "n1") + + next = limiter.Next("func1", false) + assert.Equal(t, "n3", next) + + limiter.SetStain("func1", "n3") + + next = limiter.Next("func1", false) + assert.Equal(t, "n2", next) + + limiter.SetStain("func1", "n2") + + next = limiter.Next("func1", false) + assert.Equal(t, nil, next) + + time.Sleep(5 * time.Second) + + next = limiter.Next("func1", false) + assert.Equal(t, "n2", next) + + limiter.Remove("n2") + + next = limiter.Next("func1", false) + assert.Equal(t, "n1", next) + + limiter.RemoveAll() + + next = limiter.Next("func1", false) + assert.Equal(t, "", next) + + limiter.Reset() +} + +func TestLBFactory(t *testing.T) { + convey.Convey("LBFactory", t, func() { + convey.Convey("RoundRobinNginx", func() { + factory := LBFactory(LBType(0)) + convey.So(factory, convey.ShouldNotBeNil) + }) + convey.Convey("ConsistentHashGeneric", func() { + factory := LBFactory(LBType(2)) + convey.So(factory, convey.ShouldNotBeNil) + }) + convey.Convey("ConcurrentConsistentHashGeneric", func() { + factory := LBFactory(LBType(3)) + convey.So(factory, convey.ShouldNotBeNil) + }) + convey.Convey("default", func() { + factory := LBFactory(LBType(1)) + convey.So(factory, convey.ShouldNotBeNil) + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/nolockconsistenthash.go b/yuanrong/pkg/common/faas_common/loadbalance/nolockconsistenthash.go new file mode 100644 index 0000000..852df80 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/nolockconsistenthash.go @@ -0,0 +1,126 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "errors" + "sort" +) + +// Node - +type Node struct { + Obj interface{} + Key string + hash uint32 +} + +// NoLockLoadBalance - +type NoLockLoadBalance interface { + Add(node *Node) error + Next(key string) *Node + Delete(nodeKey string) *Node +} + +// CreateNoLockLB - +func CreateNoLockLB() NoLockLoadBalance { + return &ConsistentHash{ + nodes: make([]*Node, 0), + cache: createHashCache(), + } +} + +type nodeSlice []*Node + +// Len returns the size +func (s nodeSlice) Len() int { + return len(s) +} + +// Swap will swap two elements +func (s nodeSlice) Swap(i, j int) { + if i < 0 || i >= len(s) || j < 0 || j >= len(s) { + return + } + s[i], s[j] = s[j], s[i] +} + +// Less returns true if i less than j +func (s nodeSlice) Less(i, j int) bool { + if i < 0 || i >= len(s) || j < 0 || j >= len(s) { + return false + } + return s[i].hash < s[j].hash +} + +// ConsistentHash - +type ConsistentHash struct { + cache *hashCache + nodes nodeSlice +} + +// Add - +func (c *ConsistentHash) Add(newNode *Node) error { + newNode.hash = getHashKeyCRC32([]byte(newNode.Key)) + for _, node := range c.nodes { + if node.Key == newNode.Key { + return errors.New("node already exist") + } + if node.hash == newNode.hash { + return errors.New("node hash already exist") + } + } + + c.nodes = append(c.nodes, newNode) + sort.Sort(c.nodes) + return nil +} + +// Next - +func (c *ConsistentHash) Next(key string) *Node { + if len(c.nodes) == 0 { + return nil + } + + keyHash := c.cache.getHash(key) + index := c.search(keyHash) + return c.nodes[index] +} + +func (c *ConsistentHash) search(keyHash uint32) int { + f := func(x int) bool { + if x >= len(c.nodes) { + return false + } + return c.nodes[x].hash > keyHash + } + index := sort.Search(len(c.nodes), f) + if index >= len(c.nodes) { + return 0 + } + return index +} + +// Delete - +func (c *ConsistentHash) Delete(nodeKey string) *Node { + for i, node := range c.nodes { + if node.Key == nodeKey { + c.nodes = append(c.nodes[:i], c.nodes[i+1:]...) + return node + } + } + return nil +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/nolockconsistenthash_test.go b/yuanrong/pkg/common/faas_common/loadbalance/nolockconsistenthash_test.go new file mode 100644 index 0000000..c57de88 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/nolockconsistenthash_test.go @@ -0,0 +1,96 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +const ( + nodeKey = "faas-scheduler-6b758c8b74-5zdwv" + funcKeyWithRes = "7e186a/0@base@testresourcepython36768/latest/300-128" +) + +var ( + node1 = &Node{ + Key: nodeKey, + } + + node2 = &Node{ + Key: nodeKey + "1", + } + + node3 = &Node{ + Key: nodeKey + "2", + } +) + +type mockRealNode struct { + state bool +} + +func (node *mockRealNode) IsEnable() bool { + return node.state +} + +func TestStatefulConsistent(t *testing.T) { + convey.Convey("TestStatefulConsistentHashWithOneNode", t, func() { + lb := CreateNoLockLB() + outNode := lb.Next(funcKeyWithRes) + convey.So(outNode, convey.ShouldBeNil) + + lb.Add(node1) + lb.Add(node1) + + outNode = lb.Next(funcKeyWithRes) + convey.So(outNode, convey.ShouldNotBeNil) + convey.So(outNode.Key, convey.ShouldEqual, node1.Key) + + outNode = lb.Delete(nodeKey) + convey.So(outNode, convey.ShouldNotBeNil) + convey.So(outNode.Key, convey.ShouldEqual, node1.Key) + + outNode = lb.Delete(nodeKey) + convey.So(outNode, convey.ShouldBeNil) + + outNode = lb.Next(funcKeyWithRes) + convey.So(outNode, convey.ShouldBeNil) + + lb.Add(node2) + lb.Add(node3) + lb.Add(node1) + + outNode = lb.Next(funcKeyWithRes) + convey.So(outNode, convey.ShouldNotBeNil) + convey.So(outNode.Key, convey.ShouldEqual, node2.Key) + + }) +} + +func BenchmarkStatefulConsistentHashWithThreeNode(b *testing.B) { + lb := CreateNoLockLB() + + lb.Add(node1) + lb.Add(node2) + lb.Add(node3) + + for i := 0; i < b.N; i++ { + lb.Next(funcKeyWithRes + "3") + } +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/roundrobin.go b/yuanrong/pkg/common/faas_common/loadbalance/roundrobin.go new file mode 100644 index 0000000..448d566 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/roundrobin.go @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package loadbalance provides roundrobin algorithm +package loadbalance + +// WeightNginx weight nginx +type WeightNginx struct { + Node interface{} + Weight int + CurrentWeight int + EffectiveWeight int +} + +// WNGINX w nginx +type WNGINX struct { + nodes []*WeightNginx +} + +// Add add node +func (w *WNGINX) Add(node interface{}, weight int) { + weightNginx := &WeightNginx{ + Node: node, + Weight: weight, + EffectiveWeight: weight} + w.nodes = append(w.nodes, weightNginx) +} + +// Remove removes a node +func (w *WNGINX) Remove(node interface{}) { + for i, weighted := range w.nodes { + if weighted.Node == node { + w.nodes = append(w.nodes[:i], w.nodes[i+1:]...) + break + } + } +} + +// RemoveAll remove all nodes +func (w *WNGINX) RemoveAll() { + w.nodes = w.nodes[:0] +} + +// Next get next node +func (w *WNGINX) Next(_ string, _ bool) interface{} { + if len(w.nodes) == 0 { + return nil + } + if len(w.nodes) == 1 { + return w.nodes[0].Node + } + return nextWeightedNode(w.nodes).Node +} + +// Previous - returns the previous scheduled node of a function +func (w *WNGINX) Previous(name string, move bool) interface{} { + return nil +} + +// DeleteBalancer - +func (w *WNGINX) DeleteBalancer(name string) { +} + +// nextWeightedNode get best next node info +func nextWeightedNode(nodes []*WeightNginx) *WeightNginx { + total := 0 + if len(nodes) == 0 { + return nil + } + best := nodes[0] + for _, w := range nodes { + w.CurrentWeight += w.EffectiveWeight + total += w.EffectiveWeight + if w.CurrentWeight > best.CurrentWeight { + best = w + } + } + best.CurrentWeight -= total + return best +} + +// Reset reset all nodes +func (w *WNGINX) Reset() { + for _, s := range w.nodes { + s.EffectiveWeight = s.Weight + s.CurrentWeight = 0 + } +} + +// Done - +func (w *WNGINX) Done(node interface{}) {} + +// NextWithRequest - +func (w *WNGINX) NextWithRequest(req *Request, move bool) interface{} { + return w.Next(req.Name, move) +} + +// SetConcurrency - +func (w *WNGINX) SetConcurrency(concurrency int) {} + +// Start - +func (w *WNGINX) Start() {} + +// Stop - +func (w *WNGINX) Stop() {} + +// NoLock - +func (w *WNGINX) NoLock() bool { + return false +} + +// WeightLvs weight lv5 +type WeightLvs struct { + Node interface{} + Weight int +} diff --git a/yuanrong/pkg/common/faas_common/loadbalance/roundrobin_test.go b/yuanrong/pkg/common/faas_common/loadbalance/roundrobin_test.go new file mode 100644 index 0000000..6ec3abf --- /dev/null +++ b/yuanrong/pkg/common/faas_common/loadbalance/roundrobin_test.go @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestNext(t *testing.T) { + convey.Convey("node length is 0", t, func() { + node := []*WeightNginx{} + wnginx := WNGINX{node} + + res := wnginx.Next("", true) + convey.So(res, convey.ShouldBeNil) + }) + convey.Convey("node length is 1", t, func() { + node := []*WeightNginx{ + {"Node1", 30, 10, 20}, + } + wnginx := WNGINX{node} + + res := wnginx.Next("", true) + convey.So(res, convey.ShouldNotBeNil) + }) + convey.Convey("node length > 1", t, func() { + node := []*WeightNginx{ + {"Node1", 30, 10, 20}, + {"Node2", 30, 60, 20}, + } + wnginx := WNGINX{node} + + res := wnginx.Next("", true) + resStr, ok := res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node2") + }) + + convey.Convey("remove", t, func() { + node := []*WeightNginx{ + {"Node1", 30, 10, 20}, + } + wnginx := WNGINX{node} + wnginx.Add("Node2", 60) + res := wnginx.Next("", true) + resStr, ok := res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node2") + + wnginx.Remove("Node2") + res = wnginx.Next("", true) + resStr, ok = res.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(resStr, convey.ShouldEqual, "Node1") + }) + + convey.Convey("remove", t, func() { + node := []*WeightNginx{ + {"Node1", 30, 10, 20}, + {"Node2", 30, 60, 20}, + } + wnginx := WNGINX{node} + wnginx.RemoveAll() + convey.So(len(wnginx.nodes), convey.ShouldEqual, 0) + }) +} + +func TestReset(t *testing.T) { + convey.Convey("Reset success", t, func() { + weightNginx := &WeightNginx{"Node1", 30, 10, 20} + var node []*WeightNginx + node = append(node, weightNginx) + wnginx := WNGINX{node} + + wnginx.Reset() + convey.So(weightNginx.EffectiveWeight, convey.ShouldEqual, weightNginx.Weight) + }) + +} diff --git a/yuanrong/pkg/common/faas_common/localauth/authcache.go b/yuanrong/pkg/common/faas_common/localauth/authcache.go new file mode 100644 index 0000000..49defb1 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/localauth/authcache.go @@ -0,0 +1,195 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package localauth authenticates requests by local configmaps +package localauth + +import ( + "errors" + "sync" + "sync/atomic" + "time" + + "k8s.io/client-go/tools/cache" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/signals" +) + +const ( + senderCacheDuration = 1 * time.Minute +) + +// AuthCache cache interface +type AuthCache interface { + GetSignForSender() (string, string, error) + GetSignForReceiver(auth string) (string, bool, error) + updateReceiver(string) error + updateSender(string, string) +} + +type authCache struct { + // use an atomic value to promise concurrent safety, which stores the authorization token and time. + senderCache *atomic.Value + // sign-time + receiverCache cache.Store + appID string + AuthConfig +} + +type senderValue struct { + auth string + time string +} + +var localCache *authCache +var doOnce sync.Once + +// GetLocalAuthCache you have to create it before you get it. +func GetLocalAuthCache(aKey, sKey, appID string, duration int) AuthCache { + doOnce.Do(func() { + var c cache.Store + stopCh := signals.WaitForSignal() + // cache.Store ttl the minimum valid value is 1 second. If this parameter is set to 0, + // the cache does not need to be increased. Therefore, set the cache to nil. + if duration == 0 { + c = nil + } else { + c = cache.NewTTLStore(receiverCacheKey, time.Duration(duration)*time.Minute) + } + atom := &atomic.Value{} + localCache = &authCache{ + senderCache: atom, + receiverCache: c, + } + localCache.appID = appID + localCache.AKey = aKey + localCache.SKey = sKey + localCache.initSenderCache(stopCh) + localCache.Duration = duration + // clean expired keys by ticker could avoid worker-manager oom problem + // because receiver cache clean expired keys is lazy by calling GetByKeys method or List method + go localCache.startCleanExpiredKeysByTicker(stopCh, time.Duration(duration)*time.Minute) + }) + if localCache == nil { + return nil + } + return localCache +} + +func (c *authCache) startCleanExpiredKeysByTicker(stopCh <-chan struct{}, duration time.Duration) { + if stopCh == nil || c.receiverCache == nil { + return + } + log.GetLogger().Infof("start to clean expired keys by ticker duration %s", duration.String()) + ticker := time.NewTicker(duration) + defer ticker.Stop() + for { + select { + case <-ticker.C: + // call receiver cache list method will clean all expired keys + length := len(c.receiverCache.List()) + log.GetLogger().Debugf("receiver cache length is %d after clean expired keys once by ticker", length) + case <-stopCh: + log.GetLogger().Infof("stop channel is closed") + return + } + } +} + +// GetSignForSender return time auth error +func (c *authCache) GetSignForSender() (string, string, error) { + loaded := c.senderCache.Load() + value, ok := loaded.(senderValue) + if !ok { + return "", "", errors.New("no sender cache") + } + if value.time == "" || value.auth == "" { + return "", "", errors.New("no sender time") + } + return value.time, value.auth, nil +} + +// GetSignForReceiver value exit error +func (c *authCache) GetSignForReceiver(auth string) (string, bool, error) { + if c.receiverCache == nil { + return "", false, nil + } + key, b, err := c.receiverCache.GetByKey(auth) + if !b { + key = "" + } + return key.(string), b, err +} + +func (c *authCache) updateReceiver(sign string) error { + if c.receiverCache == nil { + return nil + } + err := c.receiverCache.Add(sign) + if err != nil { + return err + } + return nil +} + +func (c *authCache) updateSender(auth, time string) { + c.senderCache.Store(senderValue{ + auth: auth, + time: time, + }) +} + +func (c *authCache) waitForDoneSignal(stopCh <-chan struct{}) { + if stopCh == nil { + return + } + ticker := time.NewTicker(senderCacheDuration) + for { + select { + case <-ticker.C: + // update senderCache + c.createAndUpdateSender() + case <-stopCh: + ticker.Stop() + return + } + } +} + +func (c *authCache) initSenderCache(stopCh <-chan struct{}) { + c.createAndUpdateSender() + go c.waitForDoneSignal(stopCh) +} + +func (c *authCache) createAndUpdateSender() { + var data []byte + authorization, t := CreateAuthorization( + c.AKey, + c.SKey, + "", + c.appID, + data, + ) + c.updateSender(authorization, t) + if c.Duration != 0 { + log.GetLogger().Debugf("the length of receiver cache is: %d", len(c.receiverCache.ListKeys())) + } +} + +func receiverCacheKey(obj interface{}) (string, error) { + return obj.(string), nil +} diff --git a/yuanrong/pkg/common/faas_common/localauth/authcache_test.go b/yuanrong/pkg/common/faas_common/localauth/authcache_test.go new file mode 100644 index 0000000..154a8f8 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/localauth/authcache_test.go @@ -0,0 +1,220 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package localauth + +import ( + "reflect" + "sync/atomic" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + "k8s.io/client-go/tools/cache" +) + +func Test_receiverCacheKey(t *testing.T) { + type args struct { + obj interface{} + } + var a args + a.obj = "aaa" + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {"case1", a, "aaa", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := receiverCacheKey(tt.args.obj) + if (err != nil) != tt.wantErr { + t.Errorf("receiverCacheKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("receiverCacheKey() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_authCache_GetSignForSender(t *testing.T) { + type fields struct { + senderCache *atomic.Value + receiverCache cache.Store + appID string + AuthConfig AuthConfig + } + var f fields + senderCache := &atomic.Value{} + f.senderCache = senderCache + + var f2 fields + senderCache2 := &atomic.Value{} + senderCache2.Store(senderValue{ + time: "aaa", + auth: "aaa", + }) + f2.senderCache = senderCache2 + tests := []struct { + name string + fields fields + wantS string + wantD string + wantErr bool + }{ + {"case1", f, "", "", true}, + {"case2", f2, "aaa", "aaa", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &authCache{ + senderCache: tt.fields.senderCache, + receiverCache: tt.fields.receiverCache, + appID: tt.fields.appID, + AuthConfig: tt.fields.AuthConfig, + } + gotS, gotD, err := c.GetSignForSender() + if (err != nil) != tt.wantErr { + t.Errorf("GetSignForSender() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotS != tt.wantS { + t.Errorf("GetSignForSender() gotS = %v, want %v", gotS, tt.wantS) + } + if gotD != tt.wantD { + t.Errorf("GetSignForSender() gotD = %v, want %v", gotD, tt.wantD) + } + }) + } +} + +func Test_authCache_updateReceiver(t *testing.T) { + type fields struct { + senderCache *atomic.Value + receiverCache cache.Store + appID string + AuthConfig AuthConfig + } + type args struct { + sign string + } + var f fields + var a args + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"case1", f, a, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &authCache{ + senderCache: tt.fields.senderCache, + receiverCache: tt.fields.receiverCache, + appID: tt.fields.appID, + AuthConfig: tt.fields.AuthConfig, + } + if err := c.updateReceiver(tt.args.sign); (err != nil) != tt.wantErr { + t.Errorf("updateReceiver() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_authCache_createAndUpdateSender(t *testing.T) { + receiverCache := cache.NewTTLStore(receiverCacheKey, time.Duration(5)*time.Second) + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(receiverCache), "ListKeys", + func(_ *cache.ExpirationCache) []string { + return []string{} + }), + gomonkey.ApplyFunc(DecryptKeys, func(inputAKey string, inputSKey string) ([]byte, []byte, error) { + return []byte("aaa"), []byte("aaa"), nil + }), + } + + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + type fields struct { + senderCache *atomic.Value + receiverCache cache.Store + appID string + AuthConfig AuthConfig + } + var f fields + f.AuthConfig.Duration = 1 + f.receiverCache = receiverCache + f.senderCache = &atomic.Value{} + f.senderCache.Store(senderValue{ + time: "aaa", + auth: "aaa", + }) + tests := []struct { + name string + fields fields + }{ + {"case1", f}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &authCache{ + senderCache: tt.fields.senderCache, + receiverCache: tt.fields.receiverCache, + appID: tt.fields.appID, + AuthConfig: tt.fields.AuthConfig, + } + c.createAndUpdateSender() + }) + } +} + +func TestWaitForDoneSignal(t *testing.T) { + c := &authCache{ + senderCache: &atomic.Value{}, + } + c.senderCache.Store(senderValue{ + time: "aaa", + auth: "bbb", + }) + c.waitForDoneSignal(nil) + stopChan := make(chan struct{}) + go c.waitForDoneSignal(stopChan) + close(stopChan) + assert.NotEqual(t, c, nil) +} + +func TestGetSignForReceiver(t *testing.T) { + c := &authCache{ + senderCache: &atomic.Value{}, + receiverCache: cache.NewTTLStore(receiverCacheKey, time.Duration(1)*time.Minute), + } + c.senderCache.Store(senderValue{ + time: "aaa", + auth: "aaa", + }) + c.updateReceiver("sign") + _, _, err := c.GetSignForReceiver("auth") + assert.Equal(t, err, nil) +} diff --git a/yuanrong/pkg/common/faas_common/localauth/authcheck.go b/yuanrong/pkg/common/faas_common/localauth/authcheck.go new file mode 100644 index 0000000..ebfbc93 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/localauth/authcheck.go @@ -0,0 +1,407 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package localauth authenticates requests by local configmaps +package localauth + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "fmt" + "math" + "net/http" + "net/url" + "os" + "sort" + "strconv" + "strings" + "time" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +const ( + modeSDK = "SDKMode" + modeHWS = "HWSMode" + // the difference limit of a timestamp + defaultTimestampDiffLimit = 5 + // 7 days + maxTimestampDiffLimit = 10080 + maxHeaderLength = 20 + minLengthOfAuthValue = 2 + base = 10 + bitSize = 64 +) + +var timestampDiffLimit = getTimestampDiffLimit() + +type modeOptions struct { + authHeaderPrefix string + timeFormat string + shortTimeFormat string + terminalString string + name string + date string +} + +var modeOption = &modeOptions{ + authHeaderPrefix: "", + timeFormat: "", + shortTimeFormat: "", + terminalString: "", + name: "", + date: "", +} + +// Signer is a struct of +type Signer struct { + signTime time.Time + serviceName string + region string +} + +// AuthConfig represents configurations of local auth +type AuthConfig struct { + AKey string `json:"aKey" yaml:"aKey" valid:"optional"` + SKey string `json:"sKey" yaml:"sKey" valid:"optional"` + Duration int `json:"duration" yaml:"duration" valid:"optional"` +} + +// Authentication represents aKey and sKey Decrypted from ak and sk +type Authentication struct { + AKey []byte + SKey []byte +} + +// signLocalAuthRequest returns the authentication header +func signLocalAuthRequest(rawURL, timeStamp, appID string, key *Authentication, data []byte) (string, []byte) { + signer := getSigner("SDKMode", "", "") + timeStampInt, err := strconv.ParseInt(timeStamp, base, bitSize) + if err != nil { + log.GetLogger().Errorf("failed to parse the timestamp string") + return "", data + } + signer.signTime = time.Unix(timeStampInt, 0) + // default text of data + if len(data) == 0 { + data = []byte(`signature verification`) + } + header := make(map[string][]string, maxHeaderLength) + header["Content-Type"] = []string{"application/json"} + + parsedURL, err := url.Parse(rawURL) + if err != nil { + log.GetLogger().Errorf("failed to parse a URL") + return "", data + } + request := &http.Request{Method: "POST", URL: parsedURL, Header: header} + signerHeader := signer.sign(request, key.AKey, key.SKey, data, appID) + return signerHeader["X-Identity-Sign"], data +} + +func getSigner(mode, serviceName, region string) *Signer { + if mode == modeSDK { + setSDKMode() + } else { + setHWSMode() + } + return &Signer{ + signTime: time.Now(), + serviceName: serviceName, + region: region, + } +} + +func setSDKMode() { + modeOption = &modeOptions{ + authHeaderPrefix: "SDK-HMAC-SHA256", + timeFormat: "20060102T150405Z", + shortTimeFormat: "20060102", + terminalString: "sdk_request", + name: "SDK", + date: "X-Sdk-Date", + } +} + +func setHWSMode() { + modeOption = &modeOptions{ + authHeaderPrefix: "HWS-HMAC-SHA256", + timeFormat: "20060102T150405Z", + shortTimeFormat: "20060102", + terminalString: "hws_request", + name: "HWS", + date: "X-Hws-Date", + } +} + +func (sig *Signer) sign(request *http.Request, aKey, sKey []byte, body []byte, + appID string) map[string]string { + header := map[string]string{} + request.Header.Add(modeOption.date, sig.signTime.UTC().Format(modeOption.timeFormat)) + contentSha256 := makeSha256Hex(body) + canonicalString := sig.buildCanonicalRequest(request, contentSha256) + stringToSign := sig.buildStringToSign(canonicalString) + signatureStr := sig.buildSignature(sKey, stringToSign) + credentialString := sig.buildCredentialString() + signedHeaders := sig.buildSignedHeadersString(request) + aKeyString := string(aKey) + utils.ClearByteMemory(aKey) + parts := []string{ + modeOption.authHeaderPrefix + " Credential=" + aKeyString + "/" + credentialString, + "SignedHeaders=" + signedHeaders, + "Signature=" + signatureStr, + } + if appID != "" { + parts = append(parts, "appid="+appID) + } + utils.ClearStringMemory(aKeyString) + + signResult := strings.Join(parts, ", ") + header["host"] = request.Host + header[modeOption.date] = sig.signTime.UTC().Format(modeOption.timeFormat) + header["Content-Type"] = "application/json;charset=UTF-8" + header["Accept"] = "application/json" + header["X-Identity-Sign"] = signResult + return header +} + +// buildSignature generate a signature with request and secret key +func (sig *Signer) buildSignature(sKey []byte, stringtoSign string) string { + var secretBuf bytes.Buffer + secretBuf.Write([]byte(modeOption.name)) + secretBuf.Write(sKey) + utils.ClearByteMemory(sKey) + sigTime := []byte(sig.signTime.UTC().Format(modeOption.shortTimeFormat)) + date := makeHmac(secretBuf.Bytes(), sigTime) + secretBuf.Reset() + region := makeHmac(date, []byte(sig.region)) + service := makeHmac(region, []byte(sig.serviceName)) + credentials := makeHmac(service, []byte(modeOption.terminalString)) + toSignature := makeHmac(credentials, []byte(stringtoSign)) + signature := hex.EncodeToString(toSignature) + return signature +} + +// buildStringToSign prepare data for building signature +func (sig *Signer) buildStringToSign(canonicalString string) string { + stringToSign := strings.Join([]string{ + modeOption.authHeaderPrefix, + sig.signTime.UTC().Format(modeOption.timeFormat), + sig.buildCredentialString(), + hex.EncodeToString(makeSha256([]byte(canonicalString))), + }, "\n") + return stringToSign +} + +// buildCanonicalRequest converts the request info into canonical format +func (sig *Signer) buildCanonicalRequest(request *http.Request, hexbody string) string { + canonicalHeadersOut := sig.buildCanonicalHeaders(request) + signedHeaders := sig.buildSignedHeadersString(request) + canonicalRequestStr := strings.Join([]string{ + request.Method, + request.URL.Path + "/", + request.URL.RawQuery, + canonicalHeadersOut, + signedHeaders, + hexbody, + }, "\n") + return canonicalRequestStr +} + +// buildCanonicalHeaders generate canonical headers +func (sig *Signer) buildCanonicalHeaders(request *http.Request) string { + var headers []string + + for header := range request.Header { + standardized := strings.ToLower(strings.TrimSpace(header)) + headers = append(headers, standardized) + } + sort.Strings(headers) + + for i, header := range headers { + headers[i] = header + ":" + strings.Replace(request.Header.Get(header), "\n", " ", -1) + } + + if len(headers) > 0 { + return strings.Join(headers, "\n") + "\n" + } + + return "" +} + +// buildSignedHeadersString convert the header in request to a certain format +func (sig *Signer) buildSignedHeadersString(request *http.Request) string { + var headers []string + for header := range request.Header { + headers = append(headers, strings.ToLower(header)) + } + sort.Strings(headers) + return strings.Join(headers, ";") +} + +// buildCredentialString add date and several other information to signature header +func (sig *Signer) buildCredentialString() string { + credentialString := strings.Join([]string{ + sig.signTime.UTC().Format(modeOption.shortTimeFormat), + sig.region, + sig.serviceName, + modeOption.terminalString, + }, "/") + return credentialString +} + +// makeHmac convert data into sha256 format with certain key +func makeHmac(key []byte, data []byte) []byte { + hash := hmac.New(sha256.New, key) + _, err := hash.Write(data) + if err != nil { + log.GetLogger().Errorf("failed to write in makeHmac, error: %s", err.Error()) + } + return hash.Sum(nil) + +} + +// makeHmac convert data into sha256 format +func makeSha256(data []byte) []byte { + hash := sha256.New() + _, err := hash.Write(data) + if err != nil { + log.GetLogger().Errorf("failed to write in makeSha256, error: %s", err.Error()) + } + return hash.Sum(nil) +} + +// makeHmac convert data into Hex format +func makeSha256Hex(data []byte) string { + hash := sha256.New() + _, err := hash.Write(data) + if err != nil { + log.GetLogger().Errorf("failed to write in makeSha256Hex, error: %s", err.Error()) + } + md := hash.Sum(nil) + hexBody := hex.EncodeToString(md) + return hexBody +} + +func getTimestampDiffLimit() float64 { + var tsDiffLimit float64 + envTimestampDiffLimit, err := strconv.Atoi(os.Getenv("AUTH_VALID_TIME_MINUTE")) + if err == nil && envTimestampDiffLimit > 0 && envTimestampDiffLimit <= maxTimestampDiffLimit { + tsDiffLimit = float64(envTimestampDiffLimit) + } else { + tsDiffLimit = float64(defaultTimestampDiffLimit) + } + log.GetLogger().Infof("current timestampDiffLimit is %f", tsDiffLimit) + return tsDiffLimit +} + +// AuthCheckLocally authenticates requests by local auth +func AuthCheckLocally(ak string, sk string, requestSign string, timestamp string, duration int) error { + if len(requestSign) == 0 { + return fmt.Errorf("authentication string is nil") + } + curTime := time.Now().Unix() + timeUnix, err := strconv.ParseInt(timestamp, base, bitSize) + if err != nil { + return fmt.Errorf("invalid timestamp") + } + // the default timestamp limit is 5 minutes + if math.Abs(float64(curTime-timeUnix)) >= timestampDiffLimit*time.Minute.Seconds() { + return fmt.Errorf("the request is timeout") + } + appID, err := getAppIDFromRequestSign(requestSign) + if err != nil { + return err + } + _, exist, err := GetLocalAuthCache(ak, sk, appID, duration).GetSignForReceiver(requestSign) + if err != nil { + log.GetLogger().Errorf("failed to get sign from receiver cache") + return err + } + if exist { + return nil + } + aKey, sKey, err := DecryptKeys(ak, sk) + if err != nil { + utils.ClearByteMemory(aKey) + utils.ClearByteMemory(sKey) + return err + } + key := &Authentication{ + AKey: aKey, + SKey: sKey, + } + var data []byte + signature, _ := signLocalAuthRequest("", timestamp, appID, key, data) + utils.ClearByteMemory(aKey) + utils.ClearByteMemory(sKey) + if signature == "" || signature != requestSign { + return fmt.Errorf("auth check failed") + } + if err := GetLocalAuthCache(ak, sk, appID, duration).updateReceiver(signature); err != nil { + log.GetLogger().Errorf("failed to update receiver cache") + return err + } + return nil +} + +func getAppIDFromRequestSign(sign string) (string, error) { + arrays := strings.Split(sign, "appid=") + if len(arrays) < minLengthOfAuthValue { + return "", fmt.Errorf("failed to parse authorization appid= %s", "*****") + } + arrays = strings.Split(arrays[1], ", ") + return arrays[0], nil +} + +// SignLocally makes signatures by local auth +func SignLocally(ak, sk, appID string, duration int) (string, string) { + t, auth, err := GetLocalAuthCache(ak, sk, appID, duration).GetSignForSender() + if err != nil { + var data []byte + log.GetLogger().Warnf("failed to get sender cache: %s", err.Error()) + return CreateAuthorization(ak, sk, "", appID, data) + } + return auth, t +} + +// SignOMSVC make signatures for request send to OMSVC +func SignOMSVC(ak, sk, url string, data []byte) (string, string) { + return CreateAuthorization(ak, sk, url, "", data) +} + +// CreateAuthorization create Authentication Information +func CreateAuthorization(ak, sk, url, appID string, data []byte) (string, string) { + timestamp := strconv.FormatInt(time.Now().Unix(), base) + aKey, sKey, err := DecryptKeys(ak, sk) + if err != nil { + utils.ClearByteMemory(aKey) + utils.ClearByteMemory(sKey) + log.GetLogger().Errorf("failed to decrypt SKey when create auth, error: %s", err.Error()) + return "", "" + } + key := &Authentication{ + AKey: aKey, + SKey: sKey, + } + authorization, _ := signLocalAuthRequest(url, timestamp, appID, key, data) + utils.ClearByteMemory(aKey) + utils.ClearByteMemory(sKey) + return authorization, timestamp +} diff --git a/yuanrong/pkg/common/faas_common/localauth/authcheck_test.go b/yuanrong/pkg/common/faas_common/localauth/authcheck_test.go new file mode 100644 index 0000000..edf4d0e --- /dev/null +++ b/yuanrong/pkg/common/faas_common/localauth/authcheck_test.go @@ -0,0 +1,292 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package localauth + +import ( + "errors" + + "net/http" + "net/url" + "os" + "reflect" + "strconv" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" +) + +func TestAuthCheckLocally(t *testing.T) { + type args struct { + ak string + sk string + requestSign string + timestamp string + duration int + } + var a args + var b args + b.requestSign = "aaa" + var c args + c.requestSign = "aaa" + c.timestamp = strconv.FormatInt(time.Now().AddDate(1, 0, 0).Unix(), 10) + var d args + d.requestSign = "aaa" + d.timestamp = strconv.FormatInt(time.Now().Unix(), 10) + var e args + e.requestSign = "aaa,appid=aaa" + e.timestamp = strconv.FormatInt(time.Now().Unix(), 10) + tests := []struct { + name string + args args + wantErr bool + }{ + {"case1", a, true}, + {"case2", b, true}, + {"case3", c, true}, + {"case4", d, true}, + {"case5", e, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := AuthCheckLocally(tt.args.ak, tt.args.sk, tt.args.requestSign, tt.args.timestamp, tt.args.duration); (err != nil) != tt.wantErr { + t.Errorf("AuthCheckLocally() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_GetTimestampDiffLimit(t *testing.T) { + tsDiffLimit := getTimestampDiffLimit() + assert.Equal(t, 5, int(tsDiffLimit)) + patches := gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return "100" + }) + tsDiffLimit = getTimestampDiffLimit() + assert.Equal(t, 100, int(tsDiffLimit)) + defer patches.Reset() +} + +func TestCreateAuthorization(t *testing.T) { + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(DecryptKeys, + func(_ string, _ string) ([]byte, []byte, error) { + return []byte{}, []byte{}, errors.New("aaa") + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + type args struct { + ak string + sk string + url string + appID string + data []byte + } + var a args + tests := []struct { + name string + args args + want string + want1 string + }{ + {"case1", a, "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := CreateAuthorization(tt.args.ak, tt.args.sk, tt.args.url, tt.args.appID, tt.args.data) + if got != tt.want { + t.Errorf("CreateAuthorization() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("CreateAuthorization() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestSignOMSVC(t *testing.T) { + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(DecryptKeys, + func(_ string, _ string) ([]byte, []byte, error) { + return []byte{}, []byte{}, errors.New("aaa") + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + type args struct { + ak string + sk string + url string + data []byte + } + var a args + tests := []struct { + name string + args args + want string + want1 string + }{ + {"case1", a, "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := SignOMSVC(tt.args.ak, tt.args.sk, tt.args.url, tt.args.data) + if got != tt.want { + t.Errorf("SignOMSVC() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("SignOMSVC() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestSigner_buildCanonicalHeaders(t *testing.T) { + type fields struct { + signTime time.Time + serviceName string + region string + } + type args struct { + request *http.Request + } + var f fields + var a args + request := &http.Request{ + Method: "", + URL: nil, + Proto: "", + ProtoMajor: 0, + ProtoMinor: 0, + Header: nil, + Body: nil, + GetBody: nil, + ContentLength: 0, + TransferEncoding: nil, + Close: false, + Host: "", + Form: nil, + PostForm: nil, + MultipartForm: nil, + Trailer: nil, + RemoteAddr: "", + RequestURI: "", + TLS: nil, + Response: nil, + } + a.request = request + tests := []struct { + name string + fields fields + args args + want string + }{ + {"case1", f, a, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sig := &Signer{ + signTime: tt.fields.signTime, + serviceName: tt.fields.serviceName, + region: tt.fields.region, + } + if got := sig.buildCanonicalHeaders(tt.args.request); got != tt.want { + t.Errorf("buildCanonicalHeaders() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getSigner(t *testing.T) { + type args struct { + mode string + serviceName string + region string + } + var a args + a.mode = "aaa" + signer := &Signer{} + tests := []struct { + name string + args args + want *Signer + }{ + {"case1", a, signer}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getSigner(tt.args.mode, tt.args.serviceName, tt.args.region); !reflect.DeepEqual(got.serviceName, tt.want.serviceName) { + t.Errorf("getSigner() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_signLocalAuthRequest(t *testing.T) { + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(url.Parse, + func(_ string) (*url.URL, error) { + return nil, errors.New("aaa") + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + type args struct { + rawURL string + timeStamp string + appID string + key *Authentication + data []byte + } + var a args + var b args + b.timeStamp = strconv.FormatInt(time.Now().Unix(), 10) + tests := []struct { + name string + args args + want string + }{ + {"case1", a, ""}, + {"case2", b, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, _ := signLocalAuthRequest(tt.args.rawURL, tt.args.timeStamp, tt.args.appID, tt.args.key, tt.args.data); got != tt.want { + t.Errorf("signLocalAuthRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSignLocally(t *testing.T) { + convey.Convey("TestSignLocally", t, func() { + auth, time := SignLocally("ak", "sk", "appID", 0) + convey.So(auth, convey.ShouldNotBeEmpty) + convey.So(time, convey.ShouldNotBeEmpty) + }) +} diff --git a/yuanrong/pkg/common/faas_common/localauth/crypto.go b/yuanrong/pkg/common/faas_common/localauth/crypto.go new file mode 100644 index 0000000..483b80d --- /dev/null +++ b/yuanrong/pkg/common/faas_common/localauth/crypto.go @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package localauth authenticates requests by local configmaps +package localauth + +import ( + "errors" + "sync" + + // Register aeswithkey engine + _ "huaweicloud.com/containers/security/cbb_adapt/src/go/aeswithkey" + "huaweicloud.com/containers/security/cbb_adapt/src/go/gcrypto" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +var ( + gcryptoEngine gcrypto.Engine + algorithm = "aeswithkey" + once sync.Once +) + +func initCrypto() error { + engine, err := gcrypto.New(algorithm) + if err != nil { + log.GetLogger().Errorf("failed to initialize the crypto engine, error is %s", err.Error()) + return err + } + gcryptoEngine = engine + return nil +} + +// Decrypt decrypts a cypher text using a certain algorithm +func Decrypt(src string) ([]byte, error) { + var err error + once.Do(func() { + err = initCrypto() + }) + if gcryptoEngine == nil { + return nil, err + } + + plaintext, err := gcryptoEngine.Decrypt(0, src) + if err != nil { + // error message may contain some sensitive content which should not be printed + return nil, errors.New("failed to decrypt the ciphertext") + } + text := []byte(plaintext) + utils.ClearStringMemory(plaintext) + return text, nil +} + +// Encrypt encrypts a cypher text using a certain algorithm +func Encrypt(src string) (string, error) { + var err error + once.Do(func() { + err = initCrypto() + }) + if gcryptoEngine == nil { + return "", errors.New("gcrypto engine is null") + } + + ciperText, err := gcryptoEngine.Encrypt(0, src) + if err != nil { + return "", errors.New("failed to encrypt the data") + } + return ciperText, nil +} + +// DecryptKeys decrypts a set of aKey and sKey +func DecryptKeys(inputAKey string, inputSKey string) ([]byte, []byte, error) { + aKey, err := Decrypt(inputAKey) + if err != nil { + log.GetLogger().Errorf("failed to decrypt AKey, error: %s", err.Error()) + return nil, nil, err + } + sKey, err := Decrypt(inputSKey) + if err != nil { + log.GetLogger().Errorf("failed to decrypt SKey, error: %s", err.Error()) + return nil, nil, err + } + return aKey, sKey, nil +} diff --git a/yuanrong/pkg/common/faas_common/localauth/crypto_test.go b/yuanrong/pkg/common/faas_common/localauth/crypto_test.go new file mode 100644 index 0000000..ba393c1 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/localauth/crypto_test.go @@ -0,0 +1,52 @@ +package localauth + +import ( + "sync" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "huaweicloud.com/containers/security/cbb_adapt/src/go/gcrypto" + + "yuanrong/pkg/common/faas_common/utils" +) + +type mockEngine struct { + gcrypto.Engine +} + +func (m mockEngine) Encrypt(domainId int, encData string) (string, error) { + return encData, nil +} + +func (m mockEngine) Decrypt(domainId int, encData string) (string, error) { + return encData, nil +} + +func TestEncrypt(t *testing.T) { + convey.Convey("Encrypt", t, func() { + defer gomonkey.ApplyFunc(gcrypto.New, func(algo string) (engine gcrypto.Engine, err error) { + return &mockEngine{}, nil + }).Reset() + once = sync.Once{} + encrypt, err := Encrypt("123") + convey.So(encrypt, convey.ShouldEqual, "123") + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestDecrypt(t *testing.T) { + convey.Convey("Encrypt", t, func() { + defer gomonkey.ApplyFunc(gcrypto.New, func(algo string) (engine gcrypto.Engine, err error) { + return &mockEngine{}, nil + }).Reset() + defer gomonkey.ApplyFunc(utils.ClearByteMemory, func(b []byte) { + return + }).Reset() + once = sync.Once{} + src := "123" + decrypt, err := Decrypt(src) + convey.So(string(decrypt), convey.ShouldEqual, "123") + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/common/faas_common/localauth/env.go b/yuanrong/pkg/common/faas_common/localauth/env.go new file mode 100644 index 0000000..c91ea7b --- /dev/null +++ b/yuanrong/pkg/common/faas_common/localauth/env.go @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package localauth authenticates requests by local configmaps +package localauth + +import ( + "encoding/json" + "os" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// GetDecryptFromEnv - +func GetDecryptFromEnv() (map[string]string, error) { + res := make(map[string]string) + value := os.Getenv("ENV_DELEGATE_DECRYPT") + err := json.Unmarshal([]byte(value), &res) + if err != nil { + log.GetLogger().Warnf("ENV_DELEGATE_DECRYPT unmarshal error, it is null") + } + return res, nil +} diff --git a/yuanrong/pkg/common/faas_common/localauth/env_test.go b/yuanrong/pkg/common/faas_common/localauth/env_test.go new file mode 100644 index 0000000..bc5a8b5 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/localauth/env_test.go @@ -0,0 +1,73 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package localauth authenticates requests by local configmaps +package localauth + +import ( + "encoding/json" + "errors" + "os" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + + mockUtils "yuanrong/pkg/common/faas_common/utils" +) + +func TestGetDecryptFromEnv(t *testing.T) { + tests := []struct { + name string + want map[string]string + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 failed to unmarshal", make(map[string]string), false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error { + return errors.New("failed to unmarshal json") + }), + }) + return patches + }}, + {"case2 succeed to unmarshal", map[string]string{"test": "test"}, + false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return `{"test":"test"}` + }), + }) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + got, err := GetDecryptFromEnv() + if (err != nil) != tt.wantErr { + t.Errorf("GetDecryptFromEnv() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetDecryptFromEnv() got = %v, want %v", got, tt.want) + } + patches.ResetAll() + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/logger/async/writer.go b/yuanrong/pkg/common/faas_common/logger/async/writer.go new file mode 100644 index 0000000..bc2e984 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/async/writer.go @@ -0,0 +1,176 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package async makes io.Writer write async +package async + +import ( + "bytes" + "fmt" + "io" + "sync/atomic" + "time" + + "go.uber.org/zap/buffer" +) + +const ( + diskBufferSize = 1024 * 1024 + diskFlushSize = diskBufferSize >> 1 + diskFlushTime = 500 * time.Millisecond + defaultChannelSize = 200000 + softLimitFactor = 0.8 // must be smaller than 1 +) + +var ( + linePool = buffer.NewPool() +) + +// Opt - +type Opt func(*Writer) + +// WithCachedLimit - +func WithCachedLimit(limit int) Opt { + return func(w *Writer) { + w.cachedLimit = limit + w.cachedSoftLimit = int(float64(limit) * softLimitFactor) + w.cachedLow = w.cachedSoftLimit >> 1 + } +} + +// NewAsyncWriteSyncer wrappers io.Writer to async zapcore.WriteSyncer +func NewAsyncWriteSyncer(w io.Writer, opts ...Opt) *Writer { + writer := &Writer{ + w: w, + diskBuf: bytes.NewBuffer(make([]byte, 0, diskBufferSize)), + lines: make(chan *buffer.Buffer, defaultChannelSize), + sync: make(chan struct{}), + syncDone: make(chan struct{}), + } + for _, opt := range opts { + opt(writer) + } + go writer.logConsumer() + return writer +} + +// Writer - +type Writer struct { + diskBuf *bytes.Buffer + lines chan *buffer.Buffer + w io.Writer + sync chan struct{} + syncDone chan struct{} + + cachedLimit int + cachedSoftLimit int + cachedLow int + cached int64 // atomic +} + +// Write sends data to channel non-blocking +func (w *Writer) Write(data []byte) (int, error) { + // note: data will be put back to zap's inner pool after Write, so we couldn't send it to channel directly + lp := linePool.Get() + lp.Write(data) + select { + case w.lines <- lp: + if w.cachedLimit != 0 && atomic.AddInt64(&w.cached, int64(len(data))) > int64(w.cachedLimit) { + w.doSync() + } + default: + fmt.Println("failed to push log to channel, skip") + lp.Free() + } + return len(data), nil +} + +// Sync implements zapcore.WriteSyncer. Current do nothing. +func (w *Writer) Sync() error { + w.doSync() + return nil +} + +func (w *Writer) doSync() { + w.sync <- struct{}{} + <-w.syncDone +} + +func (w *Writer) logConsumer() { + ticker := time.NewTicker(diskFlushTime) +loop: + for { + select { + case line := <-w.lines: + w.write(line) + if w.cachedLimit != 0 && atomic.LoadInt64(&w.cached) > int64(w.cachedSoftLimit) { + w.flushLines(len(w.lines), w.cachedLow) + } + case <-ticker.C: + if w.diskBuf.Len() == 0 { + continue + } + if _, err := w.w.Write(w.diskBuf.Bytes()); err != nil { + fmt.Println("failed to write", err.Error()) + } + w.diskBuf.Reset() + case _, ok := <-w.sync: + if !ok { + close(w.syncDone) + break loop + } + nLines := len(w.lines) + if nLines == 0 && w.diskBuf.Len() == 0 { + w.syncDone <- struct{}{} + continue + } + w.flushLines(nLines, -1) + if _, err := w.w.Write(w.diskBuf.Bytes()); err != nil { + fmt.Println("failed to write", err.Error()) + } + w.diskBuf.Reset() + w.syncDone <- struct{}{} + } + } + ticker.Stop() +} + +func (w *Writer) flushLines(nLines int, upTo int) { + nBytes := 0 + for i := 0; i < nLines; i++ { + line := <-w.lines + nBytes += line.Len() + w.write(line) + if upTo >= 0 && nBytes > upTo { + break + } + } +} + +func (w *Writer) write(line *buffer.Buffer) { + w.diskBuf.Write(line.Bytes()) + if w.cachedLimit != 0 { + atomic.AddInt64(&w.cached, -int64(line.Len())) + } + line.Free() + if w.diskBuf.Len() < diskFlushSize { + return + } + if _, err := w.w.Write(w.diskBuf.Bytes()); err != nil { + fmt.Println("failed to write", err.Error()) + } + w.diskBuf.Reset() +} diff --git a/yuanrong/pkg/common/faas_common/logger/async/writer_test.go b/yuanrong/pkg/common/faas_common/logger/async/writer_test.go new file mode 100644 index 0000000..3f34a76 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/async/writer_test.go @@ -0,0 +1,125 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package async + +import ( + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type mockWriter struct { + buf []byte + delay time.Duration + sync.Mutex +} + +func (m *mockWriter) Write(data []byte) (int, error) { + m.Lock() + m.buf = data + if m.delay != 0 { + time.Sleep(m.delay) + } + m.Unlock() + return len(data), nil +} + +func (m *mockWriter) Clear() []byte { + m.Lock() + ret := m.buf + m.buf = nil + m.Unlock() + return ret +} + +func (m *mockWriter) SetWriteDelay(delay time.Duration) { + m.delay = delay +} + +func TestWriter_Write(t *testing.T) { + w := &mockWriter{} + + asyncWriter := NewAsyncWriteSyncer(w) + + data := []byte("hello world") + + // write small data, will be cached in inner buffer + asyncWriter.Write(data) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 0, len(w.Clear())) + + // small data will be written after flush time + time.Sleep(diskFlushTime) + assert.Equal(t, data, w.Clear()) + + // big data will be flushed immediately + asyncWriter.Write(make([]byte, diskFlushSize+1)) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, diskFlushSize+1, len(w.Clear())) + + // Sync() will flush buffer immediately + asyncWriter.Write(data) + assert.Equal(t, 0, len(w.Clear())) + asyncWriter.Sync() + assert.Equal(t, len(data), len(w.Clear())) + + for i := 0; i < 100; i++ { + go asyncWriter.Sync() + } + time.Sleep(10 * time.Millisecond) + asyncWriter.Sync() +} + +func TestCachedLimit(t *testing.T) { + w := &mockWriter{} + w.SetWriteDelay(150 * time.Millisecond) + + asyncWriter := NewAsyncWriteSyncer(w, WithCachedLimit(diskFlushSize*4)) // softLimit = 512kb * 4 * 0.8 = 1.6mb + + size := float64(diskFlushSize)*2*softLimitFactor + 1 + data := make([]byte, int(size)) + + // big data will be flushed immediately and triggers the mockWriter's write + asyncWriter.Write(make([]byte, diskFlushSize+1)) + + // logConsumer is blocked in mockWriter's write + asyncWriter.Write(data) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 1, len(asyncWriter.lines)) + + // this write should hit the soft limit + asyncWriter.Write(data) + time.Sleep(100 * time.Millisecond) // mockWriter's write finishes + assert.Equal(t, 1, len(asyncWriter.lines)) + + time.Sleep(200 * time.Millisecond) + assert.Equal(t, 0, len(asyncWriter.lines)) +} + +func BenchmarkWrite(b *testing.B) { + asyncWriter := NewAsyncWriteSyncer(io.Discard, WithCachedLimit(diskFlushSize*4)) + data := []byte("hello world") + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + asyncWriter.Write(data) + } + }) +} diff --git a/yuanrong/pkg/common/faas_common/logger/config/config.go b/yuanrong/pkg/common/faas_common/logger/config/config.go new file mode 100644 index 0000000..1d52732 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/config/config.go @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config is common logger client +package config + +import ( + "encoding/json" + "errors" + "os" + + "github.com/asaskevich/govalidator/v11" + "go.uber.org/zap/zapcore" + + "yuanrong/pkg/common/faas_common/utils" +) + +const ( + configPath = "/home/sn/config/log.json" + fileMode = 0750 + logConfigKey = "LOG_CONFIG" +) + +var ( + defaultCoreInfo CoreInfo + // LogLevel - + LogLevel zapcore.Level = zapcore.InfoLevel +) + +func init() { + defaultFilePath := os.Getenv("GLOG_log_dir") + if defaultFilePath == "" { + defaultFilePath = "/home/snuser/log" + } + defaultLevel := "INFO" + // defaultCoreInfo default logger config + defaultCoreInfo = CoreInfo{ + FilePath: defaultFilePath, + Level: defaultLevel, + Tick: 0, // Unit: Second + First: 0, // Unit: Number of logs + Thereafter: 0, // Unit: Number of logs + SingleSize: 100, + Threshold: 10, + Tracing: false, // tracing log switch + Disable: false, // Disable file logger + } +} + +// CoreInfo contains the core info +type CoreInfo struct { + FilePath string `json:"filepath" valid:",optional"` + Level string `json:"level" valid:",optional"` + Tick int `json:"tick" valid:"range(0|86400),optional"` + First int `json:"first" valid:"range(0|20000),optional"` + Thereafter int `json:"thereafter" valid:"range(0|1000),optional"` + Tracing bool `json:"tracing" valid:",optional"` + Disable bool `json:"disable" valid:",optional"` + SingleSize int64 `json:"singlesize" valid:",optional"` + Threshold int `json:"threshold" valid:",optional"` + IsUserLog bool `json:"-"` + IsWiseCloudAlarmLog bool `json:"isWiseCloudAlarmLog" valid:",optional"` +} + +// GetDefaultCoreInfo get defaultCoreInfo +func GetDefaultCoreInfo() CoreInfo { + return defaultCoreInfo +} + +// GetCoreInfoFromEnv extracts the logger config and ensures that the log file is available +func GetCoreInfoFromEnv() (CoreInfo, error) { + coreInfo, err := ExtractCoreInfoFromEnv(logConfigKey) + if err != nil { + return defaultCoreInfo, err + } + if err = utils.ValidateFilePath(coreInfo.FilePath); err != nil { + return defaultCoreInfo, err + } + if err = os.MkdirAll(coreInfo.FilePath, fileMode); err != nil && !os.IsExist(err) { + return defaultCoreInfo, err + } + + return coreInfo, nil +} + +// ExtractCoreInfoFromEnv extracts the logger config from ENV +func ExtractCoreInfoFromEnv(env string) (CoreInfo, error) { + var coreInfo CoreInfo + conf := os.Getenv(env) + if conf == "" { + return defaultCoreInfo, errors.New(env + " is empty") + } + err := json.Unmarshal([]byte(conf), &coreInfo) + if err != nil { + return defaultCoreInfo, err + } + + // if the file path is empty, return error + // if the log file is not writable, zap will create a new file with the configured file path and file name + if coreInfo.FilePath == "" { + return defaultCoreInfo, errors.New("the log file path is empty") + } + if _, err = govalidator.ValidateStruct(coreInfo); err != nil { + return defaultCoreInfo, err + } + + return coreInfo, nil +} diff --git a/yuanrong/pkg/common/faas_common/logger/config/config_test.go b/yuanrong/pkg/common/faas_common/logger/config/config_test.go new file mode 100644 index 0000000..a658828 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/config/config_test.go @@ -0,0 +1,315 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config is common logger client +package config + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/utils" + mockUtils "yuanrong/pkg/common/faas_common/utils" +) + +func TestInitConfig(t *testing.T) { + convey.Convey("TestInitConfig", t, func() { + convey.Convey("test 1", func() { + patches := gomonkey.ApplyFunc(GetCoreInfoFromEnv, func() (CoreInfo, error) { + return defaultCoreInfo, nil + }) + defer patches.Reset() + coreInfo, err := GetCoreInfoFromEnv() + fmt.Printf("log config:%+v\n", coreInfo) + convey.So(err, convey.ShouldEqual, nil) + }) + }) +} + +func TestInitConfigWithReadFileError(t *testing.T) { + convey.Convey("TestInitConfigWithEmptyPath", t, func() { + convey.Convey("test 1", func() { + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(ioutil.ReadFile, + func(filename string) ([]byte, error) { + return nil, errors.New("mock read file error") + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + coreInfo, err := GetCoreInfoFromEnv() + fmt.Printf("error:%s\n", err) + fmt.Printf("log config:%+v\n", coreInfo) + convey.So(err, convey.ShouldNotEqual, nil) + }) + }) +} + +func TestInitConfigWithErrorJson(t *testing.T) { + convey.Convey("TestInitConfigWithEmptyPath", t, func() { + convey.Convey("test 1", func() { + mockErrorJson := "{\n\"filepath\": \"/home/sn/mock\",\n\"level\": \"INFO\",\n\"maxsize\": " + + "500,\n\"maxbackups\": 1,\n\"maxage\": 1,\n\"compress\": true\n" + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(ioutil.ReadFile, + func(filename string) ([]byte, error) { + return []byte(mockErrorJson), nil + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + coreInfo, err := GetCoreInfoFromEnv() + fmt.Printf("error:%s\n", err) + fmt.Printf("log config:%+v\n", coreInfo) + convey.So(err, convey.ShouldNotEqual, nil) + }) + }) +} + +func TestInitConfigWithEmptyPath(t *testing.T) { + convey.Convey("TestInitConfigWithEmptyPath", t, func() { + convey.Convey("test 1", func() { + mockCfgInfo := "{\n\"filepath\": \"\",\n\"level\": \"INFO\",\n\"maxsize\": " + + "500,\n\"maxbackups\": 1,\n\"maxage\": 1,\n\"compress\": true\n}" + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(ioutil.ReadFile, + func(filename string) ([]byte, error) { + return []byte(mockCfgInfo), nil + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + coreInfo, err := GetCoreInfoFromEnv() + fmt.Printf("error:%s\n", err) + fmt.Printf("log config:%+v\n", coreInfo) + convey.So(err, convey.ShouldNotEqual, nil) + }) + }) +} + +func TestInitConfigWithValidateError(t *testing.T) { + convey.Convey("TestInitConfigWithEmptyPath", t, func() { + convey.Convey("test 1", func() { + mockErrorJson := "{\n\"filepath\": \"some_relative_path\",\n\"level\": \"INFO\",\n\"maxsize\": " + + "500,\n\"maxbackups\": 1,\n\"maxage\": 1}" + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(ioutil.ReadFile, + func(filename string) ([]byte, error) { + return []byte(mockErrorJson), nil + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + coreInfo, err := GetCoreInfoFromEnv() + fmt.Printf("error:%s\n", err) + fmt.Printf("log config:%+v\n", coreInfo) + convey.So(err, convey.ShouldNotEqual, nil) + }) + }) +} + +func TestGetDefaultCoreInfo(t *testing.T) { + tests := []struct { + name string + want CoreInfo + }{ + { + name: "test001", + want: CoreInfo{ + FilePath: "/home/snuser/log", + Level: "INFO", + Tick: 0, // Unit: Second + First: 0, // Unit: Number of logs + Thereafter: 0, // Unit: Number of logs + SingleSize: 100, + Threshold: 10, + Tracing: false, // tracing log switch + Disable: false, // Disable file logger + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetDefaultCoreInfo(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetDefaultCoreInfo() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestExtractCoreInfoFromEnv(t *testing.T) { + normalInfo, _ := json.Marshal(defaultCoreInfo) + abnormal1 := mockUtils.PatchSlice{} + abnormalInfo1, _ := json.Marshal(abnormal1) + abnormal2 := CoreInfo{ + FilePath: "", + Level: "INFO", + Tick: 10, // Unit: Second + First: 10, // Unit: Number of logs + Thereafter: 5, // Unit: Number of logs + Tracing: false, // tracing log switch + Disable: false, // Disable file logger + } + abnormalInfo2, _ := json.Marshal(abnormal2) + type args struct { + env string + } + tests := []struct { + name string + args args + want CoreInfo + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + { + name: "case1", + args: args{logConfigKey}, + want: defaultCoreInfo, + wantErr: false, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(os.Getenv, + func(key string) string { + return string(normalInfo) + }), + }) + return patches + }, + }, + { + name: "case2", + args: args{logConfigKey}, + want: defaultCoreInfo, + wantErr: true, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(os.Getenv, + func(key string) string { + return string(abnormalInfo1) + }), + }) + return patches + }, + }, + { + name: "case3", + args: args{logConfigKey}, + want: defaultCoreInfo, + wantErr: true, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(os.Getenv, + func(key string) string { + return string(abnormalInfo2) + }), + }) + return patches + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + got, err := ExtractCoreInfoFromEnv(tt.args.env) + if (err != nil) != tt.wantErr { + t.Errorf("ExtractCoreInfoFromEnv() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ExtractCoreInfoFromEnv() got = %v, want %v", got, tt.want) + } + patches.ResetAll() + }) + } +} + +func TestGetCoreInfoFromEnv(t *testing.T) { + convey.Convey("GetCoreInfoFromEnv", t, func() { + convey.Convey("ValidateFilePath error", func() { + defer gomonkey.ApplyFunc(ExtractCoreInfoFromEnv, func(env string) (CoreInfo, error) { + return CoreInfo{FilePath: "../test"}, nil + }).Reset() + _, err := GetCoreInfoFromEnv() + convey.So(err, convey.ShouldBeError) + }) + + convey.Convey("MkdirAll error", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(ExtractCoreInfoFromEnv, func(env string) (CoreInfo, error) { + return CoreInfo{FilePath: "/home/test"}, nil + }), + gomonkey.ApplyFunc(utils.ValidateFilePath, func(path string) error { + return nil + }), + gomonkey.ApplyFunc(os.MkdirAll, func(path string, perm os.FileMode) error { + return errors.New("create dir error") + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + _, err := GetCoreInfoFromEnv() + convey.So(err, convey.ShouldBeError) + }) + + convey.Convey("success", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(ExtractCoreInfoFromEnv, func(env string) (CoreInfo, error) { + return CoreInfo{FilePath: "/home/test"}, nil + }), + gomonkey.ApplyFunc(utils.ValidateFilePath, func(path string) error { + return nil + }), + gomonkey.ApplyFunc(os.MkdirAll, func(path string, perm os.FileMode) error { + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + env, err := GetCoreInfoFromEnv() + convey.So(err, convey.ShouldBeNil) + convey.So(env.FilePath, convey.ShouldEqual, "/home/test") + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/logger/custom_encoder.go b/yuanrong/pkg/common/faas_common/logger/custom_encoder.go new file mode 100644 index 0000000..0365000 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/custom_encoder.go @@ -0,0 +1,391 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger log +package logger + +import ( + "math" + "os" + "regexp" + "strings" + "sync" + "time" + + "go.uber.org/zap/buffer" + "go.uber.org/zap/zapcore" + + "yuanrong/pkg/common/faas_common/constant" +) + +const ( + float64bitSize = 64 + float32bitSize = 32 + headerSeparator = ' ' + elementSeparator = " " + customDefaultLineEnding = "\n" + logMsgMaxLen = 1024 + fieldSeparator = " | " +) + +var ( + _customBufferPool = buffer.NewPool() + + _customPool = sync.Pool{New: func() interface{} { + return &customEncoder{} + }} + + replComp = regexp.MustCompile(`\s+`) + + clusterName = os.Getenv("CLUSTER_ID") +) + +// customEncoder represents the encoder for zap logger +// project's interface log +type customEncoder struct { + *zapcore.EncoderConfig + buf *buffer.Buffer + podName string +} + +// NewConsoleEncoder new custom console encoder to zap log module +func NewConsoleEncoder(cfg zapcore.EncoderConfig) (zapcore.Encoder, error) { + return &customEncoder{ + EncoderConfig: &cfg, + buf: _customBufferPool.Get(), + podName: os.Getenv(constant.HostNameEnvKey), + }, nil +} + +// NewCustomEncoder new custom encoder to zap log module +func NewCustomEncoder(cfg *zapcore.EncoderConfig) zapcore.Encoder { + return &customEncoder{ + EncoderConfig: cfg, + buf: _customBufferPool.Get(), + podName: os.Getenv(constant.HostNameEnvKey), + } +} + +// Clone return zap core Encoder +func (enc *customEncoder) Clone() zapcore.Encoder { + clone := enc.clone() + if enc.buf.Len() > 0 { + _, _ = clone.buf.Write(enc.buf.Bytes()) + } + return clone +} + +// EncodeEntry - +func (enc *customEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + final := enc.clone() + // add time + final.AppendString(ent.Time.UTC().Format("2006-01-02 15:04:05.000")) + final.buf.AppendString(fieldSeparator) + + final.EncodeLevel(ent.Level, final) + final.buf.AppendString(fieldSeparator) + + // add caller + if ent.Caller.Defined { + final.EncodeCaller(ent.Caller, final) + final.buf.AppendString(fieldSeparator) + } + // add podName + if enc.podName != "" { + final.buf.AppendString(enc.podName) + final.buf.AppendString(fieldSeparator) + } + // add clusterName + if clusterName != "" { + final.buf.AppendString(clusterName) + final.buf.AppendString(fieldSeparator) + } + if enc.buf.Len() > 0 { + final.buf.Write(enc.buf.Bytes()) + } + // add msg + if len(ent.Message) > logMsgMaxLen { + final.AppendString(ent.Message[0:logMsgMaxLen]) + } else { + final.AppendString(ent.Message) + } + if ent.Stack != "" && final.StacktraceKey != "" { + final.buf.AppendString(elementSeparator) + final.AddString(final.StacktraceKey, ent.Stack) + } + for _, field := range fields { + field.AddTo(final) + } + final.buf.AppendString(customDefaultLineEnding) + ret := final.buf + putCustomEncoder(final) + return ret, nil +} + +func putCustomEncoder(enc *customEncoder) { + enc.EncoderConfig = nil + enc.buf = nil + _customPool.Put(enc) +} + +func getCustomEncoder() *customEncoder { + return _customPool.Get().(*customEncoder) +} + +func (enc *customEncoder) clone() *customEncoder { + clone := getCustomEncoder() + clone.buf = _customBufferPool.Get() + clone.EncoderConfig = enc.EncoderConfig + clone.podName = enc.podName + return clone +} + +func (enc *customEncoder) writeField(k string, writeVal func()) *customEncoder { + enc.buf.AppendString("(" + k + ":") + writeVal() + enc.buf.AppendString(")") + return enc +} + +// AddArray Add Array +func (enc *customEncoder) AddArray(k string, marshaler zapcore.ArrayMarshaler) error { + return nil +} + +// AddObject Add Object +func (enc *customEncoder) AddObject(k string, marshaler zapcore.ObjectMarshaler) error { + return nil +} + +// AddBinary Add Binary +func (enc *customEncoder) AddBinary(k string, v []byte) { + enc.AddString(k, string(v)) +} + +// AddByteString Add Byte String +func (enc *customEncoder) AddByteString(k string, v []byte) { + enc.AddString(k, string(v)) +} + +// AddBool Add Bool +func (enc *customEncoder) AddBool(k string, v bool) { + enc.writeField(k, func() { + enc.AppendBool(v) + }) +} + +// AddComplex128 Add Complex128 +func (enc *customEncoder) AddComplex128(k string, val complex128) {} + +// AddComplex64 Add Complex64 +func (enc *customEncoder) AddComplex64(k string, v complex64) {} + +// AddDuration Add Duration +func (enc *customEncoder) AddDuration(k string, val time.Duration) { + enc.writeField(k, func() { + enc.AppendString(val.String()) + }) +} + +// AddFloat64 Add Float64 +func (enc *customEncoder) AddFloat64(k string, val float64) { + enc.writeField(k, func() { + enc.AppendFloat64(val) + }) +} + +// AddFloat32 Add Float32 +func (enc *customEncoder) AddFloat32(k string, v float32) { + enc.writeField(k, func() { + enc.AppendFloat64(float64(v)) + }) +} + +// AddInt Add Int +func (enc *customEncoder) AddInt(k string, v int) { + enc.writeField(k, func() { + enc.AppendInt64(int64(v)) + }) +} + +// AddInt64 Add Int64 +func (enc *customEncoder) AddInt64(k string, val int64) { + enc.writeField(k, func() { + enc.AppendInt64(val) + }) +} + +// AddInt32 Add Int32 +func (enc *customEncoder) AddInt32(k string, v int32) { + enc.writeField(k, func() { + enc.AppendInt64(int64(v)) + }) +} + +// AddInt16 Add Int16 +func (enc *customEncoder) AddInt16(k string, v int16) { + enc.writeField(k, func() { + enc.AppendInt64(int64(v)) + }) +} + +// AddInt8 Add Int8 +func (enc *customEncoder) AddInt8(k string, v int8) { + enc.writeField(k, func() { + enc.AppendInt64(int64(v)) + }) +} + +// AddString Append String +func (enc *customEncoder) AddString(k, v string) { + enc.writeField(k, func() { + v = replComp.ReplaceAllString(v, " ") + if strings.Contains(v, " ") { + enc.buf.AppendString("(" + v + ")") + return + } + enc.AppendString(v) + }) +} + +// AddTime Add Time +func (enc *customEncoder) AddTime(k string, v time.Time) { + enc.writeField(k, func() { + enc.AppendString(v.UTC().Format("2006-01-02 15:04:05.000")) + }) +} + +// AddUint Add Uint +func (enc *customEncoder) AddUint(k string, v uint) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddUint64 Add Uint64 +func (enc *customEncoder) AddUint64(k string, v uint64) { + enc.writeField(k, func() { + enc.AppendUint64(v) + }) +} + +// AddUint32 Add Uint32 +func (enc *customEncoder) AddUint32(k string, v uint32) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddUint16 Add Uint16 +func (enc *customEncoder) AddUint16(k string, v uint16) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddUint8 Add Uint8 +func (enc *customEncoder) AddUint8(k string, v uint8) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddUintptr Add Uint ptr +func (enc *customEncoder) AddUintptr(k string, v uintptr) { + enc.writeField(k, func() { + enc.AppendUint64(uint64(v)) + }) +} + +// AddReflected uses reflection to serialize arbitrary objects, so it's slow +// and allocation-heavy. +func (enc *customEncoder) AddReflected(k string, v interface{}) error { + return nil +} + +// OpenNamespace opens an isolated namespace where all subsequent fields will +// be added. Applications can use namespaces to prevent key collisions when +// injecting loggers into sub-components or third-party libraries. +func (enc *customEncoder) OpenNamespace(k string) {} + +// AppendBool Append Bool +func (enc *customEncoder) AppendBool(v bool) { enc.buf.AppendBool(v) } + +// AppendByteString Append Byte String +func (enc *customEncoder) AppendByteString(v []byte) { enc.AppendString(string(v)) } + +// AppendComplex128 Append Complex128 +func (enc *customEncoder) AppendComplex128(v complex128) {} + +// AppendComplex64 Append Complex64 +func (enc *customEncoder) AppendComplex64(v complex64) {} + +// AppendFloat64 Append Float64 +func (enc *customEncoder) AppendFloat64(v float64) { enc.appendFloat(v, float64bitSize) } + +// AppendFloat32 Append Float32 +func (enc *customEncoder) AppendFloat32(v float32) { enc.appendFloat(float64(v), float32bitSize) } + +func (enc *customEncoder) appendFloat(v float64, bitSize int) { + switch { + // If the condition is not met, a string is returned to prevent blankness. + // IsNaN reports whether f is an IEEE 754 ``not-a-number'' value. + case math.IsNaN(v): + enc.buf.AppendString(`"NaN"`) + case math.IsInf(v, 1): + enc.buf.AppendString(`"+Inf"`) + case math.IsInf(v, -1): + enc.buf.AppendString(`"-Inf"`) + default: + enc.buf.AppendFloat(v, bitSize) + } +} + +// AppendInt Append Int +func (enc *customEncoder) AppendInt(v int) { enc.buf.AppendInt(int64(v)) } + +// AppendInt64 Append Int64 +func (enc *customEncoder) AppendInt64(v int64) { enc.buf.AppendInt(v) } + +// AppendInt32 Append Int32 +func (enc *customEncoder) AppendInt32(v int32) { enc.buf.AppendInt(int64(v)) } + +// AppendInt16 Append Int16 +func (enc *customEncoder) AppendInt16(v int16) { enc.buf.AppendInt(int64(v)) } + +// AppendInt8 Append Int8 +func (enc *customEncoder) AppendInt8(v int8) { enc.buf.AppendInt(int64(v)) } + +// AppendString Append String +func (enc *customEncoder) AppendString(val string) { enc.buf.AppendString(val) } + +// AppendUint Append Uint +func (enc *customEncoder) AppendUint(v uint) { enc.buf.AppendUint(uint64(v)) } + +// AppendUint64 Append Uint64 +func (enc *customEncoder) AppendUint64(v uint64) { enc.buf.AppendUint(v) } + +// AppendUint32 Append Uint32 +func (enc *customEncoder) AppendUint32(v uint32) { enc.buf.AppendUint(uint64(v)) } + +// AppendUint16 Append Uint16 +func (enc *customEncoder) AppendUint16(v uint16) { enc.buf.AppendUint(uint64(v)) } + +// AppendUint8 Append Uint8 +func (enc *customEncoder) AppendUint8(v uint8) { enc.buf.AppendUint(uint64(v)) } + +// AppendUintptr Append Uint ptr +func (enc *customEncoder) AppendUintptr(v uintptr) { enc.buf.AppendUint(uint64(v)) } diff --git a/yuanrong/pkg/common/faas_common/logger/custom_encoder_test.go b/yuanrong/pkg/common/faas_common/logger/custom_encoder_test.go new file mode 100644 index 0000000..32b05f3 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/custom_encoder_test.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logger + +import ( + "math" + "os" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap/zapcore" + + "yuanrong/pkg/common/faas_common/constant" +) + +func TestNewCustomEncoder(t *testing.T) { + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "Logger", + MessageKey: "M", + CallerKey: "C", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + encoder := NewCustomEncoder(&encoderConfig) + clone := encoder.Clone() + assert.NotEmpty(t, clone) + encoder.AddBool("3", true) + err := encoder.AddArray("4", nil) + assert.Empty(t, err) + err = encoder.AddObject("4", nil) + assert.Empty(t, err) + encoder.AddBinary("4", []byte{}) + encoder.AddComplex128("4", complex(1, 2)) + encoder.AddComplex64("4", complex(1, 2)) + encoder.AddDuration("4", time.Second) + encoder.AddByteString("4", []byte{}) + encoder.AddFloat64("2", 3.14) + encoder.AddFloat32("2", 3.14) + encoder.AddInt("1", 1) + encoder.AddInt8("1", 1) + encoder.AddInt16("1", 1) + encoder.AddInt32("1", 1) + encoder.AddInt64("1", 1) + encoder.AddString("5", "12") + encoder.AddString("5", "1 2") + encoder.AddTime("6", time.Time{}) + encoder.AddUint("1", uint(1)) + encoder.AddUint8("1", uint8(10)) + encoder.AddUint16("1", uint16(100)) + encoder.AddUint32("1", uint32(1000)) + encoder.AddUint64("1", uint64(1000)) + b := make([]int, 1) + encoder.AddUintptr("12", uintptr(unsafe.Pointer(&b[0]))) + encoder.OpenNamespace("3") + err = encoder.AddReflected("3", 1) + assert.Empty(t, err) + +} + +func Test_customEncoder_Append(t *testing.T) { + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "Logger", + MessageKey: "M", + CallerKey: "C", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + encoder := &customEncoder{ + EncoderConfig: &encoderConfig, + buf: _customBufferPool.Get(), + podName: os.Getenv(constant.HostNameEnvKey), + } + encoder.AppendInt16(1) + encoder.AppendUint32(2) + encoder.AppendByteString([]byte("abc")) + encoder.AppendFloat32(3) + encoder.appendFloat(math.Inf(1), 10) + encoder.appendFloat(math.Inf(-1), 10) + encoder.AppendComplex64(1 + 2i) + encoder.AppendUintptr(uintptr(1)) + encoder.AppendUint8(1) + encoder.AppendInt32(2) + encoder.AppendUint16(3) + encoder.AppendUint(4) + encoder.AppendInt8(0) + encoder.AppendInt(5) + encoder.AppendInt32(7) + assert.NotEmpty(t, encoder.buf.Len()) +} diff --git a/yuanrong/pkg/common/faas_common/logger/healthlog/healthlog.go b/yuanrong/pkg/common/faas_common/logger/healthlog/healthlog.go new file mode 100644 index 0000000..e1caafe --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/healthlog/healthlog.go @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package healthlog is for printing health logs +package healthlog + +import ( + "time" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +const logInterval = 5 * time.Minute + +// PrintHealthLog prints timing health logs of components +func PrintHealthLog(stopCh <-chan struct{}, inputLog func(), name string) { + if stopCh == nil { + log.GetLogger().Errorf("stop channel is nil") + return + } + ticker := time.NewTicker(logInterval) + defer ticker.Stop() + time.After(logInterval) + for { + select { + case <-ticker.C: + inputLog() + case <-stopCh: + log.GetLogger().Warnf("%s receives a terminating signal", name) + return + } + } +} diff --git a/yuanrong/pkg/common/faas_common/logger/healthlog/healthlog_test.go b/yuanrong/pkg/common/faas_common/logger/healthlog/healthlog_test.go new file mode 100644 index 0000000..f866d10 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/healthlog/healthlog_test.go @@ -0,0 +1,49 @@ +package healthlog + +import "testing" + +func TestPrintHealthLog(t *testing.T) { + type args struct { + stopCh chan struct{} + inputLog func() + name string + } + var a args + a.stopCh = nil + a.inputLog = func() { + return + } + + tests := []struct { + name string + args args + }{ + { + name: "case1", + args: args{ + stopCh: nil, + inputLog: func() { + return + }, + }, + }, + { + name: "case2", + args: args{ + stopCh: make(chan struct{}), + inputLog: func() { + return + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.stopCh != nil { + close(tt.args.stopCh) + } + PrintHealthLog(tt.args.stopCh, tt.args.inputLog, tt.args.name) + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/logger/interface_encoder.go b/yuanrong/pkg/common/faas_common/logger/interface_encoder.go new file mode 100644 index 0000000..dacf594 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/interface_encoder.go @@ -0,0 +1,346 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger log +package logger + +import ( + "errors" + "math" + "os" + "sync" + "time" + + "go.uber.org/zap/buffer" + "go.uber.org/zap/zapcore" + + "yuanrong/pkg/common/faas_common/constant" +) + +var ( + _bufferPool = buffer.NewPool() + + _interfacePool = sync.Pool{New: func() interface{} { + return &interfaceEncoder{} + }} +) + +// InterfaceEncoderConfig holds interface log encoder config +type InterfaceEncoderConfig struct { + ModuleName string + HTTPMethod string + ModuleFrom string + TenantID string + FuncName string + FuncVer string + EncodeCaller zapcore.CallerEncoder +} + +// interfaceEncoder represents the encoder for interface log +// project's interface log +type interfaceEncoder struct { + *InterfaceEncoderConfig + buf *buffer.Buffer + podName string + spaced bool +} + +func getInterfaceEncoder() *interfaceEncoder { + return _interfacePool.Get().(*interfaceEncoder) +} + +func putInterfaceEncoder(enc *interfaceEncoder) { + enc.InterfaceEncoderConfig = nil + enc.spaced = false + enc.buf = nil + _interfacePool.Put(enc) +} + +// NewInterfaceEncoder create a new interface log encoder +func NewInterfaceEncoder(cfg InterfaceEncoderConfig, spaced bool) zapcore.Encoder { + return newInterfaceEncoder(cfg, spaced) +} + +func newInterfaceEncoder(cfg InterfaceEncoderConfig, spaced bool) *interfaceEncoder { + return &interfaceEncoder{ + InterfaceEncoderConfig: &cfg, + buf: _bufferPool.Get(), + spaced: spaced, + podName: os.Getenv(constant.PodNameEnvKey), + } +} + +// Clone return zap core Encoder +func (enc *interfaceEncoder) Clone() zapcore.Encoder { + return enc.clone() +} + +func (enc *interfaceEncoder) clone() *interfaceEncoder { + clone := getInterfaceEncoder() + clone.InterfaceEncoderConfig = enc.InterfaceEncoderConfig + clone.spaced = enc.spaced + clone.buf = _bufferPool.Get() + return clone +} + +// EncodeEntry Encode Entry +func (enc *interfaceEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + final := enc.clone() + // add time + final.AppendString(ent.Time.UTC().Format("2006-01-02 15:04:05.000")) + // add level + // Level of interfaceLog is eternally INFO + final.buf.AppendString(fieldSeparator) + final.AppendString("INFO") + // add caller + if ent.Caller.Defined { + final.buf.AppendString(fieldSeparator) + final.EncodeCaller(ent.Caller, final) + } + final.buf.AppendString(fieldSeparator) + // add podName + if enc.podName != "" { + final.buf.AppendString(enc.podName) + } + final.buf.AppendString(fieldSeparator) + if enc.buf.Len() > 0 { + _, err := final.buf.Write(enc.buf.Bytes()) + if err != nil { + return nil, err + } + } + // add msg + final.AppendString(ent.Message) + for _, field := range fields { + field.AddTo(final) + } + final.buf.AppendString(customDefaultLineEnding) + ret := final.buf + putInterfaceEncoder(final) + return ret, nil +} + +// AddString Append String +func (enc *interfaceEncoder) AddString(key, val string) { + enc.buf.AppendString(val) +} + +// AppendString Append String +func (enc *interfaceEncoder) AppendString(val string) { + enc.buf.AppendString(val) +} + +// AddDuration Add Duration +func (enc *interfaceEncoder) AddDuration(key string, val time.Duration) { + enc.AppendDuration(val) +} + +func (enc *interfaceEncoder) addElementSeparator() { + last := enc.buf.Len() - 1 + if last < 0 { + return + } + switch enc.buf.Bytes()[last] { + case headerSeparator: + return + default: + enc.buf.AppendByte(headerSeparator) + if enc.spaced { + enc.buf.AppendByte(' ') + } + } +} + +// AppendTime Append Time +func (enc *interfaceEncoder) AppendTime(val time.Time) { + cur := enc.buf.Len() + interfaceTimeEncode(val, enc) + if cur == enc.buf.Len() { + // User-supplied EncodeTime is a no-op. Fall back to nanos since epoch to keep + // output JSON valid. + enc.AppendInt64(val.UnixNano()) + } +} + +// AddArray Add Array +func (enc *interfaceEncoder) AddArray(key string, marshaler zapcore.ArrayMarshaler) error { + return errors.New("unsupported method") +} + +// AddObject Add Object +func (enc *interfaceEncoder) AddObject(key string, marshaler zapcore.ObjectMarshaler) error { + return errors.New("unsupported method") +} + +// AddBinary Add Binary +func (enc *interfaceEncoder) AddBinary(key string, value []byte) {} + +// AddByteString Add Byte String +func (enc *interfaceEncoder) AddByteString(key string, val []byte) { + enc.AppendByteString(val) +} + +// AddBool Add Bool +func (enc *interfaceEncoder) AddBool(key string, value bool) {} + +// AddComplex64 Add Complex64 +func (enc *interfaceEncoder) AddComplex64(k string, v complex64) { enc.AddComplex128(k, complex128(v)) } + +// AddFloat32 Add Float32 +func (enc *interfaceEncoder) AddFloat32(k string, v float32) { enc.AddFloat64(k, float64(v)) } + +// AddInt Add Int +func (enc *interfaceEncoder) AddInt(k string, v int) { enc.AddInt64(k, int64(v)) } + +// AddInt32 Add Int32 +func (enc *interfaceEncoder) AddInt32(k string, v int32) { enc.AddInt64(k, int64(v)) } + +// AddInt16 Add Int16 +func (enc *interfaceEncoder) AddInt16(k string, v int16) { enc.AddInt64(k, int64(v)) } + +// AddInt8 Add Int8 +func (enc *interfaceEncoder) AddInt8(k string, v int8) { enc.AddInt64(k, int64(v)) } + +// AddUint Add Uint +func (enc *interfaceEncoder) AddUint(k string, v uint) { enc.AddUint64(k, uint64(v)) } + +// AddUint32 Add Uint32 +func (enc *interfaceEncoder) AddUint32(k string, v uint32) { enc.AddUint64(k, uint64(v)) } + +// AddUint16 Add Uint16 +func (enc *interfaceEncoder) AddUint16(k string, v uint16) { enc.AddUint64(k, uint64(v)) } + +// AddUint8 Add Uint8 +func (enc *interfaceEncoder) AddUint8(k string, v uint8) { enc.AddUint64(k, uint64(v)) } + +// AddUintptr Add Uint ptr +func (enc *interfaceEncoder) AddUintptr(k string, v uintptr) { enc.AddUint64(k, uint64(v)) } + +// AddComplex128 Add Complex128 +func (enc *interfaceEncoder) AddComplex128(key string, val complex128) { + enc.AppendComplex128(val) +} + +// AddFloat64 Add Float64 +func (enc *interfaceEncoder) AddFloat64(key string, val float64) { + enc.AppendFloat64(val) +} + +// AddInt64 Add Int64 +func (enc *interfaceEncoder) AddInt64(key string, val int64) { + enc.AppendInt64(val) +} + +// AddTime Add Time +func (enc *interfaceEncoder) AddTime(key string, value time.Time) { + enc.AppendTime(value) +} + +// AddUint64 Add Uint64 +func (enc *interfaceEncoder) AddUint64(key string, value uint64) {} + +// AddReflected uses reflection to serialize arbitrary objects, so it's slow +// and allocation-heavy. +func (enc *interfaceEncoder) AddReflected(key string, value interface{}) error { + return nil +} + +// OpenNamespace opens an isolated namespace where all subsequent fields will +// be added. Applications can use namespaces to prevent key collisions when +// injecting loggers into sub-components or third-party libraries. +func (enc *interfaceEncoder) OpenNamespace(key string) {} + +// AppendComplex128 Append Complex128 +func (enc *interfaceEncoder) AppendComplex128(val complex128) {} + +// AppendInt64 Append Int64 +func (enc *interfaceEncoder) AppendInt64(val int64) { + enc.addElementSeparator() + enc.buf.AppendInt(val) +} + +// AppendBool Append Bool +func (enc *interfaceEncoder) AppendBool(val bool) { + enc.addElementSeparator() + enc.buf.AppendBool(val) +} + +func (enc *interfaceEncoder) appendFloat(val float64, bitSize int) { + enc.addElementSeparator() + switch { + case math.IsNaN(val): + enc.buf.AppendString(`"NaN"`) + case math.IsInf(val, 1): + enc.buf.AppendString(`"+Inf"`) + case math.IsInf(val, -1): + enc.buf.AppendString(`"-Inf"`) + default: + enc.buf.AppendFloat(val, bitSize) + } +} + +// AppendUint64 Append Uint64 +func (enc *interfaceEncoder) AppendUint64(val uint64) { + enc.addElementSeparator() + enc.buf.AppendUint(val) +} + +// AppendByteString Append Byte String +func (enc *interfaceEncoder) AppendByteString(val []byte) {} + +// AppendDuration Append Duration +func (enc *interfaceEncoder) AppendDuration(val time.Duration) {} + +// AppendComplex64 Append Complex64 +func (enc *interfaceEncoder) AppendComplex64(v complex64) { enc.AppendComplex128(complex128(v)) } + +// AppendFloat64 Append Float64 +func (enc *interfaceEncoder) AppendFloat64(v float64) { enc.appendFloat(v, float64bitSize) } + +// AppendFloat32 Append Float32 +func (enc *interfaceEncoder) AppendFloat32(v float32) { enc.appendFloat(float64(v), float32bitSize) } + +// AppendInt Append Int +func (enc *interfaceEncoder) AppendInt(v int) { enc.AppendInt64(int64(v)) } + +// AppendInt32 Append Int32 +func (enc *interfaceEncoder) AppendInt32(v int32) { enc.AppendInt64(int64(v)) } + +// AppendInt16 Append Int16 +func (enc *interfaceEncoder) AppendInt16(v int16) { enc.AppendInt64(int64(v)) } + +// AppendInt8 Append Int8 +func (enc *interfaceEncoder) AppendInt8(v int8) { enc.AppendInt64(int64(v)) } + +// AppendUint Append Uint +func (enc *interfaceEncoder) AppendUint(v uint) { enc.AppendUint64(uint64(v)) } + +// AppendUint32 Append Uint32 +func (enc *interfaceEncoder) AppendUint32(v uint32) { enc.AppendUint64(uint64(v)) } + +// AppendUint16 Append Uint16 +func (enc *interfaceEncoder) AppendUint16(v uint16) { enc.AppendUint64(uint64(v)) } + +// AppendUint8 Append Uint8 +func (enc *interfaceEncoder) AppendUint8(v uint8) { enc.AppendUint64(uint64(v)) } + +// AppendUintptr Append Uint ptr +func (enc *interfaceEncoder) AppendUintptr(v uintptr) { enc.AppendUint64(uint64(v)) } + +func interfaceTimeEncode(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + t = t.UTC() + enc.AppendString(t.Format("2006-01-02 15:04:05.000")) +} diff --git a/yuanrong/pkg/common/faas_common/logger/interface_encoder_test.go b/yuanrong/pkg/common/faas_common/logger/interface_encoder_test.go new file mode 100644 index 0000000..1708bcc --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/interface_encoder_test.go @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logger + +import ( + "math" + "os" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong/pkg/common/faas_common/constant" +) + +// TestNewInterfaceEncoder Test New Interface Encoder +func TestNewInterfaceEncoder(t *testing.T) { + cfg := InterfaceEncoderConfig{ + ModuleName: "FunctionWorker", + HTTPMethod: "POST", + ModuleFrom: "FrontendInvoke", + TenantID: "tenant2", + FuncName: "myFunction", + FuncVer: "latest", + } + + encoder := NewInterfaceEncoder(cfg, false) + + priority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= zapcore.InfoLevel + }) + + sink := zapcore.Lock(os.Stdout) + core := zapcore.NewCore(encoder, sink, priority) + logger := zap.New(core) + + logger.Info("e1b71add-cb24-4ef8-93eb-af8d3ceb74e8|0|success|1") + + clone := encoder.Clone() + assert.NotEmpty(t, clone) + encoder.AddBool("3", true) + err := encoder.AddArray("4", nil) + assert.NotEmpty(t, err) + err = encoder.AddObject("4", nil) + assert.NotEmpty(t, err) + encoder.AddBinary("4", []byte{}) + encoder.AddComplex128("4", complex(1, 2)) + encoder.AddComplex64("4", complex(1, 2)) + encoder.AddDuration("4", time.Second) + encoder.AddByteString("4", []byte{}) + encoder.AddFloat64("2", 3.14) + encoder.AddFloat32("2", 3.14) + encoder.AddInt("1", 1) + encoder.AddInt8("1", 1) + encoder.AddInt16("1", 1) + encoder.AddInt32("1", 1) + encoder.AddInt64("1", 1) + encoder.AddString("5", "12") + encoder.AddString("5", "1 2") + encoder.AddTime("6", time.Time{}) + encoder.AddUint("1", uint(1)) + encoder.AddUint8("1", uint8(10)) + encoder.AddUint16("1", uint16(100)) + encoder.AddUint32("1", uint32(1000)) + encoder.AddUint64("1", uint64(1000)) + b := make([]int, 1) + encoder.AddUintptr("12", uintptr(unsafe.Pointer(&b[0]))) + encoder.OpenNamespace("3") + err = encoder.AddReflected("3", 1) + assert.Empty(t, err) +} + +func Test_interfaceEncoder_Append(t *testing.T) { + encoderConfig := InterfaceEncoderConfig{ + ModuleName: "FunctionWorker", + HTTPMethod: "POST", + ModuleFrom: "FrontendInvoke", + TenantID: "tenant2", + FuncName: "myFunction", + FuncVer: "latest", + } + encoder := &interfaceEncoder{ + InterfaceEncoderConfig: &encoderConfig, + buf: _bufferPool.Get(), + spaced: false, + podName: os.Getenv(constant.HostNameEnvKey), + } + encoder.AppendInt16(1) + encoder.AppendUint32(2) + encoder.AppendByteString([]byte("abc")) + encoder.AppendFloat32(3) + encoder.appendFloat(math.Inf(1), 10) + encoder.appendFloat(math.Inf(-1), 10) + encoder.AppendComplex64(1 + 2i) + encoder.AppendUintptr(uintptr(1)) + encoder.AppendUint8(1) + encoder.AppendInt32(2) + encoder.AppendUint16(3) + encoder.AppendUint(4) + encoder.AppendInt8(0) + encoder.AppendInt(5) + encoder.AppendInt32(7) + encoder.AppendBool(false) + assert.NotEmpty(t, encoder.buf.Len()) +} diff --git a/yuanrong/pkg/common/faas_common/logger/interfacelogger.go b/yuanrong/pkg/common/faas_common/logger/interfacelogger.go new file mode 100644 index 0000000..cc5d159 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/interfacelogger.go @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger log +package logger + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong/pkg/common/faas_common/logger/config" +) + +const defaultPerm = 0666 + +// NewInterfaceLogger returns a new interface logger +func NewInterfaceLogger(logPath, fileName string, cfg InterfaceEncoderConfig) (*InterfaceLogger, error) { + coreInfo, err := config.GetCoreInfoFromEnv() + if err != nil { + coreInfo = config.GetDefaultCoreInfo() + } + filePath := filepath.Join(coreInfo.FilePath, fileName+".log") + + coreInfo.FilePath = filePath + cfg.EncodeCaller = zapcore.ShortCallerEncoder + // skip level to print caller line of origin log + const skipLevel = 3 + core, err := newCore(coreInfo, cfg) + if err != nil { + return nil, err + } + logger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(skipLevel)) + + return &InterfaceLogger{log: logger}, nil +} + +// InterfaceLogger interface logger which implements by zap logger +type InterfaceLogger struct { + log *zap.Logger +} + +// Write writes message information +func (logger *InterfaceLogger) Write(msg string) { + logger.log.Info(msg) +} + +func newCore(coreInfo config.CoreInfo, cfg InterfaceEncoderConfig) (zapcore.Core, error) { + w, err := CreateSink(coreInfo) + if err != nil { + return nil, err + } + syncer := zapcore.AddSync(w) + + encoder := NewInterfaceEncoder(cfg, false) + + priority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + var customLevel zapcore.Level + if err := customLevel.UnmarshalText([]byte(coreInfo.Level)); err != nil { + customLevel = zapcore.InfoLevel + } + return lvl >= customLevel + }) + + return zapcore.NewCore(encoder, syncer, priority), nil +} + +// CreateSink creates a new zap log sink +func CreateSink(coreInfo config.CoreInfo) (io.Writer, error) { + // create directory if not already exist + dir := filepath.Dir(coreInfo.FilePath) + err := os.MkdirAll(dir, os.ModePerm) + if err != nil { + fmt.Printf("failed to mkdir: %s", dir) + return nil, err + } + w, err := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + if err != nil { + fmt.Printf("failed to open log file: %s, err: %s\n", coreInfo.FilePath, err.Error()) + return nil, err + } + return w, nil +} diff --git a/yuanrong/pkg/common/faas_common/logger/interfacelogger_test.go b/yuanrong/pkg/common/faas_common/logger/interfacelogger_test.go new file mode 100644 index 0000000..50f989c --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/interfacelogger_test.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logger + +import ( + "errors" + "os" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/logger/config" +) + +func TestInterfaceLogger(t *testing.T) { + cfg := InterfaceEncoderConfig{ModuleName: "WorkerManager"} + interfaceLog, err := NewInterfaceLogger("", "worker-manager-interface", cfg) + interfaceLog.Write("123") + assert.Empty(t, err) + assert.NotEmpty(t, interfaceLog) +} + +func TestCreateSink(t *testing.T) { + convey.Convey("Test Create Sink Error", t, func() { + coreInfo := config.CoreInfo{} + w, err := CreateSink(coreInfo) + convey.So(w, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("Test Create Sink Error 2", t, func() { + patch := gomonkey.ApplyFunc(os.MkdirAll, func(path string, perm os.FileMode) error { + return errors.New("err") + }) + defer patch.Reset() + coreInfo := config.CoreInfo{} + w, err := CreateSink(coreInfo) + convey.So(w, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }) + +} diff --git a/yuanrong/pkg/common/faas_common/logger/log/logger.go b/yuanrong/pkg/common/faas_common/logger/log/logger.go new file mode 100644 index 0000000..3d1c054 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/log/logger.go @@ -0,0 +1,263 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package log - +package log + +import ( + "fmt" + "path/filepath" + "strings" + "sync" + + uberZap "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/config" + "yuanrong/pkg/common/faas_common/logger/zap" +) + +const ( + skipLevel = 1 + snuserLogPath = "/home/snuser/log" +) + +type loggerWrapper struct { + real api.FormatLogger +} + +func (l *loggerWrapper) With(fields ...zapcore.Field) api.FormatLogger { + return &loggerWrapper{ + real: l.real.With(fields...), + } +} + +func (l *loggerWrapper) Infof(format string, paras ...interface{}) { + l.real.Infof(format, paras...) +} +func (l *loggerWrapper) Errorf(format string, paras ...interface{}) { + l.real.Errorf(format, paras...) +} +func (l *loggerWrapper) Warnf(format string, paras ...interface{}) { + l.real.Warnf(format, paras...) +} +func (l *loggerWrapper) Debugf(format string, paras ...interface{}) { + l.real.Debugf(format, paras...) +} +func (l *loggerWrapper) Fatalf(format string, paras ...interface{}) { + l.real.Fatalf(format, paras...) +} +func (l *loggerWrapper) Info(msg string, fields ...uberZap.Field) { + l.real.Info(msg, fields...) +} +func (l *loggerWrapper) Error(msg string, fields ...uberZap.Field) { + l.real.Error(msg, fields...) +} +func (l *loggerWrapper) Warn(msg string, fields ...uberZap.Field) { + l.real.Warn(msg, fields...) +} +func (l *loggerWrapper) Debug(msg string, fields ...uberZap.Field) { + l.real.Debug(msg, fields...) +} +func (l *loggerWrapper) Fatal(msg string, fields ...uberZap.Field) { + l.real.Fatal(msg, fields...) +} +func (l *loggerWrapper) Sync() { + l.real.Sync() +} + +var ( + once sync.Once + formatLogger api.FormatLogger + defaultLogger, _ = uberZap.NewProduction() +) + +// InitRunLog init run log with log.json file +func InitRunLog(fileName string, isAsync bool) error { + coreInfo, err := config.GetCoreInfoFromEnv() + if err != nil { + return err + } + if coreInfo.Disable { + return nil + } + formatLogger, err = NewFormatLogger(fileName, isAsync, coreInfo) + return err +} + +// SetupLoggerLibruntime setup logger +func SetupLoggerLibruntime(runtimeLogger api.FormatLogger) { + if runtimeLogger == nil { + return + } + wrapLogger := &loggerWrapper{real: runtimeLogger} + formatLogger = wrapLogger +} + +// SetupLogger setup logger +func SetupLogger(runtimeLogger api.FormatLogger) { + if runtimeLogger == nil { + return + } + formatLogger = runtimeLogger +} + +// GetLogger get logger directly +func GetLogger() api.FormatLogger { + if formatLogger == nil { + once.Do(func() { + formatLogger = NewConsoleLogger() + }) + } + return formatLogger +} + +// NewConsoleLogger returns a console logger +func NewConsoleLogger() api.FormatLogger { + logger, err := newConsoleLog() + if err != nil { + fmt.Println("new console log error", err) + logger = defaultLogger + } + return &zapLoggerWithFormat{ + Logger: logger, + SLogger: logger.Sugar(), + } +} + +// NewFormatLogger new formatLogger with log config info +func NewFormatLogger(fileName string, isAsync bool, coreInfo config.CoreInfo) (api.FormatLogger, error) { + if strings.Compare(constant.MonitorFileName, fileName) == 0 { + coreInfo.FilePath = snuserLogPath + } + coreInfo.FilePath = filepath.Join(coreInfo.FilePath, fileName+"-run.log") + logger, err := zap.NewWithLevel(coreInfo, isAsync) + if err != nil { + return nil, err + } + + return &zapLoggerWithFormat{ + Logger: logger, + SLogger: logger.Sugar(), + }, nil +} + +// newConsoleLog returns a console logger based on uber zap +func newConsoleLog() (*uberZap.Logger, error) { + outputPaths := []string{"stdout"} + cfg := uberZap.Config{ + Level: uberZap.NewAtomicLevelAt(uberZap.InfoLevel), + Development: false, + DisableCaller: false, + DisableStacktrace: true, + Encoding: "custom_console", + OutputPaths: outputPaths, + ErrorOutputPaths: outputPaths, + EncoderConfig: zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "Logger", + MessageKey: "M", + CallerKey: "C", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }, + } + consoleLogger, err := cfg.Build() + if err != nil { + return nil, err + } + return consoleLogger.WithOptions(uberZap.AddCaller(), uberZap.AddCallerSkip(skipLevel)), nil +} + +// zapLoggerWithFormat define logger +type zapLoggerWithFormat struct { + Logger *uberZap.Logger + SLogger *uberZap.SugaredLogger +} + +// With add fields to log header +func (z *zapLoggerWithFormat) With(fields ...zapcore.Field) api.FormatLogger { + logger := z.Logger.With(fields...) + return &zapLoggerWithFormat{ + Logger: logger, + SLogger: logger.Sugar(), + } +} + +// Infof stdout format and paras +func (z *zapLoggerWithFormat) Infof(format string, paras ...interface{}) { + z.SLogger.Infof(format, paras...) +} + +// Errorf stdout format and paras +func (z *zapLoggerWithFormat) Errorf(format string, paras ...interface{}) { + z.SLogger.Errorf(format, paras...) +} + +// Warnf stdout format and paras +func (z *zapLoggerWithFormat) Warnf(format string, paras ...interface{}) { + z.SLogger.Warnf(format, paras...) +} + +// Debugf stdout format and paras +func (z *zapLoggerWithFormat) Debugf(format string, paras ...interface{}) { + z.SLogger.Debugf(format, paras...) +} + +// Fatalf stdout format and paras +func (z *zapLoggerWithFormat) Fatalf(format string, paras ...interface{}) { + z.SLogger.Fatalf(format, paras...) +} + +// Info stdout format and paras +func (z *zapLoggerWithFormat) Info(msg string, fields ...uberZap.Field) { + z.Logger.Info(msg, fields...) +} + +// Error stdout format and paras +func (z *zapLoggerWithFormat) Error(msg string, fields ...uberZap.Field) { + z.Logger.Error(msg, fields...) +} + +// Warn stdout format and paras +func (z *zapLoggerWithFormat) Warn(msg string, fields ...uberZap.Field) { + z.Logger.Warn(msg, fields...) +} + +// Debug stdout format and paras +func (z *zapLoggerWithFormat) Debug(msg string, fields ...uberZap.Field) { + z.Logger.Debug(msg, fields...) +} + +// Fatal stdout format and paras +func (z *zapLoggerWithFormat) Fatal(msg string, fields ...uberZap.Field) { + z.Logger.Fatal(msg, fields...) +} + +// Sync calls the underlying Core's Sync method, flushing any buffered log +// entries. Applications should take care to call Sync before exiting. +func (z *zapLoggerWithFormat) Sync() { + err := z.Logger.Sync() + if err != nil { + return + } +} diff --git a/yuanrong/pkg/common/faas_common/logger/log/logger_test.go b/yuanrong/pkg/common/faas_common/logger/log/logger_test.go new file mode 100644 index 0000000..4ed4e5d --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/log/logger_test.go @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package log + +import ( + "errors" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + uberZap "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/config" + "yuanrong/pkg/common/faas_common/logger/zap" +) + +func TestSetupLoggerRuntime(t *testing.T) { + SetupLoggerLibruntime(nil) + assert.Equal(t, formatLogger, nil) +} + +func TestInitLogger(t *testing.T) { + errCtrl := "" + patch := gomonkey.ApplyFunc(config.GetCoreInfoFromEnv, func() (config.CoreInfo, error) { + if errCtrl == "returnError" { + return config.CoreInfo{}, errors.New("some error") + } + return config.CoreInfo{}, nil + }) + defer patch.Reset() + SetupLogger(nil) + SetupLogger(NewConsoleLogger()) + assert.NotNil(t, formatLogger) + errCtrl = "returnError" + err := InitRunLog("test", false) + assert.NotNil(t, err) + errCtrl = "" + err = InitRunLog("test", false) + assert.Nil(t, err) +} + +func TestGetLogger(t *testing.T) { + convey.Convey("log", t, func() { + logger := GetLogger() + logger.With(uberZap.Any("name", "test-log")) + logger.Info("info log") + logger.Infof("info log") + logger.Debug("debug log") + logger.Debugf("debug log") + logger.Warn("warn log") + logger.Warnf("warn log") + logger.Error("error log") + logger.Errorf("error log") + }) +} + +func TestFormatLogger(t *testing.T) { + convey.Convey("new log error", t, func() { + patch := gomonkey.ApplyFunc(zap.NewWithLevel, func(coreInfo config.CoreInfo, isAsync bool) (*uberZap.Logger, error) { + return nil, errors.New("1") + }) + defer patch.Reset() + _, err := NewFormatLogger(constant.MonitorFileName, true, config.CoreInfo{}) + assert.NotNil(t, err) + }) + convey.Convey("new log success", t, func() { + logger, err := NewFormatLogger(constant.MonitorFileName, true, config.CoreInfo{}) + assert.Nil(t, err) + logger.With(uberZap.Any("name", "test-log")) + logger.Info("info log") + logger.Infof("info log") + logger.Debug("debug log") + logger.Debugf("debug log") + //logger.Fatal("fatal log") + //logger.Fatalf("fatal log") + logger.Warn("warn log") + logger.Warnf("warn log") + logger.Error("error log") + logger.Errorf("error log") + logger.Sync() + }) +} diff --git a/yuanrong/pkg/common/faas_common/logger/rollinglog.go b/yuanrong/pkg/common/faas_common/logger/rollinglog.go new file mode 100644 index 0000000..e88c496 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/rollinglog.go @@ -0,0 +1,276 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logger rollingLog +package logger + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/logger/config" +) + +const ( + megabyte = 1024 * 1024 + defaultFileSize = 100 + defaultBackups = 20 +) + +var logNameCache = struct { + m map[string]string + sync.Mutex +}{ + m: make(map[string]string, 1), + Mutex: sync.Mutex{}, +} + +type rollingLog struct { + file *os.File + reg *regexp.Regexp + mu sync.RWMutex + sinks []string + dir string + nameTemplate string + maxSize int64 + size int64 + maxBackups int + flag int + perm os.FileMode + isUserLog bool + isWiseCloudAlarmLog bool +} + +func initRollingLog(coreInfo config.CoreInfo, flag int, perm os.FileMode) (*rollingLog, error) { + if coreInfo.FilePath == "" { + return nil, errors.New("empty log file path") + } + log := &rollingLog{ + dir: filepath.Dir(coreInfo.FilePath), + nameTemplate: filepath.Base(coreInfo.FilePath), + flag: flag, + perm: perm, + maxSize: coreInfo.SingleSize * megabyte, + maxBackups: coreInfo.Threshold, + isUserLog: coreInfo.IsUserLog, + isWiseCloudAlarmLog: coreInfo.IsWiseCloudAlarmLog, + } + if log.maxBackups < 1 { + log.maxBackups = defaultBackups + } + if log.maxSize < megabyte { + log.maxSize = defaultFileSize * megabyte + } + if log.isUserLog { + return log, log.tidySinks() + } + extension := filepath.Ext(log.nameTemplate) + regExp := fmt.Sprintf(`^%s(?:(?:-|\.)\d*)?\%s$`, + log.nameTemplate[:len(log.nameTemplate)-len(extension)], extension) + reg, err := regexp.Compile(regExp) + if err != nil { + return nil, err + } + log.reg = reg + return log, log.tidySinks() +} + +func (r *rollingLog) tidySinks() error { + if r.isUserLog || r.file != nil { + return r.newSink() + } + // scan and reuse past log file when service restarted + r.scanLogFiles() + if len(r.sinks) > 0 { + fullName := r.sinks[len(r.sinks)-1] + info, err := os.Stat(fullName) + if err != nil || info.Size() >= r.maxSize { + return r.newSink() + } + file, err := os.OpenFile(fullName, r.flag, r.perm) + if err == nil { + r.file = file + r.size = info.Size() + return nil + } + } + return r.newSink() +} + +func (r *rollingLog) scanLogFiles() { + dirEntrys, err := os.ReadDir(r.dir) + if err != nil { + fmt.Printf("failed to read dir: %s\n", r.dir) + return + } + infos := make([]os.FileInfo, 0, r.maxBackups) + for _, entry := range dirEntrys { + if r.reg.MatchString(entry.Name()) { + info, err := entry.Info() + if err == nil { + infos = append(infos, info) + } + } + } + if len(infos) > 0 { + sort.Slice(infos, func(i, j int) bool { + return infos[i].ModTime().Before(infos[j].ModTime()) + }) + for i := range infos { + r.sinks = append(r.sinks, filepath.Join(r.dir, infos[i].Name())) + } + r.cleanRedundantSinks() + } +} + +func (r *rollingLog) cleanRedundantSinks() { + if len(r.sinks) < r.maxBackups { + return + } + curSinks := make([]string, 0, len(r.sinks)) + for _, name := range r.sinks { + if isAvailable(name) { + curSinks = append(curSinks, name) + } + + } + r.sinks = curSinks + sinkNum := len(r.sinks) + if sinkNum > r.maxBackups { + removes := r.sinks[:sinkNum-r.maxBackups] + go removeFiles(removes) + r.sinks = r.sinks[sinkNum-r.maxBackups:] + } + return +} + +func removeFiles(paths []string) { + for _, path := range paths { + err := os.Remove(path) + if err != nil && !os.IsNotExist(err) { + fmt.Printf("failed remove file %s\n", path) + } + } +} + +func isAvailable(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func (r *rollingLog) newSink() error { + fullName := filepath.Join(r.dir, r.newName()) + if isAvailable(fullName) && r.file != nil && r.file.Name() == filepath.Base(fullName) { + return errors.New("log file already opened: " + fullName) + } + file, err := os.OpenFile(fullName, r.flag, r.perm) + if err != nil { + return err + } + if r.file != nil { + err = r.file.Close() + } + if err != nil { + fmt.Printf("failed to close file: %s\n", err.Error()) + } + r.file = file + info, err := file.Stat() + if err != nil { + r.size = 0 + } else { + r.size = info.Size() + } + r.sinks = append(r.sinks, fullName) + r.cleanRedundantSinks() + if r.isUserLog { + logNameCache.Lock() + logNameCache.m[r.nameTemplate] = fullName + logNameCache.Unlock() + } + return nil +} + +func (r *rollingLog) newName() string { + if r.isWiseCloudAlarmLog { + timeNow := time.Now().Format("2006010215040506") + ext := filepath.Ext(r.nameTemplate) + return fmt.Sprintf("%s.%s%s", timeNow, r.nameTemplate[:len(r.nameTemplate)-len(ext)], ext) + } + if !r.isUserLog { + timeNow := time.Now().Format("2006010215040506") + ext := filepath.Ext(r.nameTemplate) + return fmt.Sprintf("%s.%s%s", r.nameTemplate[:len(r.nameTemplate)-len(ext)], timeNow, ext) + } + if r.file == nil { + return r.nameTemplate + } + timeNow := time.Now().Format("2006010215040506") + var prefix, suffix string + if index := strings.LastIndex(r.nameTemplate, "@") + 1; index <= len(r.nameTemplate) { + prefix = r.nameTemplate[:index] + } + if index := strings.Index(r.nameTemplate, "#"); index >= 0 { + suffix = r.nameTemplate[index:] + } + if prefix == "" || suffix == "" { + return "" + } + return fmt.Sprintf("%s%s%s", prefix, timeNow, suffix) +} + +// Write data to file and check whether to rotate log +func (r *rollingLog) Write(data []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r == nil || r.file == nil { + return 0, errors.New("log file is nil") + } + n, err := r.file.Write(data) + r.size += int64(n) + if r.size > r.maxSize { + r.tryRotate() + } + if syncErr := r.file.Sync(); syncErr != nil { + fmt.Printf("failed to sync log err: %s\n", syncErr.Error()) + } + return n, err +} + +func (r *rollingLog) tryRotate() { + if info, err := r.file.Stat(); err == nil && info.Size() < r.maxSize { + return + } + err := r.tidySinks() + if err != nil { + fmt.Printf("failed to rotate log err: %s\n", err.Error()) + } + return +} + +// GetLogName get current log name when refreshing user log mod time +func GetLogName(nameTemplate string) string { + logNameCache.Lock() + name := logNameCache.m[nameTemplate] + logNameCache.Unlock() + return name +} diff --git a/yuanrong/pkg/common/faas_common/logger/rollinglog_test.go b/yuanrong/pkg/common/faas_common/logger/rollinglog_test.go new file mode 100644 index 0000000..fcc5b79 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/rollinglog_test.go @@ -0,0 +1,147 @@ +package logger + +import ( + "errors" + "io/fs" + "os" + "reflect" + "strings" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/logger/config" +) + +type mockInfo struct { + name string + isDir bool + size int64 +} + +func (m mockInfo) Name() string { + return m.name +} + +func (m mockInfo) IsDir() bool { + return m.isDir +} + +func (m mockInfo) Type() fs.FileMode { + return 0 +} + +func (m mockInfo) Info() (fs.FileInfo, error) { + return m, nil +} + +func (m mockInfo) Size() int64 { + return m.size +} + +func (m mockInfo) Mode() fs.FileMode { + return 0 +} + +func (m mockInfo) ModTime() time.Time { + return time.Now() +} + +func (m mockInfo) Sys() interface{} { + return nil +} + +func Test_initRollingLog(t *testing.T) { + coreInfo := config.CoreInfo{ + FilePath: "./test-run.log", + } + defer gomonkey.ApplyFunc(os.ReadDir, func(string) ([]os.DirEntry, error) { + return []os.DirEntry{ + mockInfo{name: "test-run.2006010215040507.log"}, + mockInfo{name: "test-run.2006010215040508.log"}, + mockInfo{name: "{funcName}@ABCabc@latest@pool22-300-128-fusion-85c55c66d7-zzj9x@{timeNow}#{logGroupID}#{logStreamID}#cff-log.log"}, + }, nil + }).ApplyFunc(os.OpenFile, func(string, int, os.FileMode) (*os.File, error) { + return nil, nil + }).Reset() + convey.Convey("init service log", t, func() { + defer gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return mockInfo{name: strings.TrimPrefix(name, "./")}, nil + }).Reset() + log, err := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + convey.So(log, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("init user log", t, func() { + defer gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, &os.PathError{} + }).Reset() + coreInfo.FilePath = "{funcName}@ABCabc@latest@pool22-300-128-fusion-85c55c66d7-zzj9x@{timeNow}#{logGroupID}#{logStreamID}#cff-log.log" + coreInfo.IsUserLog = true + log, _ := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + convey.So(log, convey.ShouldNotBeNil) + convey.So(GetLogName(coreInfo.FilePath), convey.ShouldNotBeEmpty) + }) + convey.Convey("init wisecloud alarm log", t, func() { + defer gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, &os.PathError{} + }).Reset() + coreInfo.FilePath = "{funcName}@ABCabc@latest@pool22-300-128-fusion-85c55c66d7-zzj9x@{timeNow}#{logGroupID}#{logStreamID}#cff-log.log" + coreInfo.IsWiseCloudAlarmLog = true + log, _ := initRollingLog(coreInfo, os.O_WRONLY|os.O_APPEND|os.O_CREATE, defaultPerm) + convey.So(log, convey.ShouldNotBeNil) + convey.So(GetLogName(coreInfo.FilePath), convey.ShouldNotBeEmpty) + }) +} + +func Test_rollingLog_Write(t *testing.T) { + log := &rollingLog{} + log.maxSize = 0 + log.isUserLog = true + log.file = &os.File{} + log.nameTemplate = "{funcName}@ABCabc@latest@pool22-300-128-fusion-85c55c66d7-zzj9x@{timeNow}#{logGroupID}#{logStreamID}#cff-log.log" + convey.Convey("write rolling log", t, func() { + convey.Convey("case1: failed to write rolling log", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(log.file), "Write", func(f *os.File, b []byte) (n int, err error) { + return len(b), nil + }).ApplyMethod(reflect.TypeOf(log.file), "Stat", func(f *os.File) (info os.FileInfo, err error) { + return mockInfo{size: 3}, nil + }).ApplyMethod(reflect.TypeOf(log.file), "Sync", func(f *os.File) error { + return nil + }).Reset() + n, err := log.Write([]byte("abc")) + convey.So(n, convey.ShouldEqual, 3) + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("case2: failed to write rolling log", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(log.file), "Write", func(f *os.File, b []byte) (n int, err error) { + return len(b), nil + }).ApplyMethod(reflect.TypeOf(log.file), "Stat", func(f *os.File) (info os.FileInfo, err error) { + return mockInfo{size: 3}, nil + }).ApplyMethod(reflect.TypeOf(log.file), "Sync", func(f *os.File) error { + return errors.New("test") + }).Reset() + n, err := log.Write([]byte("abc")) + convey.So(n, convey.ShouldEqual, 3) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func Test_rollingLog_cleanRedundantSinks(t *testing.T) { + log := &rollingLog{} + log.maxBackups = 0 + tn := time.Now().String() + os.Create("test_log_1#" + tn) + os.Create("test_log_2#" + tn) + log.sinks = []string{"test_log_1#" + tn, "test_log_2#" + tn} + convey.Convey("rollingLog_cleanRedundantSinks", t, func() { + log.cleanRedundantSinks() + time.Sleep(50 * time.Millisecond) + convey.So(isAvailable("test_log_1#"+tn), convey.ShouldEqual, false) + convey.So(isAvailable("test_log_2#"+tn), convey.ShouldEqual, false) + }) +} diff --git a/yuanrong/pkg/common/faas_common/logger/zap/zaplog.go b/yuanrong/pkg/common/faas_common/logger/zap/zaplog.go new file mode 100644 index 0000000..fa95c45 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/zap/zaplog.go @@ -0,0 +1,160 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package zap zapper log +package zap + +import ( + "fmt" + "time" + + uberZap "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong/pkg/common/faas_common/logger" + "yuanrong/pkg/common/faas_common/logger/async" + "yuanrong/pkg/common/faas_common/logger/config" +) + +const ( + skipLevel = 1 +) + +func init() { + uberZap.RegisterEncoder("custom_console", logger.NewConsoleEncoder) +} + +// NewDevelopmentLog returns a development logger based on uber zap and it output entry to stdout and stderr +func NewDevelopmentLog() (*uberZap.Logger, error) { + cfg := uberZap.NewDevelopmentConfig() + cfg.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + return cfg.Build() +} + +// NewConsoleLog returns a console logger based on uber zap +func NewConsoleLog() (*uberZap.Logger, error) { + outputPaths := []string{"stdout"} + cfg := uberZap.Config{ + Level: uberZap.NewAtomicLevelAt(uberZap.InfoLevel), + Development: false, + DisableCaller: false, + DisableStacktrace: true, + Encoding: "custom_console", + OutputPaths: outputPaths, + ErrorOutputPaths: outputPaths, + EncoderConfig: zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "Logger", + MessageKey: "M", + CallerKey: "C", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }, + } + consoleLogger, err := cfg.Build() + if err != nil { + return nil, err + } + return consoleLogger.WithOptions(uberZap.AddCaller(), uberZap.AddCallerSkip(skipLevel)), nil +} + +// NewWithLevel returns a log based on zap with Level +func NewWithLevel(coreInfo config.CoreInfo, isAsync bool) (*uberZap.Logger, error) { + core, err := newCore(coreInfo, isAsync) + if err != nil { + return nil, err + } + + return uberZap.New(core, uberZap.AddCaller(), uberZap.AddCallerSkip(skipLevel)), nil +} + +func newCore(coreInfo config.CoreInfo, isAsync bool) (zapcore.Core, error) { + w, err := logger.CreateSink(coreInfo) + if err != nil { + return nil, err + } + + var syncer zapcore.WriteSyncer + if isAsync { + syncer = async.NewAsyncWriteSyncer(w) + } else { + syncer = zapcore.AddSync(w) + } + + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "Logger", + MessageKey: "M", + CallerKey: "C", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + fileEncoder := logger.NewCustomEncoder(&encoderConfig) + + if err := config.LogLevel.UnmarshalText([]byte(coreInfo.Level)); err != nil { + config.LogLevel = zapcore.InfoLevel + } + + priority := uberZap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= config.LogLevel + }) + + if coreInfo.Tick == 0 || coreInfo.First == 0 || coreInfo.Thereafter == 0 { + return zapcore.NewCore(fileEncoder, syncer, priority), nil + } + return zapcore.NewSamplerWithOptions(zapcore.NewCore(fileEncoder, syncer, priority), + time.Duration(coreInfo.Tick)*time.Second, coreInfo.First, coreInfo.Thereafter), nil +} + +// LoggerWithFormat zap logger +type LoggerWithFormat struct { + *uberZap.Logger +} + +// Infof stdout format and paras +func (z *LoggerWithFormat) Infof(format string, paras ...interface{}) { + z.Logger.Info(fmt.Sprintf(format, paras...)) +} + +// Errorf stdout format and paras +func (z *LoggerWithFormat) Errorf(format string, paras ...interface{}) { + z.Logger.Error(fmt.Sprintf(format, paras...)) +} + +// Warnf stdout format and paras +func (z *LoggerWithFormat) Warnf(format string, paras ...interface{}) { + z.Logger.Warn(fmt.Sprintf(format, paras...)) +} + +// Debugf stdout format and paras +func (z *LoggerWithFormat) Debugf(format string, paras ...interface{}) { + if config.LogLevel > zapcore.DebugLevel { + return + } + z.Logger.Debug(fmt.Sprintf(format, paras...)) +} + +// Fatalf stdout format and paras +func (z *LoggerWithFormat) Fatalf(format string, paras ...interface{}) { + z.Logger.Fatal(fmt.Sprintf(format, paras...)) +} diff --git a/yuanrong/pkg/common/faas_common/logger/zap/zaplog_test.go b/yuanrong/pkg/common/faas_common/logger/zap/zaplog_test.go new file mode 100644 index 0000000..96e9c9e --- /dev/null +++ b/yuanrong/pkg/common/faas_common/logger/zap/zaplog_test.go @@ -0,0 +1,183 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package zap + +import ( + "fmt" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + uberZap "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/logger/config" +) + +// TestNewDvelopmentLog Test New Dvelopment Log +func TestNewDvelopmentLog(t *testing.T) { + if _, err := NewDevelopmentLog(); err != nil { + t.Errorf("NewDevelopmentLog() = %q, wants *logger", err) + } +} + +func TestNewConsoleLog(t *testing.T) { + tests := []struct { + name string + want *uberZap.Logger + wantErr bool + }{ + {"case1", nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewConsoleLog() + if (err != nil) != tt.wantErr { + t.Errorf("NewConsoleLog() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestNewWithLevel(t *testing.T) { + type args struct { + coreInfo config.CoreInfo + isAsync bool + } + var a args + tests := []struct { + name string + args args + want *uberZap.Logger + wantErr bool + }{ + {"case1", a, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewWithLevel(tt.args.coreInfo, tt.args.isAsync) + if (err != nil) != tt.wantErr { + t.Errorf("NewWithLevel() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewWithLevel() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLoggerWithFormat_Infof(t *testing.T) { + type fields struct { + Logger *uberZap.Logger + } + type args struct { + format string + paras []interface{} + } + coreInfo := config.CoreInfo{ + FilePath: "tmp", + Level: "DEBUG", + Tick: 0, + First: 0, + Thereafter: 0, + Tracing: false, + Disable: false, + } + logger, err := NewWithLevel(coreInfo, true) + if err != nil { + fmt.Println(err) + } + var f fields + f.Logger = logger + var a args + tests := []struct { + name string + fields fields + args args + }{ + {"case1", f, a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + z := &LoggerWithFormat{ + Logger: tt.fields.Logger, + } + z.Infof(tt.args.format, tt.args.paras...) + }) + } +} + +func TestNewCoreWithDebugLevel(t *testing.T) { + convey.Convey("TestNewCoreWithInfoLevel", t, func() { + coreInfo := config.CoreInfo{ + FilePath: "tmp", + Level: "DEBUG", + Tick: 0, + First: 0, + Thereafter: 0, + Tracing: false, + Disable: false, + } + logger, err := NewWithLevel(coreInfo, true) + convey.So(logger, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + type fields struct { + Logger *uberZap.Logger + } + z := &LoggerWithFormat{ + Logger: logger, + } + cnt := 0 + gomonkey.ApplyMethod(reflect.TypeOf(logger), "Debug", + func(log *uberZap.Logger, msg string, fields ...uberZap.Field) { + cnt += 1 + }) + z.Debugf("should print") + convey.So(cnt, convey.ShouldEqual, 1) + }) +} + +func TestNewCoreWithInfoLevel(t *testing.T) { + convey.Convey("TestNewCoreWithInfoLevel", t, func() { + coreInfo := config.CoreInfo{ + FilePath: "tmp", + Level: "INFO", + Tick: 0, + First: 0, + Thereafter: 0, + Tracing: false, + Disable: false, + } + logger, err := NewWithLevel(coreInfo, true) + convey.So(logger, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + type fields struct { + Logger *uberZap.Logger + } + z := &LoggerWithFormat{ + Logger: logger, + } + cnt := 0 + gomonkey.ApplyMethod(reflect.TypeOf(logger), "Debug", + func(log *uberZap.Logger, msg string, fields ...uberZap.Field) { + cnt += 1 + }) + z.Debugf("should not print") + convey.So(cnt, convey.ShouldEqual, 0) + }) +} diff --git a/yuanrong/pkg/common/faas_common/monitor/defaultfilewatcher.go b/yuanrong/pkg/common/faas_common/monitor/defaultfilewatcher.go new file mode 100644 index 0000000..422a91e --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/defaultfilewatcher.go @@ -0,0 +1,176 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package monitor provide memory and file monitor +package monitor + +import ( + "crypto/sha256" + "encoding/hex" + "io/ioutil" + "os" + "path/filepath" + "time" + + "github.com/fsnotify/fsnotify" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +var ( + hashRetry = 60 +) + +type defaultFileWatcher struct { + watcher *fsnotify.Watcher + filename string + callback FileChangedCallback + hash string + stopCh <-chan struct{} +} + +func createDefaultFileWatcher(stopCh <-chan struct{}) (FileWatcher, error) { + fsWatcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + w := &defaultFileWatcher{ + watcher: fsWatcher, + stopCh: stopCh, + } + return w, nil +} + +// RegisterCallback impl +func (w *defaultFileWatcher) RegisterCallback(filename string, callback FileChangedCallback) { + realPath, err := w.getRealPath(filename) + if err != nil { + log.GetLogger().Errorf("filename %s getRealPath failed err %s", filename, err.Error()) + return + } + + if callback == nil { + log.GetLogger().Errorf("filename %s callback is nil", filename) + return + } + + hash := w.getFileHashRetry(filename) + w.filename = filename + w.callback = callback + w.hash = hash + if err := w.watcher.Add(realPath); err != nil { + log.GetLogger().Warnf("watch file %s failed", filename) + } else { + log.GetLogger().Infof("file %s RegisterCallback, success", filename) + } +} + +func (w *defaultFileWatcher) getRealPath(filename string) (string, error) { + realPath, err := filepath.EvalSymlinks(filename) + if err != nil { + return "", err + } + + if _, err := os.Stat(realPath); err != nil { + return "", err + } + return realPath, nil +} + +func (w *defaultFileWatcher) handleFileRemove(event fsnotify.Event) { + // remove old watcher + w.watcher.Remove(event.Name) + w.watcher.Remove(w.filename) + + // re-add new watcher + realPath, err := w.getRealPath(w.filename) + if err != nil { + log.GetLogger().Warnf("filename %s getRealPath failed err %s", w.filename, err.Error()) + } else { + if err := w.watcher.Add(realPath); err != nil { + log.GetLogger().Warnf("re-add watcher %s failed", realPath) + } else { + log.GetLogger().Infof("re-add watcher %s success", realPath) + } + } + + if err := w.watcher.Add(w.filename); err != nil { + log.GetLogger().Warnf("re-add watcher %s failed", w.filename) + } else { + log.GetLogger().Infof("re-add watcher %s success", w.filename) + } +} + +// Start impl +func (w *defaultFileWatcher) Start() { + for { + select { + case event, ok := <-w.watcher.Events: + if !ok { + log.GetLogger().Errorf("watcher event chan not ok") + continue + } + w.invokeCallback(event) + if event.Op == fsnotify.Remove { + w.handleFileRemove(event) + } + case err, ok := <-w.watcher.Errors: + if !ok { + log.GetLogger().Errorf("errors chan not ok, err %s", err.Error()) + } + case <-w.stopCh: + w.watcher.Close() + return + } + } +} + +func (w *defaultFileWatcher) invokeCallback(event fsnotify.Event) { + newHash := w.getFileHashRetry(w.filename) + if newHash != w.hash { + begin := time.Now() + log.GetLogger().Infof("file event %s happen, start invoke callback", event.String()) + w.hash = newHash + w.callback(w.filename, OpType(event.Op)) + log.GetLogger().Infof("file event %s invoke callback success, cost %v", + event.String(), time.Since(begin)) + } +} + +func (w *defaultFileWatcher) getFileHashRetry(filename string) string { + for i := 0; i < hashRetry; i++ { + hash := w.getFileHash(filename) + if len(hash) > 0 { + return hash + } + time.Sleep(1 * time.Second) + } + return "" +} + +func (w *defaultFileWatcher) getFileHash(filename string) string { + content, err := ioutil.ReadFile(filename) + if err != nil { + return "" + } + hash := sha256.New() + _, err = hash.Write(content) + if err != nil { + return "" + } + return hex.EncodeToString(hash.Sum(nil)) +} diff --git a/yuanrong/pkg/common/faas_common/monitor/defaultfilewatcher_test.go b/yuanrong/pkg/common/faas_common/monitor/defaultfilewatcher_test.go new file mode 100644 index 0000000..7fea059 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/defaultfilewatcher_test.go @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package monitor + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/fsnotify/fsnotify" + "github.com/stretchr/testify/assert" +) + +func TestDefaultFileWatcherStart(t *testing.T) { + eventChan := make(chan fsnotify.Event) + stopCh := make(chan struct{}) + errorChan := make(chan error) + mockWatcher := &fsnotify.Watcher{ + Events: eventChan, + Errors: errorChan, + } + + invokeCallbackCh := make(chan bool) + closeCh := make(chan bool) + + watcher := &defaultFileWatcher{ + watcher: mockWatcher, + filename: "/path/testfile.txt", + stopCh: stopCh, + callback: func(filename string, opType OpType) { + invokeCallbackCh <- true + }, + } + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(ioutil.ReadFile, func(filename string) ([]byte, error) { + return []byte("mock content"), nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(watcher.watcher), "Remove", func(_ *fsnotify.Watcher, _ string) error { + return nil + }), + gomonkey.ApplyFunc(filepath.EvalSymlinks, func(path string) (string, error) { + return "/mock/symlink/path", nil + }), + gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(watcher.watcher), "Add", func(_ *fsnotify.Watcher, _ string) error { + return nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(watcher.watcher), "Close", func(_ *fsnotify.Watcher) error { + closeCh <- true + return nil + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + + go watcher.Start() + + errorChan <- fmt.Errorf("err") + eventChan <- fsnotify.Event{Name: "/path/test1.txt", Op: fsnotify.Remove} + + invokeCallback := <-invokeCallbackCh + close(stopCh) + + assert.Equal(t, invokeCallback, true) + isClosed := <-closeCh + assert.Equal(t, isClosed, true) +} diff --git a/yuanrong/pkg/common/faas_common/monitor/filewatcher.go b/yuanrong/pkg/common/faas_common/monitor/filewatcher.go new file mode 100644 index 0000000..076fced --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/filewatcher.go @@ -0,0 +1,77 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package monitor provide memory and file monitor +package monitor + +import ( + "errors" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// OpType describes file operation type +type OpType uint32 + +const ( + // Create op type + Create OpType = 1 << iota + // Write op type + Write + // Remove op type + Remove + // Rename op type + Rename + // Chmod op type + Chmod +) + +var ( + creator Creator = createDefaultFileWatcher +) + +// FileChangedCallback describes callback function, when file changed, callback function will be invoked +type FileChangedCallback func(filename string, opType OpType) + +// Creator describes watcher create function +type Creator func(stopCh <-chan struct{}) (FileWatcher, error) + +// FileWatcher describes interface of general FileWatcher +type FileWatcher interface { + Start() + RegisterCallback(filename string, callback FileChangedCallback) +} + +// SetCreator set file watcher creator func, if not set, use createDefaultFileWatcher +func SetCreator(newCreator Creator) { + creator = newCreator +} + +// CreateFileWatcher create a file watcher +// notice: one FileWatcher can only watcher one file +func CreateFileWatcher(stopCh <-chan struct{}) (FileWatcher, error) { + watcher, err := creator(stopCh) + if err != nil { + log.GetLogger().Errorf("create watcher failed %s", err.Error()) + return nil, err + } + if watcher == nil { + log.GetLogger().Errorf("watcher is nil") + return nil, errors.New("watcher is nil") + } + go watcher.Start() + return watcher, nil +} diff --git a/yuanrong/pkg/common/faas_common/monitor/filewatcher_test.go b/yuanrong/pkg/common/faas_common/monitor/filewatcher_test.go new file mode 100644 index 0000000..73ccf97 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/filewatcher_test.go @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package monitor + +import ( + "errors" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" +) + +func buildTestFile() string { + path, _ := os.Getwd() + if strings.Contains(path, "\\") { + path = path + "\\test.json" + } else { + path = path + "/test.json" + } + + return path +} + +func TestCreateFileWatcher(t *testing.T) { + convey.Convey("TestCreateFileWatcher error", t, func() { + defer gomonkey.ApplyFunc(createDefaultFileWatcher, func(stopCh <-chan struct{}) (FileWatcher, error) { + return nil, fmt.Errorf("fsnotify.NewWatcher error") + }).Reset() + stopCh := make(chan struct{}, 1) + _, err := CreateFileWatcher(stopCh) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestInitFileWatcher(t *testing.T) { + convey.Convey("TestInitFileWatcher", t, func() { + stopCh := make(chan struct{}, 1) + watcher, err := CreateFileWatcher(stopCh) + convey.So(err, convey.ShouldBeNil) + + filename := buildTestFile() + handler, _ := os.Create(filename) + defer func() { + handler.Close() + os.Remove(filename) + }() + callbackChan := make(chan int, 5) + watcher.RegisterCallback("", nil) + watcher.RegisterCallback(filename, func(filename string, t OpType) { + callbackChan <- 1 + }) + + os.WriteFile(filename, []byte{'a'}, os.ModePerm) + res := <-callbackChan + convey.So(res, convey.ShouldBeGreaterThan, 0) + time.Sleep(5 * time.Millisecond) + close(stopCh) + }) +} + +func TestInitFileWatcherWithInvalidCreator(t *testing.T) { + convey.Convey("TestInitFileWatcherWithInvalidCreator", t, func() { + defer SetCreator(createDefaultFileWatcher) + SetCreator(func(stopCh <-chan struct{}) (FileWatcher, error) { + return nil, errors.New("error") + }) + + stopCh := make(chan struct{}, 1) + _, err := CreateFileWatcher(stopCh) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestFileWatcher_Start(t *testing.T) { + convey.Convey("TestFileWatcher_Start", t, func() { + stopCh := make(chan struct{}, 1) + watcher, _ := CreateFileWatcher(stopCh) + filename := "./TestFileWatcher_Start.tmp" + f, _ := os.Create(filename) + f.Close() + tmp := hashRetry + hashRetry = 1 + defer func() { + hashRetry = tmp + }() + callbackChan := make(chan int, 1) + watcher.RegisterCallback(filename, func(filename string, t OpType) { + callbackChan <- 1 + }) + os.Remove(filename) + time.AfterFunc(3*time.Second, func() { + close(callbackChan) + }) + convey.So(<-callbackChan, convey.ShouldEqual, 1) + time.Sleep(50 * time.Millisecond) + close(stopCh) + }) +} diff --git a/yuanrong/pkg/common/faas_common/monitor/memory.go b/yuanrong/pkg/common/faas_common/monitor/memory.go new file mode 100644 index 0000000..e4f95ee --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/memory.go @@ -0,0 +1,353 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package monitor monitors and controls resource usage +package monitor + +import ( + "io/ioutil" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" +) + +// MemMonitor monitor memory usage +type MemMonitor interface { + // Allow returns whether you can take some memory to use (in bytes) + Allow(uint64) bool + // AllowByLowerThreshold - + AllowByLowerThreshold(string, string, uint64) bool + // ReleaseFunctionMem function mem when request finished + ReleaseFunctionMem(urn string, size uint64) +} + +const ( + defaultMemoryRefreshInterval = 50 + highMemoryPercent = 0.9 + statefulHighMemPercent = 0.9 + base = 10 + bitSize = 64 + lowerMemoryPercent = 0.7 + bodyThreshold = 10000 + zero = 0 + defaultFuncNum = 1 +) + +var ( + memory = struct { + sync.Once + monitor *memMonitor + err error + }{ + monitor: &memMonitor{}, + } + mu sync.Mutex +) + +var ( + config = &types.MemoryControlConfig{ + LowerMemoryPercent: lowerMemoryPercent, + HighMemoryPercent: highMemoryPercent, + StatefulHighMemPercent: statefulHighMemPercent, + BodyThreshold: bodyThreshold, + MemDetectIntervalMs: defaultMemoryRefreshInterval, + } +) + +type memMonitor struct { + enable bool + used uint64 + threshold uint64 + statefulThreshold uint64 + stopCh <-chan struct{} + lowerThreshold uint64 + memMapMutex sync.Mutex + functionMemMap map[string]uint64 + totalMemCnt uint64 + isStateful bool +} + +// SetMemoryControlConfig set memory control config from different service +func SetMemoryControlConfig(memoryControlConfig *types.MemoryControlConfig) { + if memoryControlConfig == nil { + return + } + if memoryControlConfig.LowerMemoryPercent > 0 { + config.LowerMemoryPercent = memoryControlConfig.LowerMemoryPercent + } + if memoryControlConfig.BodyThreshold > 0 { + config.BodyThreshold = memoryControlConfig.BodyThreshold + } + if memoryControlConfig.MemDetectIntervalMs > 0 { + config.MemDetectIntervalMs = memoryControlConfig.MemDetectIntervalMs + } + if memoryControlConfig.HighMemoryPercent > 0 { + config.HighMemoryPercent = memoryControlConfig.HighMemoryPercent + } + if memoryControlConfig.StatefulHighMemPercent > 0 { + config.StatefulHighMemPercent = memoryControlConfig.StatefulHighMemPercent + } + log.GetLogger().Infof("LowerMemoryPercent %f, HighMemoryPercent %f, "+ + "StatefulHighMemPercent %f, BodyThreshold %d, MemDetectIntervalMs %d", + config.LowerMemoryPercent, config.HighMemoryPercent, + config.StatefulHighMemPercent, config.BodyThreshold, config.MemDetectIntervalMs) + + if memory.monitor != nil { + memory.monitor.updateConfig() + } +} + +// InitMemMonitor initialize global memory monitor +func InitMemMonitor(stopCh <-chan struct{}) error { + memory.Do(func() { + memory.err = memory.monitor.init(stopCh) + }) + return memory.err +} + +// GetMemInstance returns global memory monitor +func GetMemInstance() MemMonitor { + return memory.monitor +} + +func readValue(path string) (uint64, error) { + v, err := ioutil.ReadFile(path) + if err != nil { + return 0, err + } + return parseValue(strings.TrimSpace(string(v)), base, bitSize) +} + +func parseValue(s string, base, bitSize int) (uint64, error) { + v, err := strconv.ParseUint(s, base, bitSize) + if err != nil { + intValue, intErr := strconv.ParseInt(s, base, bitSize) + if intErr == nil && intValue < 0 { + return 0, nil + } + if intErr != nil && + intErr.(*strconv.NumError).Err == strconv.ErrRange && + intValue < 0 { + return 0, nil + } + return 0, err + } + return v, nil +} + +// refresh actual memory usage +func (m *memMonitor) refreshActualMemoryUsage() { + interval := config.MemDetectIntervalMs + parser, err := NewCGroupMemoryParser() + if err != nil { + log.GetLogger().Warnf("failed to create cgroup memory parser: %s", err.Error()) + return + } + defer parser.Close() + ticker := time.NewTicker(time.Duration(interval) * time.Millisecond) + for { + select { + case <-ticker.C: + val, err := parser.Read() + if err != nil { + log.GetLogger().Errorf("GetSystemMemoryUsed failed, err: %s", err.Error()) + continue + } + used, ok := val.(uint64) + if !ok { + log.GetLogger().Errorf("GetSystemMemoryUsed failed, err: failed to assert parser data") + continue + } + atomic.StoreUint64(&m.used, used) + if interval != config.MemDetectIntervalMs { + log.GetLogger().Infof("MemDetectIntervalMs updated, old: %d, new: %d, reset timer", + interval, config.MemDetectIntervalMs) + interval = config.MemDetectIntervalMs + ticker.Reset(time.Duration(interval) * time.Millisecond) + } + case <-m.stopCh: + log.GetLogger().Info("memory monitor stopped") + ticker.Stop() + return + } + } +} + +func (m *memMonitor) init(stopCh <-chan struct{}) error { + memLimit, err := readValue("/sys/fs/cgroup/memory/memory.limit_in_bytes") + if err != nil { + log.GetLogger().Warn("failed to read limit_in_bytes") + return nil + } + m.threshold = uint64(float64(memLimit) * config.HighMemoryPercent) + m.statefulThreshold = uint64(float64(memLimit) * config.StatefulHighMemPercent) + m.enable = true + m.memMapMutex = sync.Mutex{} + m.functionMemMap = map[string]uint64{} + m.lowerThreshold = uint64(float64(memLimit) * config.LowerMemoryPercent) + log.GetLogger().Infof("memory threshold is %d, stateful memory threshold is %d, lowerThreshold is %d", + m.threshold, m.statefulThreshold, m.lowerThreshold) + m.stopCh = stopCh + go m.refreshActualMemoryUsage() + return nil +} + +// Allow returns whether you can take some memory to use (in bytes) +func (m *memMonitor) Allow(want uint64) bool { + if !m.enable { + return true + } + for { + threshold := m.threshold + if m.isStateful { + threshold = m.statefulThreshold + } + current := atomic.LoadUint64(&m.used) + if current > threshold || want > threshold-current { + log.GetLogger().Errorf("memory threshold triggered, current=%d want=%d threshold=%d", + current, want, threshold) + return false + } + if atomic.CompareAndSwapUint64(&m.used, current, current+want) { + return true + } + } +} + +func (m *memMonitor) increaseMemCnt(size uint64) { + m.totalMemCnt += size +} + +func (m *memMonitor) decreaseMemCnt(size uint64) { + if m.totalMemCnt < size { + log.GetLogger().Warnf("invalid mem cnt %d, size %d", m.totalMemCnt, size) + m.totalMemCnt = 0 + } else { + m.totalMemCnt -= size + } +} + +// ReleaseFunctionMem release function mem when function req finished +func (m *memMonitor) ReleaseFunctionMem(urn string, size uint64) { + if !m.enable || size <= config.BodyThreshold { + return + } + + m.memMapMutex.Lock() + defer m.memMapMutex.Unlock() + + memUsed, ok := m.functionMemMap[urn] + if !ok { + return + } + + m.decreaseMemCnt(size) + if memUsed <= size { + delete(m.functionMemMap, urn) + } else { + m.functionMemMap[urn] = memUsed - size + } +} + +// mallocFunctionMem malloc function mem when function req enter +func (m *memMonitor) mallocFunctionMem(urn string, realSize uint64) { + m.increaseMemCnt(realSize) + memUsed, ok := m.functionMemMap[urn] + if !ok { + m.functionMemMap[urn] = realSize + } else { + m.functionMemMap[urn] = memUsed + realSize + } +} + +// AllowByLowerThreshold control memory use by LowerThreshold +// if used memory > LowerThreshold and function mem use > average, this function just return heavy load +func (m *memMonitor) AllowByLowerThreshold(urn string, traceID string, size uint64) bool { + if !m.enable || size <= config.BodyThreshold { + return true + } + + m.memMapMutex.Lock() + defer m.memMapMutex.Unlock() + // if current mem lower than lowerThreshold, allow + current := atomic.LoadUint64(&m.used) + if current <= m.lowerThreshold && m.totalMemCnt <= m.lowerThreshold { + m.mallocFunctionMem(urn, size) + return true + } + + memUsed, ok := m.functionMemMap[urn] + // if it's new function, allow + if !ok { + m.increaseMemCnt(size) + m.functionMemMap[urn] = size + return true + } + + functionNum := uint64(len(m.functionMemMap)) + if functionNum <= zero { + functionNum = defaultFuncNum + } + // if function use mem lower than averageMem allow + averageMem := m.totalMemCnt / functionNum + if memUsed <= averageMem { + m.increaseMemCnt(size) + m.functionMemMap[urn] = memUsed + size + return true + } + + log.GetLogger().Errorf("lower memory threshold triggered, currentFromSys=%d,currentFromEvaluator=%d,"+ + "lowerThreshold=%d,functionUsed=%d,functionNum=%d,traceID=%s,bodyLength=%d", + current, m.totalMemCnt, m.lowerThreshold, memUsed, functionNum, traceID, size) + return false +} + +func (m *memMonitor) updateConfig() { + memLimit, err := readValue("/sys/fs/cgroup/memory/memory.limit_in_bytes") + if err != nil { + log.GetLogger().Warn("failed to read limit_in_bytes") + return + } + + m.threshold = uint64(float64(memLimit) * config.HighMemoryPercent) + m.statefulThreshold = uint64(float64(memLimit) * config.StatefulHighMemPercent) + m.lowerThreshold = uint64(float64(memLimit) * config.LowerMemoryPercent) + + log.GetLogger().Infof("config updated, memory threshold is %d, stateful memory threshold is %d,lowerThreshold is %d", + m.threshold, m.statefulThreshold, m.lowerThreshold) +} + +// IsAllowByMemory returns whether you can take some memory to use +func IsAllowByMemory(urn string, memoryWant uint64, traceID string) bool { + if !GetMemInstance().Allow(memoryWant) { + log.GetLogger().Errorf("request is limited by higher threshold, urn %s traceID %s want %d", + urn, traceID, memoryWant) + return false + } + + if !GetMemInstance().AllowByLowerThreshold(urn, traceID, memoryWant) { + log.GetLogger().Errorf("request is limited by lower threshold, urn %s traceID %s want %d", + urn, traceID, memoryWant) + return false + } + + return true +} diff --git a/yuanrong/pkg/common/faas_common/monitor/memory_test.go b/yuanrong/pkg/common/faas_common/monitor/memory_test.go new file mode 100644 index 0000000..2937b52 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/memory_test.go @@ -0,0 +1,164 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package monitor + +import ( + "reflect" + "sync" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/types" +) + +func TestInitMemMonitor(t *testing.T) { + convey.Convey("TestInitMemMonitor", t, func() { + convey.Convey("success", func() { + patches := [...]*Patches{ + ApplyMethod(reflect.TypeOf(new(Parser)), "Read", func(_ *Parser) (interface{}, error) { + return uint64(100), nil + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + stopCh := make(chan struct{}) + err := InitMemMonitor(stopCh) + assert.Nil(t, err) + assert.Equal(t, uint64(0x0), memory.monitor.used) + + time.Sleep(2 * time.Second) + assert.NotEqual(t, uint64(0x0), memory.monitor.used) + }) + }) + +} + +func TestMemMonitor_Allow(t *testing.T) { + memMonitor := &memMonitor{enable: true, threshold: 1024, used: 10} + result := memMonitor.Allow(1000) + assert.Equal(t, true, result) + result = memMonitor.Allow(15) + assert.Equal(t, false, result) +} + +func TestAllowByLowerThreshold(t *testing.T) { + memMonitor := &memMonitor{ + enable: true, + threshold: 200000, + lowerThreshold: 140000, + memMapMutex: sync.Mutex{}, + functionMemMap: map[string]uint64{}, + } + + allow := memMonitor.Allow(100) + assert.Equal(t, true, allow) + allow = memMonitor.AllowByLowerThreshold("1", "1", 100) + assert.Equal(t, true, allow) + + allow = memMonitor.Allow(100000) + assert.Equal(t, true, allow) + allow = memMonitor.AllowByLowerThreshold("1", "1", 100000) + assert.Equal(t, true, allow) + + allow = memMonitor.Allow(20000) + assert.Equal(t, true, allow) + allow = memMonitor.AllowByLowerThreshold("2", "2", 20000) + assert.Equal(t, true, allow) + + allow = memMonitor.Allow(30000) + assert.Equal(t, true, allow) + allow = memMonitor.AllowByLowerThreshold("1", "1", 30000) + assert.Equal(t, false, allow) + + memMonitor.ReleaseFunctionMem("1", 100000) + memMonitor.used = memMonitor.used - 100000 + + allow = memMonitor.Allow(100000) + assert.Equal(t, true, allow) + allow = memMonitor.AllowByLowerThreshold("1", "1", 100000) + assert.Equal(t, true, allow) + + allow = memMonitor.Allow(20000) + assert.Equal(t, true, allow) + allow = memMonitor.AllowByLowerThreshold("2", "2", 20000) + assert.Equal(t, true, allow) +} + +func TestSetMemoryControlConfig(t *testing.T) { + convey.Convey("TestSetMemoryControlConfig", t, func() { + convey.Convey("nil config", func() { + SetMemoryControlConfig(nil) + }) + convey.Convey("SetMemoryControlConfig", func() { + cfg := &types.MemoryControlConfig{ + LowerMemoryPercent: 0.5, + BodyThreshold: 1024, + MemDetectIntervalMs: 3, + HighMemoryPercent: 0.5, + StatefulHighMemPercent: 0.9, + } + SetMemoryControlConfig(cfg) + convey.So(*config == *cfg, convey.ShouldEqual, true) + }) + }) +} + +func Test_parseValue(t *testing.T) { + convey.Convey("parseValue", t, func() { + v, err := parseValue("100", 10, 64) + convey.So(v, convey.ShouldEqual, 100) + convey.So(err, convey.ShouldBeNil) + v, err = parseValue("-100", 10, 64) + convey.So(v, convey.ShouldEqual, 0) + convey.So(err, convey.ShouldBeNil) + v, err = parseValue("1.01", 10, 64) + convey.So(v, convey.ShouldEqual, 0) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestIsAllowByMemory(t *testing.T) { + convey.Convey("TestIsAllowByMemory", t, func() { + memory.monitor = &memMonitor{ + enable: true, + threshold: 200000, + lowerThreshold: 140000, + memMapMutex: sync.Mutex{}, + functionMemMap: map[string]uint64{}, + } + memory.monitor.decreaseMemCnt(100) + + allow := IsAllowByMemory("1", 200001, "") + convey.So(allow, convey.ShouldBeFalse) + + allow = IsAllowByMemory("1", 100000, "") + convey.So(allow, convey.ShouldBeTrue) + + allow = IsAllowByMemory("2", 50000, "") + convey.So(allow, convey.ShouldBeTrue) + + allow = IsAllowByMemory("1", 20000, "") + convey.So(allow, convey.ShouldBeFalse) + }) +} diff --git a/yuanrong/pkg/common/faas_common/monitor/mockfilewatcher.go b/yuanrong/pkg/common/faas_common/monitor/mockfilewatcher.go new file mode 100644 index 0000000..97340d9 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/mockfilewatcher.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package monitor + +// MockFileWatcher - +type MockFileWatcher struct { + Callbacks map[string]FileChangedCallback + StopCh <-chan struct{} + EventChan chan string +} + +// Start - +func (watcher *MockFileWatcher) Start() { + for { + select { + case event, _ := <-watcher.EventChan: + callback, _ := watcher.Callbacks[event] + callback(event, Write) + case <-watcher.StopCh: + return + } + } +} + +// RegisterCallback - +func (watcher *MockFileWatcher) RegisterCallback(filename string, callback FileChangedCallback) { + watcher.Callbacks[filename] = callback +} diff --git a/yuanrong/pkg/common/faas_common/monitor/mockfilewatcher_test.go b/yuanrong/pkg/common/faas_common/monitor/mockfilewatcher_test.go new file mode 100644 index 0000000..f2f5f5c --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/mockfilewatcher_test.go @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package monitor + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestMockFileWatcherStart(t *testing.T) { + stopCh := make(chan struct{}) + eventChan := make(chan string) + + watcher := &MockFileWatcher{ + EventChan: eventChan, + Callbacks: make(map[string]FileChangedCallback), + StopCh: stopCh, + } + + callbackCalled := false + watcher.RegisterCallback("test_event", func(filename string, opType OpType) { + assert.Equal(t, "test_event", filename) + assert.Equal(t, Write, opType) + callbackCalled = true + }) + + go watcher.Start() + + watcher.EventChan <- "test_event" + + assert.Eventually(t, func() bool { return callbackCalled }, 1*time.Second, 10*time.Millisecond, + "Callback function should be called") + + close(stopCh) +} diff --git a/yuanrong/pkg/common/faas_common/monitor/parser.go b/yuanrong/pkg/common/faas_common/monitor/parser.go new file mode 100644 index 0000000..cf293e4 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/parser.go @@ -0,0 +1,110 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package monitor + +import ( + "bufio" + "bytes" + "io" + "os" + "strconv" +) + +const ( + cgroupMemoryPath = "/sys/fs/cgroup/memory/memory.stat" +) + +var ( + rssPrefix = []byte("rss ") +) + +// NewCGroupMemoryParser creates parser of /sys/fs/cgroup/memory/memory.stat +func NewCGroupMemoryParser() (*Parser, error) { + return NewParser(cgroupMemoryPath, cgroupMemoryParserFunc) +} + +var cgroupMemoryParserFunc = func(reader *bufio.Reader) (interface{}, error) { + for { + lineBytes, _, err := reader.ReadLine() + if err != nil { + return uint64(0), err + } + + if bytes.HasPrefix(lineBytes, rssPrefix) { + lineBytes = bytes.TrimSpace(lineBytes[len(rssPrefix):]) + return strconv.ParseUint(string(lineBytes), base, bitSize) + } + } +} + +// ParserFunc func that parser content of reader to uint64 +type ParserFunc func(reader *bufio.Reader) (interface{}, error) + +// NewParser creates new Parser with file path and ParserFunc +func NewParser(path string, parserFunc ParserFunc) (*Parser, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + return &Parser{ + f: f, + reader: bufio.NewReader(nil), + parser: parserFunc, + }, nil +} + +type nopCloser struct { + io.ReadSeeker +} + +// Close does nothing. It wraps io.ReadSeeker to io.ReadSeekCloser +func (nopCloser) Close() error { return nil } + +// NewReadSeekerParser creates new Parser with io.ReadSeeker and ParserFunc +func NewReadSeekerParser(reader io.ReadSeeker, parserFunc ParserFunc) *Parser { + return &Parser{ + f: nopCloser{reader}, + reader: bufio.NewReader(nil), + parser: parserFunc, + } +} + +// Parser aims to parse file content that updated frequently (such as cgroup file) with high performance. +// It opens file only once and seek to start every time before read. +// NOTICE: Parser is not thread safe +type Parser struct { + reader *bufio.Reader + f io.ReadSeekCloser + parser ParserFunc +} + +// Close closes file to parse +func (p *Parser) Close() error { + p.reader.Reset(nil) + return p.f.Close() +} + +// Read resets reader to the start of the file and parses it. +// This method is not thread safe +func (p *Parser) Read() (interface{}, error) { + _, err := p.f.Seek(0, io.SeekStart) + if err != nil { + return uint64(0), err + } + p.reader.Reset(p.f) + return p.parser(p.reader) +} diff --git a/yuanrong/pkg/common/faas_common/monitor/parser_test.go b/yuanrong/pkg/common/faas_common/monitor/parser_test.go new file mode 100644 index 0000000..2ef7c65 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/monitor/parser_test.go @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package monitor + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewReadSeekerParser(t *testing.T) { + tests := []struct { + name string + content []byte + parser ParserFunc + hasError bool + expected uint64 + }{ + { + name: "parse cgroup memory", + content: []byte(`cache 10150707200 +rss 880640 +rss_huge 0 +shmem 0 +mapped_file 946176 +dirty 135168 +writeback 270336 +swap 0 +pgpgin 3158595 +pgpgout 680215 +pgfault 992277 +pgmajfault 0 +inactive_anon 0 +active_anon 0 +inactive_file 8343023616 +active_file 1808744448 +unevictable 0 +hierarchical_memory_limit 9223372036854771712 +hierarchical_memsw_limit 9223372036854771712 +total_cache 21492334592 +total_rss 9384980480 +total_rss_huge 5515509760 +total_shmem 654385152 +total_mapped_file 2744586240 +total_dirty 8110080 +total_writeback 2027520 +total_swap 0 +total_pgpgin 1336448421 +total_pgpgout 1354048239 +total_pgfault 1405894809 +total_pgmajfault 50622 +total_inactive_anon 199806976 +total_active_anon 8360579072 +total_inactive_file 19150966784 +total_active_file 3246854144 +total_unevictable 0`), + parser: cgroupMemoryParserFunc, + expected: 880640, + }, + { + name: "parse cgroup memory no such line", + content: []byte(`880640`), + parser: cgroupMemoryParserFunc, + hasError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := NewReadSeekerParser(bytes.NewReader(tt.content), tt.parser) + data, err := parser.Read() + if tt.hasError { + assert.Error(t, err) + } else { + assert.Equal(t, tt.expected, data) + } + assert.Nil(t, parser.Close()) + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/queue/fifoqueue.go b/yuanrong/pkg/common/faas_common/queue/fifoqueue.go new file mode 100644 index 0000000..cce922b --- /dev/null +++ b/yuanrong/pkg/common/faas_common/queue/fifoqueue.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package queue - +package queue + +import ( + "container/list" +) + +const ( + defaultMapSize = 16 +) + +// FifoQueue implements a fifo scheduling queue. +type FifoQueue struct { + queue *list.List + identityFunc IdentityFunc + elementRecord map[string]*list.Element +} + +// NewFifoQueue return fifo queue +func NewFifoQueue(identityFunc IdentityFunc) *FifoQueue { + return &FifoQueue{ + queue: list.New(), + identityFunc: identityFunc, + elementRecord: make(map[string]*list.Element, defaultMapSize), + } +} + +// Front return front item of queue +func (fq *FifoQueue) Front() interface{} { + if fq.queue.Len() == 0 { + return nil + } + obj := fq.queue.Front().Value + return obj +} + +// Back return rear item of queue +func (fq *FifoQueue) Back() interface{} { + if fq.queue.Len() == 0 { + return nil + } + obj := fq.queue.Back().Value + return obj +} + +// PopFront pops an object from front +func (fq *FifoQueue) PopFront() interface{} { + if fq.queue.Len() == 0 { + return nil + } + elem := fq.queue.Front() + if elem == nil { + return nil + } + obj := elem.Value + if fq.identityFunc != nil { + delete(fq.elementRecord, fq.identityFunc(obj)) + } + fq.queue.Remove(elem) + return obj +} + +// PopBack pops an object from back +func (fq *FifoQueue) PopBack() interface{} { + if fq.queue.Len() == 0 { + return nil + } + elem := fq.queue.Back() + if elem == nil { + return nil + } + obj := elem.Value + if fq.identityFunc != nil { + delete(fq.elementRecord, fq.identityFunc(obj)) + } + fq.queue.Remove(elem) + return obj +} + +// PushBack adds an object into queue +func (fq *FifoQueue) PushBack(obj interface{}) error { + if fq.identityFunc != nil { + fq.elementRecord[fq.identityFunc(obj)] = fq.queue.PushBack(obj) + } else { + fq.queue.PushBack(obj) + } + return nil +} + +// GetByID gets an object in queue by its ID +func (fq *FifoQueue) GetByID(objID string) interface{} { + elem, exist := fq.elementRecord[objID] + if !exist { + return nil + } + return elem.Value +} + +// DelByID deletes an object in queue by its ID +func (fq *FifoQueue) DelByID(objID string) error { + elem, exist := fq.elementRecord[objID] + if !exist { + return ErrObjectNotFound + } + delete(fq.elementRecord, objID) + fq.queue.Remove(elem) + return nil +} + +// Len returns length of queue +func (fq *FifoQueue) Len() int { + return fq.queue.Len() +} + +// UpdateObjByID will update an object in queue by its ID and fix the order +func (fq *FifoQueue) UpdateObjByID(objID string, obj interface{}) error { + return ErrMethodUnsupported +} + +// Range iterates item in queue and process item with given function +func (fq *FifoQueue) Range(f func(obj interface{}) bool) { + for item := fq.queue.Front(); item != nil; item = item.Next() { + if !f(item) { + break + } + } +} + +// SortedRange iterates item in queue and process item with given function in order +func (fq *FifoQueue) SortedRange(f func(obj interface{}) bool) { + for item := fq.queue.Front(); item != nil; item = item.Next() { + if !f(item) { + break + } + } +} diff --git a/yuanrong/pkg/common/faas_common/queue/fifoqueue_test.go b/yuanrong/pkg/common/faas_common/queue/fifoqueue_test.go new file mode 100644 index 0000000..3a8f268 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/queue/fifoqueue_test.go @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package queue - +package queue + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestSetupLogger(t *testing.T) { + fq := NewFifoQueue(nil) + convey.Convey("front/back is nil", t, func() { + res := fq.Front() + convey.So(res, convey.ShouldBeNil) + res = fq.Back() + convey.So(res, convey.ShouldBeNil) + }) + convey.Convey("do support by id", t, func() { + res := fq.GetByID("test") + convey.So(res, convey.ShouldBeNil) + err := fq.UpdateObjByID("test", "test") + convey.So(err, convey.ShouldEqual, ErrMethodUnsupported) + err = fq.DelByID("test") + convey.So(err, convey.ShouldEqual, ErrObjectNotFound) + }) + convey.Convey("pushback one ele", t, func() { + res := fq.PushBack("obj1") + convey.So(res, convey.ShouldBeNil) + front := fq.Front() + back := fq.Back() + convey.So(front, convey.ShouldEqual, back) + len := fq.Len() + convey.So(len, convey.ShouldEqual, 1) + }) + convey.Convey("pushback other ele", t, func() { + res := fq.PushBack("obj2") + convey.So(res, convey.ShouldBeNil) + len := fq.Len() + convey.So(len, convey.ShouldEqual, 2) + + front := fq.PopFront() + convey.So(front, convey.ShouldEqual, "obj1") + back := fq.PopBack() + convey.So(back, convey.ShouldEqual, "obj2") + }) +} diff --git a/yuanrong/pkg/common/faas_common/queue/priorityqueue.go b/yuanrong/pkg/common/faas_common/queue/priorityqueue.go new file mode 100644 index 0000000..5898e13 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/queue/priorityqueue.go @@ -0,0 +1,410 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package queue - +package queue + +import ( + "math/bits" + "math/rand" + "time" +) + +const ( + defaultQueueLength = 20 +) + +// Item is element stored in heap +type Item struct { + ObjID string + // Obj should be a pointer, otherwise UpdateObjByID will fail + Obj interface{} + Priority int +} + +// PriorityFunc returns priority of an object +type PriorityFunc func(interface{}) (int, error) + +// UpdateObjFunc updates object inside queue +type UpdateObjFunc func(interface{}) error + +// PriorityQueue is a two-ended priority queue which keeps item with max priority at front and item with min priority +// at rear using DeHeap +type PriorityQueue struct { + deHeap *DeHeap + identityFunc IdentityFunc + priorityFunc PriorityFunc +} + +// NewPriorityQueue creates priority queue +func NewPriorityQueue(idFunc IdentityFunc, priorityFunc PriorityFunc) *PriorityQueue { + return &PriorityQueue{ + deHeap: NewDeHeap(), + identityFunc: idFunc, + priorityFunc: priorityFunc, + } +} + +// Front returns the item with max priority +func (pq *PriorityQueue) Front() interface{} { + if item, ok := pq.deHeap.GetMax().(*Item); ok { + return item.Obj + } + return nil +} + +// Range iterates item in queue and process item with given function +func (pq *PriorityQueue) Range(f func(obj interface{}) bool) { + for _, item := range pq.deHeap.items { + if !f(item.Obj) { + break + } + } +} + +// SortedRange iterates item in queue and process item with given function in order +func (pq *PriorityQueue) SortedRange(f func(obj interface{}) bool) { + tmpHeap := pq.deHeap.Copy() + for { + item, ok := tmpHeap.PopMax().(*Item) + if !ok { + break + } + if !f(item.Obj) { + break + } + } +} + +// Back returns the item with min priority +func (pq *PriorityQueue) Back() interface{} { + if item, ok := pq.deHeap.GetMin().(*Item); ok { + return item.Obj + } + return nil +} + +// PopFront pops the item with max priority +func (pq *PriorityQueue) PopFront() interface{} { + if item, ok := pq.deHeap.PopMax().(*Item); ok { + return item.Obj + } + return nil +} + +// PopBack pops the item with min priority +func (pq *PriorityQueue) PopBack() interface{} { + if item, ok := pq.deHeap.PopMin().(*Item); ok { + return item.Obj + } + return nil +} + +// PushBack adds an object into queue +func (pq *PriorityQueue) PushBack(obj interface{}) error { + priority, err := pq.priorityFunc(obj) + if err != nil { + return err + } + pq.deHeap.Push(&Item{ObjID: pq.identityFunc(obj), Obj: obj, Priority: priority}) + return nil +} + +// GetByID gets an object in queue by its ID +func (pq *PriorityQueue) GetByID(objID string) interface{} { + index, item := pq.getIndexAndItemByObjID(objID) + if index == keyNotFoundIndex { + return nil + } + return item.Obj +} + +// DelByID deletes an object in queue by its ID +func (pq *PriorityQueue) DelByID(objID string) error { + index, _ := pq.getIndexAndItemByObjID(objID) + if index != keyNotFoundIndex { + pq.deHeap.Remove(index) + return nil + } + return ErrObjectNotFound +} + +// Len returns length of queue +func (pq *PriorityQueue) Len() int { + return pq.deHeap.Len() +} + +// UpdateObjByID will update an object in queue by its ID and fix the order +func (pq *PriorityQueue) UpdateObjByID(objID string, obj interface{}) error { + var err error + index, item := pq.getIndexAndItemByObjID(objID) + if index == keyNotFoundIndex { + return ErrObjectNotFound + } + item.Obj = obj + // update this object's priority and fix the heap + if item.Priority, err = pq.priorityFunc(obj); err != nil { + return err + } + pq.deHeap.Fix(index) + return nil +} + +// UpdatePriorityFunc - +func (pq *PriorityQueue) UpdatePriorityFunc(priorityFunc PriorityFunc) { + pq.priorityFunc = priorityFunc +} + +func (pq *PriorityQueue) getIndexAndItemByObjID(objID string) (int, *Item) { + for i := 0; i < pq.deHeap.Len(); i++ { + if pq.deHeap.items[i].ObjID == objID { + return i, pq.deHeap.items[i] + } + } + return keyNotFoundIndex, nil +} + +// DeHeap is a max-min heap which stores items in max and min levels, root contains the item with max value of all +// levels and one of root's children contains the item with min value of all levels +type DeHeap struct { + items []*Item + count int +} + +// NewDeHeap creates a DeHeap +func NewDeHeap() *DeHeap { + rand.Seed(time.Now().UnixNano()) + return &DeHeap{ + items: make([]*Item, 0, defaultQueueLength), + } +} + +// Copy creates a shallow copy of DeHeap +func (dh *DeHeap) Copy() *DeHeap { + copyItems := make([]*Item, len(dh.items)) + copy(copyItems, dh.items) + return &DeHeap{ + items: copyItems, + count: dh.count, + } +} + +// Len returns the number of deHeap in heap +func (dh *DeHeap) Len() int { return len(dh.items) } + +// Compare is used to compare two items in heap +func (dh *DeHeap) Compare(i, j int) bool { + if i >= len(dh.items) || j >= len(dh.items) { + return false + } + return dh.items[i].Priority > dh.items[j].Priority +} + +// Swap swaps two items in heap +func (dh *DeHeap) Swap(i, j int) { + if i >= len(dh.items) || j >= len(dh.items) { + return + } + dh.items[i], dh.items[j] = dh.items[j], dh.items[i] +} + +// Push pushes an item to heap +func (dh *DeHeap) Push(x interface{}) { + item, ok := x.(*Item) + if !ok { + return + } + dh.items = append(dh.items, item) + dh.shiftUp(dh.Len() - 1) +} + +// Fix fixes heap's order +func (dh *DeHeap) Fix(i int) { + if j := dh.shiftDown(i); j > 0 { + dh.shiftUp(j) + } +} + +// Remove removes an item from heap +func (dh *DeHeap) Remove(i int) { + n := dh.Len() - 1 + if i > n { + return + } + dh.Swap(i, n) + dh.items[n] = nil + dh.items = dh.items[0:n] + dh.shiftDown(i) +} + +// GetMax returns the item with max value +func (dh *DeHeap) GetMax() interface{} { + if dh.Len() < 1 { + return nil + } + return dh.items[0] +} + +// GetMin returns the item with min value +func (dh *DeHeap) GetMin() interface{} { + n := dh.Len() - 1 + if n < 0 { + return nil + } + lChd := 1 + if lChd > n { + return dh.items[0] + } + rChd := 2 + min := lChd + if rChd <= n && dh.Compare(lChd, rChd) { + min = rChd + } + return dh.items[min] +} + +// PopMax pops item with max value +func (dh *DeHeap) PopMax() interface{} { + n := dh.Len() - 1 + if n < 0 { + return nil + } + item := dh.items[0] + dh.Swap(0, n) + dh.items[n] = nil + dh.items = dh.items[0:n] + dh.shiftDown(0) + return item +} + +// PopMin pops item with min value +func (dh *DeHeap) PopMin() interface{} { + n := dh.Len() - 1 + if n < 0 { + return nil + } + lc := leftChild(0) + rc := rightChild(0) + if lc > n { + item := dh.items[0] + dh.items[0] = nil + dh.items = dh.items[0:n] + return item + } + t := lc + if rc <= n && dh.Compare(lc, rc) { + t = rc + } + if t >= len(dh.items) || n >= len(dh.items) { + return nil + } + item := dh.items[t] + dh.Swap(t, n) + dh.items[n] = nil + dh.items = dh.items[0:n] + dh.shiftDown(t) + return item +} + +func (dh *DeHeap) shiftUp(i int) int { + if i < 0 { + return i + } + isMax := isMaxLevel(i) + p := parent(i) + if p >= 0 { + if dh.Compare(p, i) == isMax { + dh.Swap(p, i) + i = p + isMax = !isMax + } + } + for g := grandparent(i); g >= 0; g = grandparent(i) { + if dh.Compare(g, i) == isMax { + break + } + dh.Swap(g, i) + i = g + } + return i +} + +func (dh *DeHeap) shiftDown(i int) int { + if i < 0 { + return i + } + n := dh.Len() + for i < n { + isMax := isMaxLevel(i) + t := i + // check i's children + lc, rc := leftChild(i), rightChild(i) + // no need to go further if lc reaches n but should handle rc reaches n and lc doesn't + if lc >= n { + break + } + if dh.Compare(lc, t) == isMax { + t = lc + } + if rc < n && dh.Compare(rc, t) == isMax { + t = rc + } + // check i's grandchildren + for gc := leftChild(lc); gc < n && gc <= rightChild(rc); gc++ { + if dh.Compare(gc, t) == isMax { + t = gc + } + } + if t == i { + break + } + dh.Swap(i, t) + i = t + // t is i's children, which means i has no conflict with its grandchildren who stand in the same max/min level + // with i, no need to go further + if t == lc || t == rc { + break + } + // t is i's grandchildren, need to check if t has conflict with t's parent + p := parent(t) + if dh.Compare(p, t) == isMax { + dh.Swap(p, t) + i = p + } + } + return i +} + +func isMaxLevel(i int) bool { + level := bits.Len(uint(i)+1) - 1 + return level%2 == 0 // whether the given integer i is at the maximum level +} + +func parent(i int) int { + return (i - 1) / 2 // find the parent node's index +} + +func grandparent(i int) int { + return ((i + 1) / 4) - 1 // find the grandparent node's index +} + +func leftChild(i int) int { + return i*2 + 1 // find the leftChild node's index +} + +func rightChild(i int) int { + return i*2 + 2 // find the rightChild node's index +} diff --git a/yuanrong/pkg/common/faas_common/queue/priorityqueue_test.go b/yuanrong/pkg/common/faas_common/queue/priorityqueue_test.go new file mode 100644 index 0000000..64e2228 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/queue/priorityqueue_test.go @@ -0,0 +1,274 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package queue - +package queue + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDeHeap(t *testing.T) { + items := []*Item{ + { + Obj: "apple", + Priority: 16, + }, + { + Obj: "banana", + Priority: 15, + }, + { + Obj: "berry", + Priority: 17, + }, + { + Obj: "cherry", + Priority: 14, + }, + { + Obj: "grape", + Priority: 18, + }, + { + Obj: "lemon", + Priority: 13, + }, + { + Obj: "haw", + Priority: 12, + }, + { + Obj: "mango", + Priority: 19, + }, + { + Obj: "orange", + Priority: 20, + }, + { + Obj: "watermelon", + Priority: 11, + }, + } + dh := NewDeHeap() + for _, item := range items { + dh.Push(item) + } + popMax1 := dh.PopMax().(*Item).Obj + popMin1 := dh.PopMin().(*Item).Obj + assert.Equal(t, "orange", popMax1.(string)) + assert.Equal(t, "watermelon", popMin1.(string)) + getMax1 := dh.GetMax().(*Item).Obj + getMin1 := dh.GetMin().(*Item).Obj + assert.Equal(t, "mango", getMax1.(string)) + assert.Equal(t, "haw", getMin1.(string)) + dh.items[3].Priority = 11 + dh.Fix(3) + getMin2 := dh.GetMin().(*Item).Obj + assert.Equal(t, "grape", getMin2.(string)) + dh.items[1].Priority = 20 + dh.Fix(1) + getMax2 := dh.GetMax().(*Item).Obj + assert.Equal(t, "grape", getMax2.(string)) + dh.Remove(2) + getMin3 := dh.GetMin().(*Item).Obj + assert.Equal(t, "lemon", getMin3.(string)) +} + +func TestPriorityQueue(t *testing.T) { + type testItem struct { + id string + priority int + } + identityFunc := func(obj interface{}) string { + if item, ok := obj.(*testItem); ok { + return item.id + } + return "" + } + priorityFunc := func(obj interface{}) (int, error) { + if item, ok := obj.(*testItem); ok { + return item.priority, nil + } + return -1, fmt.Errorf("failed to get priority") + } + item1 := &testItem{id: "1", priority: 50} + item2 := &testItem{id: "2", priority: 51} + item3 := &testItem{id: "3", priority: 51} + item4 := &testItem{id: "4", priority: 51} + item5 := &testItem{id: "5", priority: 60} + queue := NewPriorityQueue(identityFunc, priorityFunc) + frontItem1 := queue.Front() + backItem1 := queue.Back() + assert.Equal(t, nil, frontItem1) + assert.Equal(t, nil, backItem1) + + popBack1 := queue.PopBack() + assert.Equal(t, nil, popBack1) + popFront1 := queue.PopFront() + assert.Equal(t, nil, popFront1) + + queue.PushBack(item1) + frontItem2 := queue.Front().(*testItem) + backItem2 := queue.Back().(*testItem) + assert.Equal(t, "1", frontItem2.id) + assert.Equal(t, "1", backItem2.id) + + queue.PushBack(item2) + queue.PushBack(item3) + frontItem3 := queue.Front().(*testItem) + backItem3 := queue.Back().(*testItem) + assert.Equal(t, 51, frontItem3.priority) + assert.Equal(t, 50, backItem3.priority) + + queue.PushBack(item4) + queue.PushBack(item5) + frontItem4 := queue.Front().(*testItem) + backItem4 := queue.Back().(*testItem) + assert.Equal(t, 60, frontItem4.priority) + assert.Equal(t, 50, backItem4.priority) + + item2.priority = 40 + queue.UpdateObjByID("2", item2) + backItem5 := queue.Back().(*testItem) + assert.Equal(t, "2", backItem5.id) + item3.priority = 70 + queue.UpdateObjByID("3", item3) + frontItem5 := queue.Front().(*testItem) + assert.Equal(t, "3", frontItem5.id) + + queue.DelByID("4") + frontItem6 := queue.Front().(*testItem) + backItem6 := queue.Back().(*testItem) + assert.Equal(t, "3", frontItem6.id) + assert.Equal(t, "2", backItem6.id) + + item3.priority = 40 + queue.UpdateObjByID("3", item3) + frontItem7 := queue.Front().(*testItem) + assert.Equal(t, "5", frontItem7.id) + + item2.priority = 30 + queue.UpdateObjByID("2", item2) + backItem7 := queue.Back().(*testItem) + assert.Equal(t, "2", backItem7.id) + + getByID1 := queue.GetByID("qwe") + assert.Equal(t, getByID1, nil) + + getByID2 := queue.GetByID("2").(*testItem) + assert.Equal(t, "2", getByID2.id) + + popBack2 := queue.PopBack().(*testItem) + assert.Equal(t, "2", popBack2.id) + popFront2 := queue.PopFront().(*testItem) + assert.Equal(t, "5", popFront2.id) + + length := queue.Len() + assert.Equal(t, 2, length) +} + +func TestPriorityQueueUpdateFrontInSequence(t *testing.T) { + type testItem struct { + id string + priority int + } + identityFunc := func(obj interface{}) string { + if item, ok := obj.(*testItem); ok { + return item.id + } + return "" + } + priorityFunc := func(obj interface{}) (int, error) { + if item, ok := obj.(*testItem); ok { + return item.priority, nil + } + return -1, fmt.Errorf("failed to get priority") + } + items := []*testItem{ + &testItem{id: "1", priority: 2}, + &testItem{id: "2", priority: 2}, + &testItem{id: "3", priority: 2}, + &testItem{id: "4", priority: 2}, + } + queue := NewPriorityQueue(identityFunc, priorityFunc) + for _, item := range items { + queue.PushBack(item) + } + for { + front := queue.Front().(*testItem) + if front.priority == 0 { + break + } + front.priority -= 1 + queue.UpdateObjByID(front.id, front) + } + queue.Range(func(obj interface{}) bool { + item := obj.(*testItem) + if item.priority != 0 { + t.Errorf("item %s priority %d should be 0", item.id, item.priority) + } + return true + }) +} + +func TestPriorityQueue_SortedRange(t *testing.T) { + type testItem struct { + id string + priority int + } + identityFunc := func(obj interface{}) string { + if item, ok := obj.(*testItem); ok { + return item.id + } + return "" + } + priorityFunc := func(obj interface{}) (int, error) { + if item, ok := obj.(*testItem); ok { + return item.priority, nil + } + return -1, fmt.Errorf("failed to get priority") + } + items := []*testItem{ + &testItem{id: "1", priority: 1}, + &testItem{id: "2", priority: 2}, + &testItem{id: "3", priority: 3}, + &testItem{id: "4", priority: 4}, + } + queue := NewPriorityQueue(identityFunc, priorityFunc) + for _, item := range items { + queue.PushBack(item) + } + rangeItems := make([]*testItem, 0, 4) + queue.SortedRange(func(obj interface{}) bool { + item := obj.(*testItem) + rangeItems = append(rangeItems, item) + return true + }) + for i, item := range rangeItems { + if i+item.priority != 4 { + t.Errorf("range item %+v in wrong order %d\n", item, i) + } + } + front := queue.PopFront().(*testItem) + back := queue.PopBack().(*testItem) + assert.Equal(t, 4, front.priority) + assert.Equal(t, 1, back.priority) +} diff --git a/yuanrong/pkg/common/faas_common/queue/queue.go b/yuanrong/pkg/common/faas_common/queue/queue.go new file mode 100644 index 0000000..4635a9d --- /dev/null +++ b/yuanrong/pkg/common/faas_common/queue/queue.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package queue - +package queue + +import "errors" + +const ( + // keyNotFoundIndex stands for index of a non exist key + keyNotFoundIndex = -1 +) + +var ( + // ErrObjectNotFound is the error of object not found + ErrObjectNotFound = errors.New("object not found") + // ErrMethodUnsupported is the error of method unsupported + ErrMethodUnsupported = errors.New("method unsupported") +) + +// IdentityFunc will get ID from object in queue +type IdentityFunc func(interface{}) string + +// Queue is interface of queue used in faas pattern +type Queue interface { + Front() interface{} + Back() interface{} + PopFront() interface{} + PopBack() interface{} + PushBack(obj interface{}) error + GetByID(objID string) interface{} + DelByID(objID string) error + UpdateObjByID(objID string, obj interface{}) error + Len() int + Range(f func(obj interface{}) bool) + SortedRange(f func(obj interface{}) bool) +} diff --git a/yuanrong/pkg/common/faas_common/redisclient/redisclient.go b/yuanrong/pkg/common/faas_common/redisclient/redisclient.go new file mode 100644 index 0000000..3cb83d7 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/redisclient/redisclient.go @@ -0,0 +1,510 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package redisclient new a redis client +package redisclient + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "strings" + "sync" + "time" + + "github.com/redis/go-redis/v9" + + "yuanrong/pkg/common/faas_common/logger/log" + commonTLS "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/common/faas_common/utils" +) + +const ( + // timeout : allow TCP reconnection for 3 times(1, 2, 4) + dialTimeout = 8 * time.Second + readTimeout = 8 * time.Second + writeTimeout = 8 * time.Second + idleTimeout = 300 * time.Second + defaultDialTimeout = 8 + defaultReadTimeout = 8 + defaultWriteTimeout = 8 + defaultIdleTimeout = 300 + defaultRedisConn = 20 + // TTL - + TTL = 1 * time.Minute + maxRetryTimes = 3 + // DefaultCAFile is the default ca file for tls client + DefaultCAFile = "/home/sn/resource/redis-secret/ca.pem" + // DefaultCertFile is the default cert file for tls client + DefaultCertFile = "/home/sn/resource/redis-secret/cert.pem" + // DefaultKeyFile is the default key file for tls client + DefaultKeyFile = "/home/sn/resource/redis-secret/key.pem" + // redisStringFile is the temp file to store string type data of redis + redisStringFile = "/tmp/redis-string" + // redisStringFile is the temp file to store slice type data of redis + redisSliceFile = "/tmp/redis-slice" + redisSeparator = "%WITH%" + + // the detection is performed every 5 seconds. + healthCheckIntervalTime = 5 + // 2 * 60min * 60s / 5 second, trigger every 5 minutes + twoHoursCount = 2 * 60 * 60 / healthCheckIntervalTime + success = 0 + fail = 1 + redisValueIndex = 2 + redisReconnectionInternal = 10 * time.Second + // DefaultRedisContextTimeout - + DefaultRedisContextTimeout = time.Second + + redisRetryTimes = 3 + redisRetryInterval = 100 * time.Millisecond +) + +var ( + errMode = errors.New("serverMode is not single or cluster") + // RedisClient - + redisClient = &Client{ + client: &redis.Client{}, + option: redisClientOption{}, + connected: false, + RWMutex: sync.RWMutex{}, + } + // defaultTimeoutConf is the default timeout conf + defaultTimeoutConf = TimeoutConf{ + DialTimeout: defaultDialTimeout, + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + IdleTimeout: defaultIdleTimeout, + } +) + +var ( + mu sync.RWMutex + redisCmd *Client +) + +// Option - +type Option func(*redisClientOption) + +type redisClientOption struct { + tlsConfig *tls.Config + serverAddr string + dialTimeout time.Duration + readTimeout time.Duration + writeTimeout time.Duration + idleTimeout time.Duration + password string + serverMode string + enableTLS bool + hotloadConfFunc func() (string, TimeoutConf, error) + enableAlarm bool +} + +// RedisOperation - +type RedisOperation struct { + Key string + Value string + Method string + TTL time.Duration +} + +// Client - +type Client struct { + client redis.Cmdable + option redisClientOption + connected bool + sync.RWMutex +} + +// Config is the config of redis client +type Config struct { + ClusterID string `json:"clusterID,omitempty" valid:",optional"` + ServerAddr string `json:"serverAddr,omitempty" valid:",optional"` + ServerMode string `json:"serverMode,omitempty" valid:",optional"` + Password string `json:"password,omitempty" valid:",optional"` + EnableTLS bool `json:"enableTLS,omitempty" valid:",optional"` + TimeoutConf TimeoutConf `json:"timeoutConf,omitempty" valid:",optional"` +} + +// TimeoutConf A variety of timeout configurations +type TimeoutConf struct { + DialTimeout int `json:"dialTimeout,omitempty" valid:",optional"` + ReadTimeout int `json:"readTimeout,omitempty" valid:",optional"` + WriteTimeout int `json:"writeTimeout,omitempty" valid:",optional"` + IdleTimeout int `json:"idleTimeout,omitempty" valid:",optional"` +} + +// NewRedisClientParam parameters of a new redis client +type NewRedisClientParam struct { + ServerMode string + ServerAddr string + Password string + Timeout TimeoutConf + EnableTLS bool `json:"enableTLS,omitempty" valid:",optional"` + HotloadConfFunc func() (string, TimeoutConf, error) +} + +// GetRedisCmd - +func GetRedisCmd() *Client { + mu.Lock() + client := redisCmd + mu.Unlock() + return client +} + +// SetRedisCmd - +func SetRedisCmd(client *Client) { + mu.Lock() + redisCmd = client + mu.Unlock() +} + +// SetEnableTLS - +func SetEnableTLS(enableTLS bool) Option { + return func(c *redisClientOption) { + c.enableTLS = enableTLS + } +} + +// ZCard - +func (c *Client) ZCard(ctx context.Context, key string) *redis.IntCmd { + c.RLock() + cli := c.client + c.RUnlock() + return cli.ZCard(ctx, key) +} + +// ZRange - +func (c *Client) ZRange(ctx context.Context, key string, start, stop int64) *redis.StringSliceCmd { + c.RLock() + cli := c.client + c.RUnlock() + return cli.ZRange(ctx, key, start, stop) +} + +// ZRem - +func (c *Client) ZRem(ctx context.Context, key string, members ...interface{}) *redis.IntCmd { + c.RLock() + cli := c.client + c.RUnlock() + return cli.ZRem(ctx, key, members...) +} + +// ZAdd - +func (c *Client) ZAdd(ctx context.Context, key string, members ...redis.Z) *redis.IntCmd { + c.RLock() + cli := c.client + c.RUnlock() + return cli.ZAdd(ctx, key, members...) +} + +// Ping - +func (c *Client) Ping(ctx context.Context) *redis.StatusCmd { + c.RLock() + cli := c.client + c.RUnlock() + return cli.Ping(ctx) +} + +// Expire - +func (c *Client) Expire(ctx context.Context, key string, expiration time.Duration) *redis.BoolCmd { + c.RLock() + cli := c.client + c.RUnlock() + return cli.Expire(ctx, key, expiration) +} + +// Get - +func (c *Client) Get(ctx context.Context, key string) *redis.StringCmd { + c.RLock() + cli := c.client + c.RUnlock() + return cli.Get(ctx, key) +} + +// Del - +func (c *Client) Del(ctx context.Context, keys ...string) *redis.IntCmd { + c.RLock() + cli := c.client + c.RUnlock() + return cli.Del(ctx, keys...) +} + +// ZADDMetricsToRedis - +func ZADDMetricsToRedis(key string, metrics interface{}, limit int64, expireTime time.Duration) error { + redisCmd := GetRedisCmd() + if redisCmd == nil { + log.GetLogger().Errorf("redis client is nil") + return errors.New("redis client is nil") + } + count, err := redisCmd.ZCard(context.TODO(), key).Result() + if err != nil { + log.GetLogger().Errorf("failed to ZCard metrics key %s from redis, err: %s", key, err.Error()) + return err + } + // if count reach limit, delete the earliest metric + if count >= limit { + earliestValues, err := redisCmd.ZRange(context.TODO(), key, 0, 0).Result() + if err != nil { + log.GetLogger().Errorf("failed to ZRange metrics key %s from redis, err: %s", key, err.Error()) + return err + } + _, err = redisCmd.ZRem(context.TODO(), key, earliestValues[0]).Result() + if err != nil { + log.GetLogger().Errorf("failed to ZRem metrics key %s to redis, err: %s", key, err.Error()) + return err + } + } + // Add a new value to the sorted set, with the score being the current timestamp + score := time.Now().Unix() + err = redisCmd.ZAdd(context.TODO(), key, redis.Z{Score: float64(score), Member: metrics}).Err() + if err != nil { + log.GetLogger().Errorf("failed to ZAdd metrics key %s to redis, err: %s", key, err.Error()) + return err + } + redisCmd.Expire(context.TODO(), key, expireTime) + return nil +} + +// New create a redis client +func New(newClientParam NewRedisClientParam, stopCh <-chan struct{}, options ...Option) (*Client, error) { + o := getNewRedisOption(newClientParam) + for _, option := range options { + option(&o) + } + + var redisCMD redis.Cmdable + switch newClientParam.ServerMode { + case "single": + redisCMD = newSingleClient(o) + case "cluster": + redisCMD = newClusterClient(o) + default: + utils.ClearStringMemory(o.password) + return nil, errMode + } + + if redisCMD == nil { + return nil, errors.New("failed to new redis cmd") + } + finished := make(chan int) + go connectRedis(redisCMD, finished, o) + select { + case i, ok := <-finished: + if ok && i == fail { + return nil, errors.New("failed to connect redis server") + } + case <-time.After(o.dialTimeout): + log.GetLogger().Errorf("dialing redis server error with incorrect ip address:%s.", o.serverAddr) + return nil, errors.New("dialing redis server timeout") + } + redisClient = &Client{ + client: redisCMD, + option: o, + connected: true, + RWMutex: sync.RWMutex{}, + } + return redisClient, nil +} + +func getNewRedisOption(param NewRedisClientParam) redisClientOption { + o := redisClientOption{ + serverAddr: param.ServerAddr, + password: param.Password, + serverMode: param.ServerMode, + } + if param.Timeout.DialTimeout > 0 { + o.dialTimeout = time.Duration(param.Timeout.DialTimeout) * time.Second + log.GetLogger().Infof("new dialTimeout: %d", param.Timeout.DialTimeout) + } else { + o.dialTimeout = dialTimeout + } + if param.Timeout.ReadTimeout > 0 { + o.readTimeout = time.Duration(param.Timeout.ReadTimeout) * time.Second + log.GetLogger().Infof("new readTimeout: %d", param.Timeout.ReadTimeout) + } else { + o.readTimeout = readTimeout + } + if param.Timeout.WriteTimeout > 0 { + o.writeTimeout = time.Duration(param.Timeout.WriteTimeout) * time.Second + log.GetLogger().Infof("new writeTimeout: %d", param.Timeout.WriteTimeout) + } else { + o.writeTimeout = writeTimeout + } + if param.Timeout.IdleTimeout > 0 { + o.idleTimeout = time.Duration(param.Timeout.IdleTimeout) * time.Second + log.GetLogger().Infof("new idleTimeout: %d", param.Timeout.IdleTimeout) + } else { + o.idleTimeout = idleTimeout + } + return o +} + +func newSingleClient(o redisClientOption) redis.Cmdable { + options := &redis.Options{ + PoolSize: defaultRedisConn, + Addr: o.serverAddr, + Password: o.password, + DialTimeout: o.dialTimeout, + ReadTimeout: o.readTimeout, + WriteTimeout: o.writeTimeout, + ConnMaxIdleTime: o.idleTimeout, + MaxRetries: maxRetryTimes, + } + if o.enableTLS { + tlsConfig, err := buildCfg(DefaultCAFile, DefaultCertFile, DefaultKeyFile) + if err != nil { + utils.ClearStringMemory(options.Password) + log.GetLogger().Errorf("failed to build single client tls config: %s", err.Error()) + return nil + } + options.TLSConfig = tlsConfig + } + return redis.NewClient(options) +} + +func connectRedis(redisCmd redis.Cmdable, finished chan<- int, o redisClientOption) { + if finished == nil { + return + } + var err error + for i := 0; i < maxRetryTimes; i++ { + if redisCmd == nil { + log.GetLogger().Errorf("redis is not ready") + continue + } + _, err = redisCmd.Ping(context.Background()).Result() + if err == nil { + finished <- success + return + } + } + // The key relies on go's GC for memory cleanup + log.GetLogger().Errorf("dialing redis server error: %s", err.Error()) + finished <- fail + return +} + +func newClusterClient(o redisClientOption) redis.Cmdable { + options := &redis.ClusterOptions{ + PoolSize: defaultRedisConn, + Addrs: strings.Split(o.serverAddr, ","), + Password: o.password, + DialTimeout: o.dialTimeout, + ReadTimeout: o.readTimeout, + WriteTimeout: o.writeTimeout, + ConnMaxIdleTime: o.idleTimeout, + MaxRetries: maxRetryTimes, + } + if o.enableTLS { + tlsConfig, err := buildCfg(DefaultCAFile, DefaultCertFile, DefaultKeyFile) + if err != nil { + utils.ClearStringMemory(options.Password) + log.GetLogger().Errorf("failed to build redis ClusterClient tls config: %s", err.Error()) + return nil + } + options.TLSConfig = tlsConfig + } + return redis.NewClusterClient(options) +} + +func buildCfg(caFile string, certFile string, keyFile string) (*tls.Config, error) { + var pools *x509.CertPool + var err error + pools, err = commonTLS.GetX509CACertPool(caFile) + if err != nil { + log.GetLogger().Errorf("failed to get X509 CACert Pool: %s", err.Error()) + return nil, err + } + + var certs []tls.Certificate + if certs, err = commonTLS.LoadServerTLSCertificate(certFile, keyFile, "", "LOCAL", false); err != nil { + log.GetLogger().Errorf("failed to load Server TLS Certificate: %s", err.Error()) + return nil, err + } + + clientAuth := tls.NoClientCert + tlsConfig := &tls.Config{ + RootCAs: pools, + Certificates: certs, + ClientAuth: clientAuth, + } + return tlsConfig, nil +} + +// CheckRedisConnectivity - +func CheckRedisConnectivity(clientRedisConfig *NewRedisClientParam, client *Client, stopCh <-chan struct{}) { + if stopCh == nil { + log.GetLogger().Errorf("stopCh is nil") + return + } + ticker := time.NewTicker(redisReconnectionInternal) + for { + select { + case <-ticker.C: + if err := checkAndReconnectRedis(clientRedisConfig, client, stopCh); err != nil { + log.GetLogger().Errorf("failed to check or reconnect redis client, err:%s", err.Error()) + } + case <-stopCh: + log.GetLogger().Errorf("module process exit") + ticker.Stop() + return + } + } +} + +func checkAndReconnectRedis(clientRedisConfig *NewRedisClientParam, client *Client, stopCh <-chan struct{}) error { + log.GetLogger().Debug("redis check redis connection start") + if client != nil { + _, err := (*client).Ping(context.TODO()).Result() + if err == nil { + log.GetLogger().Debug("redis periodically checks availability") + return nil + } + } + newClient, err := initClient(clientRedisConfig, stopCh) + if err != nil { + return err + } + if client != nil { + client = newClient + } + SetRedisCmd(newClient) + return nil +} + +func initClient(clientRedisConfig *NewRedisClientParam, stopCh <-chan struct{}) (*Client, error) { + c, err := New(NewRedisClientParam{ + ServerMode: clientRedisConfig.ServerMode, + ServerAddr: clientRedisConfig.ServerAddr, + Password: clientRedisConfig.Password, + Timeout: clientRedisConfig.Timeout, + }, stopCh, SetEnableTLS(clientRedisConfig.EnableTLS), + SetGetRealTimeServerAddrFunc(clientRedisConfig.HotloadConfFunc)) + if err != nil { + log.GetLogger().Errorf("failed to new a redis Client, %s", err.Error()) + return nil, err + } + return c, nil +} + +// SetGetRealTimeServerAddrFunc hot update server address when disconnected +func SetGetRealTimeServerAddrFunc(getServerAddr func() (string, TimeoutConf, error)) Option { + return func(c *redisClientOption) { + c.hotloadConfFunc = getServerAddr + } +} diff --git a/yuanrong/pkg/common/faas_common/redisclient/redisclient_test.go b/yuanrong/pkg/common/faas_common/redisclient/redisclient_test.go new file mode 100644 index 0000000..0033086 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/redisclient/redisclient_test.go @@ -0,0 +1,457 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package redisclient + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/redis/go-redis/v9" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + commonTLS "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/common/faas_common/utils" +) + +func TestZADDMetricsToRedis(t *testing.T) { + err := ZADDMetricsToRedis("mockKey", 1, 3, 5*time.Second) + assert.NotNil(t, err) + convey.Convey("TestZADDMetricsToRedis", t, func() { + convey.Convey("ZCard exception", func() { + redisCmd = &Client{client: &redis.Client{}} + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&redis.Client{}), "ZCard", + func(cli *redis.Client, ctx context.Context, key string) *redis.IntCmd { + return &redis.IntCmd{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.IntCmd{}), "Result", + func(_ *redis.IntCmd) (int64, error) { + return 0, errors.New("mock ZCard error") + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + err = ZADDMetricsToRedis("mockKey", 1, 3, 5*time.Second) + assert.NotNil(t, err) + assert.Equal(t, "mock ZCard error", err.Error()) + }) + + convey.Convey("ZRange exception", func() { + redisCmd = &Client{client: &redis.Client{}} + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&redis.Client{}), "ZCard", + func(cli *redis.Client, ctx context.Context, key string) *redis.IntCmd { + return &redis.IntCmd{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.IntCmd{}), "Result", + func(_ *redis.IntCmd) (int64, error) { + return 3, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.Client{}), "ZRange", + func(cli *redis.Client, ctx context.Context, key string, start, stop int64) *redis.StringSliceCmd { + return &redis.StringSliceCmd{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.StringSliceCmd{}), "Result", + func(_ *redis.StringSliceCmd) ([]string, error) { + return nil, errors.New("mock ZRange error") + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + err = ZADDMetricsToRedis("mockKey", 1, 3, 5*time.Second) + assert.NotNil(t, err) + assert.Equal(t, "mock ZRange error", err.Error()) + }) + + convey.Convey("ZAdd success", func() { + redisCmd = &Client{client: &redis.Client{}} + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&redis.Client{}), "ZCard", + func(cli *redis.Client, ctx context.Context, key string) *redis.IntCmd { + return &redis.IntCmd{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.IntCmd{}), "Result", + func(_ *redis.IntCmd) (int64, error) { + return 3, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.Client{}), "ZRange", + func(cli *redis.Client, ctx context.Context, key string, start, stop int64) *redis.StringSliceCmd { + return &redis.StringSliceCmd{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.StringSliceCmd{}), "Result", + func(_ *redis.StringSliceCmd) ([]string, error) { + return []string{"0", "1", "2"}, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.Client{}), "ZRem", + func(cli *redis.Client, ctx context.Context, key string, members ...interface{}) *redis.IntCmd { + return &redis.IntCmd{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.Client{}), "ZAdd", + func(cli *redis.Client, ctx context.Context, key string, members ...redis.Z) *redis.IntCmd { + return &redis.IntCmd{} + }), + gomonkey.ApplyMethod(reflect.TypeOf(&redis.Client{}), "Expire", + func(cli *redis.Client, ctx context.Context, key string, expiration time.Duration) *redis.BoolCmd { + return &redis.BoolCmd{} + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + err = ZADDMetricsToRedis("mockKey", 1, 3, 5*time.Second) + assert.Nil(t, err) + }) + }) +} + +func TestNew(t *testing.T) { + type args struct { + serverMode string + serverAddr string + password string + options []Option + } + var a args + var b args + option := SetEnableTLS(false) + b.serverMode = "single" + b.options = append(b.options, option) + var c args + c.serverMode = "cluster" + c.options = append(b.options, option) + c.options = append(b.options, SetGetRealTimeServerAddrFunc(func() (string, TimeoutConf, error) { + return "", TimeoutConf{}, nil + })) + tests := []struct { + name string + args args + want redis.Cmdable + wantErr bool + }{ + {"case1", a, nil, true}, + {"case2", b, nil, true}, + {"case3", c, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := New(NewRedisClientParam{ + tt.args.serverMode, + tt.args.serverAddr, + tt.args.password, + TimeoutConf{}, + false, + nil, + }, nil, tt.args.options...) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil { + t.Errorf("New() got = %v, want %v", got, tt.want) + } + }) + } + + patches := utils.InitPatchSlice() + statusCMD := &redis.StatusCmd{} + patches.Append(utils.PatchSlice{gomonkey.ApplyFunc((*redis.Client).Ping, + func(_ *redis.Client, _ context.Context) *redis.StatusCmd { + return statusCMD + })}) + defer patches.ResetAll() + _, err := New(NewRedisClientParam{ + "single", + "", + "", + TimeoutConf{}, + false, + nil, + }, nil) + if err != nil { + t.Errorf("failed to test new client with alarm switch on: %s", err.Error()) + } +} + +func Test_buildCfg(t *testing.T) { + type args struct { + caFile string + certFile string + keyFile string + } + var a args + tests := []struct { + name string + args args + want *tls.Config + wantErr bool + }{ + {"case1", a, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := buildCfg(tt.args.caFile, tt.args.certFile, tt.args.keyFile) + assert.Equalf(t, tt.want, got, "buildCfg(%v, %v, %v)", tt.args.caFile, tt.args.certFile, tt.args.keyFile) + }) + } +} + +func TestEmptyClients(t *testing.T) { + opt := redisClientOption{enableTLS: true} + redisCMD := newSingleClient(opt) + assert.Equal(t, redisCMD, nil) + redisCMD = newClusterClient(opt) + assert.Equal(t, redisCMD, nil) +} + +func TestBuildCfg(t *testing.T) { + patches := utils.InitPatchSlice() + patches.Append(utils.PatchSlice{gomonkey.ApplyFunc(commonTLS.GetX509CACertPool, + func(caCertFilePath string) (caCertPool *x509.CertPool, err error) { + return nil, nil + })}) + defer patches.ResetAll() + convey.Convey("Test build cfg error", t, func() { + tlsConfig, err := buildCfg(DefaultCAFile, DefaultCertFile, DefaultKeyFile) + convey.So(tlsConfig, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func Test_getNewRedisOption(t *testing.T) { + type args struct { + param NewRedisClientParam + } + tests := []struct { + name string + args args + want redisClientOption + }{ + { + name: "case1", + args: args{ + param: NewRedisClientParam{ + "single", + "127.0.0.1", + "aaa", + TimeoutConf{}, + false, + nil, + }, + }, + want: redisClientOption{ + serverAddr: "127.0.0.1", + serverMode: "single", + password: "aaa", + dialTimeout: dialTimeout, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + idleTimeout: idleTimeout, + }, + }, + { + name: "case1", + args: args{ + param: NewRedisClientParam{ + "single", + "127.0.0.1", + "aaa", + TimeoutConf{ + DialTimeout: 1, + ReadTimeout: 1, + WriteTimeout: 1, + IdleTimeout: 1, + }, + false, + nil, + }, + }, + want: redisClientOption{ + serverAddr: "127.0.0.1", + serverMode: "single", + password: "aaa", + dialTimeout: 1 * time.Second, + readTimeout: 1 * time.Second, + writeTimeout: 1 * time.Second, + idleTimeout: 1 * time.Second, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getNewRedisOption(tt.args.param); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getNewRedisOption() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_initClient(t *testing.T) { + convey.Convey("Test redis Client is success", t, func() { + param := NewRedisClientParam{ + ServerMode: "122", + ServerAddr: "333", + Password: "1222", + Timeout: TimeoutConf{}, + EnableTLS: false, + HotloadConfFunc: nil, + } + defer gomonkey.ApplyFunc(New, func(newClientParam NewRedisClientParam, stopCh <-chan struct{}, options ...Option) (*Client, error) { + return &Client{}, nil + }).Reset() + stopCh := make(chan struct{}) + redisClient, _ := initClient(¶m, stopCh) + convey.So(redisClient, convey.ShouldNotBeNil) + }) + convey.Convey("Test to not init redis client", t, func() { + param := NewRedisClientParam{ + ServerMode: "122", + ServerAddr: "333", + Password: "1222", + Timeout: TimeoutConf{}, + EnableTLS: false, + HotloadConfFunc: nil, + } + defer gomonkey.ApplyFunc(New, func(newClientParam NewRedisClientParam, stopCh <-chan struct{}, options ...Option) (*Client, error) { + return &Client{}, errors.New("redis is not ready") + }).Reset() + stopCh := make(chan struct{}) + _, err := initClient(¶m, stopCh) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestCheckRedisConnectivity(t *testing.T) { + var isCalled = 0 + convey.Convey("Test check redis to connect in cyclist", t, func() { + param := NewRedisClientParam{ + ServerMode: "122", + ServerAddr: "333", + Password: "1222", + Timeout: TimeoutConf{}, + EnableTLS: false, + HotloadConfFunc: nil, + } + patch1 := gomonkey.ApplyFunc(New, func(newClientParam NewRedisClientParam, stopCh <-chan struct{}, options ...Option) (*Client, error) { + isCalled++ + return &Client{}, nil + }) + patch := gomonkey.ApplyFunc((*redis.Client).Ping, + func(_ *redis.Client, _ context.Context) *redis.StatusCmd { + return &redis.StatusCmd{} + }) + defer patch.Reset() + stopCh := make(chan struct{}, 0) + tickerCh := make(chan time.Time) + patch.ApplyFunc(time.NewTicker, func(_ time.Duration) *time.Ticker { + return &time.Ticker{C: tickerCh} + }) + + go CheckRedisConnectivity(¶m, nil, stopCh) + tickerCh <- time.Time{} + stopCh <- struct{}{} + convey.So(isCalled, convey.ShouldEqual, 1) + patch1.Reset() + patch.ApplyFunc(New, func(newClientParam NewRedisClientParam, stopCh <-chan struct{}, options ...Option) (*Client, error) { + isCalled++ + return &Client{}, errors.New("state is not ready") + }) + stopCh = make(chan struct{}, 0) + go CheckRedisConnectivity(¶m, nil, stopCh) + tickerCh <- time.Time{} + tickerCh <- time.Time{} + stopCh <- struct{}{} + convey.So(isCalled, convey.ShouldEqual, 3) + + CheckRedisConnectivity(¶m, nil, nil) + convey.So(isCalled, convey.ShouldEqual, 3) + }) +} + +func TestClient_Del(t *testing.T) { + client := &Client{ + client: &redis.Client{}, + } + ctx := context.Background() + keys := []string{"key1", "key2"} + + mockResult := &redis.IntCmd{} + patches := gomonkey.ApplyMethod( + reflect.TypeOf(client.client), "Del", + func(_ redis.Cmdable, _ context.Context, _ ...string) *redis.IntCmd { + return mockResult + }, + ) + defer patches.Reset() + + result := client.Del(ctx, keys...) + + assert.Equal(t, mockResult, result) +} + +func TestClient_Get(t *testing.T) { + client := &Client{ + client: &redis.Client{}, + } + ctx := context.Background() + key := "key2" + + mockResult := &redis.StringCmd{} + patches := gomonkey.ApplyMethod( + reflect.TypeOf(client.client), "Get", + func(_ redis.Cmdable, _ context.Context, _ string) *redis.StringCmd { + return mockResult + }, + ) + defer patches.Reset() + + result := client.Get(ctx, key) + + assert.Equal(t, mockResult, result) +} +func TestClient_Ping(t *testing.T) { + client := &Client{ + client: &redis.Client{}, + } + ctx := context.Background() + + mockResult := &redis.StatusCmd{} + patches := gomonkey.ApplyMethod( + reflect.TypeOf(client.client), "Ping", + func(_ redis.Cmdable, _ context.Context) *redis.StatusCmd { + return mockResult + }, + ) + defer patches.Reset() + + result := client.Ping(ctx) + + assert.Equal(t, mockResult, result) +} diff --git a/yuanrong/pkg/common/faas_common/resspeckey/type.go b/yuanrong/pkg/common/faas_common/resspeckey/type.go new file mode 100644 index 0000000..c11747a --- /dev/null +++ b/yuanrong/pkg/common/faas_common/resspeckey/type.go @@ -0,0 +1,120 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package resspeckey - +package resspeckey + +import ( + "encoding/json" + "fmt" + "sort" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" +) + +// ResourceSpecification contains resource specification of a requested instance +type ResourceSpecification struct { + CPU int64 `json:"cpu"` + Memory int64 `json:"memory"` + InvokeLabel string `json:"invokeLabels"` + CustomResources map[string]int64 `json:"customResources"` + CustomResourcesSpec map[string]interface{} `json:"customResourcesSpec"` + EphemeralStorage int `json:"ephemeral_storage"` +} + +// DeepCopy return a ResourceSpecification Copy +func (rs *ResourceSpecification) DeepCopy() *ResourceSpecification { + customResource := map[string]int64{} + for k, v := range rs.CustomResources { + customResource[k] = v + } + customResourcesSpec := map[string]interface{}{} + for k, v := range rs.CustomResourcesSpec { + customResourcesSpec[k] = v + } + return &ResourceSpecification{ + CPU: rs.CPU, + Memory: rs.Memory, + CustomResources: customResource, + InvokeLabel: rs.InvokeLabel, + CustomResourcesSpec: customResourcesSpec, + EphemeralStorage: rs.EphemeralStorage, + } +} + +// String returns ResourceSpecification as string +func (rs *ResourceSpecification) String() string { + resourceExpression := fmt.Sprintf("cpu-%d-mem-%d", rs.CPU, rs.Memory) + for key, value := range rs.CustomResources { + if value <= constant.MinCustomResourcesSize { + continue + } + resourceExpression += fmt.Sprintf("-%s-%d", key, value) + } + keys := make([]string, 0, len(rs.CustomResourcesSpec)) + for k := range rs.CustomResourcesSpec { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + v := rs.CustomResourcesSpec[k] + resourceExpression += fmt.Sprintf("-%s-%v", k, v) + } + if rs.InvokeLabel != "" { + resourceExpression += fmt.Sprintf("-invoke-label-%s", rs.InvokeLabel) + } + resourceExpression += fmt.Sprintf("-ephemeral-storage-%v", rs.EphemeralStorage) + return resourceExpression +} + +// ResSpecKey is a representation of ResourceSpecification which can be used as key of map +type ResSpecKey struct { + CPU int64 + Memory int64 + EphemeralStorage int + CustomResources string + CustomResourcesSpec string + InvokeLabel string +} + +// String returns ResSpecKey as string +func (rsk *ResSpecKey) String() string { + return fmt.Sprintf("cpu-%d-mem-%d-storage-%d-cstRes-%s-cstResSpec-%s-invokeLabel-%s", rsk.CPU, rsk.Memory, + rsk.EphemeralStorage, rsk.CustomResources, rsk.CustomResourcesSpec, rsk.InvokeLabel) +} + +// ToResSpec convert ResSpecKey to ResourceSpecification +func (rsk *ResSpecKey) ToResSpec() *ResourceSpecification { + cstRes := map[string]int64{} + err := json.Unmarshal([]byte(rsk.CustomResources), &cstRes) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal to customResources error %s", err.Error()) + } + cstResSpec := map[string]interface{}{} + err = json.Unmarshal([]byte(rsk.CustomResourcesSpec), &cstResSpec) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal to customResourceSpec error %s", err.Error()) + } + return &ResourceSpecification{ + CPU: rsk.CPU, + Memory: rsk.Memory, + EphemeralStorage: rsk.EphemeralStorage, + CustomResources: cstRes, + CustomResourcesSpec: cstResSpec, + InvokeLabel: rsk.InvokeLabel, + } +} diff --git a/yuanrong/pkg/common/faas_common/resspeckey/util.go b/yuanrong/pkg/common/faas_common/resspeckey/util.go new file mode 100644 index 0000000..e34c8d7 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/resspeckey/util.go @@ -0,0 +1,103 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package resspeckey - +package resspeckey + +import ( + "encoding/json" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" +) + +const ( + ascendResourceD910B = "huawei.com/ascend-1980" + ascendResourceD910BInstanceType = "instanceType" +) + +// ConvertToResSpecKey converts ResourceSpecification to ResSpecKey +func ConvertToResSpecKey(resSpec *ResourceSpecification) ResSpecKey { + // for Go 1.7+ version, json.Marshal sorts the keys of map, same kv pairs will get same serialization result + var ( + cstResExp string + cstResSpecExp string + ) + if resSpec.CustomResources != nil && len(resSpec.CustomResources) != 0 { + cstResBytes, err := json.Marshal(resSpec.CustomResources) + if err != nil { + log.GetLogger().Errorf("failed to marshal customResources %#v error %s", resSpec.CustomResources, err.Error()) + } + cstResExp = string(cstResBytes) + } + if len(cstResExp) != 0 && resSpec.CustomResourcesSpec != nil && len(resSpec.CustomResourcesSpec) != 0 { + cstResSpecBytes, err := json.Marshal(resSpec.CustomResourcesSpec) + if err != nil { + log.GetLogger().Errorf("failed to marshal customResourcesSpec %#v error %s", resSpec.CustomResourcesSpec, + err.Error()) + } + cstResSpecExp = string(cstResSpecBytes) + } + return ResSpecKey{ + CPU: resSpec.CPU, + Memory: resSpec.Memory, + EphemeralStorage: resSpec.EphemeralStorage, + CustomResources: cstResExp, + CustomResourcesSpec: cstResSpecExp, + InvokeLabel: resSpec.InvokeLabel, + } +} + +// GetResKeyFromStr - +func GetResKeyFromStr(note string) (ResSpecKey, error) { + resSpec := &ResourceSpecification{} + err := json.Unmarshal([]byte(note), resSpec) + if err != nil { + return ResSpecKey{}, err + } + return ConvertToResSpecKey(resSpec), nil +} + +// ConvertResourceMetaDataToResSpec will convert resource metadata +func ConvertResourceMetaDataToResSpec(resMeta types.ResourceMetaData) *ResourceSpecification { + customResources := map[string]int64{} + if resMeta.CustomResources != "" { + if err := json.Unmarshal([]byte(resMeta.CustomResources), &customResources); err != nil { + log.GetLogger().Warnf("failed to unmarshal custom resources %s, err: %s", + resMeta.CustomResources, err.Error()) + } + } + customResourcesSpec := make(map[string]interface{}) + // npu tag may be unspecified and be updated to 376T, default value is needed to be set, otherwise reserved instance + // will be recreated + err := json.Unmarshal([]byte(resMeta.CustomResourcesSpec), &customResourcesSpec) + if resMeta.CustomResourcesSpec != "" && err != nil { + log.GetLogger().Warnf("failed to unmarshal custom resourcesSpec: %s, err: %s", + resMeta.CustomResourcesSpec, err.Error()) + } + if _, ok := customResources[ascendResourceD910B]; ok { + if _, ok := customResourcesSpec[ascendResourceD910BInstanceType]; !ok { + customResourcesSpec[ascendResourceD910BInstanceType] = "376T" + } + } + return &ResourceSpecification{ + CPU: resMeta.CPU, + Memory: resMeta.Memory, + CustomResources: customResources, + CustomResourcesSpec: customResourcesSpec, + EphemeralStorage: resMeta.EphemeralStorage, + } +} diff --git a/yuanrong/pkg/common/faas_common/resspeckey/util_test.go b/yuanrong/pkg/common/faas_common/resspeckey/util_test.go new file mode 100644 index 0000000..007be96 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/resspeckey/util_test.go @@ -0,0 +1,52 @@ +package resspeckey + +import ( + "encoding/json" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/types" +) + +func TestResSpecKey(t *testing.T) { + resSpec := &ResourceSpecification{ + CPU: 100, + Memory: 100, + CustomResources: map[string]int64{"NPU": 1}, + CustomResourcesSpec: map[string]interface{}{"Type": "type1"}, + InvokeLabel: "label1", + } + resKey := ConvertToResSpecKey(resSpec) + resKeyString := resKey.String() + assert.Equal(t, "cpu-100-mem-100-storage-0-cstRes-{\"NPU\":1}-cstResSpec-{\"Type\":\"type1\"}-invokeLabel-label1", resKeyString) + resSpec1 := resKey.ToResSpec() + assert.Equal(t, int64(100), resSpec1.CPU) + assert.Equal(t, int64(100), resSpec1.Memory) + assert.Equal(t, "label1", resSpec1.InvokeLabel) +} + +func TestConvertResourceMetaData(t *testing.T) { + convey.Convey("test ConvertResourceMetaData", t, func() { + convey.Convey("Unmarshal error", func() { + resMeta := types.ResourceMetaData{ + CustomResourcesSpec: "huawei.com/ascend-1980:D910B", + CustomResources: "", + } + resource := ConvertResourceMetaDataToResSpec(resMeta) + convey.So(len(resource.CustomResources), convey.ShouldEqual, 0) + }) + convey.Convey("Convert success", func() { + customResources := map[string]int64{"huawei.com/ascend-1980": 10} + data, _ := json.Marshal(customResources) + resMeta := types.ResourceMetaData{ + CustomResourcesSpec: "CustomResourcesSpec", + CustomResources: string(data), + } + resource := ConvertResourceMetaDataToResSpec(resMeta) + convey.So(resource.CustomResourcesSpec[ascendResourceD910BInstanceType], + convey.ShouldEqual, "376T") + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/signals/signal.go b/yuanrong/pkg/common/faas_common/signals/signal.go new file mode 100644 index 0000000..8a43d53 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/signals/signal.go @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package signals - +package signals + +import ( + "os" + "os/signal" + "syscall" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +var ( + shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM, syscall.SIGKILL} + onlyOneSignalHandler = make(chan struct{}) + shutdownHandler chan os.Signal + stopCh = make(chan struct{}) +) + +const channelCount = 2 + +func init() { + // 2 is the length of shutdown Handler channel + shutdownHandler = make(chan os.Signal, channelCount) + + signal.Notify(shutdownHandler, shutdownSignals...) + + go func() { + <-shutdownHandler + close(stopCh) + <-shutdownHandler + log.GetLogger().Sync() + os.Exit(1) + }() +} + +// WaitForSignal defines signal handler process. +func WaitForSignal() <-chan struct{} { + return stopCh +} diff --git a/yuanrong/pkg/common/faas_common/signals/signal_test.go b/yuanrong/pkg/common/faas_common/signals/signal_test.go new file mode 100644 index 0000000..7314a2a --- /dev/null +++ b/yuanrong/pkg/common/faas_common/signals/signal_test.go @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package signals + +import ( + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWaitForSignal(t *testing.T) { + stopCh := WaitForSignal() + + go func() { + time.Sleep(200 * time.Millisecond) + shutdownHandler <- syscall.SIGTERM + }() + select { + case <-stopCh: + t.Log("received termination signal") + case <-time.After(time.Second): + t.Fatal("failed to signal in 1s") + } + + _, ok := <-stopCh + assert.Equal(t, ok, false) +} diff --git a/yuanrong/pkg/common/faas_common/snerror/snerror.go b/yuanrong/pkg/common/faas_common/snerror/snerror.go new file mode 100644 index 0000000..1410d12 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/snerror/snerror.go @@ -0,0 +1,86 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package snerror is basic information contained in the SN error. +package snerror + +const ( + // UserErrorMax is maximum value of user error + UserErrorMax = 4999 + // UserErrorMin is minimal value of user error + UserErrorMin = 4000 + // ErrorSeparator split error codes and error information. + ErrorSeparator = "|" +) + +// BadResponse HTTP request message that does not return 200 +type BadResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// SNError defines the action contained in the SN error information. +type SNError interface { + // Code Returned error code + Code() int + + Error() string +} + +type snError struct { + code int + message string +} + +// New returns an error. +// message is a complete English sentence with punctuation. +func New(code int, message string) SNError { + return &snError{ + code: code, + message: message, + } +} + +// NewWithError err not nil. +func NewWithError(code int, err error) SNError { + var message = "" + if err != nil { + message = err.Error() + } + return &snError{ + code: code, + message: message, + } +} + +// Code Returned error code +func (s *snError) Code() int { + return s.code +} + +// Error Implement the native error interface. +func (s *snError) Error() string { + return s.message +} + +// IsUserError true if a user error occurs +func IsUserError(s SNError) bool { + // The user error is a four-digit integer. + if UserErrorMin <= s.Code() && s.Code() <= UserErrorMax { + return true + } + return false +} diff --git a/yuanrong/pkg/common/faas_common/snerror/snerror_test.go b/yuanrong/pkg/common/faas_common/snerror/snerror_test.go new file mode 100644 index 0000000..514edd7 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/snerror/snerror_test.go @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package snerror - +package snerror + +import ( + "fmt" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestSnError(t *testing.T) { + convey.Convey("New", t, func() { + snErr := New(1000, "test error") + convey.So(snErr.Code(), convey.ShouldEqual, 1000) + convey.So(snErr.Error(), convey.ShouldEqual, "test error") + res := IsUserError(snErr) + convey.So(res, convey.ShouldEqual, false) + }) + convey.Convey("NewWithError", t, func() { + snErr := NewWithError(1000, fmt.Errorf("test error")) + convey.So(snErr.Code(), convey.ShouldEqual, 1000) + convey.So(snErr.Error(), convey.ShouldEqual, "test error") + }) +} diff --git a/yuanrong/pkg/common/faas_common/state/observer.go b/yuanrong/pkg/common/faas_common/state/observer.go new file mode 100644 index 0000000..ce568f2 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/state/observer.go @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package state - +package state + +import ( + "context" + "fmt" + + "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/etcd3" +) + +// Observer - +type Observer interface { + Update(value interface{}, tags ...string) // add & update state to datasystem +} + +// Queue is used to cache the state processing queue +type Queue struct { + client *etcd3.EtcdClient + queue chan stateData +} + +// stateData is a state input parameter structure +type stateData struct { + data interface{} + tags []string +} + +const ( + maxQueueSize = 10000 + defaultQueueSize = 1000 +) + +// NewStateQueue - +func NewStateQueue(size int) *Queue { + if size > maxQueueSize || size <= 0 { + size = defaultQueueSize + } + client := etcd3.GetRouterEtcdClient() + if client == nil { + return nil + } + return &Queue{ + queue: make(chan stateData, size), + client: client, + } +} + +// SaveState - +func (q *Queue) SaveState(state []byte, key string) error { + ctx := etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), etcd3.DurationContextTimeout) + return q.client.Put(ctx, key, string(state)) +} + +// GetState - get state from etcd with key +func (q *Queue) GetState(key string) ([]byte, error) { + ctx := etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), etcd3.DurationContextTimeout) + response, err := q.client.GetResponse(ctx, key, clientv3.WithSerializable()) + if err != nil { + return nil, err + } + if len(response.Kvs) == 0 { + return nil, fmt.Errorf("get empty state from etcd") + } + return response.Kvs[0].Value, nil +} + +// DeleteState - +func (q *Queue) DeleteState(key string) error { + ctx := etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), etcd3.DurationContextTimeout) + return q.client.Delete(ctx, key, clientv3.WithPrefix()) +} + +// Push - +func (q *Queue) Push(value interface{}, tags ...string) error { + select { + case q.queue <- stateData{ + data: value, + tags: tags, + }: + return nil + default: + return fmt.Errorf("state queue is full, can not write data") + } +} + +// Run - +func (q *Queue) Run(handler func(value interface{}, tags ...string)) { + for state := range q.queue { + handler(state.data, state.tags...) + } +} diff --git a/yuanrong/pkg/common/faas_common/state/observer_test.go b/yuanrong/pkg/common/faas_common/state/observer_test.go new file mode 100644 index 0000000..ae15106 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/state/observer_test.go @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package state - +package state + +import ( + "testing" + + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "yuanrong/pkg/common/faas_common/etcd3" +) + +func TestNewStateQueue(t *testing.T) { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + convey.Convey("get queue", t, func() { + q := NewStateQueue(10) + q.queue <- stateData{} + convey.So(len(q.queue), convey.ShouldEqual, 1) + }) + convey.Convey("get queue", t, func() { + q := NewStateQueue(-1) + q.queue <- stateData{} + convey.So(len(q.queue), convey.ShouldEqual, 1) + }) +} + +func TestStateOperation(t *testing.T) { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + q := NewStateQueue(10) + convey.Convey("save state", t, func() { + defer gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, value string, opts ...clientv3.OpOption) error { + return nil + }).Reset() + err := q.SaveState(nil, "testKey") + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("get state", t, func() { + defer gomonkey.ApplyFunc((*etcd3.EtcdClient).GetResponse, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{Kvs: []*mvccpb.KeyValue{{Key: nil, Value: nil}}}, nil + }).Reset() + _, err := q.GetState("testKey") + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("get state", t, func() { + err := q.Push("someData", "someKey") + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/common/faas_common/statuscode/statuscode.go b/yuanrong/pkg/common/faas_common/statuscode/statuscode.go new file mode 100644 index 0000000..7db7b41 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/statuscode/statuscode.go @@ -0,0 +1,490 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package statuscode define status code of Frontend +package statuscode + +import ( + "errors" + "net/http" + "regexp" + "strconv" + "strings" + + "github.com/valyala/fasthttp" +) + +// system error code +const ( + // InnerResponseSuccessCode - + InnerResponseSuccessCode = 0 + // InternalErrorCode if the value is 331404, try again. + InternalErrorCode = 330404 + InternalRetryErrCode = 331404 + InternalErrorMessage = "internal system error" + + // InnerInstanceCircuitCode need retry + InnerInstanceCircuitCode = 4011 + + // BackpressureCode indicate that frontend should choose another proxy/worker and retry + BackpressureCode = 211429 +) + +// frontend error code +const ( + // FrontendStatusOk ok code + FrontendStatusOk = 200200 + // FrontendStatusAccepted - + FrontendStatusAccepted = 200202 + // FrontendStatusNoContent - + FrontendStatusNoContent = 200204 + + // FrontendStatusBadRequest - + FrontendStatusBadRequest = 200400 + // FrontendStatusUnAuthorized - + FrontendStatusUnAuthorized = 200401 + // FrontendStatusForbidden - + FrontendStatusForbidden = 200403 + // FrontendStatusNotFound - + FrontendStatusNotFound = 200404 + // FrontendStatusRequestEntityTooLarge - + FrontendStatusRequestEntityTooLarge = 200413 + // FrontendStatusTooManyRequests - + FrontendStatusTooManyRequests = 200429 + + // FrontendStatusInternalError - + FrontendStatusInternalError = 200500 + // HTTPStreamNOTEnableError - + HTTPStreamNOTEnableError = 200600 + // CreateStreamProducerError - + CreateStreamProducerError = 200601 + // QueryStreamCustomerError - + QueryStreamCustomerError = 200602 + // SendDataToStreamError - + SendDataToStreamError = 200603 + // WriteResponseError - + WriteResponseError = 200604 + + // DsUploadFailed - upload to data system failed + DsUploadFailed = 200701 + // DsDownloadFailed - download from data system failed + DsDownloadFailed = 200702 + // DsDeleteFailed - delete from data system failed + DsDeleteFailed = 200703 + // DsKeyNotFound - key not found on data system + DsKeyNotFound = 200704 + + // UserFunctionInvokeError - user function error + UserFunctionInvokeError = 200705 + + // FuncMetaNotFound function meta not found, this error occurs only when the internal service is abnormal. + FuncMetaNotFound = 150424 + // HeavyLoadCode indicate the server's memory usage reaches threshold + HeavyLoadCode = 214503 +) + +// User error code +const ( + // UserFuncEntryNotFoundErrCode - + UserFuncEntryNotFoundErrCode = 4001 + // UserFuncRunningExceptionErrCode - + UserFuncRunningExceptionErrCode = 4002 + // StateContentTooLargeErrCode state content is too large + StateContentTooLargeErrCode = 4003 + // UserFuncRspExceedLimitErrCode response of user function exceeds the platform limit + UserFuncRspExceedLimitErrCode = 4004 + // UndefinedStateErrCode state is undefined + UndefinedStateErrCode = 4005 + // HeartBeatFunctionInvalidErrCode heart beat function of user invalid + HeartBeatFunctionInvalidErrCode = 4006 + // FunctionResultInvalidErrCode user function result is invalid + FunctionResultInvalidErrCode = 4007 + // InitializeFunctionErrorErrCode user initialize function error + InitializeFunctionErrorErrCode = 4009 + // UserFuncInvokeTimeout - + UserFuncInvokeTimeout = 4010 + // FrontendStatusWorkerIoTimeout - + FrontendStatusWorkerIoTimeout = 4014 + // FrontendStatusTrafficLimitEffective is the error code for traffic limitation + FrontendStatusTrafficLimitEffective = 4021 + // FrontendStatusLabelUnavailable - + FrontendStatusLabelUnavailable = 4022 + // FrontendStatusFuncMetaNotFound is error code of function meta not found + FrontendStatusFuncMetaNotFound = 4024 + // FrontendStatusUnableSpecifyResource unable to specify resource in a scene where no resource specified + FrontendStatusUnableSpecifyResource = 4026 + // FrontendStatusMaxRequestBodySize - + FrontendStatusMaxRequestBodySize = 4140 + // UserFuncInitFailCode code of user function initialization failed + UserFuncInitFailCode = 4201 + // ErrSharedMemoryLimited - + ErrSharedMemoryLimited = 4202 + // ErrOperateDiskFailed - + ErrOperateDiskFailed = 4203 + // ErrInsufficientDiskSpace - + ErrInsufficientDiskSpace = 4204 + + // UserFuncInitTimeoutCode code of initialing runtime timed out + UserFuncInitTimeoutCode = 4211 + // StsConfigErrCode sts config set error code + StsConfigErrCode = 4036 + // InstanceSessionInvalidErrCode - + InstanceSessionInvalidErrCode = 4037 + // ErrFinalized - + ErrFinalized = 9000 + // ErrAllSchedulerUnavailable - + ErrAllSchedulerUnavailable = 9009 + // InnerUserErrBase - + InnerUserErrBase = 50_0000 + // InnerRuntimeInitTimeoutCode - + InnerRuntimeInitTimeoutCode = InnerUserErrBase + UserFuncInitTimeoutCode +) + +// proxy internal error codes which suggests to retry in cluster +const ( + // ClientExitErrCode function instance is exiting (proxy side) + ClientExitErrCode = 211503 + + // WorkerExitErrCode function instance is exiting (worker side) + WorkerExitErrCode = 211504 + + // UserFuncIsUpdatedCode - + UserFuncIsUpdatedCode = 211411 + // SendReqErrCode call request sending error + SendReqErrCode = 211406 +) + +// executor error code +const ( + // ExecutorErrCodeInitFail - + ExecutorErrCodeInitFail = 6001 +) + +// The kernel and faaspattern should maintain an appropriate set of error codes. +// Common, such as a unified understanding of whether retry is required. +// In addition, the current transmission involves various character string conversions, +// which increases transcoding and matching barriers and causes high overheads. +// These are important, otherwise it will cause a lot of unclear boundaries and rework :) +const ( + // ErrInstanceNotFound - + ErrInstanceNotFound = 1003 + // ErrInstanceExitedCode - + ErrInstanceExitedCode = 1007 + // ErrInstanceCircuitCode - + ErrInstanceCircuitCode = 1009 + // ErrInstanceEvicted - + ErrInstanceEvicted = 1013 + + // ErrRequestBetweenRuntimeBusCode - + ErrRequestBetweenRuntimeBusCode = 3001 + // ErrInnerCommunication - + ErrInnerCommunication = 3002 + // ErrRequestBetweenRuntimeFrontendCode - + ErrRequestBetweenRuntimeFrontendCode = 3008 + // ErrAcquireTimeoutCode - + ErrAcquireTimeoutCode = 3009 +) + +// errors comes from faas scheduler (FG worker manager error) +const ( + // StatusInternalServerError status internal server error + StatusInternalServerError = 150500 + // VIPClusterOverloadCode cluster has no available resource + VIPClusterOverloadCode = 150510 + // FuncMetaNotFoundErrCode function meta not found, this error occurs only when the internal service is abnormal. + FuncMetaNotFoundErrCode = 150424 + // FuncMetaNotFoundErrMsg is error message of function metadata not found + FuncMetaNotFoundErrMsg = "function metadata not found" + // InstanceNotFoundErrCode is error code of instance not found + InstanceNotFoundErrCode = 150425 + // InstanceNotFoundErrMsg is error message of instance not found + InstanceNotFoundErrMsg = "instance not exist" + // NoInstanceAvailableErrCode is error message of no available instance + NoInstanceAvailableErrCode = 150431 + // InstanceStatusAbnormalCode - + InstanceStatusAbnormalCode = 150427 + // InstanceStatusAbnormalMsg - + InstanceStatusAbnormalMsg = "instance status is abnormal" + // ReachMaxInstancesCode reach function max instances + ReachMaxInstancesCode = 150429 + // ReachMaxInstancesErrMsg is error message of reach max instance + ReachMaxInstancesErrMsg = "reach max instance num" + // InsThdReqTimeoutCode acquire instance lease timeout, FG: cluster is overload and unavailable now + InsThdReqTimeoutCode = 150430 + // InsThdReqTimeoutErrMsg acquire instance lease timeout + InsThdReqTimeoutErrMsg = "instance thread request timeout" + // ReachMaxInstancesPerTenantErrCode reach tenant max on-demand instances + ReachMaxInstancesPerTenantErrCode = 150432 + // GettingPodErrorCode getting pod error code + GettingPodErrorCode = 150431 + // ReachMaxOnDemandInstancesPerTenant reach tenant max on-demand instances + ReachMaxOnDemandInstancesPerTenant = 150432 + // ReachMaxInstancesPerTenantErrMsg reach tenant max on-demand instances + ReachMaxInstancesPerTenantErrMsg = "reach max instance number per tenant" + // ReachMaxReversedInstancesPerTenant reach tenant max reversed instances + ReachMaxReversedInstancesPerTenant = 150433 + // FunctionIsDisabled function is disabled + FunctionIsDisabled = 150434 + // RefreshSilentFunc waiting for silent function to refresh, retry required + RefreshSilentFunc = 150435 + // NotEnoughNIC marked that there were not enough network cards + NotEnoughNIC = 150436 + // InsufficientEphemeralStorage marked that ephemeral storage is insufficient + InsufficientEphemeralStorage = 150438 + // ClusterIsUpgrading - + ClusterIsUpgrading = 150439 + // DesignateInsNotAvailableErrCode - + DesignateInsNotAvailableErrCode = 150440 + // InstanceLabelNotFoundErrCode - + InstanceLabelNotFoundErrCode = 150444 + // InstanceLabelNotFoundErrMsg - + InstanceLabelNotFoundErrMsg = "instance label not found" + // CancelGeneralizePod user update function metadata to cancel generalize pod while generalizing is not finished + CancelGeneralizePod = 150439 + + // ScaleUpRequestErrCode failed to send scale up request to worker-manager + ScaleUpRequestErrCode = 214501 + // ScaleUpRequestErrMsg - + ScaleUpRequestErrMsg = "send scale up request to worker-manager error" + + // SpecificInstanceNotFound - + SpecificInstanceNotFound = 150460 + // InstanceExceedConcurrency - + InstanceExceedConcurrency = 150461 + + LeaseIDIllegalCode = 150462 + LeaseIDIllegalMsg = "lease id is illegal" + LeaseIDNotFoundCode = 150463 + LeaseIDNotFoundMsg = "lease id is not found" +) + +var ( + // ErrMap frontend code map to http code + // Only return 200 to the management interface if the execution is successful + ErrMap = map[int]int{ + // system error + InnerResponseSuccessCode: http.StatusOK, + InternalErrorCode: http.StatusInternalServerError, + // frontend error + FrontendStatusOk: http.StatusOK, + FrontendStatusAccepted: http.StatusAccepted, + FrontendStatusNoContent: http.StatusNoContent, + FrontendStatusBadRequest: http.StatusBadRequest, + FrontendStatusUnAuthorized: http.StatusUnauthorized, + FrontendStatusForbidden: http.StatusForbidden, + FrontendStatusNotFound: http.StatusNotFound, + FrontendStatusRequestEntityTooLarge: http.StatusRequestEntityTooLarge, + FrontendStatusTooManyRequests: http.StatusTooManyRequests, + FrontendStatusInternalError: http.StatusInternalServerError, + FuncMetaNotFound: http.StatusInternalServerError, + HeavyLoadCode: http.StatusInternalServerError, + FrontendStatusTrafficLimitEffective: http.StatusInternalServerError, + HTTPStreamNOTEnableError: http.StatusInternalServerError, + CreateStreamProducerError: http.StatusInternalServerError, + QueryStreamCustomerError: http.StatusInternalServerError, + SendDataToStreamError: http.StatusInternalServerError, + WriteResponseError: http.StatusInternalServerError, + // frontend caas / multidata error + // 500 + DsUploadFailed: http.StatusInternalServerError, + DsDownloadFailed: http.StatusInternalServerError, + DsDeleteFailed: http.StatusInternalServerError, + DsKeyNotFound: http.StatusInternalServerError, + UserFunctionInvokeError: http.StatusInternalServerError, + // user error + UserFuncEntryNotFoundErrCode: http.StatusInternalServerError, + UserFuncRunningExceptionErrCode: http.StatusInternalServerError, + UserFuncRspExceedLimitErrCode: http.StatusInternalServerError, + FrontendStatusMaxRequestBodySize: http.StatusInternalServerError, + FrontendStatusUnableSpecifyResource: http.StatusInternalServerError, + UserFuncInvokeTimeout: http.StatusInternalServerError, + UserFuncInitFailCode: http.StatusInternalServerError, + UserFuncInitTimeoutCode: http.StatusInternalServerError, + StsConfigErrCode: http.StatusInternalServerError, + // executor error + ExecutorErrCodeInitFail: http.StatusInternalServerError, + } +) + +const ( + // VpcNoOperationalPermissions vpc has no operational permissions + VpcNoOperationalPermissions = 4212 + // VPCNotFound error code of VPC not found + VPCNotFound = 4219 + // VPCXRoleNotFound vcp xrole not func + VPCXRoleNotFound = 4222 +) + +// vpc err comes from vpc controller +var ( + // ErrNoOperationalPermissionsVpc no operational permissions vpc + ErrNoOperationalPermissionsVpc = errors.New("no operational permissions vpc, check the func xrole permissions") + // ErrNoAvailableVpcPatInstance no available vpc pat instance + ErrNoAvailableVpcPatInstance = errors.New("no available vpc pat instance") + // ErrVPCNotFound VPC item not found error + ErrVPCNotFound = errors.New("vpc item not found") + // ErrVPCXRoleNotFound VPC xrole not found error + ErrVPCXRoleNotFound = errors.New("can't find xrole") + + vpcErrorMap = map[string]int{ + ErrNoOperationalPermissionsVpc.Error(): VpcNoOperationalPermissions, + ErrNoAvailableVpcPatInstance.Error(): NotEnoughNIC, + ErrVPCNotFound.Error(): VPCNotFound, + ErrVPCXRoleNotFound.Error(): VPCXRoleNotFound, + } + + vpcErrorCodeMsg = map[int]string{ + VpcNoOperationalPermissions: "no operational permissions vpc, check the func xrole permissions", + NotEnoughNIC: "not enough network cards", + VPCNotFound: "VPC item not found", + VPCXRoleNotFound: "VPC can't find xrole", + } +) + +const ( + // InvalidState - + InvalidState = 4040 + // InvalidStateErrMsg - + InvalidStateErrMsg = "invalid state, expect not blank" + // StateMismatch - + StateMismatch = 4006 + // StateMismatchErrMsg - + StateMismatchErrMsg = "invoke state id and function stateful flag are not matched" + // StateExistedErrCode - + StateExistedErrCode = 4027 + // StateExistedErrMsg - + StateExistedErrMsg = "state cannot be created repeatedly" + // StateNotExistedErrCode - + StateNotExistedErrCode = 4026 + // StateNotExistedErrMsg - + StateNotExistedErrMsg = "state not existed" + // StateInstanceNotExistedErrCode - + StateInstanceNotExistedErrCode = 4028 + // StateInstanceNotExistedErrMsg - + StateInstanceNotExistedErrMsg = "state instance not existed" + // StateInstanceNoLease - + StateInstanceNoLease = 4025 + // StateInstanceNoLeaseMsg - + StateInstanceNoLeaseMsg = "maximum number of leases reached" + // FaaSSchedulerInternalErrCode - + FaaSSchedulerInternalErrCode = 4029 + // FaaSSchedulerInternalErrMsg - + FaaSSchedulerInternalErrMsg = "internal system error" +) + +// worker error code +const ( + // WorkerInternalErrorCode code of unexpected error in worker + WorkerInternalErrorCode = 161900 + // ReadingCodeTimeoutCode reading code package timed out + ReadingCodeTimeoutCode = 161901 + // CallFunctionErrorCode code of calling other function error + CallFunctionErrorCode = 161902 + // FuncInsExceptionCode function instance exception + FuncInsExceptionCode = 161903 + // CheckSumErrorCode code of check sum error + CheckSumErrorCode = 161904 + // DownLoadCodeErrorCode code of download code error + DownLoadCodeErrorCode = 161905 + // RPCClientEmptyErrorCode code of when rpc client is nil + RPCClientEmptyErrorCode = 161906 + // RuntimeManagerProcessExited runtime-manager process exited code + RuntimeManagerProcessExited = 161907 + // WorkerPingVpcGatewayError code of worker ping vpc gateway error + WorkerPingVpcGatewayError = 161908 + // UploadSnapshotErrorCode code of worker upload snapshot error + UploadSnapshotErrorCode = 161909 + // RestoreDeadErrorCode code of restore is dead + RestoreDeadErrorCode = 161910 + // ContentInconsistentErrorCode code of worker content inconsistent error + ContentInconsistentErrorCode = 161911 + // CreateLimitErrorCode code of POSIX create limit error + CreateLimitErrorCode = 161912 + // KernelEtcdWriteFailedCode code of core write etcd failed or circuit + KernelEtcdWriteFailedCode = 161913 + // KernelResourceNotEnoughErrCode code of core resource not enough or schedule failure + KernelResourceNotEnoughErrCode = 161914 + // WiseCloudNuwaColdStartErrCode code of use nuwa cold start failed + WiseCloudNuwaColdStartErrCode = 161915 +) + +// Code trans frontend code to http code +func Code(frontendCode int) int { + httpCode, exist := ErrMap[frontendCode] + if !exist { + return http.StatusInternalServerError + } + return httpCode +} + +// Message trans frontend code to message +func Message(frontendCode int) string { + httpCode, exist := ErrMap[frontendCode] + if !exist { + return "" + } + + return fasthttp.StatusMessage(int(httpCode)) +} + +// VpcCode vpc controller err map to vpc err code +func VpcCode(errMsg string) int { + if errCode, ok := vpcErrorMap[errMsg]; ok { + return errCode + } + return 0 +} + +// VpcErMsg vpc err code map to err msg +func VpcErMsg(errCode int) string { + if errMsg, ok := vpcErrorCodeMsg[errCode]; ok { + return errMsg + } + return "" +} + +var ( + errCodeRegCompile = regexp.MustCompile("code:[ 0-9]+,") + errMsgRegCompile = regexp.MustCompile("message:.+") + codeRegCompile = regexp.MustCompile("[0-9]+") +) + +// GetKernelErrorCode will get kernel error code from error message +func GetKernelErrorCode(errMsg string) int { + res := errCodeRegCompile.FindStringSubmatch(errMsg) + if len(res) < 1 { + return InternalErrorCode + } + res = codeRegCompile.FindStringSubmatch(errMsg) + if len(res) != 1 { + return InternalErrorCode + } + code, err := strconv.Atoi(res[0]) + if err != nil { + return InternalErrorCode + } + return code +} + +// GetKernelErrorMessage will get kernel error message from error message +func GetKernelErrorMessage(errMsg string) string { + res := errMsgRegCompile.FindStringSubmatch(errMsg) + if len(res) < 1 { + return "" + } + trimRes := strings.TrimPrefix(res[0], "message: ") + return trimRes +} diff --git a/yuanrong/pkg/common/faas_common/statuscode/statuscode_test.go b/yuanrong/pkg/common/faas_common/statuscode/statuscode_test.go new file mode 100644 index 0000000..64e5fcb --- /dev/null +++ b/yuanrong/pkg/common/faas_common/statuscode/statuscode_test.go @@ -0,0 +1,86 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package statuscode define status code of Frontend +package statuscode + +import ( + "net/http" + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestStatusCode(t *testing.T) { + convey.Convey("get code", t, func() { + code := Code(InnerResponseSuccessCode) + convey.So(code, convey.ShouldEqual, http.StatusOK) + }) + convey.Convey("get message", t, func() { + msg := Message(InnerResponseSuccessCode) + convey.So(msg, convey.ShouldEqual, "OK") + }) + convey.Convey("error code get message", t, func() { + msg := Message(999999) + convey.So(msg, convey.ShouldEqual, "") + }) + convey.Convey("error code get message", t, func() { + code := Code(999999) + convey.So(code, convey.ShouldEqual, http.StatusInternalServerError) + }) +} + +func TestGetKernelErrorCode(t *testing.T) { + type args struct { + errMsg string + } + tests := []struct { + name string + args args + want int + }{ + {"case1 unknow error", args{errMsg: "unknown error"}, InternalErrorCode}, + {"case2 get code", args{errMsg: "code: 1007,"}, 1007}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetKernelErrorCode(tt.args.errMsg); got != tt.want { + t.Errorf("GetKernelErrorCode() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetKernelErrorMessage(t *testing.T) { + type args struct { + errMsg string + } + tests := []struct { + name string + args args + want string + }{ + {"case1 unknow message", args{errMsg: "unknown message"}, ""}, + {"case2 get message", args{errMsg: "message: yes"}, "yes"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetKernelErrorMessage(tt.args.errMsg); got != tt.want { + t.Errorf("GetKernelErrorMessage() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/sts/cert/cert.go b/yuanrong/pkg/common/faas_common/sts/cert/cert.go new file mode 100644 index 0000000..003c1a2 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/sts/cert/cert.go @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package cert parsing certificate +package cert + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + + "golang.org/x/crypto/pkcs12" + "huawei.com/wisesecurity/sts-sdk/pkg/cryptosts" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong/pkg/common/faas_common/utils" +) + +// LoadCerts - parsing certificate +func LoadCerts() (*x509.CertPool, *tls.Certificate, error) { + keyStorePath, err := cryptosts.GetKeyStorePath() + if err != nil { + return nil, nil, err + } + caCertsPool := x509.NewCertPool() + bytes, err := stsgoapi.GetPassphrase() + if err != nil { + return nil, nil, err + } + fileContent, err := os.ReadFile(keyStorePath) + if err != nil { + return nil, nil, err + } + pemBlocks, err := pkcs12.ToPEM(fileContent, string(bytes)) + utils.ClearByteMemory(fileContent) + if err != nil { + return nil, nil, err + } + + caBytes, certByte, keyByte, err := parseSTSCerts(pemBlocks) + if err != nil { + return nil, nil, err + + } + for _, caByte := range caBytes { + caCertsPool.AppendCertsFromPEM(caByte) + } + tlsCert, err := tls.X509KeyPair(certByte, keyByte) + utils.ClearByteMemory(certByte) + utils.ClearByteMemory(keyByte) + if err != nil { + return nil, nil, err + + } + return caCertsPool, &tlsCert, nil +} + +func parseSTSCerts(pemBlocks []*pem.Block) ([][]byte, []byte, []byte, error) { + var certByte, keyByte []byte + var err error + var caBytes [][]byte + for _, pemBlock := range pemBlocks { + pemEncoded := pem.EncodeToMemory(pemBlock) + if pemBlock.Type == "PRIVATE KEY" { + keyByte = pemEncoded + } else { + var cert *x509.Certificate + if cert, err = x509.ParseCertificate(pemBlock.Bytes); err != nil { + return nil, nil, nil, err + } + if cert == nil { + return nil, nil, nil, fmt.Errorf("parse certificate err: cert is empty") + } + if cert.IsCA { + pemBlock.Headers = map[string]string{} + caBytes = append(caBytes, pem.EncodeToMemory(pemBlock)) + } else { + certByte = append(certByte, pemEncoded...) + } + } + } + if len(caBytes) == 0 { + return caBytes, certByte, keyByte, fmt.Errorf("ca certs not exists") + } + if len(certByte) == 0 { + return caBytes, certByte, keyByte, fmt.Errorf("certs not exists") + } + if len(keyByte) == 0 { + return caBytes, certByte, keyByte, fmt.Errorf("private key not exists") + } + return caBytes, certByte, keyByte, nil +} diff --git a/yuanrong/pkg/common/faas_common/sts/cert/cert_test.go b/yuanrong/pkg/common/faas_common/sts/cert/cert_test.go new file mode 100644 index 0000000..6167ba5 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/sts/cert/cert_test.go @@ -0,0 +1,238 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "os" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "golang.org/x/crypto/pkcs12" + "huawei.com/wisesecurity/sts-sdk/pkg/cryptosts" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + mockUtils "yuanrong/pkg/common/faas_common/utils" +) + +func TestLoadCerts(t *testing.T) { + tests := []struct { + name string + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 succeed to load certificates", false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cryptosts.GetKeyStorePath, func() (string, error) { + return "", nil + }), + gomonkey.ApplyFunc(x509.NewCertPool, func() *x509.CertPool { + return &x509.CertPool{} + }), + gomonkey.ApplyFunc(stsgoapi.GetPassphrase, func() (passphrase []byte, err error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(os.ReadFile, func(name string) ([]byte, error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(pkcs12.ToPEM, func(pfxData []byte, password string) ([]*pem.Block, error) { + return []*pem.Block{}, nil + }), + gomonkey.ApplyFunc(parseSTSCerts, func(pemBlocks []*pem.Block) ([][]byte, []byte, []byte, error) { + return [][]byte{[]byte("1")}, []byte("a"), []byte("b"), nil + }), + gomonkey.ApplyFunc(tls.X509KeyPair, func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { + return tls.Certificate{}, nil + }), + gomonkey.ApplyMethod(reflect.TypeOf(&x509.CertPool{}), "AppendCertsFromPEM", + func(_ *x509.CertPool, pemCerts []byte) (ok bool) { + return true + })}) + return patches + }}, + {"case2 failed to load certificates", true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cryptosts.GetKeyStorePath, func() (string, error) { + return "", nil + }), + gomonkey.ApplyFunc(x509.NewCertPool, func() *x509.CertPool { + return &x509.CertPool{} + }), + gomonkey.ApplyFunc(stsgoapi.GetPassphrase, func() (passphrase []byte, err error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(os.ReadFile, func(name string) ([]byte, error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(pkcs12.ToPEM, func(pfxData []byte, password string) ([]*pem.Block, error) { + return []*pem.Block{}, nil + }), + gomonkey.ApplyFunc(parseSTSCerts, func(pemBlocks []*pem.Block) ([][]byte, []byte, []byte, error) { + return [][]byte{[]byte("1")}, []byte("a"), []byte("b"), nil + }), + gomonkey.ApplyFunc(tls.X509KeyPair, func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { + return tls.Certificate{}, errors.New("error") + }), + gomonkey.ApplyMethod(reflect.TypeOf(&x509.CertPool{}), "AppendCertsFromPEM", + func(_ *x509.CertPool, pemCerts []byte) (ok bool) { + return true + })}) + return patches + }}, + {"case3 failed to load certificates", true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cryptosts.GetKeyStorePath, func() (string, error) { + return "", nil + }), + gomonkey.ApplyFunc(x509.NewCertPool, func() *x509.CertPool { + return &x509.CertPool{} + }), + gomonkey.ApplyFunc(stsgoapi.GetPassphrase, func() (passphrase []byte, err error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(os.ReadFile, func(name string) ([]byte, error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(pkcs12.ToPEM, func(pfxData []byte, password string) ([]*pem.Block, error) { + return []*pem.Block{}, nil + }), + gomonkey.ApplyFunc(parseSTSCerts, func(pemBlocks []*pem.Block) ([][]byte, []byte, []byte, error) { + return [][]byte{[]byte("1")}, []byte("a"), []byte("b"), errors.New("error") + })}) + return patches + }}, + {"case4 failed to load certificates", true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cryptosts.GetKeyStorePath, func() (string, error) { + return "", nil + }), + gomonkey.ApplyFunc(x509.NewCertPool, func() *x509.CertPool { + return &x509.CertPool{} + }), + gomonkey.ApplyFunc(stsgoapi.GetPassphrase, func() (passphrase []byte, err error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(os.ReadFile, func(name string) ([]byte, error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(pkcs12.ToPEM, func(pfxData []byte, password string) ([]*pem.Block, error) { + return []*pem.Block{}, errors.New("error") + })}) + return patches + }}, + {"case5 failed to load certificates", true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cryptosts.GetKeyStorePath, func() (string, error) { + return "", nil + }), + gomonkey.ApplyFunc(x509.NewCertPool, func() *x509.CertPool { + return &x509.CertPool{} + }), + gomonkey.ApplyFunc(stsgoapi.GetPassphrase, func() (passphrase []byte, err error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(os.ReadFile, func(name string) ([]byte, error) { + return []byte{}, errors.New("error") + })}) + return patches + }}, + {"case6 failed to load certificates", true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cryptosts.GetKeyStorePath, func() (string, error) { + return "", nil + }), + gomonkey.ApplyFunc(x509.NewCertPool, func() *x509.CertPool { + return &x509.CertPool{} + }), + gomonkey.ApplyFunc(stsgoapi.GetPassphrase, func() (passphrase []byte, err error) { + return []byte{}, errors.New("error") + })}) + return patches + }}, + {"case7 failed to load certificates", true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(cryptosts.GetKeyStorePath, func() (string, error) { + return "", errors.New("error") + })}) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + _, _, err := LoadCerts() + if (err != nil) != tt.wantErr { + t.Errorf("LoadCerts() error = %v, wantErr %v", err, tt.wantErr) + return + } + patches.ResetAll() + }) + } +} + +func Test_parseSTSCerts(t *testing.T) { + type args struct { + pemBlocks []*pem.Block + } + tests := []struct { + name string + args args + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 succeed to parse", args{pemBlocks: []*pem.Block{ + &pem.Block{Type: "PRIVATE KEY"}, &pem.Block{}, &pem.Block{Bytes: []byte("a")}}}, + false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(pem.EncodeToMemory, func(b *pem.Block) []byte { + return []byte("a") + }), + gomonkey.ApplyFunc(x509.ParseCertificate, func(der []byte) (*x509.Certificate, error) { + if string(der) == "a" { + return &x509.Certificate{}, nil + } + return &x509.Certificate{IsCA: true}, nil + }), + }) + return patches + }}, + {"case2 failed to parse", args{pemBlocks: []*pem.Block{ + &pem.Block{Type: "PRIVATE KEY"}}}, + true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(pem.EncodeToMemory, func(b *pem.Block) []byte { + return []byte("a") + }), + gomonkey.ApplyFunc(x509.ParseCertificate, func(der []byte) (*x509.Certificate, error) { + if string(der) == "a" { + return &x509.Certificate{}, nil + } + return &x509.Certificate{IsCA: true}, nil + }), + }) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + _, _, _, err := parseSTSCerts(tt.args.pemBlocks) + if (err != nil) != tt.wantErr { + t.Errorf("parseSTSCerts() error = %v, wantErr %v", err, tt.wantErr) + return + } + patches.ResetAll() + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/sts/common.go b/yuanrong/pkg/common/faas_common/sts/common.go new file mode 100644 index 0000000..fe07259 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/sts/common.go @@ -0,0 +1,214 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package sts - +package sts + +import ( + "encoding/json" + "fmt" + + "huawei.com/wisesecurity/sts-sdk/pkg/cloudsoa" + "k8s.io/api/core/v1" + + "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/common/faas_common/utils" +) + +// SecretConfig - +type SecretConfig struct{} + +const ( + // FaasfrontendName - + FaasfrontendName = "faasfrontend" + // FaaSSchedulerName - + FaaSSchedulerName = "faasscheduler" + mountPath = "/opt/certs/HMSClientCloudAccelerateService/HMSCaaSYuanRongWorker/" + faasSchedulerMountPath = "/opt/certs/HMSClientCloudAccelerateService/HMSCaaSYuanRongWorkerManager/" + // HTTPSMountPath mount https certs + HTTPSMountPath = "/home/sn/resource/https" + // LocalSecretMountPath mount local secrets + LocalSecretMountPath = "/home/sn/resource/cipher" +) + +var readOnlyVolumeMode int32 = 0440 + +// ConfigVolume - +func (u *SecretConfig) ConfigVolume(b *utils.VolumeBuilder) { + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-config", + MountPath: mountPath + "HMSCaaSYuanRongWorker/apple/a", + SubPath: "a", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-config", + MountPath: mountPath + "HMSCaaSYuanRongWorker/boy/b", + SubPath: "b", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-config", + MountPath: mountPath + "HMSCaaSYuanRongWorker/cat/c", + SubPath: "c", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-config", + MountPath: mountPath + "HMSCaaSYuanRongWorker/dog/d", + SubPath: "d", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-config", + MountPath: mountPath + "HMSCaaSYuanRongWorker.ini", + SubPath: "HMSCaaSYuanRongWorker.ini", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-config", + MountPath: mountPath + "HMSCaaSYuanRongWorker.sts.p12", + SubPath: "HMSCaaSYuanRongWorker.sts.p12", + }) +} + +// ConfigFaasSchedulerVolume - +func (u *SecretConfig) ConfigFaasSchedulerVolume(b *utils.VolumeBuilder) { + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-workermanager-config", + MountPath: faasSchedulerMountPath + "HMSCaaSYuanRongWorkerManager/apple/a", + SubPath: "a", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-workermanager-config", + MountPath: faasSchedulerMountPath + "HMSCaaSYuanRongWorkerManager/boy/b", + SubPath: "b", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-workermanager-config", + MountPath: faasSchedulerMountPath + "HMSCaaSYuanRongWorkerManager/cat/c", + SubPath: "c", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-workermanager-config", + MountPath: faasSchedulerMountPath + "HMSCaaSYuanRongWorkerManager/dog/d", + SubPath: "d", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-workermanager-config", + MountPath: faasSchedulerMountPath + "HMSCaaSYuanRongWorkerManager.ini", + SubPath: "HMSCaaSYuanRongWorkerManager.ini", + }) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "sts-workermanager-config", + MountPath: faasSchedulerMountPath + "HMSCaaSYuanRongWorkerManager.sts.p12", + SubPath: "HMSCaaSYuanRongWorkerManager.sts.p12", + }) +} + +// ConfigHTTPSAndLocalSecretVolume - +func (u *SecretConfig) ConfigHTTPSAndLocalSecretVolume(b *utils.VolumeBuilder, httpsConfig tls.InternalHTTPSConfig) { + b.AddVolume(buildVolumeOfSecretSource("https", httpsConfig.SecretName)) + b.AddVolumeMount(utils.ContainerRuntimeManager, v1.VolumeMount{ + Name: "https", + MountPath: httpsConfig.SSLBasePath, + }) +} + +func buildVolumeOfSecretSource(name string, secretName string) v1.Volume { + return v1.Volume{ + Name: name, + VolumeSource: v1.VolumeSource{ + Secret: &v1.SecretVolumeSource{ + DefaultMode: &readOnlyVolumeMode, + SecretName: secretName, + }, + }, + } +} + +// GenerateSecretVolumeMounts - +func GenerateSecretVolumeMounts(systemFunctionName string, builder *utils.VolumeBuilder) ([]byte, error) { + if builder == nil { + return nil, fmt.Errorf("sts volume builder is nil") + } + sc := &SecretConfig{} + if systemFunctionName == FaaSSchedulerName { + sc.ConfigFaasSchedulerVolume(builder) + } else { + sc.ConfigVolume(builder) + } + bytesData, err := json.Marshal(builder.Mounts[utils.ContainerRuntimeManager]) + if err != nil { + return nil, err + } + return bytesData, nil +} + +// CustomKeyProvider - +type CustomKeyProvider struct { + key []byte + tenantID string +} + +// NewCustomKeyProvider - +func NewCustomKeyProvider(tenantID string, key []byte) *CustomKeyProvider { + return &CustomKeyProvider{tenantID: tenantID, key: key} +} + +// GetKey - +func (c *CustomKeyProvider) GetKey(keyLabel cloudsoa.KeyLabel) ([]byte, int64, error) { + return c.key, 0, nil +} + +// GetHmacKey - +func (c *CustomKeyProvider) GetHmacKey(keyLabel cloudsoa.KeyLabel) cloudsoa.HmacKeyEntry { + return nil +} + +// GetKeyPair - +func (c *CustomKeyProvider) GetKeyPair(keyLabel cloudsoa.KeyLabel) cloudsoa.KeyPairEntry { + return nil +} + +// Load - +func (c *CustomKeyProvider) Load() { +} + +// GetName - +func (c *CustomKeyProvider) GetName() string { + return c.tenantID +} + +// GetKeyWithVersion - +func (c *CustomKeyProvider) GetKeyWithVersion(keyLabel cloudsoa.KeyLabel, version int64) ([]byte, error) { + return c.key, nil +} + +// GenerateHTTPSAndLocalSecretVolumeMounts - +func GenerateHTTPSAndLocalSecretVolumeMounts( + httpsConfig tls.InternalHTTPSConfig, builder *utils.VolumeBuilder) (string, string, error) { + if builder == nil { + return "", "", fmt.Errorf("https volume builder is nil") + } + sc := &SecretConfig{} + sc.ConfigHTTPSAndLocalSecretVolume(builder, httpsConfig) + + volumesData, err := json.Marshal(builder.Volumes) + if err != nil { + return "", "", err + } + volumesMountData, err := json.Marshal(builder.Mounts[utils.ContainerRuntimeManager]) + if err != nil { + return "", "", err + } + return string(volumesData), string(volumesMountData), nil +} diff --git a/yuanrong/pkg/common/faas_common/sts/common_test.go b/yuanrong/pkg/common/faas_common/sts/common_test.go new file mode 100644 index 0000000..8be8a21 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/sts/common_test.go @@ -0,0 +1,118 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package sts - +package sts + +import ( + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/magiconair/properties" + "github.com/smartystreets/goconvey/convey" + "huawei.com/wisesecurity/sts-sdk/pkg/cloudsoa" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong/pkg/common/faas_common/sts/raw" + "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/common/faas_common/utils" + mockUtils "yuanrong/pkg/common/faas_common/utils" +) + +func TestGenerateSecretVolumeMounts(t *testing.T) { + type args struct { + systemFunctionName string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"case1 faasshceduler generate", args{systemFunctionName: FaaSSchedulerName}, false}, + {"case2 faasfrontend generate", args{systemFunctionName: FaasfrontendName}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + build := utils.NewVolumeBuilder() + _, err := GenerateSecretVolumeMounts(tt.args.systemFunctionName, build) + if (err != nil) != tt.wantErr { + t.Errorf("GenerateSecretVolumeMounts() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestInitStsSDK(t *testing.T) { + type args struct { + serverCfg raw.ServerConfig + } + tests := []struct { + name string + args args + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1", args{serverCfg: raw.ServerConfig{}}, false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(stsgoapi.InitWith, func(property properties.Properties) error { return nil })}) + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc((*cloudsoa.AESCryptorBuilder).Builder, func( + _ *cloudsoa.AESCryptorBuilder) (*cloudsoa.AESCryptor, error) { + return nil, nil + })}) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + if err := InitStsSDK(tt.args.serverCfg); (err != nil) != tt.wantErr { + t.Errorf("InitStsSDK() error = %v, wantErr %v", err, tt.wantErr) + } + patches.ResetAll() + }) + } +} + +func TestCustomKeyProvider(t *testing.T) { + convey.Convey("test custom key provider", t, func() { + provider := NewCustomKeyProvider("aaa", []byte("bbb")) + convey.So(provider, convey.ShouldNotBeNil) + key, i, err := provider.GetKey(cloudsoa.KeyLabel{}) + convey.So(string(key), convey.ShouldEqual, "bbb") + convey.So(i, convey.ShouldEqual, 0) + convey.So(err, convey.ShouldBeNil) + key, err = provider.GetKeyWithVersion(cloudsoa.KeyLabel{}, 0) + convey.So(string(key), convey.ShouldEqual, "bbb") + convey.So(err, convey.ShouldBeNil) + convey.So(provider.GetName(), convey.ShouldEqual, "aaa") + }) +} + +func TestGenerateHTTPSAndLocalSecretVolumeMounts(t *testing.T) { + convey.Convey("TestGenerateHTTPSAndLocalSecretVolumeMounts", t, func() { + httpsConfig := tls.InternalHTTPSConfig{} + volumeData, volumeMountData, err := GenerateHTTPSAndLocalSecretVolumeMounts(httpsConfig, nil) + convey.So(volumeData, convey.ShouldEqual, "") + convey.So(volumeMountData, convey.ShouldEqual, "") + convey.So(err, convey.ShouldNotBeNil) + + volumeData, volumeMountData, err = GenerateHTTPSAndLocalSecretVolumeMounts(httpsConfig, utils.NewVolumeBuilder()) + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/common/faas_common/sts/raw/crypto.go b/yuanrong/pkg/common/faas_common/sts/raw/crypto.go new file mode 100644 index 0000000..404fd41 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/sts/raw/crypto.go @@ -0,0 +1,84 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package raw use work key to encrypt and decrypt +package raw + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" +) + +const ( + defaultSaltSize = 12 +) + +// AesGCMDecrypt decrypt a cypher text using AES_GCM algorithm +func AesGCMDecrypt(secret, salt, cipherBytes []byte) ([]byte, error) { + defer postRecover() + block, err := aes.NewCipher(secret) + if err != nil { + return nil, err + } + // salt 长度和 nonceSize 保持一致 + // cipher.NewGCM(block) 使用的是默认12字节的nonceSize,也代表盐值长度必须是12Byte;为了适应性强,我们使用自定义的 nonceSize + gcm, err := cipher.NewGCMWithNonceSize(block, len(salt)) + if err != nil { + return nil, err + } + plainBytes, err := gcm.Open(nil, salt, cipherBytes, nil) + if err != nil { + return nil, err + } + return plainBytes, nil +} + +func postRecover() { + var err error + if r := recover(); r != nil { + switch value := r.(type) { + case string: + err = fmt.Errorf("%s", value) + case error: + err = value + default: + err = fmt.Errorf("unexpect panic error: %w", err) + } + err = fmt.Errorf("panic error: %w", err) + } +} + +// AesGCMEncrypt will encrypt plainBytes to cipherBytes +func AesGCMEncrypt(secret, plainBytes []byte) ([]byte, []byte, error) { + defer postRecover() + block, err := aes.NewCipher(secret) + if err != nil { + return nil, nil, err + } + gcm, err := cipher.NewGCMWithNonceSize(block, defaultSaltSize) + if err != nil { + return nil, nil, fmt.Errorf("failed NewGCM: %w", err) + } + salt := make([]byte, gcm.NonceSize()) + _, err = rand.Read(salt) + if err != nil { + return nil, nil, err + } + cipherBytes := gcm.Seal(nil, salt, plainBytes, nil) + return salt, cipherBytes, nil +} diff --git a/yuanrong/pkg/common/faas_common/sts/raw/crypto_test.go b/yuanrong/pkg/common/faas_common/sts/raw/crypto_test.go new file mode 100644 index 0000000..8f6603d --- /dev/null +++ b/yuanrong/pkg/common/faas_common/sts/raw/crypto_test.go @@ -0,0 +1,64 @@ +package raw + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + saltKeySep = ":" + + shareKey = "1752F862B5176946F18D45D67E256642F115D2D6A3D77773FAF1E5874AC5211D" + + plain = "{\"key1\":\"value1\",\"key2\":\"value2\"}" + + saltKey = "R8Mi3gSG3ou4X6eY:VIQASOEBJTQT3yd4qGrpqSbLrgemB5eTaD5KRefaOcXh/r18YSwhtv0j0A==" + plain2 = "{\"key1\":\"va1\",\"key2\":\"va2\"}" +) + +func TestAesGCMDecrypt(t *testing.T) { + shareKey2 := make([]byte, hex.DecodedLen(len(shareKey))) + _, err := hex.Decode(shareKey2, []byte(shareKey)) + if err != nil { + t.Errorf("%s", err) + } + + salt, cipherBytes, err := AesGCMEncrypt(shareKey2, []byte(plain)) + fmt.Println(string(salt), string(cipherBytes)) + saltBase64 := base64.StdEncoding.EncodeToString(salt) + cipherBase64 := base64.StdEncoding.EncodeToString(cipherBytes) + fmt.Println(saltBase64, cipherBase64) + + if err != nil { + t.Errorf("%s", err) + } + blocks1, err := AesGCMDecrypt(shareKey2, salt, cipherBytes) + + assert.Equal(t, string(blocks1), plain) +} + +func TestAesGCMDecrypt2(t *testing.T) { + shareKey2 := make([]byte, hex.DecodedLen(len(shareKey))) + _, err := hex.Decode(shareKey2, []byte(shareKey)) + if err != nil { + t.Errorf("%s", err) + } + + fields := strings.Split(saltKey, saltKeySep) + salt1, err := base64.StdEncoding.DecodeString(fields[0]) + if err != nil { + t.Errorf("%s", err) + } + cipher, err := base64.StdEncoding.DecodeString(fields[1]) + blocks, err := AesGCMDecrypt(shareKey2, salt1, cipher) + if err != nil { + t.Errorf("%s", err) + } + + assert.Equal(t, string(blocks), plain2) +} diff --git a/yuanrong/pkg/common/faas_common/sts/raw/raw.go b/yuanrong/pkg/common/faas_common/sts/raw/raw.go new file mode 100644 index 0000000..d914b0b --- /dev/null +++ b/yuanrong/pkg/common/faas_common/sts/raw/raw.go @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package raw define the sts structure +package raw + +// StsConfig - +type StsConfig struct { + StsEnable bool `json:"stsEnable,omitempty"` + SensitiveConfigs SensitiveConfigs `json:"sensitiveConfigs,omitempty"` + ServerConfig ServerConfig `json:"serverConfig,omitempty"` + MgmtServerConfig MgmtServerConfig `json:"mgmtServerConfig"` + StsDomainForRuntime string `json:"stsDomainForRuntime"` +} + +// SensitiveConfigs - +type SensitiveConfigs struct { + ShareKeys map[string]string `json:"shareKeys"` +} + +// ServerConfig - +type ServerConfig struct { + Domain string `json:"domain,omitempty" validate:"max=255"` + Path string `json:"path,omitempty" validate:"max=255"` +} + +// MgmtServerConfig - +type MgmtServerConfig struct { + Domain string `json:"domain,omitempty"` +} diff --git a/yuanrong/pkg/common/faas_common/sts/sts.go b/yuanrong/pkg/common/faas_common/sts/sts.go new file mode 100644 index 0000000..967328e --- /dev/null +++ b/yuanrong/pkg/common/faas_common/sts/sts.go @@ -0,0 +1,88 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package sts used for init sts +package sts + +import ( + "os" + "time" + + "github.com/magiconair/properties" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/config" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/sts/raw" +) + +// EnvSTSEnable flag +const EnvSTSEnable = "STS_ENABLE" +const fileMode = 0640 + +// InitStsSDK - Configure sts go sdk +func InitStsSDK(serverCfg raw.ServerConfig) error { + initStsSdkLog() + stsProperties := properties.LoadMap( + map[string]string{ + "sts.server.domain": serverCfg.Domain, + "sts.config.path": serverCfg.Path, + "sts.connect.timeout": "20000", + "sts.handshake.timeout": "20000", + }, + ) + err := stsgoapi.InitWith(*stsProperties) + if err != nil { + reportStsAlarm(err.Error()) + } + return err +} + +func reportStsAlarm(errMsg string) { + alarmDetail := &alarm.Detail{ + SourceTag: os.Getenv(constant.PodNameEnvKey) + "|" + os.Getenv(constant.PodIPEnvKey) + + "|" + os.Getenv(constant.ClusterName), + OpType: alarm.GenerateAlarmLog, + Details: "Init sts err, " + errMsg, + StartTimestamp: int(time.Now().Unix()), + EndTimestamp: 0, + } + alarmInfo := &alarm.LogAlarmInfo{ + AlarmID: alarm.InitStsSdkErr00001, + AlarmName: "InitStsSdkErr", + AlarmLevel: alarm.Level3, + } + + alarm.ReportOrClearAlarm(alarmInfo, alarmDetail) +} + +func initStsSdkLog() { + coreInfo, err := config.GetCoreInfoFromEnv() + if err != nil { + coreInfo = config.GetDefaultCoreInfo() + } + stsSdkLogFilePath := coreInfo.FilePath + "/sts.sdk.log" + stsgoapi.SetLogFile(stsSdkLogFilePath) + file, err := os.OpenFile(stsSdkLogFilePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileMode) + if err != nil { + log.GetLogger().Errorf("failed to open stsSdkLogFile") + return + } + defer file.Close() + return +} diff --git a/yuanrong/pkg/common/faas_common/timewheel/simpletimewheel.go b/yuanrong/pkg/common/faas_common/timewheel/simpletimewheel.go new file mode 100644 index 0000000..ba994be --- /dev/null +++ b/yuanrong/pkg/common/faas_common/timewheel/simpletimewheel.go @@ -0,0 +1,242 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package timewheel - +package timewheel + +import ( + "errors" + "fmt" + "sync" + "time" +) + +const ( + minPace = 2 * time.Millisecond + minSlotNum = 1 + notifyChannelSize = 1000 +) + +var ( + timeTriggerPool = sync.Pool{New: func() interface{} { + return &timeTrigger{} + }} +) + +type timeTrigger struct { + taskID string + times int + index int64 + circle int64 + circleCount int64 + disable bool + ch chan struct{} + prev *timeTrigger + next *timeTrigger +} + +// SimpleTimeWheel will trigger task at given interval by given times, it contains a certain number of slots and moves +// from one slot to another with a pace which is also the granularity of time wheel, task interval will be measured +// with a number of slots and recorded in the slot arrays, each slot has a linked list to trigger a series of tasks +// when time wheel moves to this slot +type SimpleTimeWheel struct { + ticker *time.Ticker + pace time.Duration + perimeter int64 + slotNum int64 + curSlot int64 + pendingTask int + slots []*timeTrigger + readyList []string + record *sync.Map + notifyCh chan struct{} + readyCh chan struct{} + stopCh chan struct{} + sync.RWMutex +} + +// NewSimpleTimeWheel will create a SimpleTimeWheel +func NewSimpleTimeWheel(pace time.Duration, slotNum int64) TimeWheel { + if pace < minPace { + pace = minPace + } + if slotNum < minSlotNum { + slotNum = minSlotNum + } + timeWheel := &SimpleTimeWheel{ + ticker: time.NewTicker(pace), + pace: pace, + perimeter: slotNum * int64(pace), + slotNum: slotNum, + curSlot: 0, + slots: make([]*timeTrigger, slotNum, slotNum), + record: new(sync.Map), + notifyCh: make(chan struct{}, notifyChannelSize), + readyCh: make(chan struct{}, 1), + stopCh: make(chan struct{}), + } + go timeWheel.run() + return timeWheel +} + +func (gt *SimpleTimeWheel) run() { + for { + select { + case <-gt.ticker.C: + gt.Lock() + gt.curSlot = (gt.curSlot + 1) % int64(len(gt.slots)) + gt.Unlock() + gt.checkAndFireTrigger() + case <-gt.stopCh: + gt.ticker.Stop() + return + } + } +} + +func (gt *SimpleTimeWheel) checkAndFireTrigger() { + trigger := gt.slots[gt.curSlot] + var readyList []string + for trigger != nil { + if !trigger.disable && trigger.circleCount == trigger.circle { + trigger.circleCount = 0 + if trigger.times == 0 { + trigger.disable = true + gt.record.Delete(trigger.taskID) + gt.removeTrigger(trigger) + continue + } + readyList = append(readyList, trigger.taskID) + select { + case trigger.ch <- struct{}{}: + default: + } + if trigger.times > 0 { + trigger.times-- + } + } + trigger.circleCount++ + trigger = trigger.next + } + gt.Lock() + gt.readyList = readyList + gt.Unlock() + if len(readyList) != 0 { + gt.readyCh <- struct{}{} + } +} + +// Wait will block until tasks are triggered and returns triggered task list +func (gt *SimpleTimeWheel) Wait() []string { + select { + case _, ok := <-gt.readyCh: + if !ok { + return nil + } + } + gt.RLock() + readyList := gt.readyList + gt.RUnlock() + return readyList +} + +// AddTask will add a task which will be triggered periodically over an given interval with given times (-1 means to +// run endlessly), considering that pace has a reasonable size and the logic below won't cost more time than that, +// AddTask won't catch up with the curSlot, so we don't need a mutex. it's also worth noticing that interval can't be +// smaller than the circumference of this time wheel +func (gt *SimpleTimeWheel) AddTask(taskID string, interval time.Duration, times int) (<-chan struct{}, error) { + if interval < time.Duration(gt.perimeter) { + return nil, ErrInvalidTaskInterval + } + if _, exist := gt.record.Load(taskID); exist { + return nil, fmt.Errorf("%s, taskId: %s", ErrTaskAlreadyExist.Error(), taskID) + } + trigger, ok := timeTriggerPool.Get().(*timeTrigger) + if !ok { + return nil, errors.New("not a timeTrigger type") + } + gt.Lock() + curSlot := gt.curSlot + circle := (int64(interval)/int64(gt.pace) + curSlot + 1) / gt.slotNum + circleCount := int64(1) + index := (int64(interval)/int64(gt.pace) + curSlot + 1) % gt.slotNum + if index > curSlot { + circleCount-- + } + trigger.taskID = taskID + trigger.times = times + trigger.circle = circle + trigger.circleCount = circleCount + trigger.index = index + trigger.disable = false + trigger.ch = make(chan struct{}, 1) + trigger.prev = nil + trigger.next = gt.slots[index] + if gt.slots[index] != nil { + gt.slots[index].prev = trigger + } + gt.slots[index] = trigger + gt.Unlock() + gt.record.Store(taskID, trigger) + return trigger.ch, nil +} + +// DelTask will delete a task in SimpleTimeWheel and remove its trigger +func (gt *SimpleTimeWheel) DelTask(taskID string) error { + object, exist := gt.record.Load(taskID) + if !exist { + return nil + } + gt.record.Delete(taskID) + trigger, ok := object.(*timeTrigger) + if !ok { + return errors.New("not a timeTrigger type") + } + // since caller no longer need this task, it's ok that this trigger still fires + trigger.disable = true + gt.removeTrigger(trigger) + timeTriggerPool.Put(trigger) + return nil +} + +// Stop will stop time wheel +func (gt *SimpleTimeWheel) Stop() { + close(gt.stopCh) + close(gt.readyCh) +} + +// removeTrigger won't set trigger's prev and next to nil since checkAndFireTrigger may processing this trigger right +// now and we don't want to lose track of the next trigger +func (gt *SimpleTimeWheel) removeTrigger(trigger *timeTrigger) { + gt.Lock() + defer gt.Unlock() + // special treatment if this trigger is the head of linked list + if trigger.prev == nil { + if trigger.index >= int64(len(gt.slots)) { + fmt.Errorf("trigger.index is out of slots slice") + } else { + gt.slots[trigger.index] = trigger.next + } + if trigger.next != nil { + trigger.next.prev = nil + } + } else { + trigger.prev.next = trigger.next + if trigger.next != nil { + trigger.next.prev = trigger.prev + } + } +} diff --git a/yuanrong/pkg/common/faas_common/timewheel/simpletimewheel_test.go b/yuanrong/pkg/common/faas_common/timewheel/simpletimewheel_test.go new file mode 100644 index 0000000..d78fdc2 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/timewheel/simpletimewheel_test.go @@ -0,0 +1,192 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package timewheel - +package timewheel + +import ( + "math" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSimpleTimeWheelBasic(t *testing.T) { + timeWheel := NewSimpleTimeWheel(5*time.Millisecond, 10) + defer timeWheel.Stop() + time.Sleep(11 * time.Millisecond) + taskName := "TestSimpleTimeWheelBasic_" + "task-1" + ch, err := timeWheel.AddTask(taskName, 500*time.Millisecond, -1) + addTime := time.Now() + if err != nil { + t.Errorf("failed to add task error %s", err) + } + var triggerTime time.Time + select { + case <-time.NewTimer(750 * time.Millisecond).C: + t.Errorf("timeout waiting for timeWheel to trigger after %d", time.Now().Sub(addTime).Milliseconds()) + case <-ch: + triggerTime = time.Now() + interval := int(math.Floor(float64(triggerTime.Sub(addTime).Milliseconds()))) + assert.Equal(t, true, interval >= 450 && interval <= 750) + } + + err = timeWheel.DelTask(taskName) + if err != nil { + t.Errorf("failed to delete task error %s", err) + } + select { + case <-time.NewTimer(200 * time.Millisecond).C: + case <-ch: + t.Errorf("trigger should not fire") + } +} + +func TestSimpleTimeWheel_Wait(t *testing.T) { + readyCh := make(chan struct{}) + readyList := []string{"TestSimpleTimeWheel_Wait_task1", "TestSimpleTimeWheel_Wait_task2"} + + wheel := &SimpleTimeWheel{ + readyCh: readyCh, + readyList: readyList, + } + + go func() { + readyCh <- struct{}{} + }() + + result := wheel.Wait() + assert.Equal(t, readyList, result, "The readyList should be returned") + close(readyCh) + result = wheel.Wait() + assert.Nil(t, result, "The result should be nil when channel is closed") +} + +func TestSimpleTimeWheelCombination(t *testing.T) { + timeWheel := NewSimpleTimeWheel(5*time.Millisecond, 10) + defer timeWheel.Stop() + var ( + err error + task1Ch <-chan struct{} + task2Ch <-chan struct{} + task3Ch <-chan struct{} + task2AddTime time.Time + task3AddTime time.Time + ) + wg := sync.WaitGroup{} + wg.Add(1) + task1Name := "TestSimpleTimeWheelCombination_" + "task-1" + task2Name := "TestSimpleTimeWheelCombination_" + "task-2" + task3Name := "TestSimpleTimeWheelCombination_" + "task-3" + go func() { + task1Ch, err = timeWheel.AddTask(task1Name, time.Duration(500)*time.Millisecond, -1) + if err != nil { + t.Errorf("failed to add task error %s", err) + } + wg.Done() + }() + wg.Add(1) + go func() { + task2Ch, err = timeWheel.AddTask(task2Name, time.Duration(500)*time.Millisecond, -1) + task2AddTime = time.Now() + if err != nil { + t.Errorf("failed to add task error %s", err) + } + wg.Done() + }() + wg.Add(1) + go func() { + task3Ch, err = timeWheel.AddTask(task3Name, time.Duration(500)*time.Millisecond, -1) + task3AddTime = time.Now() + if err != nil { + t.Errorf("failed to add task error %s", err) + } + wg.Done() + }() + wg.Wait() + err = timeWheel.DelTask(task1Name) + if err != nil { + t.Errorf("failed to delete task error %s", err) + } + done := 0 + timer := time.NewTimer(900 * time.Millisecond) + defer timer.Stop() + for done != 2 { + select { + case <-timer.C: + t.Errorf("timeout waiting for timeWheel to trigger") + case <-task1Ch: + t.Errorf("trigger should not fire") + case <-task2Ch: + interval := int(math.Floor(float64(time.Now().Sub(task2AddTime).Milliseconds()))) + if interval < 450 || interval > 800 { + t.Errorf("task2's trigger interval %d is out of range [450, 800]", interval) + } + done++ + case <-task3Ch: + interval := int(math.Floor(float64(time.Now().Sub(task3AddTime).Milliseconds()))) + if interval < 450 || interval > 800 { + t.Errorf("task3's trigger interval %d is out of range [450, 800]", interval) + } + done++ + } + } +} + +func TestSimpleTimeWheel_Stop(t *testing.T) { + timeWheel := NewSimpleTimeWheel(2*time.Millisecond, 10) + timeWheel.Stop() +} + +func TestNewSimpleTimeWheel(t *testing.T) { + timeWheel := NewSimpleTimeWheel(minPace-1, 0) + defer timeWheel.Stop() + assert.NotNil(t, timeWheel) +} + +func TestSimpleTimeWheelBasic1(t *testing.T) { + timeWheel := NewSimpleTimeWheel(10*time.Millisecond, 10) + defer timeWheel.Stop() + time.Sleep(11 * time.Millisecond) + task1Name := "TestSimpleTimeWheelBasic1_" + "task-1" + ch, err := timeWheel.AddTask(task1Name, 1000*time.Millisecond, -1) + addTime := time.Now() + if err != nil { + t.Errorf("failed to add task error %s", err) + } + var triggerTime time.Time + select { + case <-time.NewTimer(10000 * time.Millisecond).C: + t.Errorf("timeout waiting for timeWheel to trigger %s", time.Now().Format(time.RFC3339Nano)) + case <-ch: + triggerTime = time.Now() + interval := int(math.Floor(float64(triggerTime.Sub(addTime).Milliseconds()))) + t.Logf("show invterval %d\n", interval) + assert.Equal(t, true, interval >= 800 && interval <= 1400) + } + + err = timeWheel.DelTask(task1Name) + if err != nil { + t.Errorf("failed to delete task error %s", err) + } + select { + case <-time.NewTimer(1000 * time.Millisecond).C: + case <-ch: + t.Errorf("trigger should not fire") + } +} diff --git a/yuanrong/pkg/common/faas_common/timewheel/timewheel.go b/yuanrong/pkg/common/faas_common/timewheel/timewheel.go new file mode 100644 index 0000000..5300d9f --- /dev/null +++ b/yuanrong/pkg/common/faas_common/timewheel/timewheel.go @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package timewheel - +package timewheel + +import ( + "errors" + "time" +) + +var ( + // ErrTaskAlreadyExist is the error of task already exist + ErrTaskAlreadyExist = errors.New("task already exist") + // ErrInvalidTaskInterval is the error of invalid interval for TimeWheel task + ErrInvalidTaskInterval = errors.New("interval of task is invalid") +) + +// TimeWheel can trigger tasks periodically by given intervals +type TimeWheel interface { + Wait() []string + AddTask(taskID string, interval time.Duration, times int) (<-chan struct{}, error) + DelTask(taskID string) error + Stop() +} diff --git a/yuanrong/pkg/common/faas_common/tls/https.go b/yuanrong/pkg/common/faas_common/tls/https.go new file mode 100644 index 0000000..392be48 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/tls/https.go @@ -0,0 +1,401 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tls - +package tls + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + "sync" + + commonCrypto "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +const urlIndex = 1 + +// HTTPSConfig is for needed HTTPS config +type HTTPSConfig struct { + CipherSuite []uint16 + MinVers uint16 + MaxVers uint16 + CACertFile string + CertFile string + SecretKeyFile string + PwdFilePath string + KeyPassPhase string + SecretName string + DecryptTool string + DisableClientCertVerify bool +} + +// InternalHTTPSConfig is for input config +type InternalHTTPSConfig struct { + HTTPSEnable bool `json:"httpsEnable" yaml:"httpsEnable" valid:"optional"` + TLSProtocol string `json:"tlsProtocol" yaml:"tlsProtocol" valid:"optional"` + TLSCiphers string `json:"tlsCiphers" yaml:"tlsCiphers" valid:"optional"` + SSLBasePath string `json:"sslBasePath" yaml:"sslBasePath" valid:"optional"` + RootCAFile string `json:"rootCAFile" yaml:"rootCAFile" valid:"optional"` + ModuleCertFile string `json:"moduleCertFile" yaml:"moduleCertFile" valid:"optional"` + ModuleKeyFile string `json:"moduleKeyFile" yaml:"moduleKeyFile" valid:"optional"` + PwdFile string `json:"pwdFile" yaml:"pwdFile" valid:"optional"` + SecretName string `json:"secretName" yaml:"secretName" valid:"optional"` + SSLDecryptTool string `json:"sslDecryptTool" yaml:"sslDecryptTool" valid:"optional"` + DisableClientCertVerify bool `json:"disEnableClientCertVerify" yaml:"disEnableClientCertVerify" valid:"optional"` +} + +var ( + // tlsVersionMap is a set of TLS versions + tlsVersionMap = map[string]uint16{ + "TLSv1.2": tls.VersionTLS12, + } + // httpsConfigs is a global variable of HTTPS config + httpsConfigs = &HTTPSConfig{} + // tlsConfig is a global variable of TLS config + tlsConfig *tls.Config + once sync.Once +) + +// GetURLScheme returns "http" or "https" +func GetURLScheme(https bool) string { + if https { + return "https" + } + return "http" +} + +// tlsCipherSuiteMap is a set of supported TLS algorithms +var tlsCipherSuiteMap = map[string]uint16{ + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, +} + +// GetClientTLSConfig - +func GetClientTLSConfig() *tls.Config { + if tlsConfig == nil { + return nil + } + certs := make([]tls.Certificate, len(tlsConfig.Certificates)) + copy(certs, tlsConfig.Certificates) + suits := make([]uint16, len(tlsConfig.CipherSuites)) + copy(suits, tlsConfig.CipherSuites) + newCfg := &tls.Config{ + ClientCAs: tlsConfig.ClientCAs, + Certificates: certs, + CipherSuites: suits, + PreferServerCipherSuites: tlsConfig.PreferServerCipherSuites, + ClientAuth: tlsConfig.ClientAuth, + InsecureSkipVerify: tlsConfig.InsecureSkipVerify, + MinVersion: tlsConfig.MinVersion, + MaxVersion: tlsConfig.MaxVersion, + Renegotiation: tlsConfig.Renegotiation, + } + return newCfg +} + +func loadCerts(path string, filename string) string { + certPath, err := filepath.Abs(filepath.Join(path, filename)) + if err != nil { + log.GetLogger().Errorf("failed to return an absolute representation of filename: %s", filename) + return "" + } + ok := utils.FileExists(certPath) + if !ok { + log.GetLogger().Errorf("failed to load the cert file: %s", certPath) + return "" + } + return certPath +} + +func loadTLSConfig() error { + clientAuthMode := tls.RequireAndVerifyClientCert + if httpsConfigs.DisableClientCertVerify { + clientAuthMode = tls.NoClientCert + } + var pool *x509.CertPool + + pool, err := GetX509CACertPool(httpsConfigs.CACertFile) + if err != nil { + log.GetLogger().Errorf("failed to GetX509CACertPool: %s", err.Error()) + return err + } + + var certs []tls.Certificate + certs, err = LoadServerTLSCertificate(httpsConfigs.CertFile, httpsConfigs.SecretKeyFile, + httpsConfigs.KeyPassPhase, httpsConfigs.DecryptTool, true) + if err != nil { + log.GetLogger().Errorf("failed to loadServerTLSCertificate: %s", err.Error()) + return err + } + + tlsConfig = &tls.Config{ + ClientCAs: pool, + Certificates: certs, + CipherSuites: httpsConfigs.CipherSuite, + PreferServerCipherSuites: true, + ClientAuth: clientAuthMode, + InsecureSkipVerify: true, + MinVersion: httpsConfigs.MinVers, + MaxVersion: httpsConfigs.MaxVers, + Renegotiation: tls.RenegotiateNever, + } + + return nil +} + +// loadHTTPSConfig loads the protocol and ciphers of TLS +func loadHTTPSConfig(config InternalHTTPSConfig) error { + httpsConfigs = &HTTPSConfig{ + MinVers: tls.VersionTLS12, + MaxVers: tls.VersionTLS12, + CipherSuite: nil, + CACertFile: loadCerts(config.SSLBasePath, config.RootCAFile), + CertFile: loadCerts(config.SSLBasePath, config.ModuleCertFile), + SecretKeyFile: loadCerts(config.SSLBasePath, config.ModuleKeyFile), + PwdFilePath: loadCerts(config.SSLBasePath, config.PwdFile), + KeyPassPhase: "", + SecretName: config.SecretName, + DecryptTool: config.SSLDecryptTool, + DisableClientCertVerify: config.DisableClientCertVerify, + } + + minVersion := parseSSLProtocol(config.TLSProtocol) + if httpsConfigs.MinVers == 0 { + return errors.New("invalid TLS protocol") + } + if minVersion == 0 { + minVersion = tls.VersionTLS12 + } + httpsConfigs.MinVers = minVersion + cipherSuites := parseSSLCipherSuites(config.TLSCiphers) + if len(cipherSuites) == 0 { + return errors.New("invalid TLS ciphers") + } + httpsConfigs.CipherSuite = cipherSuites + + keyPassPhase, err := ioutil.ReadFile(httpsConfigs.PwdFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read file cert_pwd: %s", err.Error()) + return err + } + httpsConfigs.KeyPassPhase = string(keyPassPhase) + utils.ClearByteMemory(keyPassPhase) + + return nil +} + +// InitTLSConfig inits config of HTTPS +func InitTLSConfig(config InternalHTTPSConfig) error { + var err error + once.Do(func() { + err = loadHTTPSConfig(config) + if err != nil { + err = fmt.Errorf("failed to load HTTPS config,err %s", err.Error()) + return + } + + err = loadTLSConfig() + if err != nil { + return + } + }) + return err +} + +// GetX509CACertPool generates CACertPool by CA certificate +func GetX509CACertPool(caCertFilePath string) (*x509.CertPool, error) { + pool := x509.NewCertPool() + caCertContent, err := loadCACertBytes(caCertFilePath) + if err != nil { + return nil, err + } + + pool.AppendCertsFromPEM(caCertContent) + return pool, nil + +} + +// LoadServerTLSCertificate generates tls certificate by certfile and keyfile +func LoadServerTLSCertificate(certFile, keyFile, passPhase, decryptTool string, + isHTTPS bool) ([]tls.Certificate, error) { + certContent, keyContent, err := loadCertAndKeyBytes(certFile, keyFile, passPhase, decryptTool, isHTTPS) + utils.ClearStringMemory(passPhase) + utils.ClearStringMemory(httpsConfigs.KeyPassPhase) + if err != nil { + utils.ClearByteMemory(certContent) + utils.ClearByteMemory(keyContent) + return nil, err + } + + cert, err := tls.X509KeyPair(certContent, keyContent) + utils.ClearByteMemory(certContent) + utils.ClearByteMemory(keyContent) + if err != nil { + log.GetLogger().Errorf("failed to load the X509 key pair from cert file with key file: %s", + err.Error()) + return nil, err + } + var certs []tls.Certificate + certs = append(certs, cert) + return certs, nil +} + +func containPassPhase(keyContent []byte, passPhase string, decryptTool string, + isHTTPS bool) (Content []byte, err error) { + if !isHTTPS { + plainkeyContent, err := localauth.Decrypt(string(keyContent)) + if err != nil { + log.GetLogger().Errorf("failed to decrypt keyContent: %s", err.Error()) + return nil, err + } + return plainkeyContent, nil + } + + keyBlock, _ := pem.Decode(keyContent) + if keyBlock == nil { + log.GetLogger().Errorf("failed to decode key file ") + return nil, errors.New("failed to decode key file") + } + + if commonCrypto.IsEncryptedPEMBlock(keyBlock) { + var plainPassPhase []byte + var err error + var decrypted string + if len(passPhase) > 0 { + if decryptTool == "SCC" { + decrypted, err = crypto.SCCDecrypt([]byte(passPhase)) + plainPassPhase = []byte(decrypted) + } else if decryptTool == "LOCAL" { + plainPassPhase, err = localauth.Decrypt(passPhase) + } + if err != nil { + log.GetLogger().Errorf("failed to decrypt the ssl passPhase(%d): %s", len(passPhase), + err.Error()) + return nil, err + } + } + + keyData, err := commonCrypto.DecryptPEMBlock(keyBlock, plainPassPhase) + clearByteMemory(plainPassPhase) + utils.ClearStringMemory(decrypted) + + if err != nil { + log.GetLogger().Errorf("failed to decrypt key file, error: %s", err.Error()) + return nil, err + } + + // The decryption is successful, then the file is re-encoded to a PEM file + plainKeyBlock := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: keyData, + } + + keyContent = pem.EncodeToMemory(plainKeyBlock) + } + return keyContent, nil + +} + +func loadCertAndKeyBytes(certFilePath, keyFilePath, passPhase string, decryptTool string, isHTTPS bool) ( + certPEMBlock, keyPEMBlock []byte, err error) { + certContent, err := ioutil.ReadFile(certFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read cert file %s: %s", certFilePath, err.Error()) + return nil, nil, err + } + + keyContent, err := ioutil.ReadFile(keyFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read key file %s: %s", keyFilePath, err.Error()) + return nil, nil, err + } + keyContent, err = containPassPhase(keyContent, passPhase, decryptTool, isHTTPS) + if err != nil { + log.GetLogger().Errorf("failed to decode keyContent, error is %s", err.Error()) + return nil, nil, err + } + + return certContent, keyContent, nil + +} + +func clearByteMemory(src []byte) { + for idx := 0; idx < len(src)&32; idx++ { + src[idx] = 0 + } +} + +func loadCACertBytes(caCertFilePath string) ([]byte, error) { + caCertContent, err := ioutil.ReadFile(caCertFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read ca cert file %s, err: %s", caCertFilePath, err.Error()) + return nil, err + } + + return caCertContent, nil +} + +func parseSSLProtocol(rawProtocol string) uint16 { + if protocol, ok := tlsVersionMap[rawProtocol]; ok { + return protocol + } + log.GetLogger().Errorf("invalid SSL version: %s, use the default protocol version", rawProtocol) + return 0 +} + +func parseSSLCipherSuites(ciphers string) []uint16 { + cipherSuiteNameList := strings.Split(ciphers, ",") + if len(cipherSuiteNameList) == 0 { + log.GetLogger().Errorf("input cipher suite is empty") + return nil + } + cipherSuites := make([]uint16, 0, len(cipherSuiteNameList)) + for _, cipherSuiteItem := range cipherSuiteNameList { + cipherSuiteItem = strings.TrimSpace(cipherSuiteItem) + if len(cipherSuiteItem) == 0 { + continue + } + + if cipherSuite, ok := tlsCipherSuiteMap[cipherSuiteItem]; ok { + cipherSuites = append(cipherSuites, cipherSuite) + } else { + log.GetLogger().Errorf("cipher %s does not exist", cipherSuiteItem) + } + } + + return cipherSuites +} + +// ParseURL URL may be: ip:port | http://ip:port | https://ip:port +func ParseURL(rawURL string) string { + urls := strings.Split(rawURL, "//") + if len(urls) > urlIndex { + return urls[urlIndex] + } + return rawURL +} diff --git a/yuanrong/pkg/common/faas_common/tls/https_test.go b/yuanrong/pkg/common/faas_common/tls/https_test.go new file mode 100644 index 0000000..942a7e4 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/tls/https_test.go @@ -0,0 +1,256 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tls + +import ( + "crypto/tls" + "encoding/pem" + "errors" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/crypto" +) + +func TestGetURLScheme(t *testing.T) { + if "https" != GetURLScheme(true) { + t.Error("GetURLScheme failed") + } + if "http" != GetURLScheme(false) { + t.Error("GetURLScheme failed") + } +} + +func TestInitTLSConfig(t *testing.T) { + p := gomonkey.ApplyFunc(ioutil.ReadFile, func(filename string) ([]byte, error) { + return nil, nil + }) + p.ApplyFunc(containPassPhase, func(keyContent []byte, passPhase string, decryptTool string, isHttps bool) (Content []byte, err error) { + return nil, nil + }) + p.ApplyFunc(tls.X509KeyPair, func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { + var cert tls.Certificate + return cert, nil + }) + defer p.Reset() + os.Setenv("SSL_ROOT", "/home/sn/resource/https") + var config InternalHTTPSConfig + config.TLSProtocol = "TLSv1.2" + config.TLSCiphers = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, TLS_TEST" + err := InitTLSConfig(config) + assert.Equal(t, nil, err) +} + +func TestGetClientTLSConfig(t *testing.T) { + actual := GetClientTLSConfig() + assert.Equal(t, tlsConfig, actual) +} + +func TestContainPassPhase(t *testing.T) { + convey.Convey("ContainPassPhase", t, func() { + errCtrl := "" + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(crypto.DecryptPEMBlock, func(b *pem.Block, password []byte) ([]byte, error) { + if errCtrl == "returnError" { + return nil, errors.New("some error") + } + return nil, nil + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + convey.Convey("http error case 1", func() { + keyContent := []byte{} + passPhase := "" + isHttps := false + content, err := containPassPhase(keyContent, passPhase, "LOCAL", isHttps) + convey.So(err, convey.ShouldNotBeNil) + convey.So(content, convey.ShouldBeNil) + }) + convey.Convey("https error case 1", func() { + keyContent := []byte{} + passPhase := "" + isHttps := true + content, err := containPassPhase(keyContent, passPhase, "LOCAL", isHttps) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldEqual, "failed to decode key file") + convey.So(content, convey.ShouldBeNil) + }) + convey.Convey("Decrypt error", func() { + keyContent := pem.EncodeToMemory(&pem.Block{ + Type: "MESSAGE", + Headers: map[string]string{"DEK-Info": "test"}, + Bytes: []byte("test containPassPhase")}) + passPhase := "abc" + isHttps := true + errCtrl = "returnError" + content, err := containPassPhase(keyContent, passPhase, "LOCAL", isHttps) + convey.So(err, convey.ShouldNotBeNil) + convey.So(content, convey.ShouldBeNil) + }) + convey.Convey("Decrypt success", func() { + keyContent := pem.EncodeToMemory(&pem.Block{ + Type: "MESSAGE", + Headers: map[string]string{"DEK-Info": "test"}, + Bytes: []byte("test containPassPhase")}) + passPhase := "abc" + isHttps := true + errCtrl = "" + content, err := containPassPhase(keyContent, passPhase, "LOCAL", isHttps) + convey.So(err, convey.ShouldBeNil) + convey.So(content, convey.ShouldNotBeNil) + }) + }) +} + +func TestLoadCerts(t *testing.T) { + convey.Convey("https Load Certs 1", t, func() { + patch := gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return "aaa" + }) + defer patch.Reset() + cert := loadCerts("./test", "trust.cer") + convey.So(cert, convey.ShouldNotBeNil) + }) + convey.Convey("https Load Certs 2", t, func() { + patch := gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return "aaa" + }) + patch2 := gomonkey.ApplyFunc(filepath.Abs, func(path string) (string, error) { + return "a", errors.New("bbb") + }) + defer patch.Reset() + defer patch2.Reset() + cert := loadCerts("1", "trust.cer") + convey.So(cert, convey.ShouldNotBeNil) + }) + +} + +func Test_parseSSLProtocol(t *testing.T) { + convey.Convey("Test_parseSSLProtocol", t, func() { + convey.So(parseSSLProtocol("TLSv1.2"), convey.ShouldEqual, tls.VersionTLS12) + convey.So(parseSSLProtocol("abc"), convey.ShouldEqual, 0) + }) +} + +func Test_parseURL(t *testing.T) { + url := ParseURL("http://test.com") + assert.Equal(t, url, "test.com") + url1 := ParseURL("test.com") + assert.Equal(t, url1, "test.com") +} + +func TestGetClientTLSConfig_Multi(t *testing.T) { + old := tlsConfig + + tlsConfig = &tls.Config{} + + defer func() { + tlsConfig = old + }() + + a := GetClientTLSConfig() + a.CipherSuites = append(a.CipherSuites, 10) + + b := GetClientTLSConfig() + + assert.NotEqual(t, a, b) + assert.NotSame(t, a, b) + assert.Equal(t, 1, len(a.CipherSuites)) + assert.Equal(t, 0, len(b.CipherSuites)) +} + +func TestLoadServerTLSCertificate(t *testing.T) { + readFileCtrl := "" + readFileCtrlCount := 0 + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(tls.X509KeyPair, func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { + return tls.Certificate{}, errors.New("some error") + }), + gomonkey.ApplyFunc(ioutil.ReadFile, func(filename string) ([]byte, error) { + if readFileCtrl == "successOnce" { + if readFileCtrlCount == 0 { + readFileCtrlCount++ + return nil, nil + } + return nil, errors.New("some error") + } + readFileCtrlCount = 0 + return nil, nil + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + passLiteral := "testPassPhase" + passByteArray := []byte(passLiteral) + passPhase := string(passByteArray) + readFileCtrl = "successOnce" + certs, err := LoadServerTLSCertificate("testCertFile", "testKeyFile", passPhase, "LOCAL", true) + assert.NotNil(t, err) + assert.Empty(t, certs) + certs, err = LoadServerTLSCertificate("testCertFile", "testKeyFile", passPhase, "LOCAL", true) + assert.NotNil(t, err) + assert.Empty(t, certs) + readFileCtrl = "" + certs, err = LoadServerTLSCertificate("testCertFile", "testKeyFile", passPhase, "LOCAL", true) + assert.NotNil(t, err) + assert.Empty(t, certs) +} + +func Test_loadCertAndKeyBytes(t *testing.T) { + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(ioutil.ReadFile, func(filename string) ([]byte, error) { + return []byte("abc"), nil + }), + gomonkey.ApplyFunc(pem.Decode, func(data []byte) (p *pem.Block, rest []byte) { + return &pem.Block{}, []byte{} + }), + gomonkey.ApplyFunc(crypto.IsEncryptedPEMBlock, func(b *pem.Block) bool { + return true + }), + gomonkey.ApplyFunc(crypto.DecryptPEMBlock, func(b *pem.Block, password []byte) ([]byte, error) { + return []byte{}, nil + }), + gomonkey.ApplyFunc(pem.EncodeToMemory, func(b *pem.Block) []byte { + return []byte("abc") + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + convey.Convey("loadCertAndKeyBytes", t, func() { + bytes, keyPEMBlock, err := loadCertAndKeyBytes("path1", "path2", "", "", true) + convey.So(err, convey.ShouldBeNil) + convey.So(string(bytes), convey.ShouldEqual, "abc") + convey.So(string(keyPEMBlock), convey.ShouldEqual, "abc") + }) +} diff --git a/yuanrong/pkg/common/faas_common/tls/option.go b/yuanrong/pkg/common/faas_common/tls/option.go new file mode 100644 index 0000000..d8ccf21 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/tls/option.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tls - +package tls + +import ( + "crypto/tls" + "crypto/x509" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +const ( + // DefaultCAFile is the default file for tls client + DefaultCAFile = "/home/sn/resource/ca/ca.pem" +) + +// NewTLSConfig returns tls.Config with given options +func NewTLSConfig(opts ...Option) *tls.Config { + config := &tls.Config{ + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + // for TLS1.2 + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + // for TLS1.3 + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + PreferServerCipherSuites: true, + Renegotiation: tls.RenegotiateNever, + } + for _, opt := range opts { + opt.apply(config) + } + return config +} + +// Option is optional argument for tls.Config +type Option interface { + apply(*tls.Config) +} + +type rootCAOption struct { + cas *x509.CertPool +} + +func (r *rootCAOption) apply(config *tls.Config) { + config.RootCAs = r.cas +} + +// WithRootCAs returns Option that applies root CAs to tls.Config +func WithRootCAs(caFiles ...string) Option { + rootCAs, err := LoadRootCAs(caFiles...) + if err != nil { + log.GetLogger().Warnf("failed to load root ca, err: %s", err.Error()) + rootCAs = nil + } + return &rootCAOption{ + cas: rootCAs, + } +} + +type certsOption struct { + certs []tls.Certificate +} + +func (c *certsOption) apply(config *tls.Config) { + config.Certificates = c.certs +} + +// WithCerts returns Option that applies cert file and key file to tls.Config +func WithCerts(certFile, keyFile string) Option { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.GetLogger().Warnf("load cert.pem and key.pem error: %s", err) + cert = tls.Certificate{} + } + return &certsOption{ + certs: []tls.Certificate{cert}, + } +} + +type skipVerifyOption struct { +} + +func (s *skipVerifyOption) apply(config *tls.Config) { + config.InsecureSkipVerify = true +} + +// WithSkipVerify returns Option that skips to verify certificates +func WithSkipVerify() Option { + return &skipVerifyOption{} +} diff --git a/yuanrong/pkg/common/faas_common/tls/option_test.go b/yuanrong/pkg/common/faas_common/tls/option_test.go new file mode 100644 index 0000000..3cf0a64 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/tls/option_test.go @@ -0,0 +1,146 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package tls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type TestSuite struct { + suite.Suite + server http.Server + rootKEY string + rootPEM string + rootSRL string + serverKEY string + serverPEM string + serverCSR string +} + +func (s *TestSuite) SetupSuite() { + certificatePath, err := os.Getwd() + if err != nil { + s.T().Errorf("failed to get current working dictionary: %s", err.Error()) + return + } + + certificatePath += "/../../../test/" + s.rootKEY = certificatePath + "ca.key" + s.rootPEM = certificatePath + "ca.crt" + s.rootSRL = certificatePath + "ca.srl" + s.serverKEY = certificatePath + "server.key" + s.serverPEM = certificatePath + "server.crt" + s.serverCSR = certificatePath + "server.csr" + + body := "Hello" + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%s", body) + }) + + s.server = http.Server{ + Addr: "127.0.0.1:6061", + Handler: handler, + } +} + +func (s *TestSuite) TearDownSuite() { + s.server.Shutdown(context.Background()) + + os.Remove(s.serverKEY) + os.Remove(s.serverPEM) + os.Remove(s.serverCSR) + os.Remove(s.rootKEY) + os.Remove(s.rootPEM) + os.Remove(s.rootSRL) +} + +func TestOptionTestSuite(t *testing.T) { + suite.Run(t, new(TestSuite)) +} + +func TestVerifyCert(t *testing.T) { + var raw [][]byte + tlsConfig = &tls.Config{} + tlsConfig.ClientCAs = x509.NewCertPool() + err := VerifyCert(raw, nil) + assert.NotNil(t, err) + + raw = [][]byte{ + []byte("0"), + []byte("1"), + } + err = VerifyCert(raw, nil) + assert.NotNil(t, err) +} + +func TestNewTLSConfig(t *testing.T) { + defaultCertFile := "/home/sn/resource/secret/cert.pem" + defaultKeyFile := "/home/sn/resource/secret/key.pem" + p := gomonkey.ApplyFunc(ioutil.ReadFile, func(filename string) ([]byte, error) { + return nil, nil + }) + + p.ApplyFunc(tls.LoadX509KeyPair, func(certFile, keyFile string) (tls.Certificate, error) { + return tls.Certificate{}, nil + }) + defer p.Reset() + actual := NewTLSConfig(WithRootCAs(DefaultCAFile), + WithCerts(defaultCertFile, defaultKeyFile), WithSkipVerify()) + expect := &tls.Config{ + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + // for TLS1.2 + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + // for TLS1.3 + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + PreferServerCipherSuites: true, + Renegotiation: tls.RenegotiateNever, + InsecureSkipVerify: true, + RootCAs: nil, + Certificates: []tls.Certificate{{}}, + } + + assert.Equal(t, expect, actual) +} + +func TestWithCerts(t *testing.T) { + defer gomonkey.ApplyFunc(tls.LoadX509KeyPair, func(certFile, keyFile string) (tls.Certificate, error) { + return tls.Certificate{}, errors.New("LoadX509KeyPair error") + }).Reset() + certs := WithCerts("", "") + option := certs.(*certsOption) + assert.Nil(t, option.certs[0].Certificate) +} diff --git a/yuanrong/pkg/common/faas_common/tls/tls.go b/yuanrong/pkg/common/faas_common/tls/tls.go new file mode 100644 index 0000000..585e795 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/tls/tls.go @@ -0,0 +1,74 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tls provides tls utils +package tls + +import ( + "crypto/x509" + "errors" + "io/ioutil" + "time" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// LoadRootCAs returns system cert pool with caFiles added +func LoadRootCAs(caFiles ...string) (*x509.CertPool, error) { + rootCAs, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + for _, file := range caFiles { + cert, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + if !rootCAs.AppendCertsFromPEM(cert) { + return nil, err + } + } + return rootCAs, nil +} + +// VerifyCert Used to verity the server certificate +func VerifyCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(rawCerts)) + if len(certs) == 0 { + log.GetLogger().Errorf("cert number is 0") + return errors.New("cert number is 0") + } + opts := x509.VerifyOptions{ + Roots: tlsConfig.ClientCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + log.GetLogger().Errorf("failed to parse certificate from server: %s", err.Error()) + return err + } + certs[i] = cert + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err +} diff --git a/yuanrong/pkg/common/faas_common/tls/tls_test.go b/yuanrong/pkg/common/faas_common/tls/tls_test.go new file mode 100644 index 0000000..e6329ee --- /dev/null +++ b/yuanrong/pkg/common/faas_common/tls/tls_test.go @@ -0,0 +1,79 @@ +package tls + +import ( + "crypto/x509" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/require" +) + +// TestLoadRootCAs is used to test the root certificate loading error. +func TestLoadRootCAs(t *testing.T) { + convey.Convey("LoadRootCAs", t, func() { + convey.Convey("error case 1", func() { + caFiles := "" + _, err := LoadRootCAs(caFiles) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("error case 2", func() { + dir, err := ioutil.TempDir("", "*") + require.NoError(t, err) + cryptoFile, err := ioutil.TempFile(dir, "crypto") + require.NoError(t, err) + _, err = LoadRootCAs(cryptoFile.Name()) + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("error case 3", func() { + dir, err := ioutil.TempDir("", "*") + require.NoError(t, err) + cryptoFile, err := ioutil.TempFile(dir, "test") + require.NoError(t, err) + err = ioutil.WriteFile(filepath.Join(dir, "test"), []byte("a"), os.ModePerm) + require.NoError(t, err) + _, err = LoadRootCAs(cryptoFile.Name()) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +// TestVerifyCert2 is used to test certificate modification errors. +func TestVerifyCert2(t *testing.T) { + convey.Convey("VerifyCert", t, func() { + convey.Convey("error case 1", func() { + rawCerts := [][]byte{} + verifiedChains := [][]*x509.Certificate{} + err := VerifyCert(rawCerts, verifiedChains) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldEqual, "cert number is 0") + }) + + convey.Convey("error case 2", func() { + rawCerts := [][]byte{[]byte("test1"), []byte("test2")} + verifiedChains := [][]*x509.Certificate{} + err := VerifyCert(rawCerts, verifiedChains) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("success", func() { + defer gomonkey.ApplyFunc(x509.ParseCertificate, func(der []byte) (*x509.Certificate, error) { + return &x509.Certificate{}, nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&x509.Certificate{}), "Verify", + func(_ *x509.Certificate, opts x509.VerifyOptions) (chains [][]*x509.Certificate, err error) { + return nil, nil + }).Reset() + rawCerts := [][]byte{[]byte("test1"), []byte("test2")} + verifiedChains := [][]*x509.Certificate{} + err := VerifyCert(rawCerts, verifiedChains) + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/trafficlimit/trafficlimit.go b/yuanrong/pkg/common/faas_common/trafficlimit/trafficlimit.go new file mode 100644 index 0000000..5456104 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/trafficlimit/trafficlimit.go @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package trafficlimit - +package trafficlimit + +import ( + "math" + "sync" + + "golang.org/x/time/rate" +) + +const ( + // DefaultFunctionLimitRate default function limit rate for traffic limitation + DefaultFunctionLimitRate = 5000 + // TrafficRedundantRate limit redundancy rate for traffic limitation + TrafficRedundantRate = 1.1 + // DefaultAccessorInitCopies Initial number of copies + DefaultAccessorInitCopies = 3 +) + +var functionLimitRate int + +// LimiterContainer function key and function's Limiter +type LimiterContainer struct { + funcLimiterMap *sync.Map +} + +// FunctionLimiter - +type FunctionLimiter struct { + Quota int + Limiter *rate.Limiter +} + +var ( + // FunctionBuf - + funcLimiterContainer = &LimiterContainer{ + funcLimiterMap: &sync.Map{}, + } +) + +// RateLimiter rate limiter struct +type RateLimiter struct { + *rate.Limiter +} + +// Take return if a function request is allowed +func (r *RateLimiter) take() bool { + return r.Limiter.Allow() +} + +// SetFunctionLimitRate - +func SetFunctionLimitRate(limit int) { + if limit <= 0 { + limit = DefaultFunctionLimitRate + } + functionLimitRate = limit +} + +// FuncTrafficLimit is the main function of function traffic limitation +func FuncTrafficLimit(funcKey string) bool { + return funcLimiterContainer.funcTakeOneToken(funcKey) +} + +func (t *LimiterContainer) funcTakeOneToken(funcKey string) bool { + funcLimiter := t.getFunctionLimiter(funcKey) + if funcLimiter.Limiter == nil { + return true + } + return funcLimiter.Limiter.Allow() +} + +// getFunctionInfo to generator the function limiter +func (t *LimiterContainer) getFunctionLimiter(functionKey string) FunctionLimiter { + funcLimiter, ok := t.funcLimiterMap.Load(functionKey) + if !ok { + if functionLimitRate <= 0 { + functionLimitRate = DefaultFunctionLimitRate + } + limiter := FunctionLimiter{Limiter: t.getLimiter(functionLimitRate), Quota: DefaultFunctionLimitRate} + t.funcLimiterMap.Store(functionKey, limiter) + return limiter + } + return funcLimiter.(FunctionLimiter) +} + +func (t *LimiterContainer) getLimiter(quota int) *rate.Limiter { + limitRate := float64(quota) / DefaultAccessorInitCopies + limitBucketSize := int(math.Ceil(float64(quota)) / + DefaultAccessorInitCopies * TrafficRedundantRate) + tenantLimiter := rate.NewLimiter(rate.Limit(limitRate), limitBucketSize) + return tenantLimiter +} diff --git a/yuanrong/pkg/common/faas_common/trafficlimit/trafficlimit_test.go b/yuanrong/pkg/common/faas_common/trafficlimit/trafficlimit_test.go new file mode 100644 index 0000000..41f9da3 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/trafficlimit/trafficlimit_test.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package trafficlimit - +package trafficlimit + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestFuncTrafficLimit(t *testing.T) { + SetFunctionLimitRate(0) + functionName := "funcTest1" + allow := FuncTrafficLimit(functionName) + assert.Equal(t, allow, true) + + for i := 0; i < DefaultFunctionLimitRate; i++ { + FuncTrafficLimit(functionName) + } + allow = FuncTrafficLimit(functionName) + assert.Equal(t, allow, false) + + time.Sleep(5 * time.Second) + allow = FuncTrafficLimit(functionName) + assert.Equal(t, allow, true) +} diff --git a/yuanrong/pkg/common/faas_common/types/serve.go b/yuanrong/pkg/common/faas_common/types/serve.go new file mode 100644 index 0000000..087f794 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/types/serve.go @@ -0,0 +1,236 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import ( + "fmt" + "regexp" + + "yuanrong/pkg/common/faas_common/constant" +) + +const ( + defaultServeAppRuntime = "python3.9" + defaultServeAppTimeout = 900 + defaultServeAppCpu = 1000 + defaultServeAppMemory = 1024 + defaultServeAppConcurrentNum = 1000 +) + +// ServeDeploySchema - +type ServeDeploySchema struct { + Applications []ServeApplicationSchema `json:"applications"` +} + +// ServeApplicationSchema - +type ServeApplicationSchema struct { + Name string `json:"name"` + RoutePrefix string `json:"route_prefix"` + ImportPath string `json:"import_path"` + RuntimeEnv ServeRuntimeEnvSchema `json:"runtime_env"` + Deployments []ServeDeploymentSchema `json:"deployments"` +} + +// ServeDeploymentSchema - +type ServeDeploymentSchema struct { + Name string `json:"name"` + NumReplicas int64 `json:"num_replicas"` + HealthCheckPeriodS int64 `json:"health_check_period_s"` + HealthCheckTimeoutS int64 `json:"health_check_timeout_s"` +} + +// ServeRuntimeEnvSchema - +type ServeRuntimeEnvSchema struct { + Pip []string `json:"pip"` + WorkingDir string `json:"working_dir"` + EnvVars map[string]any `json:"env_vars"` +} + +// ServeFuncWithKeysAndFunctionMetaInfo - +type ServeFuncWithKeysAndFunctionMetaInfo struct { + FuncMetaKey string + InstanceMetaKey string + FuncMetaInfo *FunctionMetaInfo +} + +// Validate serve deploy schema by set of rules +func (s *ServeDeploySchema) Validate() error { + // 1. app name unique + appNameSet := make(map[string]struct{}) + for _, app := range s.Applications { + if _, ok := appNameSet[app.Name]; ok { + return fmt.Errorf("duplicated application name: %s", app.Name) + } + appNameSet[app.Name] = struct{}{} + } + // 2. app routes unique + appRouteSet := make(map[string]struct{}) + for _, app := range s.Applications { + if _, ok := appRouteSet[app.RoutePrefix]; ok { + return fmt.Errorf("duplicated application route prefix: %s", app.RoutePrefix) + } + appRouteSet[app.RoutePrefix] = struct{}{} + } + // 3. app name non empty + for _, app := range s.Applications { + if app.Name == "" { + return fmt.Errorf("application names must be nonempty") + } + } + return nil +} + +// ToFaaSFuncMetas - +func (s *ServeDeploySchema) ToFaaSFuncMetas() []*ServeFuncWithKeysAndFunctionMetaInfo { + var allMetas []*ServeFuncWithKeysAndFunctionMetaInfo + for _, a := range s.Applications { + // we don't really check it there are some repeated part? and just assume translate won't fail + for _, deploymentFuncMeta := range a.ToFaaSFuncMetas() { + allMetas = append(allMetas, deploymentFuncMeta) + } + } + return allMetas +} + +// ToFaaSFuncMetas - +func (s *ServeApplicationSchema) ToFaaSFuncMetas() []*ServeFuncWithKeysAndFunctionMetaInfo { + var allMetas []*ServeFuncWithKeysAndFunctionMetaInfo + for _, d := range s.Deployments { + meta := d.ToFaaSFuncMeta(s) + allMetas = append(allMetas, meta) + } + return allMetas +} + +// ToFaaSFuncMeta - +func (s *ServeDeploymentSchema) ToFaaSFuncMeta( + belongedApp *ServeApplicationSchema) *ServeFuncWithKeysAndFunctionMetaInfo { + faasFuncUrn := NewServeFunctionKeyWithDefault() + faasFuncUrn.AppName = belongedApp.Name + faasFuncUrn.DeploymentName = s.Name + + // make a copied app to make it contains only this deployment info + copiedApp := *belongedApp + copiedApp.Deployments = []ServeDeploymentSchema{*s} + + return &ServeFuncWithKeysAndFunctionMetaInfo{ + FuncMetaKey: faasFuncUrn.ToFuncMetaKey(), + InstanceMetaKey: faasFuncUrn.ToInstancesMetaKey(), + FuncMetaInfo: &FunctionMetaInfo{ + FuncMetaData: FuncMetaData{ + Name: faasFuncUrn.DeploymentName, + Runtime: defaultServeAppRuntime, + Timeout: defaultServeAppTimeout, + Version: faasFuncUrn.Version, + FunctionURN: faasFuncUrn.ToFaasFunctionUrn(), + TenantID: faasFuncUrn.TenantID, + FunctionVersionURN: faasFuncUrn.ToFaasFunctionVersionUrn(), + FuncName: faasFuncUrn.DeploymentName, + BusinessType: constant.BusinessTypeServe, + }, + ResourceMetaData: ResourceMetaData{ + CPU: defaultServeAppCpu, + Memory: defaultServeAppMemory, + }, + InstanceMetaData: InstanceMetaData{ + MaxInstance: s.NumReplicas, + MinInstance: s.NumReplicas, + ConcurrentNum: defaultServeAppConcurrentNum, + IdleMode: false, + }, + ExtendedMetaData: ExtendedMetaData{ + ServeDeploySchema: ServeDeploySchema{ + Applications: []ServeApplicationSchema{ + copiedApp, + }, + }, + }, + }, + } +} + +const ( + defaultTenantID = "12345678901234561234567890123456" + defaultFuncVersion = "latest" + + faasMetaKey = constant.MetaFuncKey + instanceMetaKey = "/instances/business/yrk/cluster/cluster001/tenant/%s/function/%s/version/%s" + faasFuncURN6tuplePattern = "sn:cn:yrk:%s:function:%s" + faasFuncURN7tuplePattern = "sn:cn:yrk:%s:function:%s:%s" +) + +// ServeFunctionKey is a faas urn with necessary parts +type ServeFunctionKey struct { + TenantID string + AppName string + DeploymentName string + Version string +} + +// NewServeFunctionKeyWithDefault returns a struct with default values +func NewServeFunctionKeyWithDefault() *ServeFunctionKey { + return &ServeFunctionKey{ + TenantID: defaultTenantID, + Version: defaultFuncVersion, + } +} + +// ToFuncNameTriplet - 0@svc@func +func (f *ServeFunctionKey) ToFuncNameTriplet() string { + return fmt.Sprintf("0@%s@%s", f.AppName, f.DeploymentName) +} + +// ToFuncMetaKey - /sn/functions/business/yrk/tenant/12345678901234561234567890123456/function/0@svc@func/version/latest +func (f *ServeFunctionKey) ToFuncMetaKey() string { + return fmt.Sprintf(faasMetaKey, f.TenantID, f.ToFuncNameTriplet(), f.Version) +} + +// ToInstancesMetaKey - /instances/business/yrk/cluster/cluster001/tenant/125...346/function/0@svc@func/version/latest +func (f *ServeFunctionKey) ToInstancesMetaKey() string { + return fmt.Sprintf(instanceMetaKey, f.TenantID, f.ToFuncNameTriplet(), f.Version) +} + +// ToFaasFunctionUrn - sn:cn:yrk:12345678901234561234567890123456:function:0@service@function +func (f *ServeFunctionKey) ToFaasFunctionUrn() string { + return fmt.Sprintf(faasFuncURN6tuplePattern, f.TenantID, f.ToFuncNameTriplet()) +} + +// ToFaasFunctionVersionUrn - sn:cn:yrk:12345678901234561234567890123456:function:0@svc@func:latest +func (f *ServeFunctionKey) ToFaasFunctionVersionUrn() string { + return fmt.Sprintf(faasFuncURN7tuplePattern, f.TenantID, f.ToFuncNameTriplet(), f.Version) +} + +// FromFaasFunctionKey - 12345678901234561234567890123456/0@svc@func/latest +func (f *ServeFunctionKey) FromFaasFunctionKey(funcKey string) error { + const ( + serveFaasFuncKeyMatchesIdxTenantID = iota + 1 + serveFaasFuncKeyMatchesIdxAppName + serveFaasFuncKeyMatchesIdxDeploymentName + serveFaasFuncKeyMatchesIdxVersion + serveFaasFuncKeyMatchesIdxMax + ) + re := regexp.MustCompile(`^([a-zA-Z0-9]*)/.*@([^@]+)@([^/]+)/(.*)$`) + matches := re.FindStringSubmatch(funcKey) + if len(matches) < serveFaasFuncKeyMatchesIdxMax { + return fmt.Errorf("extract failed from %s", funcKey) + } + f.TenantID = matches[serveFaasFuncKeyMatchesIdxTenantID] + f.AppName = matches[serveFaasFuncKeyMatchesIdxAppName] + f.DeploymentName = matches[serveFaasFuncKeyMatchesIdxDeploymentName] + f.Version = matches[serveFaasFuncKeyMatchesIdxVersion] + return nil +} diff --git a/yuanrong/pkg/common/faas_common/types/serve_test.go b/yuanrong/pkg/common/faas_common/types/serve_test.go new file mode 100644 index 0000000..3b91f50 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/types/serve_test.go @@ -0,0 +1,198 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +import ( + "github.com/smartystreets/goconvey/convey" + "testing" +) + +func TestServeFunctionKeyTrans(t *testing.T) { + k := NewServeFunctionKeyWithDefault() + k.AppName = "svc" + k.DeploymentName = "func" + convey.Convey("Given a serve function key", t, func() { + convey.Convey("When trans to a func name triplet", func() { + convey.So(k.ToFuncNameTriplet(), convey.ShouldEqual, "0@svc@func") + }) + convey.Convey("When trans to a func meta key", func() { + convey.So(k.ToFuncMetaKey(), convey.ShouldEqual, + "/sn/functions/business/yrk/tenant/12345678901234561234567890123456/function/0@svc@func/version/latest") + }) + convey.Convey("When trans to a instance meta key", func() { + convey.So(k.ToInstancesMetaKey(), convey.ShouldEqual, + "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456/function/0@svc@func/version/latest") + }) + convey.Convey("When trans to a FaasFunctionUrn", func() { + convey.So(k.ToFaasFunctionUrn(), convey.ShouldEqual, + "sn:cn:yrk:12345678901234561234567890123456:function:0@svc@func") + }) + convey.Convey("When trans ToFaasFunctionVersionUrn", func() { + convey.So(k.ToFaasFunctionVersionUrn(), convey.ShouldEqual, + "sn:cn:yrk:12345678901234561234567890123456:function:0@svc@func:latest") + }) + }) +} + +func TestServeDeploySchema_ToFaaSFuncMetas(t *testing.T) { + convey.Convey("Test ServeDeploySchema ToFaaSFuncMetas", t, func() { + // Setup mock data + app1 := ServeApplicationSchema{ + Name: "app1", + RoutePrefix: "/app1", + ImportPath: "path1", + RuntimeEnv: ServeRuntimeEnvSchema{ + Pip: []string{"package1", "package2"}, + WorkingDir: "/app1", + EnvVars: map[string]any{"key1": "value1"}, + }, + Deployments: []ServeDeploymentSchema{ + { + Name: "deployment1", + NumReplicas: 2, + HealthCheckPeriodS: 30, + HealthCheckTimeoutS: 10, + }, + }, + } + + serveDeploy := ServeDeploySchema{ + Applications: []ServeApplicationSchema{app1}, + } + + convey.Convey("It should return correct faas function metas", func() { + result := serveDeploy.ToFaaSFuncMetas() + convey.So(len(result), convey.ShouldBeGreaterThan, 0) + convey.So(result[0].FuncMetaKey, convey.ShouldNotBeEmpty) + }) + }) +} + +func TestServeFunctionKey(t *testing.T) { + convey.Convey("Test FromFaasFunctionKey", t, func() { + convey.Convey("It should return correct faas function metas", func() { + key := "12345678901234561234567890123456/0@svc@func/latest" + sfk := ServeFunctionKey{} + err := sfk.FromFaasFunctionKey(key) + convey.So(err, convey.ShouldBeNil) + convey.So(sfk.Version, convey.ShouldEqual, "latest") + convey.So(sfk.AppName, convey.ShouldEqual, "svc") + convey.So(sfk.DeploymentName, convey.ShouldEqual, "func") + convey.So(sfk.TenantID, convey.ShouldEqual, "12345678901234561234567890123456") + }) + convey.Convey("It should return incorrect faas function metas", func() { + key := "12345678901234561234567890123456/0@svc@func" + sfk := ServeFunctionKey{} + err := sfk.FromFaasFunctionKey(key) + convey.So(err, convey.ShouldNotBeNil) + }) + }) + + convey.Convey("Test FaasKey Test", t, func() { + convey.Convey("test default faas key", func() { + sfk := NewServeFunctionKeyWithDefault() + convey.So(sfk.TenantID, convey.ShouldEqual, defaultTenantID) + convey.So(sfk.Version, convey.ShouldEqual, defaultFuncVersion) + }) + + convey.Convey("test convert", func() { + sfk := NewServeFunctionKeyWithDefault() + sfk.AppName = "svc" + sfk.DeploymentName = "func" + + convey.So(sfk.ToFuncNameTriplet(), + convey.ShouldEqual, + "0@svc@func") + convey.So(sfk.ToFuncMetaKey(), + convey.ShouldEqual, + "/sn/functions/business/yrk/tenant/12345678901234561234567890123456/function/0@svc@func/version/latest") + convey.So(sfk.ToInstancesMetaKey(), + convey.ShouldEqual, + "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456/function/0@svc@func/version/latest") + convey.So(sfk.ToFaasFunctionUrn(), + convey.ShouldEqual, + "sn:cn:yrk:12345678901234561234567890123456:function:0@svc@func") + convey.So(sfk.ToFaasFunctionVersionUrn(), + convey.ShouldEqual, + "sn:cn:yrk:12345678901234561234567890123456:function:0@svc@func:latest") + }) + + convey.Convey("It should return incorrect faas function metas", func() { + sfk := NewServeFunctionKeyWithDefault() + convey.So(sfk.TenantID, convey.ShouldEqual, defaultTenantID) + convey.So(sfk.Version, convey.ShouldEqual, defaultFuncVersion) + }) + }) +} + +func TestServeDeploySchemaValidate(t *testing.T) { + convey.Convey("Test Validate", t, func() { + sds := ServeDeploySchema{ + Applications: []ServeApplicationSchema{ + { + Name: "app1", + RoutePrefix: "/app1", + ImportPath: "path1", + RuntimeEnv: ServeRuntimeEnvSchema{ + Pip: []string{"package1", "package2"}, + WorkingDir: "/app1", + EnvVars: map[string]any{"key1": "value1"}, + }, + Deployments: []ServeDeploymentSchema{ + { + Name: "deployment1", + NumReplicas: 2, + HealthCheckPeriodS: 30, + HealthCheckTimeoutS: 10, + }, + }, + }, + }} + convey.Convey("on repeated app name", func() { + sdsOther := sds + app0 := sdsOther.Applications[0] + sdsOther.Applications = append(sdsOther.Applications, app0) + + err := sdsOther.Validate() + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("on repeated route prefix", func() { + sdsOther := sds + app0 := sdsOther.Applications[0] + app0.Name = "othername" + sdsOther.Applications = append(sdsOther.Applications, app0) + + err := sdsOther.Validate() + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("on empty app name", func() { + sdsOther := sds + app0 := sdsOther.Applications[0] + app0.Name = "" + app0.RoutePrefix = "/other" + sdsOther.Applications = append(sdsOther.Applications, app0) + + err := sdsOther.Validate() + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("ok", func() { + err := sds.Validate() + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/types/types.go b/yuanrong/pkg/common/faas_common/types/types.go new file mode 100644 index 0000000..4661164 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/types/types.go @@ -0,0 +1,895 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +import ( + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// HTTPResponse is general http response +type HTTPResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// InnerInstanceData is the function instance data stored in ETCD +type InnerInstanceData struct { + IP string `json:"ip"` + Port string `json:"port"` + Status string `json:"status"` + P2pPort string `json:"p2pPort"` + GrpcPort string `json:"grpcPort,omitempty"` + NodeIP string `json:"nodeIP,omitempty"` + NodePort string `json:"nodePort,omitempty"` + NodeName string `json:"nodeName,omitempty"` + NodeID string `json:"nodeID,omitempty"` + Applier string `json:"applier,omitempty"` // silimar to OwnerIP + OwnerIP string `json:"ownerIP,omitempty"` + FuncSig string `json:"functionSignature,omitempty"` + Reserved bool `json:"reserved,omitempty"` + CPU int64 `json:"cpu,omitempty"` + Memory int64 `json:"memory,omitempty"` + GroupID string `json:"groupID,omitempty"` + StackID string `json:"stackID,omitempty"` + CustomResources map[string]int64 `json:"customResources,omitempty" valid:"optional"` +} + +// LogTankService - +type LogTankService struct { + GroupID string `json:"logGroupId" valid:",optional"` + StreamID string `json:"logStreamId" valid:",optional"` +} + +// TraceService - +type TraceService struct { + TraceAK string `json:"tracing_ak" valid:",optional"` + TraceSK string `json:"tracing_sk" valid:",optional"` + ProjectName string `json:"project_name" valid:",optional"` +} + +// Initializer include initializer handler and timeout +type Initializer struct { + Handler string `json:"initializer_handler" valid:",optional"` + Timeout int64 `json:"initializer_timeout" valid:",optional"` +} + +// FuncMountConfig function mount config +type FuncMountConfig struct { + FuncMountUser FuncMountUser `json:"mount_user" valid:",optional"` + FuncMounts []FuncMount `json:"func_mounts" valid:",optional"` +} + +// FuncMountUser function mount user +type FuncMountUser struct { + UserID int `json:"user_id" valid:",optional"` + GroupID int `json:"user_group_id" valid:",optional"` +} + +// FuncMount function mount +type FuncMount struct { + MountType string `json:"mount_type" valid:",optional"` + MountResource string `json:"mount_resource" valid:",optional"` + MountSharePath string `json:"mount_share_path" valid:",optional"` + LocalMountPath string `json:"local_mount_path" valid:",optional"` + Status string `json:"status" valid:",optional"` +} + +// Role include x_role and app_x_role +type Role struct { + XRole string `json:"xrole" valid:",optional"` + AppXRole string `json:"app_xrole" valid:",optional"` +} + +// FunctionDeploymentSpec define function deployment spec +type FunctionDeploymentSpec struct { + BucketID string `json:"bucket_id"` + ObjectID string `json:"object_id"` + Layers string `json:"layers"` + DeployDir string `json:"deploydir"` +} + +// InstanceResource describes the cpu and memory info of an instance +type InstanceResource struct { + CPU string `json:"cpu"` + Memory string `json:"memory"` + CustomResources map[string]int64 `json:"customresources"` +} + +// Worker define a worker +type Worker struct { + Instances []*Instance `json:"instances"` + FunctionName string `json:"functionname"` + FunctionVersion string `json:"functionversion"` + Tenant string `json:"tenant"` + Business string `json:"business"` +} + +// Instance define a instance +type Instance struct { + IP string `json:"ip"` + Port string `json:"port"` + GrpcPort string `json:"grpcPort"` + InstanceID string `json:"instanceID,omitempty"` + DeployedIP string `json:"deployed_ip"` + DeployedNode string `json:"deployed_node"` + DeployedNodeID string `json:"deployed_node_id"` + TenantID string `json:"tenant_id"` +} + +// InstanceCreationRequest is used to create instance +type InstanceCreationRequest struct { + LogicInstanceID string `json:"logicInstanceID"` + FuncName string `json:"functionName"` + Applier string `json:"applier"` + DeployNode string `json:"deployNode"` + Business string `json:"business"` + TenantID string `json:"tenantID"` + Version string `json:"version"` + OwnerIP string `json:"ownerIP"` + TraceID string `json:"traceID"` + TriggerFlag string `json:"triggerFlag"` + VersionUrn string `json:"versionUrn"` + CPU int64 `json:"cpu"` + Memory int64 `json:"memory"` + GroupID string `json:"groupID"` + StackID string `json:"stackID"` + CustomResources map[string]int64 `json:"customResources,omitempty" valid:"optional"` +} + +// InstanceCreationSuccessResponse is the struct returned by workermanager upon successful instance creation +type InstanceCreationSuccessResponse struct { + HTTPResponse + Worker *Worker `json:"worker"` + Instance *Instance `json:"instance"` +} + +// InstanceDeletionRequest is used to delete instance +type InstanceDeletionRequest struct { + InstanceID string `json:"instanceID"` + FuncName string `json:"functionName"` + FuncVersion string `json:"functionVersion"` + TenantID string `json:"tenantID"` + BusinessID string `json:"businessID"` + Applier string `json:"applier"` + Force bool `json:"force"` +} + +// InstanceDeletionResponse is the struct returned by workermanager upon successful instance deletion +type InstanceDeletionResponse struct { + HTTPResponse + Reserved bool `json:"reserved"` +} + +// HookArgs keeps args of hook +type HookArgs struct { + FuncArgs []byte // Call() request in worker + SrcTenant string + DstTenant string + StateID string + LogType string + StateKey string // for trigger state call + FunctionVersion string // for trigger state call + ExternalRequest bool // for trigger state call + ServiceID string + TraceID string + InvokeType string +} + +// ResourceStack stores properties of resource stack +type ResourceStack struct { + StackID string `json:"id" valid:"required"` + CPU int64 `json:"cpu" valid:"required"` + Mem int64 `json:"mem" valid:"required"` + CustomResources map[string]int64 `json:"customResources,omitempty" valid:"optional"` +} + +// ResourceGroup stores properties of resource group +type ResourceGroup struct { + GroupID string `json:"id" valid:"required"` + DeployOption string `json:"deployOption" valid:"required"` + GroupState string `json:"groupState" valid:"required"` + ResourceStacks []ResourceStack `json:"resourceStacks" valid:"required"` + ScheduledStacks map[string][]ResourceStack `json:"scheduledStacks,omitempty" valid:"optional"` +} + +// AffinityInfo is data affinity information +type AffinityInfo struct { + AffinityRequest AffinityRequest + AffinityNode string // if AffinityNode is not empty, the affinity node has been calculated + NeedToForward bool +} + +// AffinityRequest is affinity request parameter +type AffinityRequest struct { + Strategy string `json:"strategy"` + ObjectIDs []string `json:"object_ids"` +} + +// GroupInfo stores groupID and stackID +type GroupInfo struct { + GroupID string `json:"groupID"` + StackID string `json:"stackID"` +} + +// InvokeOption contains invoke options +type InvokeOption struct { + AffinityRequest AffinityRequest + GroupInfo GroupInfo + ResourceMetaData map[string]float32 +} + +// ScheduleConfig defines schedule config +type ScheduleConfig struct { + Policy int `json:"policy" valid:"optional"` + ForwardScheduleFirst bool `json:"forwardScheduleResourceNotEnough" valid:"optional"` + SleepingMemThreshold float32 `json:"sleepingMemoryThreshold" valid:"optional"` + SelectInstanceToSleepingPolicy string `json:"selectInstanceToSleepingPolicy" valid:"optional"` +} + +// MetricsData shows the quantities of a specific resource +type MetricsData struct { + TotalResource float32 `json:"totalResource"` + InUseResource float32 `json:"inUseResource"` +} + +// ResourceMetrics contains several resources' MetricsData +type ResourceMetrics map[string]MetricsData + +// WorkerMetrics stores metrics used for scheduler +type WorkerMetrics struct { + SystemResources ResourceMetrics + // key levels: functionUrn instanceID + FunctionResources map[string]map[string]ResourceMetrics +} + +// InnerWorkerData is the worker data stored in ETCD +type InnerWorkerData struct { + IP string `json:"ip"` + Port string `json:"port"` + NodeIP string `json:"nodeIP"` + P2pPort string `json:"p2pPort"` + NodeName string `json:"nodeName"` + NodeID string `json:"nodeID"` + WorkerAgentID string `json:"workerAgentID"` + AllocatableCPU int64 `json:"allocatableCPU"` + AllocatableMemory int64 `json:"allocatableMemory"` + AllocatableCustomResource map[string]int64 `json:"allocatableCustomResource"` +} + +// TerminateRequest sent from worker manager to worker to delete function instance +type TerminateRequest struct { + RuntimeID string `json:"runtime_id"` + FuncName string `json:"function_name"` + FuncVersion string `json:"function_version"` + TenantID string `json:"tenant_id"` + BusinessID string `json:"business_id" valid:"optional"` +} + +// UserAgency define AK/SK of user's agency +type UserAgency struct { + AccessKey string `json:"accessKey"` + SecretKey string `json:"secretKey"` + Token string `json:"token"` + SecurityAk string `json:"securityAk"` + SecuritySk string `json:"securitySk"` + SecurityToken string `json:"securityToken"` +} + +// CustomHealthCheck custom health check +type CustomHealthCheck struct { + TimeoutSeconds int `json:"timeoutSeconds" valid:",optional"` + PeriodSeconds int `json:"periodSeconds" valid:",optional"` + FailureThreshold int `json:"failureThreshold" valid:",optional"` +} + +// FuncCode include function code file and link info +type FuncCode struct { + File string `json:"file" valid:",optional"` + Link string `json:"link" valid:",optional"` +} + +// StrategyConfig - +type StrategyConfig struct { + Concurrency int `json:"concurrency" valid:",optional"` +} + +// FuncSpec contains specifications of a function +type FuncSpec struct { + ETCDType string `json:"-"` + FunctionKey string `json:"-"` + FuncMetaSignature string `json:"-"` + FuncMetaData FuncMetaData `json:"funcMetaData" valid:",optional"` + S3MetaData S3MetaData `json:"s3MetaData" valid:",optional"` + CodeMetaData CodeMetaData `json:"codeMetaData" valid:",optional"` + EnvMetaData EnvMetaData `json:"envMetaData" valid:",optional"` + StsMetaData StsMetaData `json:"stsMetaData" valid:",optional"` + ResourceMetaData ResourceMetaData `json:"resourceMetaData" valid:",optional"` + InstanceMetaData InstanceMetaData `json:"instanceMetaData" valid:",optional"` + ExtendedMetaData ExtendedMetaData `json:"extendedMetaData" valid:",optional"` +} + +// FunctionMetaInfo define function meta info for FunctionGraph +type FunctionMetaInfo struct { + FuncMetaData FuncMetaData `json:"funcMetaData" valid:",optional"` + S3MetaData S3MetaData `json:"s3MetaData" valid:",optional"` + CodeMetaData CodeMetaData `json:"codeMetaData" valid:",optional"` + EnvMetaData EnvMetaData `json:"envMetaData" valid:",optional"` + StsMetaData StsMetaData `json:"stsMetaData" valid:",optional"` + ResourceMetaData ResourceMetaData `json:"resourceMetaData" valid:",optional"` + InstanceMetaData InstanceMetaData `json:"instanceMetaData" valid:",optional"` + ExtendedMetaData ExtendedMetaData `json:"extendedMetaData" valid:",optional"` +} + +// FuncMetaData define meta data of functions +type FuncMetaData struct { + Layers []*Layer `json:"layers" valid:",optional"` + Name string `json:"name"` + FunctionDescription string `json:"description" valid:"stringlength(1|1024)"` + FunctionURN string `json:"functionUrn"` + TenantID string `json:"tenantId"` + Tags map[string]string `json:"tags" valid:",optional"` + FunctionUpdateTime string `json:"functionUpdateTime" valid:",optional"` + FunctionVersionURN string `json:"functionVersionUrn"` + RevisionID string `json:"revisionId" valid:"stringlength(1|20),optional"` + CodeSize int `json:"codeSize" valid:"int"` + CodeSha512 string `json:"codeSha512" valid:"stringlength(1|128),optional"` + Handler string `json:"handler" valid:"stringlength(1|255)"` + Runtime string `json:"runtime" valid:"stringlength(1|63)"` + Timeout int64 `json:"timeout" valid:"required"` + Version string `json:"version" valid:"stringlength(1|32)"` + DeadLetterConfig string `json:"deadLetterConfig" valid:"stringlength(1|255)"` + BusinessID string `json:"businessId" valid:"stringlength(1|32)"` + FunctionType string `json:"functionType" valid:",optional"` + FuncID string `json:"func_id" valid:",optional"` + FuncName string `json:"func_name" valid:",optional"` + DomainID string `json:"domain_id" valid:",optional"` + ProjectName string `json:"project_name" valid:",optional"` + Service string `json:"service" valid:",optional"` + Dependencies string `json:"dependencies" valid:",optional"` + EnableCloudDebug string `json:"enable_cloud_debug" valid:",optional"` + IsStatefulFunction bool `json:"isStatefulFunction" valid:"optional"` + IsBridgeFunction bool `json:"isBridgeFunction" valid:"optional"` + IsStreamEnable bool `json:"isStreamEnable" valid:"optional"` + Type string `json:"type" valid:"optional"` + EnableAuthInHeader bool `json:"enable_auth_in_header" valid:"optional"` + DNSDomainCfg []DNSDomainInfo `json:"dns_domain_cfg" valid:",optional"` + VPCTriggerImage string `json:"vpcTriggerImage" valid:",optional"` + StateConfig StateConfig `json:"stateConfig" valid:",optional"` + BusinessType string `json:"businessType" valid:"optional"` +} + +// StateConfig ConsistentWithInstance- The lifecycle is consistent with that of the instance. +// Independent - The lifecycle is independent of instances. +type StateConfig struct { + LifeCycle string `json:"lifeCycle"` +} + +// S3MetaData define meta function info for OBS +type S3MetaData struct { + AppID string `json:"appId" valid:"stringlength(1|128),optional"` + BucketID string `json:"bucketId" valid:"stringlength(1|255),optional"` + ObjectID string `json:"objectId" valid:"stringlength(1|255),optional"` + BucketURL string `json:"bucketUrl" valid:"url,optional"` + CodeType string `json:"code_type" valid:",optional"` + CodeURL string `json:"code_url" valid:",optional"` + CodeFileName string `json:"code_filename" valid:",optional"` + FuncCode FuncCode `json:"func_code" valid:",optional"` +} + +// LocalMetaData - +type LocalMetaData struct { + StorageType string `json:"storage_type" valid:",optional"` + CodePath string `json:"code_path" valid:",optional"` +} + +// CodeMetaData - +type CodeMetaData struct { + Sha512 string `json:"sha512" valid:",optional"` + LocalMetaData + S3MetaData +} + +// EnvMetaData - +type EnvMetaData struct { + Environment string `json:"environment"` + EncryptedUserData string `json:"encrypted_user_data"` + EnvKey string `json:"envKey" valid:",optional"` + CryptoAlgorithm string `json:"cryptoAlgorithm" valid:",optional"` +} + +// StsMetaData define sts info of functions +type StsMetaData struct { + EnableSts bool `json:"enableSts"` + ServiceName string `json:"serviceName,omitempty"` + MicroService string `json:"microService,omitempty"` + SensitiveConfigs map[string]string `json:"sensitiveConfigs,omitempty"` + StsCertConfig map[string]string `json:"stsCertConfig,omitempty"` +} + +// ResourceMetaData include resource data such as cpu and memory +type ResourceMetaData struct { + CPU int64 `json:"cpu"` + Memory int64 `json:"memory"` + GpuMemory int64 `json:"gpu_memory"` + EnableDynamicMemory bool `json:"enable_dynamic_memory" valid:",optional"` + CustomResources string `json:"customResources" valid:",optional"` + EnableTmpExpansion bool `json:"enable_tmp_expansion" valid:",optional"` + EphemeralStorage int `json:"ephemeral_storage" valid:"int,optional"` + CustomResourcesSpec string `json:"CustomResourcesSpec" valid:",optional"` +} + +// InstanceMetaData define instance meta data of FG functions +type InstanceMetaData struct { + MaxInstance int64 `json:"maxInstance" valid:",optional"` + MinInstance int64 `json:"minInstance" valid:",optional"` + ConcurrentNum int `json:"concurrentNum" valid:",optional"` + DiskLimit int64 `json:"diskLimit" valid:",optional"` + InstanceType string `json:"instanceType" valid:",optional"` + SchedulePolicy string `json:"schedulePolicy" valid:",optional"` + ScalePolicy string `json:"scalePolicy" valid:",optional"` + IdleMode bool `json:"idleMode" valid:",optional"` + PoolLabel string `json:"poolLabel"` + PoolID string `json:"poolId" valid:",optional"` +} + +// ExtendedMetaData define external meta data of functions +type ExtendedMetaData struct { + ImageName string `json:"image_name" valid:",optional"` + Role Role `json:"role" valid:",optional"` + VpcConfig *VpcConfig `json:"func_vpc" valid:",optional"` + EndpointTenantVpc *VpcConfig `json:"endpoint_tenant_vpc" valid:",optional"` + FuncMountConfig *FuncMountConfig `json:"mount_config" valid:",optional"` + StrategyConfig StrategyConfig `json:"strategy_config" valid:",optional"` + ExtendConfig string `json:"extend_config" valid:",optional"` + Initializer Initializer `json:"initializer" valid:",optional"` + Heartbeat Heartbeat `json:"heartbeat" valid:",optional"` + EnterpriseProjectID string `json:"enterprise_project_id" valid:",optional"` + LogTankService LogTankService `json:"log_tank_service" valid:",optional"` + TraceService TraceService `json:"tracing_config" valid:",optional"` + CustomContainerConfig CustomContainerConfig `json:"custom_container_config" valid:",optional"` + AsyncConfigLoaded bool `json:"async_config_loaded" valid:",optional"` + RestoreHook RestoreHook `json:"restore_hook,omitempty" valid:",optional"` + NetworkController NetworkController `json:"network_controller" valid:",optional"` + UserAgency UserAgency `json:"user_agency" valid:",optional"` + CustomFilebeatConfig CustomFilebeatConfig `json:"custom_filebeat_config"` + CustomHealthCheck CustomHealthCheck `json:"custom_health_check" valid:",optional"` + DynamicConfig DynamicConfigEvent `json:"dynamic_config" valid:",optional"` + CustomGracefulShutdown CustomGracefulShutdown `json:"runtime_graceful_shutdown"` + PreStop PreStop `json:"pre_stop"` + RaspConfig RaspConfig `json:"rasp_config"` + ServeDeploySchema ServeDeploySchema `json:"serveDeploySchema" valid:"optional"` +} + +// CustomGracefulShutdown define the option of custom container's runtime graceful shutdown +type CustomGracefulShutdown struct { + MaxShutdownTimeout int `json:"maxShutdownTimeout"` +} + +// PreStop include pre_stop handler and timeout +type PreStop struct { + Handler string `json:"pre_stop_handler" valid:",optional"` + Timeout int `json:"pre_stop_timeout" valid:",optional"` +} + +// DynamicConfigEvent dynamic config etcd event +type DynamicConfigEvent struct { + Enabled bool `json:"enabled"` // use for signature + UpdateTime string `json:"update_time"` + ConfigContent []KV `json:"config_content"` +} + +// KV config key and value +type KV struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// Heartbeat define user custom heartbeat function config +type Heartbeat struct { + // Handler define heartbeat function entry + Handler string `json:"heartbeat_handler" valid:",optional"` +} + +// CustomContainerConfig contains the metadata for custom container +type CustomContainerConfig struct { + ControlPath string `json:"control_path" valid:",optional"` + Image string `json:"image" valid:",optional"` + Command []string `json:"command" valid:",optional"` + Args []string `json:"args" valid:",optional"` + WorkingDir string `json:"working_dir" valid:",optional"` + UID int `json:"uid" valid:",optional"` + GID int `json:"gid" valid:",optional"` +} + +// CustomFilebeatConfig custom filebeat config +type CustomFilebeatConfig struct { + SidecarConfigInfo *SidecarConfigInfo `json:"sidecarConfigInfo"` + CPU int64 `json:"cpu"` + Memory int64 `json:"memory"` + Version string `json:"version"` + ImageAddress string `json:"imageAddress"` +} + +// RaspConfig rasp config key and value +type RaspConfig struct { + InitImage string `json:"init-image"` + RaspImage string `json:"rasp-image"` + RaspServerIP string `json:"rasp-server-ip"` + RaspServerPort string `json:"rasp-server-port"` + Envs []KV `json:"envs"` +} + +// SidecarConfigInfo sidecat config info +type SidecarConfigInfo struct { + ConfigFiles []CustomLogConfigFile `json:"configFiles"` + LiveNessShell string `json:"livenessShell"` + ReadNessShell string `json:"readnessShell"` + PreStopCommands string `json:"preStopCommands"` +} + +// CustomLogConfigFile custom log config file +type CustomLogConfigFile struct { + Path string `json:"path"` + Data string `json:"data"` + Secret bool `json:"secret"` +} + +// RestoreHook include restorehook handler and timeout +type RestoreHook struct { + Handler string `json:"restore_hook_handler,omitempty" valid:",optional"` + Timeout int64 `json:"restore_hook_timeout,omitempty" valid:",optional"` +} + +// NetworkController contains some special network settings +type NetworkController struct { + DisablePublicNetwork bool `json:"disable_public_network" valid:",optional"` + TriggerAccessVpcs []VpcInfo `json:"trigger_access_vpcs" valid:",optional"` +} + +// VpcInfo contains the information of VPC access restriction +type VpcInfo struct { + VpcName string `json:"vpc_name,omitempty"` + VpcID string `json:"vpc_id,omitempty"` +} + +// VpcConfig include info of function vpc +type VpcConfig struct { + ID string `json:"id,omitempty"` + DomainID string `json:"domain_id,omitempty"` + Namespace string `json:"namespace,omitempty"` + VpcName string `json:"vpc_name,omitempty"` + VpcID string `json:"vpc_id,omitempty"` + SubnetName string `json:"subnet_name,omitempty"` + SubnetID string `json:"subnet_id,omitempty"` + TenantCidr string `json:"tenant_cidr,omitempty"` + HostVMCidr string `json:"host_vm_cidr,omitempty"` + Gateway string `json:"gateway,omitempty"` + Xrole string `json:"xrole,omitempty"` +} + +// Layer define layer info +type Layer struct { + BucketURL string `json:"bucketUrl" valid:"url,optional"` + ObjectID string `json:"objectId" valid:"stringlength(1|255),optional"` + BucketID string `json:"bucketId" valid:"stringlength(1|255),optional"` + AppID string `json:"appId" valid:"stringlength(1|128),optional"` + ETag string `json:"etag" valid:"optional"` + Link string `json:"link" valid:"optional"` + Name string `json:"name" valid:",optional"` + Sha256 string `json:"sha256" valid:"optional"` + DependencyType string `json:"dependencyType" valid:",optional"` +} + +// DNSDomainInfo dns domain info +type DNSDomainInfo struct { + ID string `json:"id"` + DomainName string `json:"domain_name"` + Type string `json:"type" valid:",optional"` + ZoneType string `json:"zone_type" valid:",optional"` +} + +// DataSystemConfig data system client config +type DataSystemConfig struct { + TimeoutMs int `json:"timeoutMs" validate:"required"` + Clusters []string `json:"clusters"` +} + +// XiangYunFourConfig - +type XiangYunFourConfig struct { + Site string `json:"site"` + TenantID string `json:"tenantID"` + ApplicationID string `json:"applicationID"` + ServiceID string `json:"serviceID"` +} + +// MemoryControlConfig Memory use control config +type MemoryControlConfig struct { + LowerMemoryPercent float64 `json:"lowerMemoryPercent" valid:",optional"` + HighMemoryPercent float64 `json:"highMemoryPercent" valid:",optional"` + StatefulHighMemPercent float64 `json:"statefulHighMemoryPercent" valid:",optional"` + BodyThreshold uint64 `json:"bodyThreshold" valid:",optional"` + MemDetectIntervalMs int `json:"memDetectIntervalMs" valid:",optional"` +} + +// InstanceStatus Instance status, controlled by the kernel +type InstanceStatus struct { + Code int32 `json:"code" validate:"required"` + Msg string `json:"msg" validate:"required"` + Type int32 `json:"type" validate:"optional"` + ExitCode int32 `json:"exitCode" validate:"optional"` + ErrorCode int32 `json:"errCode" validate:"optional"` +} + +// PodResourceInfo describe actual resource info of pod +type PodResourceInfo struct { + Worker ResourceConfig `json:"worker,omitempty"` + Runtime ResourceConfig `json:"runtime,omitempty"` +} + +// ResourceConfig sub-struct of FuncInstanceInfo +type ResourceConfig struct { + CPULimit int64 `json:"cpuLimit" valid:",optional"` // unit: milli-cores(m) + CPURequest int64 `json:"cpuRequest" valid:",optional"` + MemoryLimit int64 `json:"memoryLimit" valid:",optional"` // unit: byte + MemoryRequest int64 `json:"memoryRequest" valid:",optional"` +} + +// Extensions - +type Extensions struct { + Source string `json:"source"` + CreateTimestamp string `json:"createTimestamp"` + UpdateTimestamp string `json:"updateTimestamp"` + PID string `json:"pid"` + PodName string `json:"podName"` + PodNamespace string `json:"podNamespace"` + PodDeploymentName string `json:"podDeploymentName"` +} + +// InstanceSpecification contains specification of a instance in etcd +type InstanceSpecification struct { + InstanceID string `json:"instanceID" validate:"required"` + DataSystemHost string `json:"dataSystemHost" validate:"required"` + RequestID string `json:"requestID" valid:",optional"` + RuntimeID string `json:"runtimeID" valid:",optional"` + RuntimeAddress string `json:"runtimeAddress" valid:",optional"` + FunctionAgentID string `json:"functionAgentID" valid:",optional"` + FunctionProxyID string `json:"functionProxyID" valid:",optional"` + Function string `json:"function"` + RestartPolicy string `json:"restartPolicy" valid:",optional"` + Resources Resources `json:"resources"` + ActualUse Resources `json:"actualUse" valid:",optional"` + ScheduleOption ScheduleOption `json:"scheduleOption"` + CreateOptions map[string]string `json:"createOptions"` + Labels []string `json:"labels"` + StartTime string `json:"startTime"` + InstanceStatus InstanceStatus `json:"instanceStatus"` + JobID string `json:"jobID"` + SchedulerChain []string `json:"schedulerChain" valid:",optional"` + ParentID string `json:"parentID"` + DeployTimes int32 `json:"deployTimes"` + Extensions Extensions `json:"extensions" valid:",optional"` +} + +// InstanceSpecificationFG contains specification of instance in etcd for functionGraph +type InstanceSpecificationFG struct { + OwnerIP string `json:"ownerIP"` + CreationTime int `json:"creationTime"` + Applier string `json:"applier"` + NodeIP string `json:"nodeIP"` + NodePort string `json:"nodePort"` + InstanceIP string `json:"ip"` + InstancePort string `json:"port"` + CPU int `json:"cpu"` + Memory int `json:"memory"` + BusinessType string `json:"businessType"` + Resource PodResourceInfo `json:"resource,omitempty"` +} + +// Resources - +type Resources struct { + Resources map[string]Resource `json:"resources"` +} + +// Resource - +type Resource struct { + Name string `json:"name"` + Type ValueType `json:"type"` + Scalar ValueScalar `json:"scalar"` + Ranges ValueRanges `json:"ranges"` + Set ValueSet `json:"set"` + Runtime string `json:"runtime"` + Driver string `json:"driver"` + Disk DiskInfo `json:"disk"` +} + +// ValueType - +type ValueType int32 + +// ValueScalar - +type ValueScalar struct { + Value float64 `json:"value"` + Limit float64 `json:"limit"` +} + +// ValueRanges - +type ValueRanges struct { + Range []ValueRange `protobuf:"bytes,1,rep,name=range,proto3" json:"range,omitempty"` +} + +// ValueSet - +type ValueSet struct { + Items string `json:"items"` +} + +// ValueRange - +type ValueRange struct { + Begin uint64 `json:"begin"` + End uint64 `json:"end"` +} + +// DiskInfo - +type DiskInfo struct { + Volume Volume `json:"volume"` + Type string `json:"type"` + DevPath string `json:"devPath"` + MountPath string `json:"mountPath"` +} + +// Volume - +type Volume struct { + Mode int32 `json:"mode"` + SourceType int32 `json:"sourceType"` + HostPaths string `json:"hostPaths"` + ContainerPath string `json:"containerPath"` + ConfigMapPath string `json:"configMapPath"` + EmptyDir string `json:"emptyDir"` + ElaraPath string `json:"elaraPath"` +} + +// ScheduleOption - +type ScheduleOption struct { + SchedPolicyName string `json:"schedPolicyName"` + Priority int32 `json:"priority"` + Affinity Affinity `json:"affinity"` +} + +// Affinity - +type Affinity struct { + NodeAffinity NodeAffinity `json:"nodeAffinity"` + InstanceAffinity InstanceAffinity `json:"instanceAffinity"` + InstanceAntiAffinity InstanceAffinity `json:"instanceAntiAffinity"` +} + +// NodeAffinity - +type NodeAffinity struct { + Affinity map[string]string `json:"affinity"` +} + +// InstanceAffinity - +type InstanceAffinity struct { + Affinity map[string]string `json:"affinity"` +} + +// InstanceInfo the instance info which can be parsed from the etcd path, instanceName is used to hold a place in the +// hash ring while instanceID is used to invoke this instance +type InstanceInfo struct { + TenantID string + FunctionName string + Version string + InstanceName string `json:"instanceName"` + InstanceID string `json:"instanceId"` + Exclusivity string + Address string +} + +// InstanceResponse is the response returned by faas scheduler's CallHandler +type InstanceResponse struct { + InstanceAllocationInfo + ErrorCode int `json:"errorCode"` + ErrorMessage string `json:"errorMessage"` + SchedulerTime float64 `json:"schedulerTime"` +} + +// BatchInstanceResponse is the batch response returned by faas scheduler's CallHandler +type BatchInstanceResponse struct { + InstanceAllocSucceed map[string]InstanceAllocationSucceedInfo `json:"instanceAllocSucceed"` + InstanceAllocFailed map[string]InstanceAllocationFailedInfo `json:"instanceAllocFailed"` + LeaseInterval int64 `json:"leaseInterval"` + SchedulerTime float64 `json:"schedulerTime"` +} + +// RolloutResponse - +type RolloutResponse struct { + AllocRecord map[string][]string `json:"allocRecord"` + RegisterKey string `json:"registerKey"` + ErrorCode int `json:"errorCode"` + ErrorMessage string `json:"errorMessage"` +} + +// InstanceAllocationSucceedInfo is the response returned by faas scheduler's CallHandler +type InstanceAllocationSucceedInfo struct { + FuncKey string `json:"funcKey"` + FuncSig string `json:"funcSig"` + InstanceID string `json:"instanceID"` + ThreadID string `json:"threadID"` +} + +// InstanceAllocationFailedInfo contains err info for allocation failed info +type InstanceAllocationFailedInfo struct { + ErrorCode int `json:"errorCode"` + ErrorMessage string `json:"errorMessage"` +} + +// InstanceAllocationInfo contains instance router info and lease returned to function accessor +type InstanceAllocationInfo struct { + FuncKey string `json:"funcKey"` + FuncSig string `json:"funcSig"` + InstanceID string `json:"instanceID"` + ThreadID string `json:"threadID"` + InstanceIP string `json:"instanceIP"` + InstancePort string `json:"instancePort"` + NodeIP string `json:"nodeIP"` + NodePort string `json:"nodePort"` + LeaseInterval int64 `json:"leaseInterval"` + CPU int64 `json:"cpu"` + Memory int64 `json:"memory"` + ForceInvoke bool `json:"forceInvoke"` +} + +// ExtraParams for interface CreateInstance +type ExtraParams struct { + DesignatedInstanceID string + Label []string + Resources map[string]float64 + CustomResources map[string]float64 + CreateOpt map[string]string + CustomExtensions map[string]string + ScheduleAffinities []api.Affinity +} + +// NuwaRuntimeInfo contains ers workload info for function +type NuwaRuntimeInfo struct { + WisecloudRuntimeId string `json:"wisecloudRuntimeId"` + WisecloudSite string `json:"wisecloudSite"` + WisecloudTenantId string `json:"wisecloudTenantId"` + WisecloudApplicationId string `json:"wisecloudApplicationId"` + WisecloudServiceId string `json:"wisecloudServiceId"` + WisecloudEnvironmentId string `json:"wisecloudEnvironmentId"` + EnvLabel string `json:"envLabel"` +} + +// InstanceSessionConfig - +type InstanceSessionConfig struct { + SessionID string `json:"sessionID"` + SessionTTL int `json:"sessionTTL"` + Concurrency int `json:"concurrency"` +} + +// CallHandlerResponse is the response returned by faas manager's CallHandler +type CallHandlerResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// LeaseEvent - +type LeaseEvent struct { + Type string `json:"type"` + RemoteClientID string `json:"remoteClientId"` + Timestamp int64 `json:"timestamp"` + TraceID string `json:"traceId"` +} diff --git a/yuanrong/pkg/common/faas_common/urnutils/gadgets.go b/yuanrong/pkg/common/faas_common/urnutils/gadgets.go new file mode 100644 index 0000000..a7dfb73 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/urnutils/gadgets.go @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package urnutils contains URN element definitions and tools +package urnutils + +import ( + "strings" +) + +var ( + separator = "-" +) + +const ( + // ServiceIDPrefix is the prefix of the function with serviceID. + ServiceIDPrefix = "0" + + // DefaultSeparator is a character that separates functions and services. + DefaultSeparator = "-" + + // ServicePrefix is the prefix of the function with serviceID. + ServicePrefix = "0@" + + // TenantProductSplitStr separator between a tenant and a product + TenantProductSplitStr = "@" + + minEleSize = 3 +) + +// ComplexFuncName contains service ID and raw function name +type ComplexFuncName struct { + prefix string + ServiceID string + FuncName string +} + +// NewComplexFuncName - +func NewComplexFuncName(svcID, funcName string) *ComplexFuncName { + return &ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: svcID, + FuncName: funcName, + } +} + +// IsComplexFuncName - +func IsComplexFuncName(funcName string) bool { + return strings.Contains(funcName, separator) +} + +// ParseFrom parse ComplexFuncName from string +func (c *ComplexFuncName) ParseFrom(name string) *ComplexFuncName { + fields := strings.Split(name, separator) + if len(fields) < minEleSize || fields[0] != ServiceIDPrefix { + c.prefix = "" + c.ServiceID = "" + c.FuncName = name + return c + } + idx := 0 + c.prefix = fields[idx] + idx++ + c.ServiceID = fields[idx] + // $prefix$separator$ServiceID$separator$FuncName equals name + c.FuncName = name[(len(c.prefix) + len(separator) + len(c.ServiceID) + len(separator)):] + return c +} + +// String - +func (c *ComplexFuncName) String() string { + return strings.Join([]string{c.prefix, c.ServiceID, c.FuncName}, separator) +} + +// GetSvcIDWithPrefix get serviceID with prefix from function name +func (c *ComplexFuncName) GetSvcIDWithPrefix() string { + return c.prefix + separator + c.ServiceID +} + +// SetSeparator - +func SetSeparator(sep string) { + if sep != "" { + separator = sep + } +} diff --git a/yuanrong/pkg/common/faas_common/urnutils/gadgets_test.go b/yuanrong/pkg/common/faas_common/urnutils/gadgets_test.go new file mode 100644 index 0000000..3386531 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/urnutils/gadgets_test.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package urnutils contains URN element definitions and tools +package urnutils + +import ( + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +func TestComplexFuncName_GetSvcIDWithPrefix(t *testing.T) { + tests := []struct { + name string + fields ComplexFuncName + want string + }{ + { + name: "normal", + fields: ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: "absserviceid", + FuncName: "absFuncName", + }, + want: "0-absserviceid", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ComplexFuncName{ + prefix: tt.fields.prefix, + ServiceID: tt.fields.ServiceID, + FuncName: tt.fields.FuncName, + } + if got := c.GetSvcIDWithPrefix(); got != tt.want { + t.Errorf("GetSvcIDWithPrefix() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestComplexFuncName_ParseFrom(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + want *ComplexFuncName + }{ + { + name: "normal", + args: args{ + name: "0-absserviceid-absFunc-Name", + }, + want: &ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: "absserviceid", + FuncName: "absFunc-Name", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ComplexFuncName{} + if got := c.ParseFrom(tt.args.name); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseFrom() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestComplexFuncName_String(t *testing.T) { + tests := []struct { + name string + fields ComplexFuncName + want string + }{ + { + name: "normal", + fields: ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: "absserviceid", + FuncName: "absFunc-Name", + }, + want: "0-absserviceid-absFunc-Name", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ComplexFuncName{ + prefix: tt.fields.prefix, + ServiceID: tt.fields.ServiceID, + FuncName: tt.fields.FuncName, + } + if got := c.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewComplexFuncName(t *testing.T) { + type args struct { + svcID string + funcName string + } + tests := []struct { + name string + args args + want *ComplexFuncName + }{ + { + name: "normal", + args: args{ + svcID: "absserviceid", + funcName: "absFunc-Name", + }, + want: &ComplexFuncName{ + prefix: ServiceIDPrefix, + ServiceID: "absserviceid", + FuncName: "absFunc-Name", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewComplexFuncName(tt.args.svcID, tt.args.funcName); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewComplexFuncName() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSetSeparator(t *testing.T) { + SetSeparator("@") + assert.Equal(t, "@", separator) +} diff --git a/yuanrong/pkg/common/faas_common/urnutils/urn_utils.go b/yuanrong/pkg/common/faas_common/urnutils/urn_utils.go new file mode 100644 index 0000000..dde5562 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/urnutils/urn_utils.go @@ -0,0 +1,561 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package urnutils contains URN element definitions and tools +package urnutils + +import ( + "errors" + "fmt" + "net" + "os" + "regexp" + "strconv" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" +) + +var ( + once sync.Once + serverIP = "" +) + +const ( + funcNamePrefix = "0@default@" + shortFuncNameSplit = 1 + standardFuncNameSplit = 3 +) + +// example of function URN: :::::: +// Indices of elements in FunctionURN +const ( + // ProductIDIndex is the index of the product ID in a URN + ProductIDIndex = iota + // RegionIDIndex is the index of the region ID in a URN + RegionIDIndex + // BusinessIDIndex is the index of the business ID in a URN + BusinessIDIndex + // TenantIDIndex is the index of the tenant ID in a URN + TenantIDIndex + // FunctionSignIndex is the index of the product ID in a URN + FunctionSignIndex + // FunctionNameIndex is the index of the product name in a URN + FunctionNameIndex + // VersionIndex is the index of the version in a URN + VersionIndex + // URNLenWithVersion is the normal URN length with a version + URNLenWithVersion +) + +// An example of a function functionkey: // +const ( + // TenantIDIndexKey is the index of the tenant ID in a functionkey + TenantIDIndexKey = iota + // FunctionNameIndexKey is the index of the function name in a functionkey + FunctionNameIndexKey + // VersionIndexKey is the index of the version in a functionkey + VersionIndexKey +) + +const ( + // TenantMetadataTenantIndex is the index of the tenant ID in a tenantMetadataEtcdKey + TenantMetadataTenantIndex = 6 +) + +const ( + urnLenWithoutVersion = URNLenWithVersion - 1 + // URNSep is a URN separator of functions + URNSep = ":" + // FunctionKeySep is a functionkey separator of functions + FunctionKeySep = "/" + // DefaultURNProductID is the default product ID of a URN + DefaultURNProductID = "sn" + // DefaultURNRegion is the default region of a URN + DefaultURNRegion = "cn" + // DefaultURNFuncSign is the default function sign of a URN + DefaultURNFuncSign = "function" + defaultURNLayerSign = "layer" + anonymization = "****" + anonymizeLen = 3 + + // BranchAliasPrefix is used to remove "!" from aliasing rules at the begining of "!" + BranchAliasPrefix = 1 + // BranchAliasRule is an aliased rule that begins with an "!" + BranchAliasRule = "!" + functionNameStartIndex = 2 + // ServiceNameIndex is index of service name in urn + ServiceNameIndex = 1 + funcNameMinLen = 3 + // defaultFunctionMaxLen is max length of function name + defaultFunctionMaxLen = 128 +) + +// An example of a worker-manager URN: +// +// /sn/workers/business/iot/tenant/j0f4413f7b4b4c33be576d432f7ee085/function/functest/version/$latest +// /cn-north-1a/cn-north-1a-#-ws-j0f4413f7b-functest-faaslatest-deployment-55b5f9dcb7-r2dsv +const ( + // URNIndexZero URN index 0 + URNIndexZero = iota + // URNIndexOne URN index 1 + URNIndexOne + // URNIndexTwo URN index 2 + URNIndexTwo + // URNIndexThree URN index 3 + URNIndexThree + // URNIndexFour URN index 4 + URNIndexFour + // URNIndexFive URN index 5 + URNIndexFive + // URNIndexSix URN index 6 + URNIndexSix + // URNIndexSeven URN index 7 + URNIndexSeven + // URNIndexEight URN index 8 + URNIndexEight + // URNIndexNine URN index 9 + URNIndexNine + // URNIndexTen URN index 10 + URNIndexTen + // URNIndexEleven URN index 11 + URNIndexEleven + // URNIndexTwelve URN index 12 + URNIndexTwelve + // URNIndexThirteen URN index 13 + URNIndexThirteen +) +const ( + k8sLabelLen = 63 + otherStrLen = 4 + crHashMaxLen = 10 + versionManLen = 30 +) + +const ( + // OwnerReadWrite - + OwnerReadWrite = 416 // 640:rw- r-- --- + // DefaultMode - + DefaultMode = 420 // 644:rw- r-- r-- + // CertMode - + CertMode = 384 // 600:rw- --- --- +) + +var ( + functionGraphFuncNameRegexp = regexp.MustCompile("^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$") +) + +// FunctionURN contains elements of a product URN. It can expand to FunctionURN, LayerURN and WorkerURN +type FunctionURN struct { + ProductID string + RegionID string + BusinessID string + TenantID string + TypeSign string + FuncName string + FuncVersion string +} + +// String serializes elements of function URN struct to string +func (p *FunctionURN) String() string { + urn := fmt.Sprintf("%s:%s:%s:%s:%s:%s", p.ProductID, p.RegionID, + p.BusinessID, p.TenantID, p.TypeSign, p.FuncName) + if p.FuncVersion != "" { + return fmt.Sprintf("%s:%s", urn, p.FuncVersion) + } + return urn +} + +// ParseFrom parses elements from a function URN +func (p *FunctionURN) ParseFrom(urn string) error { + elements := strings.Split(urn, URNSep) + urnLen := len(elements) + if urnLen < urnLenWithoutVersion || urnLen > URNLenWithVersion { + return fmt.Errorf("failed to parse urn from: %s, invalid length: %d", urn, urnLen) + } + p.ProductID = elements[ProductIDIndex] + p.RegionID = elements[RegionIDIndex] + p.BusinessID = elements[BusinessIDIndex] + p.TenantID = elements[TenantIDIndex] + p.TypeSign = elements[FunctionSignIndex] + p.FuncName = elements[FunctionNameIndex] + if urnLen == URNLenWithVersion { + p.FuncVersion = elements[VersionIndex] + } + return nil +} + +// StringWithoutVersion return string without version +func (p *FunctionURN) StringWithoutVersion() string { + return fmt.Sprintf("%s:%s:%s:%s:%s:%s", p.ProductID, p.RegionID, + p.BusinessID, p.TenantID, p.TypeSign, p.FuncName) +} + +// GetFunctionInfo collects function information from a URN +func GetFunctionInfo(urn string) (FunctionURN, error) { + var parsedURN FunctionURN + if err := parsedURN.ParseFrom(urn); err != nil { + log.GetLogger().Errorf("error while parsing an URN: %s", err.Error()) + return FunctionURN{}, fmt.Errorf("parsing an URN error: %s", err) + } + return parsedURN, nil +} + +// GetFuncInfoWithVersion collects function information and distinguishes if the URN contains a version +func GetFuncInfoWithVersion(urn string) (FunctionURN, error) { + parsedURN, err := GetFunctionInfo(urn) + if err != nil { + return parsedURN, err + } + if parsedURN.FuncVersion == "" { + log.GetLogger().Errorf("incorrect URN length: %s", Anonymize(urn)) + return parsedURN, errors.New("incorrect URN length, no version") + } + return parsedURN, nil +} + +// ParseAliasURN is used to remove "!" from the beginning of the alias +func ParseAliasURN(aliasURN string) string { + elements := strings.Split(aliasURN, URNSep) + if len(elements) == URNLenWithVersion { + if strings.HasPrefix(elements[VersionIndex], BranchAliasRule) { + elements[VersionIndex] = elements[VersionIndex][BranchAliasPrefix:] + } + return strings.Join(elements, ":") + } + return aliasURN +} + +// GetAlias returns an alias +func (p *FunctionURN) GetAlias() string { + if p.FuncVersion == constant.DefaultURNVersion { + return "" + } + if _, err := strconv.Atoi(p.FuncVersion); err == nil { + return "" + } + return p.FuncVersion +} + +// GetAliasForFuncBranch returns an alias for function branch +func (p *FunctionURN) GetAliasForFuncBranch() string { + if strings.HasPrefix(p.FuncVersion, BranchAliasRule) { + // remove "!" from the beginning of the alias + return p.FuncVersion[BranchAliasPrefix:] + } + return "" +} + +// Valid check whether the self-verification function name complies with the specifications. +func (p *FunctionURN) Valid() error { + serviceID, functionName, err := GetFunctionNameAndServiceName(p.FuncName) + if err != nil { + log.GetLogger().Errorf("failed to get serviceID and functionName") + return err + } + if !(functionGraphFuncNameRegexp.MatchString(serviceID) || + functionGraphFuncNameRegexp.MatchString(functionName)) { + errmsg := "failed to match reg%s" + log.GetLogger().Errorf(errmsg, functionGraphFuncNameRegexp) + return fmt.Errorf(errmsg, functionGraphFuncNameRegexp) + } + if len(serviceID) > defaultFunctionMaxLen || len(functionName) > defaultFunctionMaxLen { + errmsg := "serviceID or functionName's len is out of range %d" + log.GetLogger().Errorf(errmsg, defaultFunctionMaxLen) + return fmt.Errorf(errmsg, defaultFunctionMaxLen) + } + return nil +} + +// GetFunctionNameAndServiceName returns serviceName and FunctionName +func GetFunctionNameAndServiceName(funcName string) (string, string, error) { + if strings.HasPrefix(funcName, ServiceIDPrefix) { + split := strings.Split(funcName, separator) + if len(split) < funcNameMinLen { + log.GetLogger().Errorf("incorrect function name length: %s", len(split)) + return "", "", errors.New("parsing a function name error") + } + return split[ServiceNameIndex], strings.Join(split[functionNameStartIndex:], separator), nil + } + log.GetLogger().Errorf("incorrect function name: %s", funcName) + return "", "", errors.New("parsing a function name error") +} + +// Anonymize anonymize input str to xxx****xxx +func Anonymize(str string) string { + if len(str) < anonymizeLen+1+anonymizeLen { + return anonymization + } + return str[:anonymizeLen] + anonymization + str[len(str)-anonymizeLen:] +} + +// AnonymizeTenantURN Anonymize tenant info in urn +func AnonymizeTenantURN(urn string) string { + elements := strings.Split(urn, URNSep) + urnLen := len(elements) + if urnLen < urnLenWithoutVersion || urnLen > URNLenWithVersion { + return urn + } + elements[TenantIDIndex] = Anonymize(elements[TenantIDIndex]) + return strings.Join(elements, URNSep) +} + +// AnonymizeTenantKey Anonymize tenant info in functionkey +func AnonymizeTenantKey(functionKey string) string { + elements := strings.Split(functionKey, FunctionKeySep) + keyLen := len(elements) + if TenantIDIndexKey >= keyLen { + return functionKey + } + elements[TenantIDIndexKey] = Anonymize(elements[TenantIDIndexKey]) + return strings.Join(elements, FunctionKeySep) +} + +// AnonymizeTenantURNSlice Anonymize tenant info in urn slice +func AnonymizeTenantURNSlice(urns []string) []string { + var anonymizeUrns []string + for i := 0; i < len(urns); i++ { + anonymizeUrn := AnonymizeTenantURN(urns[i]) + anonymizeUrns = append(anonymizeUrns, anonymizeUrn) + } + return anonymizeUrns +} + +// AnonymizeTenantMetadataEtcdKey Anonymize tenant info in tenant metadata etcd key +// /sn/quota/cluster/cluster001/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/instancemetadata +func AnonymizeTenantMetadataEtcdKey(etcdKey string) string { + elements := strings.Split(etcdKey, "/") + if len(elements) <= TenantMetadataTenantIndex { + return etcdKey + } + elements[TenantMetadataTenantIndex] = Anonymize(elements[TenantMetadataTenantIndex]) + return strings.Join(elements, "/") +} + +// AnonymizeKeys - anonymize the input slice of string to slice of xxx****xxx +// data system key example: 638cf733-a625-4850-9f23-9ef49873f5a3;2ba6f9cd-c8d3-4655-a9d0-e67d7abcfb3f +func AnonymizeKeys(keys []string) []string { + res := make([]string, len(keys)) + for i, str := range keys { + res[i] = Anonymize(str) + } + return res +} + +// BuildURNOrAliasURNTemp - build urn format +func BuildURNOrAliasURNTemp(business, tenant, function, versionOrAlias string) string { + if business == "" || tenant == "" || function == "" || versionOrAlias == "" { + return "" + } + return fmt.Sprintf("%s:%s:%s:%s:%s:%s:%s", DefaultURNProductID, DefaultURNRegion, + business, tenant, DefaultURNFuncSign, function, versionOrAlias) +} + +// GetServerIP - +func GetServerIP() (string, error) { + var err error + once.Do(func() { + addr, errMsg := GetHostAddr() + if errMsg != nil { + err = errMsg + return + } + serverIP = addr[0] + }) + return serverIP, err +} + +// GetHostAddr - +func GetHostAddr() ([]string, error) { + name, err := os.Hostname() + if err != nil { + log.GetLogger().Errorf("get hostname failed: %v", err) + return nil, err + } + + addrs, err := net.LookupHost(name) + if err != nil || len(addrs) == 0 { + log.GetLogger().Errorf("look up host by name failed") + return nil, fmt.Errorf("look up host by name failed") + } + return addrs, nil +} + +// CrNameByURN returns a CR name by URN +func CrNameByURN(urn string) string { + if len(urn) == URNIndexZero { + return "" + } + baseUrn, err := GetFunctionInfo(urn) + if err != nil { + return "" + } + return CrName(baseUrn.BusinessID, baseUrn.TenantID, baseUrn.FuncName, baseUrn.FuncVersion) +} + +// CrName CR Name +// [y/z]brief-functionname-version-hash +func CrName(business, tenant, funcName, version string) string { + hashStr := genFunctionCRStr(business, tenant, funcName, version) + crHash := utils.FnvHash(hashStr) + if len(crHash) > crHashMaxLen { + crHash = crHash[:crHashMaxLen] + } + brief := acquireBrief(business, tenant) + ver := VersionConvForBranch(version) + // cannot contain (urnutils.separator, ususually @) or _. If contains, replace it with -. + funcName = strings.ReplaceAll(funcName, "@", "-") + funcName = strings.ReplaceAll(funcName, "_", "-") + + // otherStrLen is 4 contains three - and a z or y. + shortFunctionNameLen := k8sLabelLen - len(brief) - len(ver) - len(crHash) - otherStrLen + // funcName prefix is 0- means funcName has joint sn service id + // k8s label max length is 63, so cr name need to delete sn service id + // otherwise, cr name length more than 63 characters, error + if strings.HasPrefix(funcName, ServiceIDPrefix) && len(funcName) > shortFunctionNameLen { + funcName = acquireShorter(funcName, shortFunctionNameLen) + } + + crName := brief + "-" + funcName + "-" + ver + "-" + crHash + crNameLower := strings.ToLower(crName) + if crName == crNameLower { + return "y" + crNameLower + } + + return "z" + crNameLower +} + +func genFunctionCRStr(business string, tenant string, funcName string, version string) string { + return business + "-" + tenant + "-" + funcName + "-" + version +} + +func acquireBrief(business, tenant string) string { + if len(business) > URNIndexFour { + business = business[:URNIndexFour] + } + product, tenant := splitTenant(tenant) + if len(tenant) > URNIndexFour { + tenant = tenant[:URNIndexFour] + } + + if len(product) > URNIndexFour { + product = product[:URNIndexFour] + } + + return business + tenant + product +} + +func splitTenant(tenant string) (string, string) { + var product string + t := strings.Split(tenant, TenantProductSplitStr) + l := len(t) + if l == URNIndexOne { + return product, tenant + } + if l == URNIndexTwo { + tenant = t[URNIndexZero] + product = t[URNIndexOne] + return product, tenant + } + return "", product +} + +// VersionConvForBranch return version Conv for branch +func VersionConvForBranch(v string) string { + // cannot contain _. If the version cr contains _, replace it with -. + version := strings.ReplaceAll(v, "_", "-") + if len(version) > versionManLen { + version = version[:versionManLen] + } + return version +} + +// if funcName contains sn service id, this method can acquire +// first 4 character of sn id and real function name with split _ +// return shorter serviceID and shorter funcName +func acquireShorter(funcName string, functionNameLen int) string { + shorterFuncName := []rune(funcName) + return string(shorterFuncName[len(shorterFuncName)-functionNameLen : len(shorterFuncName)-1]) +} + +// GetTenantFromFuncKey - +func GetTenantFromFuncKey(funcKey string) string { + elements := strings.Split(funcKey, FunctionKeySep) + keyLen := len(elements) + if keyLen != URNIndexThree { + return "" + } + return elements[TenantIDIndexKey] +} + +// GetFuncNameFromFuncKey - +func GetFuncNameFromFuncKey(funcKey string) string { + elements := strings.Split(funcKey, FunctionKeySep) + keyLen := len(elements) + if keyLen != URNIndexThree { + return "" + } + return elements[TenantIDIndexKey] + FunctionKeySep + elements[FunctionNameIndexKey] +} + +// GetTenantFromAliasUrn - +func GetTenantFromAliasUrn(aliasUrn string) string { + elements := strings.Split(aliasUrn, URNSep) + keyLen := len(elements) + if keyLen != URNIndexSeven { + return "" + } + return elements[URNIndexThree] +} + +// CheckAliasUrnTenant - +func CheckAliasUrnTenant(tenantID string, aliasUrn string) bool { + if GetTenantFromAliasUrn(aliasUrn) != "" && + GetTenantFromAliasUrn(aliasUrn) == tenantID { + return true + } + return false +} + +// CombineFunctionKey will generate funcKey from three IDs +func CombineFunctionKey(tenantID, funcName, version string) string { + return fmt.Sprintf("%s/%s/%s", tenantID, funcName, version) +} + +// GetShortFuncName - +func GetShortFuncName(funcName string) string { + if len(funcName) > k8sLabelLen { + // labels must begin and end with an alphanumeric character, so set first character always X + funcName = "X" + funcName[len(funcName)-k8sLabelLen+1:] + } + return funcName +} + +// BuildStandardFunctionName - 将不带版本、别名的方法名拼接成0@default@开头的完整方法名 +func BuildStandardFunctionName(functionName string) string { + splits := strings.Split(functionName, "@") + if len(splits) != shortFuncNameSplit && len(splits) != standardFuncNameSplit { + return "" + } + standardFunctionName := functionName + if len(splits) == shortFuncNameSplit { + standardFunctionName = funcNamePrefix + standardFunctionName + } + return standardFunctionName +} diff --git a/yuanrong/pkg/common/faas_common/urnutils/urn_utils_test.go b/yuanrong/pkg/common/faas_common/urnutils/urn_utils_test.go new file mode 100644 index 0000000..a3e42b9 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/urnutils/urn_utils_test.go @@ -0,0 +1,475 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package urnutils + +import ( + "net" + "os" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + mockUtils "yuanrong/pkg/common/faas_common/utils" +) + +func TestProductUrn_ParseFrom(t *testing.T) { + absURN := FunctionURN{ + "absPrefix", + "absZone", + "absBusinessID", + "absTenantID", + "absProductID", + "absName", + "latest", + } + absURNStr := "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName:latest" + type args struct { + urn string + } + tests := []struct { + name string + fields FunctionURN + args args + want FunctionURN + }{ + { + name: "normal test", + args: args{ + absURNStr, + }, + want: absURN, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &FunctionURN{} + if _ = p.ParseFrom(tt.args.urn); !reflect.DeepEqual(*p, tt.want) { + t.Errorf("ParseFrom() p = %v, want %v", *p, tt.want) + } + }) + } +} + +func TestProductUrn_String(t *testing.T) { + tests := []struct { + name string + fields FunctionURN + want string + }{ + { + "stringify with version", + FunctionURN{ + "absPrefix", + "absZone", + "absBusinessID", + "absTenantID", + "absProductID", + "absName", + "latest", + }, + "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName:latest", + }, + { + "stringify without version", + FunctionURN{ + ProductID: "absPrefix", + RegionID: "absZone", + BusinessID: "absBusinessID", + TenantID: "absTenantID", + TypeSign: "absProductID", + FuncName: "absName", + }, + "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &FunctionURN{ + ProductID: tt.fields.ProductID, + RegionID: tt.fields.RegionID, + BusinessID: tt.fields.BusinessID, + TenantID: tt.fields.TenantID, + TypeSign: tt.fields.TypeSign, + FuncName: tt.fields.FuncName, + FuncVersion: tt.fields.FuncVersion, + } + if got := p.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProductUrn_StringWithoutVersion(t *testing.T) { + tests := []struct { + name string + fields FunctionURN + want string + }{ + { + "stringify without version", + FunctionURN{ + "absPrefix", + "absZone", + "absBusinessID", + "absTenantID", + "absProductID", + "absName", + "latest", + }, + "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &FunctionURN{ + ProductID: tt.fields.ProductID, + RegionID: tt.fields.RegionID, + BusinessID: tt.fields.BusinessID, + TenantID: tt.fields.TenantID, + TypeSign: tt.fields.TypeSign, + FuncName: tt.fields.FuncName, + FuncVersion: tt.fields.FuncVersion, + } + if got := p.StringWithoutVersion(); got != tt.want { + t.Errorf("StringWithoutVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAnonymize(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"0", anonymization}, + {"123", anonymization}, + {"123456", anonymization}, + {"1234567", "123****567"}, + {"12345678901234546", "123****546"}, + } + for _, tt := range tests { + assert.Equal(t, tt.want, Anonymize(tt.input)) + } +} + +func TestAnonymizeTenantURN(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName", "absPrefix:absZone:absBusinessID:abs****tID:absProductID:absName"}, + {"absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName:latest", "absPrefix:absZone:absBusinessID:abs****tID:absProductID:absName:latest"}, + {"a:b:c", "a:b:c"}, + } + for _, tt := range tests { + assert.Equal(t, tt.want, AnonymizeTenantURN(tt.input)) + } +} + +func TestBaseURN_Valid(t *testing.T) { + separator = "@" + urn := FunctionURN{ + ProductID: "", + RegionID: "", + BusinessID: "", + TenantID: "", + TypeSign: "", + FuncName: "0@a_-9AA@AA", + FuncVersion: "", + } + success := urn.Valid() + assert.Equal(t, nil, success) + + urn.FuncName = "0@a_-9AA@tttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttt" + success = urn.Valid() + assert.Equal(t, nil, success) + + urn.FuncName = "0@a_-9AA@ttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttt" + err := urn.Valid() + assert.NotEqual(t, nil, err) + + urn.FuncName = "@func" + err = urn.Valid() + assert.NotEqual(t, nil, err) + + urn.FuncName = "0@func" + err = urn.Valid() + assert.NotEqual(t, nil, err) + + urn.FuncName = "0@^@^" + err = urn.Valid() + assert.NotEqual(t, nil, err) + + separator = "-" +} + +func TestBaseURN_GetAlias(t *testing.T) { + urn := FunctionURN{ + ProductID: "", + RegionID: "", + BusinessID: "", + TenantID: "", + TypeSign: "", + FuncName: "0@a_-9AA@AA", + FuncVersion: constant.DefaultURNVersion, + } + + alias := urn.GetAlias() + assert.Equal(t, "", alias) + + urn.FuncVersion = "old" + alias = urn.GetAlias() + assert.Equal(t, "old", alias) +} + +func TestGetFuncInfoWithVersion(t *testing.T) { + urn := "urn" + _, err := GetFuncInfoWithVersion(urn) + assert.NotEqual(t, nil, err) + + urn = "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName" + _, err = GetFuncInfoWithVersion(urn) + assert.NotEqual(t, nil, err) + + urn = "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName:latest" + parsedURN, err := GetFuncInfoWithVersion(urn) + assert.Equal(t, "absName", parsedURN.FuncName) +} + +func TestAnonymizeTenantKey(t *testing.T) { + inputKey := "" + outputKey := AnonymizeTenantKey(inputKey) + assert.Equal(t, "****", outputKey) + + inputKey = "input/key" + outputKey = AnonymizeTenantKey(inputKey) + assert.Equal(t, "****/key", outputKey) +} + +func TestParseAliasURN(t *testing.T) { + urn := "" + alias := ParseAliasURN(urn) + assert.Equal(t, urn, alias) + + urn = "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName:!latest" + alias = ParseAliasURN(urn) + assert.Equal(t, "absPrefix:absZone:absBusinessID:absTenantID:absProductID:absName:latest", alias) +} + +func TestAnonymizeTenantURNSlice(t *testing.T) { + inUrn := []string{"in", "in/urn"} + outUrn := AnonymizeTenantURNSlice(inUrn) + assert.Equal(t, "in", outUrn[0]) + assert.Equal(t, "in/urn", outUrn[1]) +} + +func TestBaseURN_GetAliasForFuncBranch(t *testing.T) { + urn := FunctionURN{ + ProductID: "", + RegionID: "", + BusinessID: "", + TenantID: "", + TypeSign: "", + FuncName: "0@a_-9AA@AA", + FuncVersion: "!latest", + } + + alias := urn.GetAliasForFuncBranch() + assert.Equal(t, "latest", alias) + + urn.FuncVersion = "latest" + alias = urn.GetAliasForFuncBranch() + assert.Equal(t, "", alias) +} + +func TestAnonymizeKeys(t *testing.T) { + type args struct { + keys []string + } + tests := []struct { + name string + args args + want []string + }{ + {"case", args{keys: []string{"123", "1234567"}}, []string{"****", "123****567"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, AnonymizeKeys(tt.args.keys), "AnonymizeKeys(%v)", tt.args.keys) + }) + } +} + +func TestBuildURNOrAliasURNTemp(t *testing.T) { + type args struct { + business string + tenant string + function string + versionOrAlias string + } + tests := []struct { + name string + args args + want string + }{ + {"empty", args{}, ""}, + {"empty", args{"1", "2", "3", "4"}, "sn:cn:1:2:function:3:4"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, BuildURNOrAliasURNTemp(tt.args.business, tt.args.tenant, tt.args.function, tt.args.versionOrAlias), "BuildURNOrAliasURNTemp(%v, %v, %v, %v)", tt.args.business, tt.args.tenant, tt.args.function, tt.args.versionOrAlias) + }) + } +} + +func TestCrNameByUrn(t *testing.T) { + type args struct { + args string + } + var a args + a.args = "sn:cn:yrk:12345678901234561234567890123456:function:0@yrservice@test_func:v1" + var b args + b.args = "" + tests := []struct { + name string + args args + want string + }{ + {"case1", a, "yyrk1234-0-yrservice-test-func-v1-2966683772"}, + {"case2", b, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CrNameByURN(tt.args.args); got != tt.want { + t.Errorf("CrNameByURN() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetServerIP(t *testing.T) { + tests := []struct { + name string + want string + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 succeed to get ip", "127.0.0.1", false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(os.Hostname, func() (name string, err error) { return "127.0.0.1", nil })}) + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(net.LookupHost, + func(host string) (addrs []string, err error) { return []string{"127.0.0.1", "0"}, nil })}) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + got, err := GetServerIP() + if (err != nil) != tt.wantErr { + t.Errorf("GetServerIP() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetServerIP() got = %v, want %v", got, tt.want) + } + patches.ResetAll() + }) + } +} + +func TestCheckAliasUrnTenant(t *testing.T) { + type args struct { + tenantID string + aliasUrn string + } + tests := []struct { + name string + args args + want bool + }{ + {"case1", args{tenantID: "12345678901234561234567890123456", + aliasUrn: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld:myaliasv1"}, true}, + {"case2 error", args{tenantID: "12345678901234561234567890123456", + aliasUrn: "sn:cn:yrk:12345678901234561234567890123456:function:helloworld"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CheckAliasUrnTenant(tt.args.tenantID, tt.args.aliasUrn); got != tt.want { + t.Errorf("CheckAliasUrnTenant() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetTenantFormFuncKey(t *testing.T) { + type args struct { + funcKey string + } + tests := []struct { + name string + args args + want string + }{ + {"case1", args{funcKey: "12345678901234561234567890123456/0-system-faasscheduler/$latest"}, + "12345678901234561234567890123456"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetTenantFromFuncKey(tt.args.funcKey); got != tt.want { + t.Errorf("GetTenantFromFuncKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetShortFuncName(t *testing.T) { + funcName := "testFunc1111111111111111111111111111111111111111111111111111111" + shortFuncName := GetShortFuncName(funcName) + assert.Equal(t, "testFunc1111111111111111111111111111111111111111111111111111111", shortFuncName) + + funcName = "testFunc1111111111111111111111111111111111111111111111111111111111111111111111111111111" + shortFuncName = GetShortFuncName(funcName) + assert.Equal(t, "X11111111111111111111111111111111111111111111111111111111111111", shortFuncName) +} + +func TestGetFuncNameFromFuncKey(t *testing.T) { + funcKey := "12345/test_func/latest/1" + funcName := GetFuncNameFromFuncKey(funcKey) + assert.Equal(t, "", funcName) + + funcKey = "12345/test_func/latest" + funcName = GetFuncNameFromFuncKey(funcKey) + assert.Equal(t, "12345/test_func", funcName) +} + +func TestAnonymizeTenantMetadataEtcdKey(t *testing.T) { + etcdKey := "/sn/quota/cluster/cluster001/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a" + AnonymizedKey := AnonymizeTenantMetadataEtcdKey(etcdKey) + assert.Equal(t, "/sn/quota/cluster/cluster001/tenant/7e1****86a", AnonymizedKey) + + etcdKey = "/sn/quota/cluster/cluster001/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/instancemetadata" + AnonymizedKey = GetFuncNameFromFuncKey(etcdKey) + assert.Equal(t, "", AnonymizedKey) +} diff --git a/yuanrong/pkg/common/faas_common/urnutils/urnconv.go b/yuanrong/pkg/common/faas_common/urnutils/urnconv.go new file mode 100644 index 0000000..30bacf8 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/urnutils/urnconv.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package urnutils - +package urnutils + +import ( + "strings" +) + +// FunctionInfo defines Function Info +type FunctionInfo struct { + Business string + Tenant string + FuncName string + Version string +} + +// CrNameByKey return Cr Name By function key +func CrNameByKey(funcKey string) string { + functionInfo := GetFunctionInfoByKey(funcKey) + business, tenant, funcName, version := functionInfo.Business, functionInfo.Tenant, + functionInfo.FuncName, functionInfo.Version + + return CrName(business, tenant, funcName, version) +} + +// GetFunctionInfoByKey - +func GetFunctionInfoByKey(key string) FunctionInfo { + var functionInfo FunctionInfo + keyFields := strings.Split(key, "/") + + if len(keyFields) != URNIndexEleven && len(keyFields) != URNIndexThirteen { + return functionInfo + } + + functionInfo.Business = keyFields[URNIndexFour] + functionInfo.Tenant = keyFields[URNIndexSix] + functionInfo.FuncName = keyFields[URNIndexEight] + functionInfo.Version = keyFields[URNIndexTen] + + return functionInfo +} diff --git a/yuanrong/pkg/common/faas_common/urnutils/urnconv_test.go b/yuanrong/pkg/common/faas_common/urnutils/urnconv_test.go new file mode 100644 index 0000000..4181f0c --- /dev/null +++ b/yuanrong/pkg/common/faas_common/urnutils/urnconv_test.go @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package urnutils - +package urnutils + +import "testing" + +func TestCrNameByKey(t *testing.T) { + type args struct { + funcKey string + } + tests := []struct { + name string + args args + want string + }{ + {"case1 succeed to get CrNameByKey", args{funcKey: "/sn/functions/business/yrk/tenant" + + "/172120022624850603/function/0@default@testurpccustomoom002/version/latest"}, + "yyrk1721-0-default-testurpccustomoom002-latest-1257561201"}, + {"case2 long funcName", args{funcKey: "/sn/functions/business/yrk/tenant/12345678901234561234567890123456/" + + "function/0-actordemo-test-actor-support-version-publish-delete-version/version/$latest"}, + "yyrk1234-port-version-publish-delete-versio-$latest-4279038269"}, + {"case3 long version", args{funcKey: "/sn/functions/business/yrk/tenant/12345678901234561234567890123456/function" + + "/0-actordemo-test-actor-support-version-publish-delete-version/version/123456789123456789123456789123456789123456789123456789123456"}, + "yyrk1234-lete-versio-123456789123456789123456789123-3816641367"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CrNameByKey(tt.args.funcKey); got != tt.want { + t.Errorf("CrNameByKey() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/utils/component_util.go b/yuanrong/pkg/common/faas_common/utils/component_util.go new file mode 100644 index 0000000..b89c712 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/component_util.go @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "k8s.io/api/core/v1" +) + +type container string + +const ( + // ContainerRuntimeManager - + ContainerRuntimeManager container = "runtime-manager" +) + +// VolumeBuilder - +type VolumeBuilder struct { + Volumes []v1.Volume + Mounts map[container][]v1.VolumeMount +} + +// AddVolume - +func (vc *VolumeBuilder) AddVolume(volume v1.Volume) { + vc.Volumes = append(vc.Volumes, volume) +} + +// AddVolumeMount - +func (vc *VolumeBuilder) AddVolumeMount(name container, mount v1.VolumeMount) { + vc.Mounts[name] = append(vc.Mounts[name], mount) +} + +// NewVolumeBuilder - +func NewVolumeBuilder() *VolumeBuilder { + return &VolumeBuilder{ + Mounts: make(map[container][]v1.VolumeMount), + } +} diff --git a/yuanrong/pkg/common/faas_common/utils/file_test.go b/yuanrong/pkg/common/faas_common/utils/file_test.go new file mode 100644 index 0000000..5c9ec26 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/file_test.go @@ -0,0 +1,53 @@ +package utils + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestFileExists(t *testing.T) { + Convey("Given a temp file", t, func() { + file, err := ioutil.TempFile("", "test-file") + So(err, ShouldBeNil) + filename := file.Name() + + Convey("When it is created", func() { + Convey("Then it should return true", func() { + So(FileExists(filename), ShouldBeTrue) + }) + }) + + Convey("When we delete the file", func() { + err := file.Close() + So(err, ShouldBeNil) + err = os.Remove(filename) + So(err, ShouldBeNil) + + Convey("Then it should return false", func() { + So(FileExists(filename), ShouldBeFalse) + }) + }) + }) +} + +func TestValidateFilePath(t *testing.T) { + Convey("Given a abs file path and a rel file path", t, func() { + relPath := "a/b" + absPath, err := filepath.Abs(relPath) + So(err, ShouldBeNil) + + Convey("The abs path should not return an error", func() { + err = ValidateFilePath(absPath) + So(err, ShouldBeNil) + }) + + Convey("The rel path should return an error", func() { + err := ValidateFilePath(relPath) + So(err, ShouldNotBeNil) + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/utils/func_meta_util.go b/yuanrong/pkg/common/faas_common/utils/func_meta_util.go new file mode 100644 index 0000000..23c7c35 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/func_meta_util.go @@ -0,0 +1,193 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "encoding/json" + "fmt" + "hash/fnv" + "strconv" + "strings" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/types" +) + +const ( + funcInfoMinLen = 3 + // InstanceScalePolicyStaticFunction is the schedule policy for static function + InstanceScalePolicyStaticFunction = "staticFunction" +) + +// GetFuncMetaSignature will calculate function signature based on essentials +func GetFuncMetaSignature(metaInfo *types.FunctionMetaInfo, filterFlag bool) string { + // static function set revisionID as signature + if metaInfo.InstanceMetaData.ScalePolicy == InstanceScalePolicyStaticFunction { + return metaInfo.FuncMetaData.RevisionID + } + metaInfoCopy := &types.FunctionMetaInfo{} + if err := DeepCopyObj(metaInfo, metaInfoCopy); err != nil { + return "invalid function meta info" + } + if filterFlag { + metaInfoFieldFilter(metaInfoCopy) + } + metaInfoCopy.FuncMetaData.FuncID = "" + metaInfoCopy.FuncMetaData.Type = "" + metaInfoCopy.FuncMetaData.EnableCloudDebug = "" + metaInfoCopy.FuncMetaData.Dependencies = "" + metaInfoCopy.FuncMetaData.CodeSize = 0 + metaInfoCopy.FuncMetaData.CodeSha512 = "" + metaInfoCopy.FuncMetaData.FunctionType = "" + metaInfoCopy.FuncMetaData.Tags = nil + metaInfoCopy.FuncMetaData.FunctionDescription = "" + metaInfoCopy.FuncMetaData.FunctionUpdateTime = "" + metaInfoCopy.InstanceMetaData.ScalePolicy = "" + metaInfoCopy.InstanceMetaData.MaxInstance = 0 + metaInfoCopy.InstanceMetaData.MinInstance = 0 + metaInfoCopy.ExtendedMetaData.DynamicConfig.UpdateTime = "" + metaInfoCopy.ExtendedMetaData.DynamicConfig.ConfigContent = []types.KV{} + metaInfoCopy.ExtendedMetaData.StrategyConfig = types.StrategyConfig{} + metaInfoCopy.ExtendedMetaData.ExtendConfig = "" + metaInfoCopy.ExtendedMetaData.EnterpriseProjectID = "" + metaInfoCopy.ExtendedMetaData.AsyncConfigLoaded = false + metaInfoCopy.ExtendedMetaData.NetworkController = types.NetworkController{} + metaInfoCopy.ResourceMetaData.CustomResourcesSpec = + getCustomResourceSpec(metaInfo.ResourceMetaData.CustomResources, metaInfo.ResourceMetaData.CustomResourcesSpec) + data, err := json.Marshal(metaInfoCopy) + if err != nil { + return "invalid function meta info" + } + return FnvHash(string(data)) +} +func getCustomResourceSpec(customResources string, customResourceSpec string) string { + // customResources为空,customResourceSpec必然为空 + if customResources == "" { + return "" + } + customResourcesJSON := make(map[string]int64) + customResourcesSpecJSON := make(map[string]interface{}) + err1 := json.Unmarshal([]byte(customResources), &customResourcesJSON) + + err2 := json.Unmarshal([]byte(customResourceSpec), &customResourcesSpecJSON) + if err1 != nil || (err2 != nil && customResourceSpec != "") { + return "" + } + for k := range customResourcesJSON { + if k == "huawei.com/ascend-1980" { + _, ok := customResourcesSpecJSON["instanceType"] + if !ok { + customResourcesSpecJSON["instanceType"] = "376T" + } + break + } + } + v, err3 := json.Marshal(customResourcesSpecJSON) + if err3 != nil { + return "" + } + return string(v) +} + +func metaInfoFieldFilter(metaInfoCopy *types.FunctionMetaInfo) { + metaInfoCopy.FuncMetaData.Service = "" + metaInfoCopy.S3MetaData = types.S3MetaData{} + + metaInfoCopy.EnvMetaData = types.EnvMetaData{} + + metaInfoCopy.ResourceMetaData.EnableDynamicMemory = false + metaInfoCopy.ResourceMetaData.EnableTmpExpansion = false + metaInfoCopy.ResourceMetaData.GpuMemory = 0 + metaInfoCopy.ResourceMetaData.EphemeralStorage = 0 + + metaInfoCopy.ExtendedMetaData.ImageName = "" + if metaInfoCopy.ExtendedMetaData.VpcConfig != nil { + metaInfoCopy.ExtendedMetaData.VpcConfig.Xrole = "" + } + metaInfoCopy.ExtendedMetaData.UserAgency = types.UserAgency{} +} + +// FnvHash a hash function +func FnvHash(s string) string { + h := fnv.New32a() + _, err := h.Write([]byte(s)) + if err != nil { + return "" + } + + // for 2 <= base <= 36. The result uses the lower-case letters 'a' to 'z' + return strconv.FormatUint(uint64(h.Sum32()), 10) +} + +// DeepCopyObj deal with src and dst +func DeepCopyObj(src interface{}, dst interface{}) error { + if dst == nil { + return fmt.Errorf("dst cannot be nil") + } + if src == nil { + return fmt.Errorf("src cannot be nil") + } + + bytes, err := json.Marshal(src) + if err != nil { + return fmt.Errorf("unable to marshal src: %s", err) + } + + err = json.Unmarshal(bytes, dst) + if err != nil { + return fmt.Errorf("unable to unmarshal into dst: %s", err) + } + return nil +} + +// SetFuncMetaDynamicConfEnable will calculate DynamicConfig and set DynamicConfig.Enabled +func SetFuncMetaDynamicConfEnable(metaInfo *types.FunctionMetaInfo) { + // The DynamicConfig.Enabled will use for calculate function signature. + // When DynamicConfig.Enabled changes, the instance will be restarted. + // If function version is not latest,DynamicConfig.Enabled will never change + if len(metaInfo.ExtendedMetaData.DynamicConfig.UpdateTime) == 0 { + metaInfo.ExtendedMetaData.DynamicConfig.Enabled = false + return + } + // + if metaInfo.FuncMetaData.Version == constant.DefaultURNVersion && + len(metaInfo.ExtendedMetaData.DynamicConfig.ConfigContent) == 0 { + metaInfo.ExtendedMetaData.DynamicConfig.Enabled = false + return + } + metaInfo.ExtendedMetaData.DynamicConfig.Enabled = true +} + +// ParseFuncKey parse funcKey with format "tenantID/funcName/funcVersion" or "tenantID/funcName/funcVersion/CPU-memory" +func ParseFuncKey(funcKey string) (string, string, string) { + funcInfo := strings.Split(funcKey, "/") + if len(funcInfo) < funcInfoMinLen { + return "", "", "" + } + return funcInfo[0], funcInfo[1], funcInfo[2] +} + +// GetAPIType - +func GetAPIType(BusinessType string) api.ApiType { + if BusinessType == constant.BusinessTypeServe { + return api.ServeApi + } + return api.FaaSApi +} diff --git a/yuanrong/pkg/common/faas_common/utils/func_meta_util_test.go b/yuanrong/pkg/common/faas_common/utils/func_meta_util_test.go new file mode 100644 index 0000000..12feb84 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/func_meta_util_test.go @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/types" +) + +func TestGetFuncMetaSignature(t *testing.T) { + convey.Convey("success", t, func() { + signature := GetFuncMetaSignature(&types.FunctionMetaInfo{}, true) + convey.So(signature, convey.ShouldEqual, "2778597263") + }) + convey.Convey("marshal error", t, func() { + defer gomonkey.ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + }).Reset() + str := GetFuncMetaSignature(&types.FunctionMetaInfo{}, true) + convey.So(str, convey.ShouldContainSubstring, "invalid function meta info") + }) + convey.Convey("unmarshal error", t, func() { + defer gomonkey.ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error { + return fmt.Errorf("unmarshal error") + }).Reset() + str := GetFuncMetaSignature(&types.FunctionMetaInfo{}, true) + convey.So(str, convey.ShouldContainSubstring, "invalid function meta info") + }) +} + +func TestSetFuncMetaDynamicConfEnable(t *testing.T) { + type args struct { + metaInfo *types.FunctionMetaInfo + } + tests := []struct { + name string + args args + }{ + {"case1", args{metaInfo: &types.FunctionMetaInfo{}}}, + {"case2", args{metaInfo: &types.FunctionMetaInfo{FuncMetaData: types.FuncMetaData{Version: constant.DefaultURNVersion}, + ExtendedMetaData: types.ExtendedMetaData{DynamicConfig: types.DynamicConfigEvent{UpdateTime: "1"}}}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + SetFuncMetaDynamicConfEnable(tt.args.metaInfo) + }) + } +} + +func TestGetCustomResource(t *testing.T) { + convey.Convey("success", t, func() { + customResources := getCustomResourceSpec("{\"huawei.com/ascend-1980\":8}", "") + convey.So(customResources, convey.ShouldEqual, "{\"instanceType\":\"376T\"}") + }) + + convey.Convey("success", t, func() { + customResources := getCustomResourceSpec("{\"huawei.com/ascend-1980\": 8}", "{\"instanceType\": \"376T\"}") + convey.So(customResources, convey.ShouldEqual, "{\"instanceType\":\"376T\"}") + }) + + convey.Convey("success", t, func() { + customResources := getCustomResourceSpec("{\"huawei.com/ascend-1980\":8}", "{ \"instanceType\": \"280T\"}") + convey.So(customResources, convey.ShouldEqual, "{\"instanceType\":\"280T\"}") + }) +} diff --git a/yuanrong/pkg/common/faas_common/utils/helper.go b/yuanrong/pkg/common/faas_common/utils/helper.go new file mode 100644 index 0000000..9578063 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/helper.go @@ -0,0 +1,170 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils for common functions +package utils + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "path" + "path/filepath" + "strings" +) + +const ( + envPathSeparators = ":" +) + +// IsFile returns true if the path is a file +func IsFile(path string) bool { + file, err := os.Stat(path) + if err != nil { + return false + } + return file.Mode().IsRegular() +} + +// IsDir returns true if the path is a dir +func IsDir(path string) bool { + dir, err := os.Stat(path) + if err != nil { + return false + } + + return dir.IsDir() +} + +// FileExists returns true if the path exists +func FileExists(path string) bool { + _, err := os.Stat(path) + if err != nil { + return false + } + return true +} + +// IsHexString judge If Hex String +func IsHexString(str string) bool { + + str = strings.ToLower(str) + + for _, c := range str { + if c < '0' || (c > '9' && c < 'a') || c > 'f' { + return false + } + } + + return true +} + +// ValidateFilePath verify the legitimacy of the file path +func ValidateFilePath(path string) error { + absPath, err := filepath.Abs(path) + if err != nil || !strings.HasPrefix(path, absPath) { + return errors.New("invalid file path, expect to be configured as an absolute path") + } + return nil +} + +// ValidEnvValuePath verify the legitimacy of the env path +func ValidEnvValuePath(envValues string) error { + if envValues == "" { + return nil + } + envByte := strings.Split(envValues, envPathSeparators) + for _, envValue := range envByte { + if err := ValidateFilePath(envValue); err != nil { + return err + } + } + return nil +} + +// copyFile copies a single file from src to dst +func copyFile(srcPath, dstPath string) error { + var err error + var fromFd *os.File + var toFd *os.File + var fromFdInfo os.FileInfo + + if fromFd, err = os.Open(srcPath); err != nil { + return err + } + defer func(fromFd *os.File) { + if fromFd != nil { + err = fromFd.Close() + } + }(fromFd) + + if fromFdInfo, err = os.Stat(srcPath); err != nil { + return err + } + + toFd, err = os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fromFdInfo.Mode()) + defer func(toFd *os.File) { + if toFd != nil { + err = toFd.Close() + } + }(toFd) + + if err != nil { + return err + } + + if _, err = io.Copy(toFd, fromFd); err != nil { + return err + } + + return err +} + +// CopyDir copies a whole directory recursively +func CopyDir(srcPath string, dstPath string) error { + var err error + var dirFds []os.FileInfo + var fromInfo os.FileInfo + + if fromInfo, err = os.Stat(srcPath); err != nil { + return err + } + + if err = os.MkdirAll(dstPath, fromInfo.Mode()); err != nil { + return err + } + + if dirFds, err = ioutil.ReadDir(srcPath); err != nil { + return err + } + for _, fd := range dirFds { + fromPath := path.Join(srcPath, fd.Name()) + toPath := path.Join(dstPath, fd.Name()) + + if fd.IsDir() { + if err = CopyDir(fromPath, toPath); err != nil { + fmt.Println(err) + } + } else { + if err = copyFile(fromPath, toPath); err != nil { + fmt.Println(err) + } + } + } + return nil +} diff --git a/yuanrong/pkg/common/faas_common/utils/helper_test.go b/yuanrong/pkg/common/faas_common/utils/helper_test.go new file mode 100644 index 0000000..0478f9f --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/helper_test.go @@ -0,0 +1,196 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type IsFileTestSuite struct { + suite.Suite + tempDir string +} + +// SetupSuite Setup Suite +func (suite *IsFileTestSuite) SetupSuite() { + var err error + + // Create temp dir for IsFileTestSuite + suite.tempDir, err = ioutil.TempDir("", "isfile-test") + suite.Require().NoError(err) +} + +// TearDownSuite TearDown Suite +func (suite *IsDirTestSuite) TearDownSuite() { + defer os.RemoveAll(suite.tempDir) +} + +// TestPositive Test Positive +func (suite *IsFileTestSuite) TestPositive() { + + // Create temp file + tempFile, err := ioutil.TempFile(suite.tempDir, "temp_file") + suite.Require().NoError(err) + defer os.Remove(tempFile.Name()) + + // Verify that function isFile() returns true when file is created + suite.Require().True(IsFile(tempFile.Name())) + +} + +// TestFileIsNotExist Test File Is Not Exist +func (suite *IsFileTestSuite) TestFileIsNotExist() { + + // Set path to unexisted file + tempFile := filepath.Join(suite.tempDir, "somePath.txt") + + // Verify that function isFile() returns false when file doesn't exist in the system + suite.Require().False(IsFile(tempFile)) +} + +// TestFileIsADirectory Test File Is A Directory +func (suite *IsFileTestSuite) TestFileIsADirectory() { + suite.Require().False(IsFile(suite.tempDir)) +} + +type IsDirTestSuite struct { + suite.Suite + tempDir string +} + +// SetupSuite Setup Suite +func (suite *IsDirTestSuite) SetupSuite() { + var err error + + // Create temp dir for IsDirTestSuite + suite.tempDir, err = ioutil.TempDir("", "isdir-test") + suite.Require().NoError(err) +} + +// TearDownSuite TearDown Suite +func (suite *IsFileTestSuite) TearDownSuite() { + defer os.RemoveAll(suite.tempDir) +} + +// TestPositive Test Positive +func (suite *IsDirTestSuite) TestPositive() { + + // Verify that function IsDir() returns true when directory exists in the system + suite.Require().True(IsDir(suite.tempDir)) +} + +// TestNegative Test Negative +func (suite *IsDirTestSuite) TestNegative() { + + // Create temp file + tempFile, err := ioutil.TempFile(suite.tempDir, "temp_file") + suite.Require().NoError(err) + defer os.Remove(tempFile.Name()) + + // Verify that function IsDir( returns false when file instead of directory is function argument + suite.Require().False(IsDir(tempFile.Name())) +} + +type FileExistTestSuite struct { + suite.Suite + tempDir string +} + +// SetupSuite Setup Suite +func (suite *FileExistTestSuite) SetupSuite() { + var err error + + // Create temp dir for FileExistTestSuite + suite.tempDir, err = ioutil.TempDir("", "file_exists-test") + suite.Require().NoError(err) +} + +// TearDownSuite TearDown Suite +func (suite *FileExistTestSuite) TearDownSuite() { + defer os.RemoveAll(suite.tempDir) +} + +// TestPositive Test Positive +func (suite *FileExistTestSuite) TestPositive() { + + // Create temp file + tempFile, err := ioutil.TempFile(suite.tempDir, "temp_file") + suite.Require().NoError(err) + defer os.Remove(tempFile.Name()) + + // Verify that function FileExists() returns true when file is exist + suite.Require().True(FileExists(tempFile.Name())) +} + +// TestFileNotExist Test File Not Exist +func (suite *FileExistTestSuite) TestFileNotExist() { + + // Set path to unexisted file + tempFile := filepath.Join(suite.tempDir, "somePath.txt") + + // Verify that function FileExists() returns false when file doesn't exist + suite.Require().False(FileExists(tempFile)) +} + +// TestFileIsNotAFile Test File Is Not A File +func (suite *FileExistTestSuite) TestFileIsNotAFile() { + + // Verify that function returns true when folder is exist in the system + suite.Require().True(FileExists(suite.tempDir)) +} + +// TestHelperTestSuite Test Helper Test Suite +func TestHelperTestSuite(t *testing.T) { + suite.Run(t, new(FileExistTestSuite)) + suite.Run(t, new(IsDirTestSuite)) + suite.Run(t, new(IsFileTestSuite)) +} + +func TestValidEnvValuePath(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"", true}, + {"/home/sn/test", true}, + {"../../home/sn", false}, + {"/home/sn:/home/test:/opt", true}, + } + for _, tt := range tests { + assert.Equal(t, tt.want, ValidEnvValuePath(tt.input) == nil) + } +} + +func TestCopyDir(t *testing.T) { + convey.Convey("CopyDir", t, func() { + convey.Convey("CopyDir case 1", func() { + srcPath, _ := ioutil.TempDir("", "src") + dstPath, _ := ioutil.TempDir("", "dst") + fileName := "fastfreeze.log" + _, err := ioutil.TempFile(srcPath, fileName) + err = CopyDir(srcPath, dstPath) + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/utils/libruntimeapi_mock.go b/yuanrong/pkg/common/faas_common/utils/libruntimeapi_mock.go new file mode 100644 index 0000000..0ba980f --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/libruntimeapi_mock.go @@ -0,0 +1,305 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils is sdk +package utils + +import ( + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/uuid" +) + +// FakeLibruntimeSdkClient - +type FakeLibruntimeSdkClient struct{} + +// CreateInstance - +func (f *FakeLibruntimeSdkClient) CreateInstance(funcMeta api.FunctionMeta, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + + InstanceID := uuid.New().String() + return InstanceID, nil +} + +// InvokeByInstanceId - +func (f *FakeLibruntimeSdkClient) InvokeByInstanceId(funcMeta api.FunctionMeta, + instanceID string, args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + return "", nil +} + +// InvokeByFunctionName - +func (f *FakeLibruntimeSdkClient) InvokeByFunctionName(funcMeta api.FunctionMeta, + args []api.Arg, invokeOpt api.InvokeOptions) (string, error) { + return "", nil +} + +// AcquireInstance - +func (f *FakeLibruntimeSdkClient) AcquireInstance(state string, funcMeta api.FunctionMeta, + acquireOpt api.InvokeOptions) (api.InstanceAllocation, error) { + return api.InstanceAllocation{}, nil +} + +// ReleaseInstance - +func (f *FakeLibruntimeSdkClient) ReleaseInstance(allocation api.InstanceAllocation, + stateID string, abnormal bool, option api.InvokeOptions) { + return +} + +// Kill - +func (f *FakeLibruntimeSdkClient) Kill(instanceID string, signal int, payload []byte) error { + return nil +} + +// CreateInstanceRaw - +func (f *FakeLibruntimeSdkClient) CreateInstanceRaw(createReqRaw []byte) ([]byte, error) { + return nil, nil +} + +// InvokeByInstanceIdRaw - +func (f *FakeLibruntimeSdkClient) InvokeByInstanceIdRaw(invokeReqRaw []byte) ([]byte, error) { + return nil, nil +} + +// KillRaw - +func (f *FakeLibruntimeSdkClient) KillRaw(killReqRaw []byte) ([]byte, error) { + return nil, nil +} + +// SaveState - +func (f *FakeLibruntimeSdkClient) SaveState(state []byte) (string, error) { + return "", nil +} + +// LoadState - +func (f *FakeLibruntimeSdkClient) LoadState(checkpointID string) ([]byte, error) { + return nil, nil +} + +// Exit - +func (f *FakeLibruntimeSdkClient) Exit(code int, message string) { + return +} + +// Finalize - +func (f *FakeLibruntimeSdkClient) Finalize() { + return +} + +// KVSet - +func (f *FakeLibruntimeSdkClient) KVSet(key string, value []byte, param api.SetParam) error { + return nil +} + +// KVSetWithoutKey - +func (f *FakeLibruntimeSdkClient) KVSetWithoutKey(value []byte, param api.SetParam) (string, error) { + return "", nil +} + +// KVMSetTx - +func (f *FakeLibruntimeSdkClient) KVMSetTx(keys []string, values [][]byte, param api.MSetParam) error { + return nil +} + +// KVGet - +func (f *FakeLibruntimeSdkClient) KVGet(key string, timeoutms uint) ([]byte, error) { + return nil, nil +} + +// KVGetMulti - +func (f *FakeLibruntimeSdkClient) KVGetMulti(keys []string, timeoutms uint) ([][]byte, error) { + return nil, nil +} + +// KVDel - +func (f *FakeLibruntimeSdkClient) KVDel(key string) error { + return nil +} + +// KVDelMulti - +func (f *FakeLibruntimeSdkClient) KVDelMulti(keys []string) ([]string, error) { + return []string{}, nil +} + +// CreateProducer - +func (f *FakeLibruntimeSdkClient) CreateProducer(streamName string, + producerConf api.ProducerConf) (api.StreamProducer, error) { + return &FakeStreamProducer{}, nil +} + +// Subscribe - +func (f *FakeLibruntimeSdkClient) Subscribe(streamName string, + config api.SubscriptionConfig) (api.StreamConsumer, error) { + return &FakeStreamConsumer{}, nil +} + +// DeleteStream - +func (f *FakeLibruntimeSdkClient) DeleteStream(streamName string) error { + return nil +} + +// QueryGlobalProducersNum - +func (f *FakeLibruntimeSdkClient) QueryGlobalProducersNum(streamName string) (uint64, error) { + return 0, nil +} + +// QueryGlobalConsumersNum - +func (f *FakeLibruntimeSdkClient) QueryGlobalConsumersNum(streamName string) (uint64, error) { + return 0, nil +} + +// SetTraceID - +func (f *FakeLibruntimeSdkClient) SetTraceID(traceID string) { + return +} + +// SetTenantID - +func (f *FakeLibruntimeSdkClient) SetTenantID(tenantID string) error { + return nil +} + +// Put - +func (f *FakeLibruntimeSdkClient) Put(objectID string, value []byte, + param api.PutParam, nestedObjectIDs ...string) error { + return nil +} + +// PutRaw - +func (f *FakeLibruntimeSdkClient) PutRaw(objectID string, value []byte, + param api.PutParam, nestedObjectIDs ...string) error { + return nil +} + +// Get - +func (f *FakeLibruntimeSdkClient) Get(objectIDs []string, timeoutMs int) ([][]byte, error) { + return nil, nil +} + +// GetRaw - +func (f *FakeLibruntimeSdkClient) GetRaw(objectIDs []string, timeoutMs int) ([][]byte, error) { + return nil, nil +} + +// Wait - +func (f *FakeLibruntimeSdkClient) Wait(objectIDs []string, + waitNum uint64, timeoutMs int) ([]string, []string, map[string]error) { + return nil, nil, nil +} + +// GIncreaseRef - +func (f *FakeLibruntimeSdkClient) GIncreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return nil, nil +} + +// GIncreaseRefRaw - +func (f *FakeLibruntimeSdkClient) GIncreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + return nil, nil +} + +// GDecreaseRef - +func (f *FakeLibruntimeSdkClient) GDecreaseRef(objectIDs []string, remoteClientID ...string) ([]string, error) { + return nil, nil +} + +// GDecreaseRefRaw - +func (f *FakeLibruntimeSdkClient) GDecreaseRefRaw(objectIDs []string, remoteClientID ...string) ([]string, error) { + return nil, nil +} + +// GetAsync - +func (f *FakeLibruntimeSdkClient) GetAsync(objectID string, cb api.GetAsyncCallback) { + return +} + +// GetFormatLogger - +func (f *FakeLibruntimeSdkClient) GetFormatLogger() api.FormatLogger { + return nil +} + +// CreateClient - +func (f *FakeLibruntimeSdkClient) CreateClient(config api.ConnectArguments) (api.KvClient, error) { + return nil, nil +} + +// ReleaseGRefs - +func (f *FakeLibruntimeSdkClient) ReleaseGRefs(remoteClientID string) error { + return nil +} + +// GetCredential - +func (f *FakeLibruntimeSdkClient) GetCredential() api.Credential { + return api.Credential{} +} + +// UpdateSchdulerInfo - +func (f *FakeLibruntimeSdkClient) UpdateSchdulerInfo(schedulerName string, schedulerId string, option string) { + return +} + +// IsHealth - +func (f *FakeLibruntimeSdkClient) IsHealth() bool { + return true +} + +// IsDsHealth - +func (f *FakeLibruntimeSdkClient) IsDsHealth() bool { + return true +} + +// FakeStreamProducer - +type FakeStreamProducer struct{} + +// Send - +func (fsp *FakeStreamProducer) Send(element api.Element) error { + return nil +} + +// SendWithTimeout - +func (fsp *FakeStreamProducer) SendWithTimeout(element api.Element, timeoutMs int64) error { + return nil +} + +// Flush - +func (fsp *FakeStreamProducer) Flush() error { + return nil +} + +// Close - +func (fsp *FakeStreamProducer) Close() error { + return nil +} + +// FakeStreamConsumer - +type FakeStreamConsumer struct{} + +// ReceiveExpectNum - +func (fsc *FakeStreamConsumer) ReceiveExpectNum(expectNum uint32, timeoutMs uint32) ([]api.Element, error) { + return nil, nil +} + +// Receive - +func (fsc *FakeStreamConsumer) Receive(timeoutMs uint32) ([]api.Element, error) { + return nil, nil +} + +// Ack - +func (fsc *FakeStreamConsumer) Ack(elementId uint64) error { + return nil +} + +// Close - +func (fsc *FakeStreamConsumer) Close() error { + return nil +} diff --git a/yuanrong/pkg/common/faas_common/utils/libruntimeapi_mock_test.go b/yuanrong/pkg/common/faas_common/utils/libruntimeapi_mock_test.go new file mode 100644 index 0000000..3f07776 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/libruntimeapi_mock_test.go @@ -0,0 +1,143 @@ +package utils + +import ( + "github.com/stretchr/testify/assert" + "testing" + "yuanrong.org/kernel/runtime/libruntime/api" +) + +func TestFakeLibruntimeSdkClient(t *testing.T) { + fakeLibruntimeSdkClient := FakeLibruntimeSdkClient{} + instanceID, err := fakeLibruntimeSdkClient.CreateInstance(api.FunctionMeta{}, []api.Arg{}, api.InvokeOptions{}) + assert.NotEqual(t, 0, len(instanceID)) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.InvokeByInstanceId(api.FunctionMeta{}, "", []api.Arg{}, api.InvokeOptions{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.InvokeByFunctionName(api.FunctionMeta{}, []api.Arg{}, api.InvokeOptions{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.AcquireInstance("", api.FunctionMeta{}, api.InvokeOptions{}) + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.Kill("", 0, []byte{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.CreateInstanceRaw([]byte{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.InvokeByInstanceIdRaw([]byte{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.KillRaw([]byte{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.SaveState([]byte{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.LoadState("") + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.KVSet("", []byte{}, api.SetParam{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.KVSetWithoutKey([]byte{}, api.SetParam{}) + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.KVMSetTx([]string{}, [][]byte{}, api.MSetParam{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.KVGet("", 1) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.KVGetMulti([]string{}, 1) + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.KVDel("") + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.KVDelMulti([]string{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.CreateProducer("", api.ProducerConf{}) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.Subscribe("", api.SubscriptionConfig{}) + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.DeleteStream("") + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.QueryGlobalProducersNum("") + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.QueryGlobalConsumersNum("") + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.SetTenantID("") + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.Put("", []byte{}, api.PutParam{}, "") + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.PutRaw("", []byte{}, api.PutParam{}, "") + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.Get([]string{}, 1) + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.GetRaw([]string{}, 1) + assert.Equal(t, nil, err) + + _, _, aa := fakeLibruntimeSdkClient.Wait([]string{}, 1, 1) + assert.Equal(t, map[string]error(map[string]error(nil)), aa) + + _, err = fakeLibruntimeSdkClient.GIncreaseRef([]string{}, "") + assert.Equal(t, nil, err) + + _, err = fakeLibruntimeSdkClient.GDecreaseRefRaw([]string{}, "") + assert.Equal(t, nil, err) + + bb := fakeLibruntimeSdkClient.GetFormatLogger() + assert.Equal(t, nil, bb) + + _, err = fakeLibruntimeSdkClient.CreateClient(api.ConnectArguments{}) + assert.Equal(t, nil, err) + + err = fakeLibruntimeSdkClient.ReleaseGRefs("") + assert.Equal(t, nil, err) + + credential := fakeLibruntimeSdkClient.GetCredential() + assert.NotEqual(t, nil, credential) +} + +func TestFakeStreamProducer(t *testing.T) { + fakeStreamProducer := FakeStreamProducer{} + err := fakeStreamProducer.Send(api.Element{}) + assert.Equal(t, nil, err) + + err = fakeStreamProducer.SendWithTimeout(api.Element{}, 1) + assert.Equal(t, nil, err) + + err = fakeStreamProducer.Flush() + assert.Equal(t, nil, err) + + err = fakeStreamProducer.Close() + assert.Equal(t, nil, err) +} + +func TestFakeStreamConsumer(t *testing.T) { + fakeStreamConsumer := FakeStreamConsumer{} + _, err := fakeStreamConsumer.ReceiveExpectNum(1, 1) + assert.Equal(t, nil, err) + + _, err = fakeStreamConsumer.Receive(1) + assert.Equal(t, nil, err) + + err = fakeStreamConsumer.Ack(1) + assert.Equal(t, nil, err) + + err = fakeStreamConsumer.Close() + assert.Equal(t, nil, err) +} diff --git a/yuanrong/pkg/common/faas_common/utils/memory_test.go b/yuanrong/pkg/common/faas_common/utils/memory_test.go new file mode 100644 index 0000000..611b18f --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/memory_test.go @@ -0,0 +1,22 @@ +package utils + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestClearStringMemory(t *testing.T) { + Convey("Given a string", t, func() { + testStr := "helloworld" + + b := []byte(testStr) + s := string(b) + Convey("When we clear the string", func() { + ClearStringMemory(s) + Convey("The string should be empty", func() { + So(s, ShouldEqual, string([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0})) + }) + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/utils/mock_utils.go b/yuanrong/pkg/common/faas_common/utils/mock_utils.go new file mode 100644 index 0000000..ff3714b --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/mock_utils.go @@ -0,0 +1,138 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "context" + "errors" + + "github.com/agiledragon/gomonkey/v2" + "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "yuanrong.org/kernel/runtime/libruntime/api" +) + +// PatchSlice - +type PatchSlice []*gomonkey.Patches + +// PatchesFunc - +type PatchesFunc func() PatchSlice + +// InitPatchSlice - +func InitPatchSlice() PatchSlice { + return make([]*gomonkey.Patches, 0) +} + +// Append - +func (p *PatchSlice) Append(patches PatchSlice) { + if len(patches) > 0 { + *p = append(*p, patches...) + } +} + +// ResetAll - +func (p PatchSlice) ResetAll() { + for _, item := range p { + item.Reset() + } +} + +// FakeLogger - +type FakeLogger struct{} + +// With - +func (f *FakeLogger) With(fields ...zapcore.Field) api.FormatLogger { + return f +} + +// Infof - +func (f *FakeLogger) Infof(format string, paras ...interface{}) {} + +// Errorf - +func (f *FakeLogger) Errorf(format string, paras ...interface{}) {} + +// Warnf - +func (f *FakeLogger) Warnf(format string, paras ...interface{}) {} + +// Debugf - +func (f *FakeLogger) Debugf(format string, paras ...interface{}) {} + +// Fatalf - +func (f *FakeLogger) Fatalf(format string, paras ...interface{}) {} + +// Info - +func (f *FakeLogger) Info(msg string, fields ...zap.Field) {} + +// Error - +func (f *FakeLogger) Error(msg string, fields ...zap.Field) {} + +// Warn - +func (f *FakeLogger) Warn(msg string, fields ...zap.Field) {} + +// Debug - +func (f *FakeLogger) Debug(msg string, fields ...zap.Field) {} + +// Fatal - +func (f *FakeLogger) Fatal(msg string, fields ...zap.Field) {} + +// Sync - +func (f *FakeLogger) Sync() {} + +// FakeEtcdLease - +type FakeEtcdLease struct { +} + +// Grant - +func (m FakeEtcdLease) Grant(_ context.Context, _ int64) (*clientv3.LeaseGrantResponse, error) { + return &clientv3.LeaseGrantResponse{ID: 1}, nil +} + +// Revoke - +func (m FakeEtcdLease) Revoke(_ context.Context, _ clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) { + return nil, nil +} + +// TimeToLive - +func (m FakeEtcdLease) TimeToLive(_ context.Context, _ clientv3.LeaseID, + _ ...clientv3.LeaseOption) (*clientv3.LeaseTimeToLiveResponse, error) { + return nil, nil +} + +// Leases - +func (m FakeEtcdLease) Leases(_ context.Context) (*clientv3.LeaseLeasesResponse, error) { + return nil, nil +} + +// KeepAlive - +func (m FakeEtcdLease) KeepAlive(_ context.Context, _ clientv3.LeaseID) ( + <-chan *clientv3.LeaseKeepAliveResponse, error) { + return nil, nil +} + +// KeepAliveOnce - +func (m FakeEtcdLease) KeepAliveOnce(_ context.Context, _ clientv3.LeaseID) ( + *clientv3.LeaseKeepAliveResponse, error) { + return nil, nil +} + +// Close - +func (m FakeEtcdLease) Close() error { + return errors.New("close error") +} diff --git a/yuanrong/pkg/common/faas_common/utils/resourcepath.go b/yuanrong/pkg/common/faas_common/utils/resourcepath.go new file mode 100644 index 0000000..17e3f46 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/resourcepath.go @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils for common functions +package utils + +import ( + "fmt" + "os" + "os/exec" + "path" + "path/filepath" + "strings" +) + +// GetResourcePath Get Resource Path +func GetResourcePath() string { + return getPath("ResourcePath", "resource") +} + +// GetServicesPath Get Services Path +func GetServicesPath() string { + return getPath("ServicesPath", "service-config") +} + +func getPath(env, defaultPath string) string { + envPath := os.Getenv(env) + if envPath == "" { + var err error + cliPath, err := exec.LookPath(os.Args[0]) + if err != nil { + return envPath + } + envPath, err = filepath.Abs(filepath.Dir(cliPath)) + // do not return this error + if err != nil { + fmt.Printf("GetResourcePath abs filepath dir error") + } + envPath = strings.Replace(envPath, "\\", "/", -1) + envPath = path.Join(path.Dir(envPath), defaultPath) + } else { + envPath = strings.Replace(envPath, "\\", "/", -1) + } + + return envPath +} diff --git a/yuanrong/pkg/common/faas_common/utils/scheduler_option.go b/yuanrong/pkg/common/faas_common/utils/scheduler_option.go new file mode 100644 index 0000000..ba6c4c9 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/scheduler_option.go @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "fmt" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/types" +) + +const ( + schedulePolicyKey = "schedule_policy" + scheduleCPU = "CPU" + scheduleMemory = "Memory" +) + +const ( + // NodeSelectorKey - + NodeSelectorKey = "node_selector" + // MonopolyPolicyValue - + MonopolyPolicyValue = "monopoly" + // SharedPolicyValue - + SharedPolicyValue = "shared" +) + +// CreateCustomExtensions create customExtensions +func CreateCustomExtensions(customExtensions map[string]string, schedulePolicy string) map[string]string { + if customExtensions == nil { + customExtensions = make(map[string]string, 1) + } + customExtensions[schedulePolicyKey] = schedulePolicy + return customExtensions +} + +// CreatePodAffinity - create pod affinity +func CreatePodAffinity(key, label string, affinityType api.AffinityType) []api.Affinity { + var ( + operators []api.LabelOperator + affinity []api.Affinity + ) + if label != "" { + operators = append(operators, api.LabelOperator{ + Type: api.LabelOpIn, + LabelKey: key, + LabelValues: []string{label}, + }) + } else { + operators = append(operators, api.LabelOperator{ + Type: api.LabelOpExists, + LabelKey: key, + LabelValues: []string{}, + }) + } + affinity = append(affinity, api.Affinity{ + Kind: api.AffinityKindInstance, + Affinity: affinityType, + PreferredPriority: false, + PreferredAntiOtherLabels: false, + LabelOps: operators, + }) + return affinity +} + +// CreateCreateOptions create CreateOptions +func CreateCreateOptions(createOptions map[string]string, key, value string) map[string]string { + if createOptions == nil { + return make(map[string]string) + } + createOptions[key] = value + return createOptions +} + +// GenerateResourcesMap - +func GenerateResourcesMap(cpu, memory float64) map[string]float64 { + resourcesMap := make(map[string]float64) + resourcesMap[scheduleCPU] = cpu + resourcesMap[scheduleMemory] = memory + return resourcesMap +} + +// AddNodeSelector - +func AddNodeSelector(nodeSelectorMap map[string]string, extraParams *types.ExtraParams) { + if extraParams.CustomExtensions == nil { + extraParams.CustomExtensions = make(map[string]string, 1) + } + if nodeSelectorMap != nil && len(nodeSelectorMap) != 0 { + for k, v := range nodeSelectorMap { + extraParams.CustomExtensions[NodeSelectorKey] = fmt.Sprintf(`{"%s": "%s"}`, k, v) + } + } +} diff --git a/yuanrong/pkg/common/faas_common/utils/scheduler_option_test.go b/yuanrong/pkg/common/faas_common/utils/scheduler_option_test.go new file mode 100644 index 0000000..d34011f --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/scheduler_option_test.go @@ -0,0 +1,97 @@ +package utils + +import ( + "reflect" + "testing" + + . "github.com/smartystreets/goconvey/convey" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/types" +) + +func TestCreateCustomExtensions(t *testing.T) { + Convey("Test CreateCustomExtensions", t, func() { + got := CreateCustomExtensions(nil, MonopolyPolicyValue) + So(got, ShouldNotBeNil) + }) +} + +func TestCreateCreateOptions(t *testing.T) { + Convey("Test CreateSchedulingOptions", t, func() { + expectValue := "test" + createOptions := make(map[string]string, 20) + got := CreateCreateOptions(createOptions, "test", "test") + So(got["test"], ShouldEqual, expectValue) + }) +} + +func TestGenerateResourcesMap(t *testing.T) { + Convey("Test GenerateResourcesMap", t, func() { + res := GenerateResourcesMap(300, 128) + So(res, ShouldResemble, map[string]float64{ + scheduleCPU: 300, + scheduleMemory: 128, + }) + }) +} + +func TestCreatePodAffinity(t *testing.T) { + type args struct { + key string + label string + affinityType api.AffinityType + } + tests := []struct { + name string + args args + want []api.Affinity + }{ + {"case1", args{ + key: "faasfrontend", + label: "faasfrontend", + affinityType: api.PreferredAntiAffinity, + }, []api.Affinity{ + api.Affinity{ + Kind: api.AffinityKindInstance, + Affinity: api.PreferredAntiAffinity, + PreferredPriority: false, + PreferredAntiOtherLabels: false, + LabelOps: []api.LabelOperator{{ + Type: api.LabelOpIn, + LabelKey: "faasfrontend", + LabelValues: []string{"faasfrontend"}, + }, + }, + }}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CreatePodAffinity(tt.args.key, tt.args.label, tt.args.affinityType); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CreatePodAffinity() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddNodeSelector(t *testing.T) { + type args struct { + nodeSelectorMap map[string]string + extraParams *types.ExtraParams + } + tests := []struct { + name string + args args + }{ + {"case1", args{ + nodeSelectorMap: map[string]string{"k": "v"}, + extraParams: &types.ExtraParams{CustomExtensions: make(map[string]string)}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + AddNodeSelector(tt.args.nodeSelectorMap, tt.args.extraParams) + }) + } +} diff --git a/yuanrong/pkg/common/faas_common/utils/tools.go b/yuanrong/pkg/common/faas_common/utils/tools.go new file mode 100644 index 0000000..face6d7 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/tools.go @@ -0,0 +1,656 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils for common functions +package utils + +import ( + "bufio" + "crypto/md5" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "fmt" + "hash/fnv" + "io" + "io/ioutil" + "math" + "math/rand" + "net" + "os" + "path" + "path/filepath" + "reflect" + "strings" + "sync" + "syscall" + "time" + "unsafe" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/uuid" +) + +const ( + // OriginDefaultTimeout is 900 + OriginDefaultTimeout = 900 + // maxTimeout is 100 days + maxTimeout = 100 * 24 * 3600 + bytesToMb = 1024 * 1024 + uint64ArrayLength = 8 + uint32Len = 4 + // DirMode dir mode + DirMode = 0700 + // FileMode file mode + FileMode = 0600 + readSize = 32 * 1024 + // ObsMaxRetry obs max retry times 0 + ObsMaxRetry = 0 + // ObsDefaultTimeout 30 seconds + ObsDefaultTimeout = 30 + // ObsDefaultConnectTimeout 10 seconds + ObsDefaultConnectTimeout = 5 + // LayerListSep define the LayerList separation character + LayerListSep = "-#-" + instanceIDLength = 2 + dnsPairLength = 2 + hostFilePath = "/etc/hosts" + defaultMessageLen = 256 +) + +const ( + minimumMemoryUnit = 128 + minimumCPUUnit = 100 + minimumReservedCPUUnit = 200 +) + +const ( + tenantValueIndex = 6 + funcNameValueIndex = 8 + versionValueIndex = 10 + instanceIDValueIndex = 13 + functionSchedulerKeyLen = 14 + moduleSchedulerKeyLen = 7 + functionNameIndex = 6 + defaultVersion = "latest" + defaultTenant = "0" + defaultFunctionName = "faas-scheduler" +) + +type hostFileInfo struct { + Sha256 string + Content []byte + Mutex sync.Mutex +} + +// HostFile /etc/hosts file info +var HostFile hostFileInfo + +// SetClusterNameEnv - +func SetClusterNameEnv(clusterName string) error { + if err := os.Setenv(constant.ClusterNameEnvKey, clusterName); err != nil { + return fmt.Errorf("failed to set env of %s, err: %s", constant.ClusterNameEnvKey, err.Error()) + } + return nil +} + +// CalculateCPUByMemory CPU and memory calculation methods presented by fg: cpu=memory/128*100+200 +func CalculateCPUByMemory(memory int) int { + return memory/minimumMemoryUnit*minimumCPUUnit + minimumReservedCPUUnit +} + +var azEnv = parseAzEnv() + +func parseAzEnv() string { + az := os.Getenv(constant.ZoneKey) + if az == "" { + az = constant.DefaultAZ + } + if len(az) > constant.ZoneNameLen { + az = az[0 : constant.ZoneNameLen-1] + } + return az +} + +// AzEnv set defaultaz env +func AzEnv() string { + return azEnv +} + +// GenerateInstanceID - +func GenerateInstanceID(podName string) string { + return AzEnv() + "-#-" + podName +} + +// GetPodNameByInstanceID - +func GetPodNameByInstanceID(instanceID string) string { + elements := strings.Split(instanceID, LayerListSep) + if len(elements) < instanceIDLength { + return "" + } + return elements[1] +} + +// Domain2IP convert domain to ip +func Domain2IP(endpoint string) (string, error) { + var host, port string + var err error + host = endpoint + if strings.Contains(endpoint, ":") { + host, port, err = net.SplitHostPort(endpoint) + if err != nil { + return "", err + } + } + if net.ParseIP(host) != nil { + return endpoint, nil + } + ips, err := net.LookupHost(host) + if err != nil { + return "", err + } + if port == "" { + return ips[0], nil + } + return net.JoinHostPort(ips[0], port), nil +} + +// DeepCopy will generate a new copy of original collection type +// currently this function is not recursive so elements will not be deep copied +func DeepCopy(origin interface{}) interface{} { + oriTyp := reflect.TypeOf(origin) + oriVal := reflect.ValueOf(origin) + switch oriTyp.Kind() { + case reflect.Slice: + elemType := oriTyp.Elem() + length := oriVal.Len() + capacity := oriVal.Cap() + newObj := reflect.MakeSlice(reflect.SliceOf(elemType), length, capacity) + reflect.Copy(newObj, oriVal) + return newObj.Interface() + case reflect.Map: + newObj := reflect.MakeMapWithSize(oriTyp, len(oriVal.MapKeys())) + for _, key := range oriVal.MapKeys() { + value := oriVal.MapIndex(key) + newObj.SetMapIndex(key, value) + } + return newObj.Interface() + default: + return nil + } +} + +// ValidateTimeout check timeout +func ValidateTimeout(timeout *int64, defaultTimeout int64) { + if *timeout <= 0 { + *timeout = defaultTimeout + return + } + if *timeout > maxTimeout { + *timeout = maxTimeout + } +} + +// ClearStringMemory - +func ClearStringMemory(s string) { + if len(s) == 0 { + return + } + bs := *(*[]byte)(unsafe.Pointer(&s)) + ClearByteMemory(bs) +} + +// ClearByteMemory - +func ClearByteMemory(b []byte) { + for i := 0; i < len(b); i++ { + b[i] = 0 + } +} + +// Float64ToByte - +func Float64ToByte(float float64) []byte { + bits := math.Float64bits(float) + bytes := make([]byte, 8) + binary.LittleEndian.PutUint64(bytes, bits) + return bytes +} + +// ByteToFloat64 - +func ByteToFloat64(bytes []byte) float64 { + // bounds check to guarantee safety of function Uint64 + if len(bytes) != uint64ArrayLength { + return 0 + } + bits := binary.LittleEndian.Uint64(bytes) + return math.Float64frombits(bits) +} + +// ExistPath whether path exists +func ExistPath(path string) bool { + _, err := os.Stat(path) + if err != nil && os.IsNotExist(err) { + return false + } + return true +} + +// IsInputParameterValid check if input parameter is valid +func IsInputParameterValid(cmdName string) bool { + if strings.Contains(cmdName, "&") || + strings.Contains(cmdName, "|") || + strings.Contains(cmdName, ";") || + strings.Contains(cmdName, "$") || + strings.Contains(cmdName, "'") || + strings.Contains(cmdName, "`") || + strings.Contains(cmdName, "(") || + strings.Contains(cmdName, ")") || + strings.Contains(cmdName, "\"") { + return false + } + return true +} + +// UniqueID get unique ID +func UniqueID() string { + return uuid.New().String() +} + +// ShortUUID return short uuid encode by base64 +func ShortUUID() string { + id := uuid.New() + buf := make([]byte, base64.StdEncoding.EncodedLen(len(id))) + base64.StdEncoding.Encode(buf, id[:]) + for i := range buf { + if buf[i] == '=' || buf[i] == '+' || buf[i] == '/' { + buf[i] = '-' + } + } + return strings.ToLower(strings.Trim(string(buf), "-")) +} + +// WriteFileToPath write file to path +func WriteFileToPath(writePath string, buffer []byte) error { + baseDir := path.Dir(writePath) + err := os.MkdirAll(baseDir, DirMode) + if err != nil { + return err + } + if err = ioutil.WriteFile(writePath, buffer, FileMode); err != nil { + return err + } + return nil +} + +// IsConnRefusedErr - +func IsConnRefusedErr(err error) bool { + netErr, ok := err.(net.Error) + if !ok { + return false + } + opErr, ok := netErr.(*net.OpError) + if !ok { + return false + } + syscallErr, ok := opErr.Err.(*os.SyscallError) + if !ok { + return false + } + if errno, ok := syscallErr.Err.(syscall.Errno); ok { + if errno == syscall.ECONNREFUSED { + return true + } + } + return false +} + +// ContainsConnRefusedErr - +func ContainsConnRefusedErr(err error) bool { + const connRefusedStr = "connection refused" + return strings.Contains(err.Error(), connRefusedStr) +} + +// DefaultStringEnv return environment variable named by key and return val when not exist +func DefaultStringEnv(key string, val string) string { + if env := os.Getenv(key); env != "" { + return env + } + return val +} + +// ReplaceByDNS update /etc/hosts +func ReplaceByDNS(filePath string, domainNames map[string]string) error { + lines, err := ReadLines(filePath) + if err != nil { + return err + } + checkedDNSNames := make(map[string]bool, len(domainNames)) + var hasChange bool + for i := range lines { + arr := strings.Fields(lines[i]) + if len(arr) != dnsPairLength { + continue + } + for name, ipAddress := range domainNames { + if arr[0] == name || arr[1] == name { + originLine := lines[i] + lines[i] = ipAddress + " " + name + checkedDNSNames[name] = true + if lines[i] != originLine { + hasChange = true + } + break + } + } + } + for name, ipAddress := range domainNames { + // domain name is not in hosts file will append to hosts file + if !checkedDNSNames[name] { + lines = append(lines, ipAddress+" "+name) + hasChange = true + } + } + if !hasChange { + return nil + } + HostFile.Mutex.Lock() + defer HostFile.Mutex.Unlock() + if err := WriteLines(filePath, lines); err != nil { + return err + } + if err := HostFile.SaveHostFileInfo(); err != nil { + return err + } + return nil +} + +func (hostFileInfo) SaveHostFileInfo() error { + _, sha, err := GetFileHashInfo(hostFilePath) + if err != nil { + return err + } + HostFile.Sha256 = sha + content, err := ioutil.ReadFile(hostFilePath) + if err != nil { + return err + } + HostFile.Content = content + return nil +} + +// ReadLines read the lines of the given file. +func ReadLines(path string) ([]string, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + var lines []string + scanner := bufio.NewScanner(file) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + return lines, scanner.Err() +} + +// WriteLines writes the lines to the given file. +func WriteLines(path string, lines []string) error { + file, err := os.Create(path) + if err != nil { + return err + } + defer file.Close() + + w := bufio.NewWriter(file) + for _, line := range lines { + fmt.Fprintln(w, line) + } + return w.Flush() +} + +// GenStateIDByKey returns stateID by serviceID, functionName and key +func GenStateIDByKey(tenantID, serviceID, funcName, key string) string { + // if stateKey is empty, stateID is generated by default. + if len(key) == 0 { + return uuid.New().String() + } + preAllocationSlice := make([]byte, 0, len(tenantID)+len(serviceID)+len(funcName)+len(key)) + preAllocationSlice = append(preAllocationSlice, tenantID...) + preAllocationSlice = append(preAllocationSlice, serviceID...) + preAllocationSlice = append(preAllocationSlice, funcName...) + preAllocationSlice = append(preAllocationSlice, key...) + stateID := uuid.NewSHA1(uuid.NameSpaceURL, preAllocationSlice) + return stateID.String() +} + +// GetFileHashInfo get file hash info +func GetFileHashInfo(path string) (int64, string, error) { + var fileSize int64 + realPath, err := filepath.Abs(path) + if err != nil { + return 0, "", err + } + file, err := os.Open(realPath) + if err != nil { + return 0, "", err + } + defer file.Close() + stat, err := file.Stat() + if err != nil { + return 0, "", err + } + fileSize = stat.Size() + fileHash := sha256.New() + if _, err := io.Copy(fileHash, file); err != nil { + return 0, "", err + } + hashValue := hex.EncodeToString(fileHash.Sum(nil)) + return fileSize, hashValue, nil +} + +// IsNetworkError judge whether it is a network error +func IsNetworkError(err error) bool { + if err == nil { + return false + } + _, ok := err.(net.Error) + if !ok { + return false + } + return true +} + +// IsUserError - +func IsUserError(err error) bool { + newErr, ok := err.(snerror.SNError) + if !ok { + return false + } + return snerror.IsUserError(newErr) +} + +// FnvHashInt a hash function +func FnvHashInt(s string) int { + h := fnv.New32a() + _, err := h.Write([]byte(s)) + if err != nil { + return 0 + } + + // for 2 <= base <= 36. The result uses the lower-case letters 'a' to 'z' + return int(h.Sum32()) +} + +// FileMD5 calculate the md5 of file +func FileMD5(filePath string) (string, error) { + file, err := os.Open(filePath) + defer file.Close() + if err != nil { + return "", err + } + hash := md5.New() + _, err = io.Copy(hash, file) + if err != nil { + return "", err + } + return hex.EncodeToString(hash.Sum(nil)), nil +} + +// ShuffleOneArray - +func ShuffleOneArray(arr []string) []string { + arrLength := len(arr) + if arrLength <= 1 { + return arr + } + copyArr := make([]string, arrLength) + copy(copyArr, arr) + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(arrLength, func(i, j int) { copyArr[i], copyArr[j] = copyArr[j], copyArr[i] }) + return copyArr +} + +// IsCAEFunc judge whether it is a CAE function +func IsCAEFunc(businessType string) bool { + return businessType == constant.BusinessTypeCAE +} + +// IsWebSocketFunc return true if the business type is websocket or cae with enable remote debug +func IsWebSocketFunc(businessType string, enableRemoteDebug bool) bool { + return businessType == constant.BusinessTypeWebSocket || + (businessType == constant.BusinessTypeCAE && enableRemoteDebug) +} + +var directFunctions = map[string]struct{}{ + "javax": {}, +} + +// IsDirectFunc check whether it if a direct function (runtime connect to bus directly) +func IsDirectFunc(language string) bool { + _, ok := directFunctions[language] + return ok +} + +// IsStringInArray - +func IsStringInArray(str string, arr []string) bool { + for _, s := range arr { + if s == str { + return true + } + } + return false +} + +// GetFunctionInstanceInfoFromEtcdKey parses the instance info from the etcd path +// e.g. /sn/instance/business/yrk/tenant/0/function/xxx/version/lastest/defaultaz/ +// job-9e54951c-task-77156757-fb16-4b4a-ad61-6646c7d1c57c-d4ad6c74-0/3f079541-15fc-4009-8c41-50b2b2936772 +func GetFunctionInstanceInfoFromEtcdKey(path string) (*types.InstanceInfo, error) { + elements := strings.Split(path, "/") + if len(elements) != functionSchedulerKeyLen { + return nil, fmt.Errorf("unexpected etcd path format: %s", path) + } + return &types.InstanceInfo{ + TenantID: elements[tenantValueIndex], + FunctionName: elements[funcNameValueIndex], + Version: elements[versionValueIndex], + InstanceName: elements[instanceIDValueIndex], + InstanceID: elements[instanceIDValueIndex], + }, nil +} + +// GetModuleSchedulerInfoFromEtcdKey /sn/faas-scheduler/instances/cluster001/7.xx.xx.25/faas-scheduler-xxxx-8xdjf +func GetModuleSchedulerInfoFromEtcdKey(path string) (*types.InstanceInfo, error) { + elements := strings.Split(path, "/") + if len(elements) != moduleSchedulerKeyLen { + return nil, fmt.Errorf("unexpected etcd path format: %s", path) + } + return &types.InstanceInfo{ + TenantID: defaultTenant, + FunctionName: defaultFunctionName, + Version: defaultVersion, + InstanceName: elements[functionNameIndex], + }, nil +} + +// CheckFaaSSchedulerInstanceFault - +func CheckFaaSSchedulerInstanceFault(status types.InstanceStatus) bool { + faultInstanceStatusMap := map[constant.InstanceStatus]struct{}{ + constant.KernelInstanceStatusFatal: {}, + constant.KernelInstanceStatusScheduleFailed: {}, + constant.KernelInstanceStatusEvicting: {}, + constant.KernelInstanceStatusEvicted: {}, + constant.KernelInstanceStatusExiting: {}, + constant.KernelInstanceStatusExited: {}, + } + + _, ok := faultInstanceStatusMap[constant.InstanceStatus(status.Code)] + return ok +} + +// IsNil checks if an object (could be an interface) is nil +func IsNil(i interface{}) bool { + return i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) +} + +// CalcFileMD5 calculates file MD5 +func CalcFileMD5(filepath string) string { + file, err := os.Open(filepath) + if err != nil { + return "" + } + defer file.Close() + hash := md5.New() + _, err = io.Copy(hash, file) + if err != nil { + return "" + } + return hex.EncodeToString(hash.Sum(nil)) +} + +// ReceiveWithinTimeout first element is the chan, second element is the timeout +func ReceiveWithinTimeout[T any](ch <-chan T, timeout time.Duration) (T, bool) { + var val T + select { + case val, ok := <-ch: + return val, ok + case <-time.After(timeout): + return val, false + } +} + +func MessageTruncation(message string) string { + if len(message) > defaultMessageLen { + return message[:defaultMessageLen] + } + return message +} + +// SafeCloseChannel will close channel in a safe way +func SafeCloseChannel(stopCh chan struct{}) { + if stopCh == nil { + return + } + select { + case _, ok := <-stopCh: + if ok { + close(stopCh) + } + default: + close(stopCh) + } +} diff --git a/yuanrong/pkg/common/faas_common/utils/tools_test.go b/yuanrong/pkg/common/faas_common/utils/tools_test.go new file mode 100644 index 0000000..5824a79 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/utils/tools_test.go @@ -0,0 +1,649 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "errors" + "fmt" + "io/ioutil" + "net" + "os" + "strings" + "syscall" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/types" +) + +// TestDomain2IP convert domain to ip +func TestDomain2IP(t *testing.T) { + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(net.LookupHost, func(_ string) ([]string, error) { + return []string{"1.1.1.1"}, nil + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + + type args struct { + endpoint string + } + tests := []struct { + args args + want string + wantErr bool + }{ + { + args{endpoint: "1.1.1.1:9000"}, + "1.1.1.1:9000", + false, + }, + { + args{endpoint: "1.1.1.1"}, + "1.1.1.1", + false, + }, + { + args{endpoint: "test:9000"}, + "1.1.1.1:9000", + false, + }, + { + args{endpoint: "test"}, + "1.1.1.1", + false, + }, + } + for _, tt := range tests { + got, err := Domain2IP(tt.args.endpoint) + if (err != nil) != tt.wantErr { + t.Errorf("Domain2IP() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Domain2IP() got = %v, want %v", got, tt.want) + } + } +} + +func TestGenStateIDByKey(t *testing.T) { + convey.Convey("Test gen stateID by UUID", t, func() { + stateID := GenStateIDByKey("tenantID", "serviceID", "funcName", "") + convey.So(stateID, convey.ShouldNotBeNil) + }) + convey.Convey("Test gen stateID by params", t, func() { + stateID := GenStateIDByKey("tenantID", "serviceID", "funcName", "key") + convey.So(stateID, convey.ShouldEqual, "993e96b4-0550-523f-a412-a4b58682cb2e") + }) +} + +func TestGetFileHashInfo(t *testing.T) { + convey.Convey("Test get file hashInfo failed", t, func() { + _, _, err := GetFileHashInfo("/xyz") + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestFloat64ToByte(t *testing.T) { + value := 123.45 + bytesValue := Float64ToByte(value) + if ByteToFloat64(bytesValue) != value { + t.Errorf("Float64ToByte and ByteToFloat64 failed") + } +} + +func TestExistPath(t *testing.T) { + path := os.Args[0] + if !ExistPath(path) { + t.Errorf("test path exist true failed, path: %s", path) + } + if ExistPath(path + "abc") { + t.Errorf("test path exist false failed, path: %s", path+"abc") + } +} + +func TestUniqueID(t *testing.T) { + uuid1 := UniqueID() + uuid2 := UniqueID() + assert.NotEqual(t, uuid1, uuid2) +} + +func Test_parseAzEnv(t *testing.T) { + assert.Equal(t, constant.DefaultAZ, AzEnv()) + tests := []struct { + name string + zoneValue string + want string + }{ + { + name: "empty zoneValue", + zoneValue: "", + want: constant.DefaultAZ, + }, + { + name: fmt.Sprintf("ZoneName > %d", constant.ZoneNameLen), + zoneValue: "12345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890" + + "123456", + want: "12345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890" + + "1234", + }, + { + name: "Normal", + zoneValue: "1234567890", + want: "1234567890", + }, + } + for _, tt := range tests { + if err := os.Setenv(constant.ZoneKey, tt.zoneValue); err != nil { + t.Errorf("failed to set Zone env, %s", err) + } + actual := parseAzEnv() + assert.Equal(t, tt.want, actual) + } +} + +func TestIsConnRefusedErr(t *testing.T) { + err := errors.New("abc") + assert.False(t, IsConnRefusedErr(err)) + err = syscall.EADDRINUSE + assert.False(t, IsConnRefusedErr(err)) + _, err = net.Dial("tcp", "127.0.0.1:33334") + assert.True(t, IsConnRefusedErr(err)) +} + +func TestContainsConnRefusedErr(t *testing.T) { + err := errors.New("dial tcp 10.249.0.54:22668: connect: connection refused") + assert.True(t, ContainsConnRefusedErr(err)) +} + +// TestWriteFileToPath is used to test the function of writing a file to a specified path. +func TestWriteFileToPath(t *testing.T) { + dir, err := ioutil.TempDir("", "test") + assert.Equal(t, err, nil) + addFile, err := ioutil.TempFile(dir, "test") + err = WriteFileToPath(addFile.Name(), []byte("test")) + assert.Equal(t, err, nil) +} + +// TestIsHexString is used to test whether the character string meets the requirements. +func TestIsHexString(t *testing.T) { + flag := IsHexString("2345") + assert.True(t, flag) + flag = IsHexString("test") + assert.False(t, flag) +} + +// TestValidateTimeout: indicates whether the timeout interval exceeds the maximum value or is the default value. +func TestValidateTimeout(t *testing.T) { + var timeout int64 = -1 + var defaultTimeout int64 = 1 + ValidateTimeout(&timeout, defaultTimeout) + assert.Equal(t, timeout, int64(1)) + timeout = 100*24*3600 + 1 + ValidateTimeout(&timeout, defaultTimeout) + assert.Equal(t, timeout, int64(100*24*3600)) +} + +// TestDeepCopy is used to test the deep copy of maps and slices. +func TestDeepCopy(t *testing.T) { + str := []string{"test1", "test2"} + cpyStr := DeepCopy(str) + curStr, ok := cpyStr.([]string) + assert.True(t, ok) + assert.Equal(t, len(curStr), 2) + assert.Equal(t, curStr[0], "test1") + assert.Equal(t, curStr[1], "test2") + + tmpMap := make(map[string]string) + tmpMap["test1"] = "test1" + tmpMap["test2"] = "test2" + cpyMap := DeepCopy(tmpMap) + curMap, ok := cpyMap.(map[string]string) + assert.True(t, ok) + assert.Equal(t, len(curMap), 2) + assert.Equal(t, curMap["test1"], "test1") + assert.Equal(t, curMap["test2"], "test2") +} + +func TestIsInputParameterValid(t *testing.T) { + res1 := IsInputParameterValid("|") + assert.Equal(t, res1, false) + res2 := IsInputParameterValid("ddd") + assert.Equal(t, res2, true) + res3 := IsInputParameterValid("ab(d)e") + assert.Equal(t, res3, false) + res4 := IsInputParameterValid("abde;") + assert.Equal(t, res4, false) + res5 := IsInputParameterValid("&abde") + assert.Equal(t, res5, false) +} + +func TestDefaultString(t *testing.T) { + convey.Convey("TestDefaultString", t, func() { + convey.So(DefaultStringEnv("abc", "def"), convey.ShouldEqual, "def") + }) +} + +func Test_replaceByDNS(t *testing.T) { + convey.Convey("Test_replaceByDNSError", t, func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(ReadLines, func(path string) ([]string, error) { + return nil, errors.New("mock error") + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + err := ReplaceByDNS("", nil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("Test_replaceByDNS", t, func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(ReadLines, func(path string) ([]string, error) { + return []string{"192.168.1.1 www.example.com"}, nil + }), + gomonkey.ApplyFunc(WriteLines, func(path string, lines []string) error { + return nil + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + err := ReplaceByDNS("", map[string]string{"www.example.com": "192.168.1.2"}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("Test_replaceByDNSFileReWrite", t, func() { + testLine := []string{"192.168.1.1 www.example.com"} + err1 := WriteLines("/tmp/dnsTestFile", testLine) + convey.So(err1, convey.ShouldBeNil) + + // 能够将第一次的文件内容覆盖 + testLine = []string{"192.168.1.1 www.example.com", + "192.168.1.2 www.example2.com", + "192.168.1.3 www.example3.com", + "192.168.1.4 www.example4.com", + "192.168.1.5 www.example5.com", + "192.168.1.6 www.example6.com", + "192.168.1.7 www.example7.com", + "192.168.1.8 www.example8.com", + "192.168.1.9 www.example9.com", + "192.168.1.10 www.example10.com", + "192.168.1.11 www.example11.com", + "192.168.1.12 www.example12.com", + "192.168.1.13 www.example13.com test-array-len-3-exception"} + err1 = WriteLines("/tmp/dnsTestFile", testLine) + lineContext, err1 := ReadLines("/tmp/dnsTestFile") + convey.So(err1, convey.ShouldBeNil) + convey.So(len(lineContext), convey.ShouldEqual, 13) + var rst = 1 + for i := range lineContext { + if strings.Contains(lineContext[i], "www.example.com") { + if strings.Contains(lineContext[i], "192.168.1.1") { + rst = 0 + } + } + } + convey.So(rst, convey.ShouldEqual, 0) + + // 修改其中的一个条,其他内容不变 + err := ReplaceByDNS("/tmp/dnsTestFile", map[string]string{"www.example.com": "192.168.1.4"}) + err = ReplaceByDNS("/tmp/dnsTestFile", map[string]string{"www.example.com": "192.168.1.4"}) + convey.So(err, convey.ShouldBeNil) + lineContext, err1 = ReadLines("/tmp/dnsTestFile") + fmt.Println(lineContext) + convey.So(err1, convey.ShouldBeNil) + convey.So(len(lineContext), convey.ShouldEqual, 13) + lineContext, err = ReadLines("/tmp/dnsTestFile") + rst = 1 + for i := range lineContext { + if strings.Contains(lineContext[i], "www.example.com") { + if strings.Contains(lineContext[i], "192.168.1.4") { + rst = 0 + } + } + } + convey.So(rst, convey.ShouldEqual, 0) + rst = 1 + for i := range lineContext { + if strings.Contains(lineContext[i], "www.example2.com") { + if strings.Contains(lineContext[i], "192.168.1.2") { + rst = 0 + } + } + } + convey.So(rst, convey.ShouldEqual, 0) + + // 新增一条,修改一条,一条不变,能够保持成功 + testLine = []string{"192.168.1.1 www.example.com", + "192.168.1.2 www.example2.com", + "192.168.1.14 www.example14.com"} + err = ReplaceByDNS("/tmp/dnsTestFile", map[string]string{"www.example.com": "192.168.1.1", + "www.example14.com": "192.168.1.14", + "www.example2.com": "192.168.1.2"}) + convey.So(err, convey.ShouldBeNil) + lineContext, err1 = ReadLines("/tmp/dnsTestFile") + convey.So(err1, convey.ShouldBeNil) + convey.So(len(lineContext), convey.ShouldEqual, 14) + lineContext, err = ReadLines("/tmp/dnsTestFile") + rst = 1 + for i := range lineContext { + if strings.Contains(lineContext[i], "www.example.com") { + if strings.Contains(lineContext[i], "192.168.1.1") { + rst = 0 + } + } + } + convey.So(rst, convey.ShouldEqual, 0) + rst = 1 + for i := range lineContext { + if strings.Contains(lineContext[i], "www.example2.com") { + if strings.Contains(lineContext[i], "192.168.1.2") { + rst = 0 + } + } + } + convey.So(rst, convey.ShouldEqual, 0) + rst = 1 + for i := range lineContext { + if strings.Contains(lineContext[i], "www.example14.com") { + if strings.Contains(lineContext[i], "192.168.1.14") { + rst = 0 + } + } + } + convey.So(rst, convey.ShouldEqual, 0) + }) +} + +func TestReadLines(t *testing.T) { + convey.Convey("Test_replaceByDNSFileReWrite", t, func() { + defer gomonkey.ApplyFunc(os.Open, func(name string) (*os.File, error) { + return nil, fmt.Errorf("os.Open error") + }).Reset() + _, err := ReadLines("/test") + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestShortUUID(t *testing.T) { + assert.Greater(t, 30, len(ShortUUID())) +} + +func TestIsNetworkError(t *testing.T) { + assert.Equal(t, false, IsNetworkError(nil)) + assert.Equal(t, true, IsNetworkError(syscall.EHOSTUNREACH)) + assert.Equal(t, true, IsNetworkError(os.ErrDeadlineExceeded)) + assert.Equal(t, false, IsNetworkError(errors.New("test error"))) +} + +func Test_IsUserError(t *testing.T) { + flag := IsUserError(errors.New("test")) + assert.Equal(t, false, flag) + + snErr := snerror.New(100, "test") + flag = IsUserError(snErr) + assert.Equal(t, false, flag) + + snErr = snerror.New(10500, "test") + flag = IsUserError(snErr) + assert.Equal(t, false, flag) + + snErr = snerror.New(4001, "test") + flag = IsUserError(snErr) + assert.Equal(t, true, flag) +} + +func TestCalculateCPUByMemory(t *testing.T) { + cpuInfo := CalculateCPUByMemory(10) + assert.Equal(t, cpuInfo, 200) +} + +func TestGenerateInstanceID(t *testing.T) { + instanceID := GenerateInstanceID("podName") + assert.Equal(t, instanceID, "defaultaz-#-podName") +} + +func TestGetPodNameByInstanceID(t *testing.T) { + podName := GetPodNameByInstanceID("defaultaz-#-podName") + assert.Equal(t, podName, "podName") +} + +func TestShuffleOneArray(t *testing.T) { + arr1 := []string{"1"} + arr2 := ShuffleOneArray(arr1) + assert.Equal(t, arr1, arr2) + + arr3 := []string{"1", "2", "5", "6", "7"} + arr4 := ShuffleOneArray(arr3) + assert.NotEqual(t, arr3, arr4) + + arr5 := make([]string, 0) + arr6 := ShuffleOneArray(arr5) + assert.Equal(t, len(arr6), 0) +} + +func TestIsCAEFunc(t *testing.T) { + assert.Equal(t, true, IsCAEFunc(constant.BusinessTypeCAE)) + assert.Equal(t, false, IsCAEFunc(constant.WorkerManagerApplier)) +} + +func TestIsDirectFunc(t *testing.T) { + tests := []struct { + name string + expected bool + }{ + { + name: "python3.6", + expected: false, + }, + { + name: "java8", + expected: false, + }, + { + name: "javax", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, IsDirectFunc(tt.name), tt.name) + }) + } +} + +func TestIsNil(t *testing.T) { + var obj *os.File + assert.Equal(t, true, IsNil(obj)) + getObjFunc := func() interface{} { + return obj + } + assert.Equal(t, true, IsNil(getObjFunc())) +} + +func TestCalcFileMD5(t *testing.T) { + os.Remove("./testFile") + assert.Equal(t, "", CalcFileMD5("invalidPath")) + os.WriteFile("./testFile", + []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + assert.Equal(t, "4fca8f1c736ca30135ed16538f4aebfc", CalcFileMD5("./testFile")) + os.Remove("./testFile") +} + +func TestFileMD5(t *testing.T) { + os.Remove("./testFile") + md5, err := FileMD5("invalidPath") + assert.NotNil(t, err) + assert.Equal(t, "", md5) + os.WriteFile("./testFile", + []byte("/sn/function/123/hello/latest|100|{\"name\":\"hello\",\"version\":\"latest\"}\n"), 0600) + md5, err = FileMD5("./testFile") + assert.Nil(t, err) + assert.Equal(t, "4fca8f1c736ca30135ed16538f4aebfc", md5) + os.Remove("./testFile") +} + +func TestFnvHashInt(t *testing.T) { + hashInt := FnvHashInt("123") + assert.Equal(t, 1916298011, hashInt) +} + +func TestSafeCloseStopCh(t *testing.T) { + convey.Convey("stopCh", t, func() { + stopCh := make(chan struct{}, 1) + stopCh <- struct{}{} + SafeCloseChannel(stopCh) + _, ok := <-stopCh + assert.Equal(t, false, ok) + }) + convey.Convey("default", t, func() { + stopCh := make(chan struct{}, 1) + SafeCloseChannel(stopCh) + _, ok := <-stopCh + assert.Equal(t, false, ok) + }) + convey.Convey("chan is nil", t, func() { + SafeCloseChannel(nil) + }) +} + +func TestMessageTruncation(t *testing.T) { + message := "aaaaaaaaaaaaaaaaaaa" + truncationMessage := MessageTruncation(message) + assert.Equal(t, message, truncationMessage) + rawMessage := "" + for i := 0; i < 300; i++ { + rawMessage = rawMessage + "a" + } + truncationMessage = MessageTruncation(rawMessage) + assert.Equal(t, len(truncationMessage), 256) +} + +func TestGetFunctionInstanceInfoFromEtcdKey(t *testing.T) { + convey.Convey("Test GetFunctionInstanceInfoFromEtcdKey", t, func() { + key := "/sn/instance/business/yrk/tenant/0/function/faasscheduler/version/latest/defaultaz/falseParam/requestID/3f079541-15fc-4009-8c41-50b2b2936772" + _, err := GetFunctionInstanceInfoFromEtcdKey(key) + convey.So(err, convey.ShouldNotBeNil) + + key = "/sn/instance/business/yrk/tenant/0/function/faasscheduler/version/latest/defaultaz/requestID/3f079541-15fc-4009-8c41-50b2b2936772" + info, err := GetFunctionInstanceInfoFromEtcdKey(key) + convey.So(err, convey.ShouldBeNil) + convey.So(info.FunctionName, convey.ShouldEqual, "faasscheduler") + convey.So(info.TenantID, convey.ShouldEqual, "0") + convey.So(info.Version, convey.ShouldEqual, "latest") + convey.So(info.InstanceName, convey.ShouldEqual, "3f079541-15fc-4009-8c41-50b2b2936772") + + key = "/sn/instance/business/yrk/tenant/0/function/faasscheduler/version/$latest/defaultaz/requestID/876a3352-44ea-4f0f-83b2-851c50aa89e1" + info, err = GetFunctionInstanceInfoFromEtcdKey(key) + convey.So(err, convey.ShouldBeNil) + convey.So(info.FunctionName, convey.ShouldEqual, "faasscheduler") + convey.So(info.TenantID, convey.ShouldEqual, "0") + convey.So(info.Version, convey.ShouldEqual, "$latest") + convey.So(info.InstanceName, convey.ShouldEqual, "876a3352-44ea-4f0f-83b2-851c50aa89e1") + }) +} + +func TestGetModuleSchedulerInfoFromEtcdKey(t *testing.T) { + convey.Convey("Test GetModuleSchedulerInfoFromEtcdKey", t, func() { + key := "/sn/faas-scheduler/instances/cluster1/node1/falseParam/faas-scheduler-123" + _, err := GetModuleSchedulerInfoFromEtcdKey(key) + convey.So(err, convey.ShouldNotBeNil) + + key = "/sn/faas-scheduler/instances/cluster1/node1/faas-scheduler-123" + info, err := GetModuleSchedulerInfoFromEtcdKey(key) + convey.So(err, convey.ShouldBeNil) + convey.So(info.FunctionName, convey.ShouldEqual, defaultFunctionName) + convey.So(info.TenantID, convey.ShouldEqual, defaultTenant) + convey.So(info.Version, convey.ShouldEqual, defaultVersion) + convey.So(info.InstanceName, convey.ShouldEqual, "faas-scheduler-123") + }) +} + +func TestCheckFaaSSchedulerInstanceFault(t *testing.T) { + convey.Convey("Test CheckFaaSSchedulerInstanceFault", t, func() { + testCases := []struct { + name string + input types.InstanceStatus + expected bool + }{ + { + name: "should return true for KernelInstanceStatusFatal", + input: types.InstanceStatus{Code: int32(constant.KernelInstanceStatusFatal)}, + expected: true, + }, + { + name: "should return true for KernelInstanceStatusScheduleFailed", + input: types.InstanceStatus{Code: int32(constant.KernelInstanceStatusScheduleFailed)}, + expected: true, + }, + { + name: "should return true for KernelInstanceStatusEvicting", + input: types.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicting)}, + expected: true, + }, + { + name: "should return true for KernelInstanceStatusEvicted", + input: types.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicted)}, + expected: true, + }, + { + name: "should return true for KernelInstanceStatusExiting", + input: types.InstanceStatus{Code: int32(constant.KernelInstanceStatusExiting)}, + expected: true, + }, + { + name: "should return true for KernelInstanceStatusExited", + input: types.InstanceStatus{Code: int32(constant.KernelInstanceStatusExited)}, + expected: true, + }, + { + name: "should return false for unknown status", + input: types.InstanceStatus{Code: 999}, + expected: false, + }, + } + + for _, tc := range testCases { + convey.Convey(tc.name, func() { + result := CheckFaaSSchedulerInstanceFault(tc.input) + convey.So(result, convey.ShouldEqual, tc.expected) + }) + } + }) +} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/pod_operator.go b/yuanrong/pkg/common/faas_common/wisecloudtool/pod_operator.go new file mode 100644 index 0000000..6227c17 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/pod_operator.go @@ -0,0 +1,185 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package wisecloudtool - +package wisecloudtool + +import ( + "crypto/tls" + "fmt" + "time" + + "github.com/json-iterator/go" + "github.com/valyala/fasthttp" + "go.uber.org/zap" + "k8s.io/apimachinery/pkg/util/wait" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount" + "yuanrong/pkg/common/faas_common/wisecloudtool/types" +) + +const ( + queryRetryTime = 10 + queryRetryDuration = 200 * time.Millisecond // 初始等待时间 + queryRetryFactor = 4 // 倍数因子(每次翻4倍) + queryRetryJitter = 0.5 // 随机抖动系数 + queryRetryCap = 20 * time.Second // 最大等待时间上限 +) + +var ( + coldStartBackoff = wait.Backoff{ + Duration: queryRetryDuration, + Factor: queryRetryFactor, + Jitter: queryRetryJitter, + Steps: queryRetryTime, + Cap: queryRetryCap, + } +) + +// PodOperator - +type PodOperator struct { + nuwaConsoleAddr string // + nuwaGatewayAddr string + *types.ServiceAccountJwt + *fasthttp.Client + logger api.FormatLogger +} + +// NewColdStarter - +func NewColdStarter(serviceAccountJwt *types.ServiceAccountJwt, logger api.FormatLogger) *PodOperator { + return &PodOperator{ + nuwaConsoleAddr: serviceAccountJwt.NuwaRuntimeAddr, + nuwaGatewayAddr: serviceAccountJwt.NuwaGatewayAddr, + ServiceAccountJwt: serviceAccountJwt, + Client: &fasthttp.Client{ + TLSConfig: &tls.Config{ + InsecureSkipVerify: serviceAccountJwt.TlsConfig.HttpsInsecureSkipVerify, + CipherSuites: serviceAccountJwt.TlsConfig.TlsCipherSuites, + MinVersion: tls.VersionTLS12, + }, + MaxIdemponentCallAttempts: 3, + }, + logger: logger, + } +} + +// ColdStart - +func (p *PodOperator) ColdStart(funcKeyWithRes string, resSpec resspeckey.ResSpecKey, + info *types.NuwaRuntimeInfo) error { + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + req.SetRequestURI(fmt.Sprintf("%s/activator/coldstart", p.nuwaConsoleAddr)) + req.Header.SetMethod(fasthttp.MethodPost) + createInstanceReq := types.NuwaColdCreateInstanceReq{ + RuntimeId: info.WisecloudRuntimeId, + RuntimeType: "Function", + PoolType: "noPool", + Memory: resSpec.Memory, + CPU: resSpec.CPU, + EnvLabel: info.EnvLabel, + } + logger := p.logger.With(zap.Any("funcKeyWithRes", funcKeyWithRes), zap.Any("resKey", resSpec.String())) + + body, err := jsoniter.Marshal(createInstanceReq) + if err != nil { + return err + } + err = serviceaccount.GenerateJwtSignedHeaders(req, body, *info, p.ServiceAccountJwt) + if err != nil { + return err + } + req.SetBodyRaw(body) + + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(resp) + backoffErr := wait.ExponentialBackoff( + coldStartBackoff, func() (bool, error) { + err = p.Client.Do(req, resp) + if err != nil { + return false, nil + } + return true, nil + }) + if backoffErr != nil { + logger.Warnf("cold start error, backoffErr: %s", backoffErr.Error()) + return backoffErr + } + if err != nil { + logger.Warnf("cold start error, backoffErr: %s", err.Error()) + return err + } + if resp.StatusCode()/100 != 2 { // resp http code != 2xx + logger.Warnf("cold start error, code: %d, body: %s", resp.StatusCode(), string(resp.Body())) + return fmt.Errorf("failed to cold start") + } + logger.Infof("cold start %s succeed", info.WisecloudRuntimeId) + return nil +} + +// DelPod will send a req to erase runtime pod +func (p *PodOperator) DelPod(nuwaRuntimeInfo *types.NuwaRuntimeInfo, deploymentName string, + podId string) error { + p.logger.Infof("delete nuwa pod %s", podId) + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + req.SetRequestURI(fmt.Sprintf("%s/runtime/instance", p.NuwaGatewayAddr)) + req.Header.SetMethod(fasthttp.MethodDelete) + destroyInsReq := types.NuwaDestroyInstanceReq{ + RuntimeType: "Function", + RuntimeId: nuwaRuntimeInfo.WisecloudRuntimeId, + InstanceId: podId, + WorkLoadName: deploymentName, + } + + reqBody, err := jsoniter.Marshal(destroyInsReq) + if err != nil { + return err + } + err = serviceaccount.GenerateJwtSignedHeaders(req, reqBody, *nuwaRuntimeInfo, p.ServiceAccountJwt) + if err != nil { + return err + } + req.SetBodyRaw(reqBody) + + logger := p.logger.With(zap.Any("deployment", deploymentName), zap.Any("podId", podId)) + rsp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(rsp) + backoffError := wait.ExponentialBackoff( + coldStartBackoff, func() (bool, error) { + err = p.Client.Do(req, rsp) + if err != nil { + return false, nil + } + return true, nil + }) + if backoffError != nil { + logger.Warnf("delete runtime pod error, backoffErr: %s", backoffError.Error()) + return backoffError + } + if err != nil { + logger.Warnf("delete runtime pod error, err: %s", err.Error()) + return err + } + if rsp.StatusCode()/100 != 2 { // resp http code != 2xx + logger.Warnf("delete runtime pod error, code: %d, body: %s", rsp.StatusCode(), string(rsp.Body())) + return fmt.Errorf("failed to delete runtime pod") + } + logger.Infof("succeed to delete runtime pod %s", podId) + return nil +} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/pod_operator_test.go b/yuanrong/pkg/common/faas_common/wisecloudtool/pod_operator_test.go new file mode 100644 index 0000000..838cf38 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/pod_operator_test.go @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wisecloudtool + +import ( + "crypto/tls" + "errors" + "net" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/valyala/fasthttp" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount" + "yuanrong/pkg/common/faas_common/wisecloudtool/types" +) + +func TestNewColdStarter(t *testing.T) { + saJwt := &types.ServiceAccountJwt{ + NuwaRuntimeAddr: "http://test-addr", + NuwaGatewayAddr: "http://gateway-addr", + TlsConfig: &types.TLSConfig{ + HttpsInsecureSkipVerify: true, + TlsCipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + }, + } + + po := NewColdStarter(saJwt, log.GetLogger()) + + if po.nuwaConsoleAddr != saJwt.NuwaRuntimeAddr { + t.Errorf("expected nuwaConsoleAddr %s, got %s", saJwt.NuwaRuntimeAddr, po.nuwaConsoleAddr) + } + if po.Client == nil { + t.Error("expected non-nil client") + } +} + +func TestColdStart_Success(t *testing.T) { + po := NewColdStarter(&types.ServiceAccountJwt{ + ServiceAccount: &types.ServiceAccount{}, + TlsConfig: &types.TLSConfig{}, + }, log.GetLogger()) + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(serviceaccount.GenerateJwtSignedHeaders, func(*fasthttp.Request, []byte, types.NuwaRuntimeInfo, *types.ServiceAccountJwt) error { + return nil + }) + patches.ApplyMethodFunc(&fasthttp.Client{}, "Do", func(*fasthttp.Request, *fasthttp.Response) error { + return nil + }) + + err := po.ColdStart("funcKey", resspeckey.ResSpecKey{}, &types.NuwaRuntimeInfo{}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } +} + +func TestDelPod_Success(t *testing.T) { + po := NewColdStarter(&types.ServiceAccountJwt{ + ServiceAccount: &types.ServiceAccount{}, + TlsConfig: &types.TLSConfig{}, + }, log.GetLogger()) + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(serviceaccount.GenerateJwtSignedHeaders, func(*fasthttp.Request, []byte, types.NuwaRuntimeInfo, *types.ServiceAccountJwt) error { + return nil + }) + patches.ApplyMethodFunc(&fasthttp.Client{}, "Do", func(*fasthttp.Request, *fasthttp.Response) error { + return nil + }) + runtimeInfo := &types.NuwaRuntimeInfo{ + WisecloudRuntimeId: "test-runtime", + } + err := po.DelPod(runtimeInfo, "deploy1", "pod1") + if err != nil { + t.Errorf("expected nil error, got %v", err) + } +} + +func TestDelPod_Error(t *testing.T) { + po := NewColdStarter(&types.ServiceAccountJwt{ + ServiceAccount: &types.ServiceAccount{}, + TlsConfig: &types.TLSConfig{}, + }, log.GetLogger()) + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(net.Listen, func(string, string) (net.Listener, error) { + return nil, errors.New("test error") + }) + + err := po.DelPod(&types.NuwaRuntimeInfo{}, "deploy1", "pod1") + if err == nil { + t.Error("expected error, got nil") + } +} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/prometheus_metrics.go b/yuanrong/pkg/common/faas_common/wisecloudtool/prometheus_metrics.go new file mode 100644 index 0000000..b02043a --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/prometheus_metrics.go @@ -0,0 +1,285 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package wisecloudtool - +package wisecloudtool + +import ( + "fmt" + "strings" + "sync" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + k8stype "k8s.io/apimachinery/pkg/types" + + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/urnutils" +) + +const defaultLabel = "UNKNOWN_LABEL" +const labelLen = 8 + +var ( + concurrencyGauge = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "yuanrong_concurrency_num", + Help: "The current concurrency number of the application.", + }, + []string{"businessid", "tenantid", "funcname", "version", "label", "namespace", "deployment_name", "pod_name"}, + ) + + leaseRequestTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "yuanrong_lease_total", + Help: "The lease total number of the application.", + }, + []string{"businessid", "tenantid", "funcname", "version", "label", "namespace", "deployment_name", "pod_name"}, + ) +) + +// GetLeaseRequestTotal - +func GetLeaseRequestTotal() *prometheus.CounterVec { + return leaseRequestTotal +} + +// GetConcurrencyGauge - +func GetConcurrencyGauge() *prometheus.GaugeVec { + return concurrencyGauge +} + +// MetricProvider - +type MetricProvider struct { + sync.RWMutex + // key is {funcKey}#{invokeLabel}, subKey namespace value is {namespace, podName} + WorkLoadMap map[string]map[string]*k8stype.NamespacedName +} + +// NewMetricProvider - +func NewMetricProvider() *MetricProvider { + return &MetricProvider{ + RWMutex: sync.RWMutex{}, + WorkLoadMap: make(map[string]map[string]*k8stype.NamespacedName), + } +} + +// AddWorkLoad - +func (m *MetricProvider) AddWorkLoad(funcKey string, invokeLabel string, namespaceName *k8stype.NamespacedName) { + workload := getWorkloadName(funcKey, invokeLabel) + m.Lock() + defer m.Unlock() + + deployments, ok := m.WorkLoadMap[workload] + if !ok { + deployments = make(map[string]*k8stype.NamespacedName) + m.WorkLoadMap[workload] = deployments + } + if _, ok = deployments[namespaceName.String()]; !ok { + deployments[namespaceName.String()] = namespaceName + } +} + +// EnsureConcurrencyGaugeWithLabel - +func (m *MetricProvider) EnsureConcurrencyGaugeWithLabel(labels []string) error { + if len(labels) != labelLen { + return fmt.Errorf("labels len must be 8") + } + + m.RLock() + defer m.RUnlock() + _, err := concurrencyGauge.GetMetricWithLabelValues(labels...) + return err +} + +// EnsureLeaseRequestTotalWithLabel - +func (m *MetricProvider) EnsureLeaseRequestTotalWithLabel(labels []string) error { + if len(labels) != labelLen { + return fmt.Errorf("labels len must be 8") + } + + m.RLock() + defer m.RUnlock() + _, err := leaseRequestTotal.GetMetricWithLabelValues(labels...) + return err +} + +// Exist - +func (m *MetricProvider) Exist(funcKey string, invokeLabel string) bool { + return m.GetRandomDeployment(funcKey, invokeLabel) != nil +} + +// GetRandomDeployment - +func (m *MetricProvider) GetRandomDeployment(funcKey string, invokeLabel string) *k8stype.NamespacedName { + workName := getWorkloadName(funcKey, invokeLabel) + m.RLock() + defer m.RUnlock() + deployments, ok := m.WorkLoadMap[workName] + if !ok { + return nil + } + if len(deployments) == 0 { + return nil + } + for _, namespaceName := range deployments { + return namespaceName + } + return nil +} + +// IncLeaseRequestTotalWithLabel - +func (m *MetricProvider) IncLeaseRequestTotalWithLabel(labels []string) error { + if len(labels) != labelLen { + return fmt.Errorf("labels len must be 8") + } + counter, err := leaseRequestTotal.GetMetricWithLabelValues(labels...) + if err != nil { + return err + } + counter.Inc() + return nil +} + +// IncConcurrencyGaugeWithLabel - +func (m *MetricProvider) IncConcurrencyGaugeWithLabel(labels []string) error { + if len(labels) != labelLen { + return fmt.Errorf("labels len must be 8") + } + gauge, err := concurrencyGauge.GetMetricWithLabelValues(labels...) + if err != nil { + return err + } + gauge.Inc() + return nil +} + +// DecConcurrencyGaugeWithLabel - +func (m *MetricProvider) DecConcurrencyGaugeWithLabel(labels []string) error { + if len(labels) != labelLen { + return fmt.Errorf("labels len must be 8") + } + gauge, err := concurrencyGauge.GetMetricWithLabelValues(labels...) + if err != nil { + return err + } + gauge.Dec() + return nil +} + +// ClearConcurrencyGaugeWithLabel - +func (m *MetricProvider) ClearConcurrencyGaugeWithLabel(labels []string) error { + if len(labels) != labelLen { + return fmt.Errorf("labels len must be 8") + } + concurrencyGauge.DeleteLabelValues(labels...) + return nil +} + +// ClearLeaseRequestTotalWithLabel - +func (m *MetricProvider) ClearLeaseRequestTotalWithLabel(labels []string) error { + if len(labels) != labelLen { + return fmt.Errorf("labels len must be 8") + } + leaseRequestTotal.DeleteLabelValues(labels...) + return nil +} + +// ClearMetricsForFunction - +func (m *MetricProvider) ClearMetricsForFunction(funcMetaData *types.FuncMetaData) { + funcKey0 := urnutils.CombineFunctionKey(funcMetaData.TenantID, funcMetaData.FuncName, funcMetaData.Version) + m.Lock() + defer m.Unlock() + for workload, _ := range m.WorkLoadMap { + funcKey1, invokeLabel := GetFuncKeyAndLabelFromWorkload(workload) + if funcKey0 == funcKey1 { + m.clearMetricsForInsConfigWithoutLock(funcMetaData, invokeLabel) + } + } +} + +// ClearMetricsForInsConfig - +func (m *MetricProvider) ClearMetricsForInsConfig(funcMetaData *types.FuncMetaData, invokeLabel string) { + m.Lock() + m.clearMetricsForInsConfigWithoutLock(funcMetaData, invokeLabel) + m.Unlock() +} + +func (m *MetricProvider) clearMetricsForInsConfigWithoutLock(funcMetaData *types.FuncMetaData, invokeLabel string) { + // 得看下和FunctionVersion有啥区别 + funcKey := urnutils.CombineFunctionKey(funcMetaData.TenantID, funcMetaData.FuncName, funcMetaData.Version) + workload := getWorkloadName(funcKey, invokeLabel) + deployments, ok := m.WorkLoadMap[workload] + if !ok { + return + } + delete(m.WorkLoadMap, workload) + + if invokeLabel == "" { + invokeLabel = defaultLabel + } + + for _, deployment := range deployments { + labels := map[string]string{ + "businessid": funcMetaData.BusinessID, + "tenantid": funcMetaData.TenantID, + "funcname": funcMetaData.FuncName, + "version": funcMetaData.Version, + "label": invokeLabel, + "namespace": deployment.Namespace, + "deployment_name": deployment.Name, + } + concurrencyGauge.DeletePartialMatch(labels) + leaseRequestTotal.DeletePartialMatch(labels) + } +} + +// GetMetricLabels - +// 判断label是否符合预期 +func GetMetricLabels(funcMetaData *types.FuncMetaData, invokeLabel string, + namespace string, deploymentName string, podName string) []string { + var metricLabelValue []string + if namespace != "" && deploymentName != "" && podName != "" && funcMetaData != nil { + if invokeLabel == "" { + invokeLabel = defaultLabel + } + metricLabelValue = []string{ + funcMetaData.BusinessID, + funcMetaData.TenantID, + funcMetaData.FuncName, + funcMetaData.Version, + invokeLabel, + namespace, + deploymentName, + podName} + } + return metricLabelValue +} + +// GetFuncKeyAndLabelFromWorkload - +func GetFuncKeyAndLabelFromWorkload(workload string) (string, string) { + strs := strings.Split(workload, "#") + if len(strs) == 2 { // deployment key must be 2 + return strs[0], strs[1] + } + return "", "" +} + +// getWorkloadName - get deploymentforfunckey +func getWorkloadName(funcKey, invokeLabel string) string { + if invokeLabel == "" { + invokeLabel = defaultLabel + } + return fmt.Sprintf("%s#%s", funcKey, invokeLabel) +} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/prometheus_metrics_test.go b/yuanrong/pkg/common/faas_common/wisecloudtool/prometheus_metrics_test.go new file mode 100644 index 0000000..2101f97 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/prometheus_metrics_test.go @@ -0,0 +1,313 @@ +package wisecloudtool + +import ( + "fmt" + "github.com/agiledragon/gomonkey/v2" + "github.com/prometheus/client_golang/prometheus" + "github.com/smartystreets/goconvey/convey" + "testing" + + "github.com/stretchr/testify/assert" + k8stype "k8s.io/apimachinery/pkg/types" + + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/urnutils" +) + +func TestNewMetricProvider(t *testing.T) { + provider := NewMetricProvider() + assert.NotNil(t, provider) + assert.NotNil(t, provider.WorkLoadMap) + assert.Equal(t, 0, len(provider.WorkLoadMap)) +} + +func TestMetricProvider_AddWorkLoad(t *testing.T) { + t.Run("Add new workload", func(t *testing.T) { + provider := NewMetricProvider() + funcKey := "test-func" + invokeLabel := "test-label" + namespaceName := &k8stype.NamespacedName{ + Namespace: "test-ns", + Name: "test-name", + } + + provider.AddWorkLoad(funcKey, invokeLabel, namespaceName) + + assert.Equal(t, 1, len(provider.WorkLoadMap)) + assert.Equal(t, 1, len(provider.WorkLoadMap[getWorkloadName(funcKey, invokeLabel)])) + }) + + t.Run("Add duplicate workload", func(t *testing.T) { + provider := NewMetricProvider() + funcKey := "test-func" + invokeLabel := "test-label" + namespaceName := &k8stype.NamespacedName{ + Namespace: "test-ns", + Name: "test-name", + } + + // Add twice + provider.AddWorkLoad(funcKey, invokeLabel, namespaceName) + provider.AddWorkLoad(funcKey, invokeLabel, namespaceName) + + assert.Equal(t, 1, len(provider.WorkLoadMap)) + assert.Equal(t, 1, len(provider.WorkLoadMap[getWorkloadName(funcKey, invokeLabel)])) + }) +} + +func TestMetricProvider_Exist(t *testing.T) { + provider := NewMetricProvider() + funcKey := "test-func" + invokeLabel := "test-label" + + t.Run("Workload does not exist", func(t *testing.T) { + assert.False(t, provider.Exist(funcKey, invokeLabel)) + }) + + t.Run("Workload exists", func(t *testing.T) { + provider.AddWorkLoad(funcKey, invokeLabel, &k8stype.NamespacedName{ + Namespace: "test-ns", + Name: "test-name", + }) + assert.True(t, provider.Exist(funcKey, invokeLabel)) + }) +} + +func TestMetricProvider_GetRandomDeployment(t *testing.T) { + provider := NewMetricProvider() + funcKey := "test-func" + invokeLabel := "test-label" + testDeployment0 := &k8stype.NamespacedName{ + Namespace: "test-ns-0", + Name: "test-name-0", + } + + testDeployment1 := &k8stype.NamespacedName{ + Namespace: "test-ns-1", + Name: "test-name-1", + } + + t.Run("Get non-existent deployment", func(t *testing.T) { + assert.Nil(t, provider.GetRandomDeployment(funcKey, invokeLabel)) + }) + + t.Run("Get existing deployment", func(t *testing.T) { + provider.AddWorkLoad(funcKey, invokeLabel, testDeployment0) + provider.AddWorkLoad(funcKey, invokeLabel, testDeployment1) + flag0 := false + flag1 := false + for i := 0; i < 100; i++ { + result := provider.GetRandomDeployment(funcKey, invokeLabel) + switch result.Name { + case "test-name-0": + flag0 = true + case "test-name-1": + flag1 = true + } + if flag1 && flag0 { + break + } + } + assert.True(t, flag0 && flag1) + }) +} + +func TestMetricProvider_ClearMetrics(t *testing.T) { + provider := NewMetricProvider() + funcMeta := &types.FuncMetaData{ + TenantID: "tenant1", + FuncName: "func1", + Version: "v1", + BusinessID: "biz1", + } + invokeLabel := "test-label" + workload := getWorkloadName(urnutils.CombineFunctionKey(funcMeta.TenantID, funcMeta.FuncName, funcMeta.Version), invokeLabel) + + // Add test data + provider.AddWorkLoad( + urnutils.CombineFunctionKey(funcMeta.TenantID, funcMeta.FuncName, funcMeta.Version), + invokeLabel, + &k8stype.NamespacedName{ + Namespace: "test-ns", + Name: "test-name", + }, + ) + + t.Run("Clear function metrics", func(t *testing.T) { + provider.ClearMetricsForFunction(funcMeta) + assert.Equal(t, 0, len(provider.WorkLoadMap)) + }) + + t.Run("Clear instance config metrics", func(t *testing.T) { + // Re-add data + provider.AddWorkLoad( + urnutils.CombineFunctionKey(funcMeta.TenantID, funcMeta.FuncName, funcMeta.Version), + invokeLabel, + &k8stype.NamespacedName{ + Namespace: "test-ns", + Name: "test-name", + }, + ) + + provider.ClearMetricsForInsConfig(funcMeta, invokeLabel) + assert.Nil(t, provider.WorkLoadMap[workload]) + }) +} + +func TestGetMetricLabels(t *testing.T) { + funcMeta := &types.FuncMetaData{ + BusinessID: "biz1", + TenantID: "tenant1", + FuncName: "func1", + Version: "v1", + } + + t.Run("Generate complete labels", func(t *testing.T) { + labels := GetMetricLabels(funcMeta, "label1", "ns1", "deploy1", "pod1") + assert.Equal(t, []string{"biz1", "tenant1", "func1", "v1", "label1", "ns1", "deploy1", "pod1"}, labels) + }) + + t.Run("Use default label", func(t *testing.T) { + labels := GetMetricLabels(funcMeta, "", "ns1", "deploy1", "pod1") + assert.Equal(t, "UNKNOWN_LABEL", labels[4]) + }) + + t.Run("Return nil when missing required parameters", func(t *testing.T) { + assert.Nil(t, GetMetricLabels(nil, "label1", "ns1", "deploy1", "pod1")) + assert.Nil(t, GetMetricLabels(funcMeta, "label1", "", "deploy1", "pod1")) + }) +} + +func TestWorkloadHelpers(t *testing.T) { + t.Run("Get workload name", func(t *testing.T) { + name := getWorkloadName("func1", "label1") + assert.Equal(t, "func1#label1", name) + assert.Equal(t, "func1#UNKNOWN_LABEL", getWorkloadName("func1", "")) + }) + + t.Run("Parse from workload name", func(t *testing.T) { + funcKey, label := GetFuncKeyAndLabelFromWorkload("func1#label1") + assert.Equal(t, "func1", funcKey) + assert.Equal(t, "label1", label) + + funcKey, label = GetFuncKeyAndLabelFromWorkload("invalid") + assert.Equal(t, "", funcKey) + assert.Equal(t, "", label) + }) +} + +func TestMetricProvider(t *testing.T) { + convey.Convey("Test MetricProvider Functions", t, func() { + m := &MetricProvider{} + validLabels := make([]string, labelLen) + invalidLabels := make([]string, labelLen-1) + + convey.Convey("Test IncLeaseRequestTotalWithLabel", func() { + convey.Convey("should return error for invalid label length", func() { + err := m.IncLeaseRequestTotalWithLabel(invalidLabels) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "labels len must be 8") + }) + + convey.Convey("should handle GetMetricWithLabelValues error", func() { + patches := gomonkey.ApplyMethodFunc(leaseRequestTotal, "GetMetricWithLabelValues", func(...string) (prometheus.Counter, error) { + return nil, fmt.Errorf("mock error") + }) + defer patches.Reset() + + err := m.IncLeaseRequestTotalWithLabel(validLabels) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("should increment counter successfully", func() { + patches := gomonkey.ApplyMethodFunc(leaseRequestTotal, "GetMetricWithLabelValues", func(...string) (prometheus.Counter, error) { + counter := &fakeCounter{} + return counter, nil + }) + defer patches.Reset() + + err := m.IncLeaseRequestTotalWithLabel(validLabels) + convey.So(err, convey.ShouldBeNil) + }) + }) + + convey.Convey("Test IncConcurrencyGaugeWithLabel", func() { + convey.Convey("should return error for invalid label length", func() { + err := m.IncConcurrencyGaugeWithLabel(invalidLabels) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "labels len must be 8") + }) + + convey.Convey("should handle GetMetricWithLabelValues error", func() { + patches := gomonkey.ApplyMethodFunc(concurrencyGauge, "GetMetricWithLabelValues", func(...string) (prometheus.Gauge, error) { + return nil, fmt.Errorf("mock error") + }) + defer patches.Reset() + + err := m.IncConcurrencyGaugeWithLabel(validLabels) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("should increment gauge successfully", func() { + patches := gomonkey.ApplyMethodFunc(concurrencyGauge, "GetMetricWithLabelValues", func(...string) (prometheus.Gauge, error) { + gauge := &fakeGauge{} + return gauge, nil + }) + defer patches.Reset() + + err := m.IncConcurrencyGaugeWithLabel(validLabels) + convey.So(err, convey.ShouldBeNil) + }) + }) + + convey.Convey("Test DecConcurrencyGaugeWithLabel", func() { + convey.Convey("should decrement gauge successfully", func() { + patches := gomonkey.ApplyMethodFunc(concurrencyGauge, "GetMetricWithLabelValues", func(...string) (prometheus.Gauge, error) { + gauge := &fakeGauge{} + return gauge, nil + }) + defer patches.Reset() + + err := m.DecConcurrencyGaugeWithLabel(validLabels) + convey.So(err, convey.ShouldBeNil) + }) + }) + + convey.Convey("Test ClearConcurrencyGaugeWithLabel", func() { + convey.Convey("should clear gauge successfully", func() { + patches := gomonkey.ApplyMethodFunc(concurrencyGauge, "DeleteLabelValues", func(...string) bool { + return true + }) + defer patches.Reset() + + err := m.ClearConcurrencyGaugeWithLabel(validLabels) + convey.So(err, convey.ShouldBeNil) + }) + }) + + convey.Convey("Test ClearLeaseRequestTotalWithLabel", func() { + convey.Convey("should clear counter successfully", func() { + patches := gomonkey.ApplyMethodFunc(leaseRequestTotal, "DeleteLabelValues", func(...string) bool { + return true + }) + defer patches.Reset() + + err := m.ClearLeaseRequestTotalWithLabel(validLabels) + convey.So(err, convey.ShouldBeNil) + }) + }) + }) +} + +type fakeCounter struct { + prometheus.Counter +} + +func (f *fakeCounter) Inc() {} + +type fakeGauge struct { + prometheus.Gauge +} + +func (f *fakeGauge) Inc() {} +func (f *fakeGauge) Dec() {} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/jwtsign.go b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/jwtsign.go new file mode 100644 index 0000000..7bd335e --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/jwtsign.go @@ -0,0 +1,203 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package serviceaccount sign http request by jwttoken +package serviceaccount + +import ( + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/valyala/fasthttp" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/wisecloudtool/types" +) + +const defaultExp = 300 * time.Second + +// GenerateJwtSignedHeaders put header authorization to request header +func GenerateJwtSignedHeaders(req *fasthttp.Request, body []byte, wiseCloudCtx types.NuwaRuntimeInfo, + serviceAccountJwt *types.ServiceAccountJwt) error { + headers := map[string]string{} + req.Header.Set("x-wisecloud-site", wiseCloudCtx.WisecloudSite) + req.Header.Set("x-wisecloud-service-id", wiseCloudCtx.WisecloudServiceId) + req.Header.Set("x-wisecloud-environment-id", wiseCloudCtx.WisecloudEnvironmentId) + headers = map[string]string{ + "x-wisecloud-site": wiseCloudCtx.WisecloudSite, + "x-wisecloud-service-id": wiseCloudCtx.WisecloudServiceId, + "x-wisecloud-environment-id": wiseCloudCtx.WisecloudEnvironmentId, + } + + jwtToken, err := generateJWTToken(req, string(body), headers, serviceAccountJwt) + if err != nil { + return err + } + // Set headers + req.Header.Set(constant.HeaderAuthorization, jwtToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-client-id", strconv.FormatInt(serviceAccountJwt.ClientId, 10)) // decimal notation + return nil +} + +func generateJWTToken(req *fasthttp.Request, body string, headers map[string]string, + serviceAccountJwt *types.ServiceAccountJwt) (string, error) { + return generateJWTTokenGeneric(&types.ServiceAccount{ + PrivateKey: serviceAccountJwt.PrivateKey, + ClientId: serviceAccountJwt.ClientId, + KeyId: serviceAccountJwt.KeyId, + }, headers, buildQueryPayload(string(req.URI().Path()), string(req.Header.Method()), + string(req.URI().QueryString()), body), serviceAccountJwt.OauthTokenUrl, "JWT-PRO2") +} + +func buildQueryPayload(queryPath, method, queryString, body string) string { + var payloadBuilder strings.Builder + payloadBuilder.WriteString(body) + if queryPath != "" { + payloadBuilder.WriteString("\n") + payloadBuilder.WriteString(queryPath) + } + if method != "" { + payloadBuilder.WriteString("\n") + payloadBuilder.WriteString(method) + } + if queryString != "" { + payloadBuilder.WriteString("\n") + payloadBuilder.WriteString(queryString) + } + return payloadBuilder.String() +} + +func generateJWTTokenGeneric(sa *types.ServiceAccount, headers map[string]string, + body string, aud string, jwtTokenType string) (string, error) { + requestSign, err := getRequestSignature(headers, body) + if err != nil { + return "", err + } + iat := time.Now() + exp := iat.Add(defaultExp) + token := &Token{ + Header: map[string]interface{}{ + "typ": jwtTokenType, + "sdkVersion": 20200, + "clientVersion": 2, + "alg": "RS256", + "kid": sa.KeyId, + }, + Claims: map[string]interface{}{ + "aud": aud, + "iss": strconv.FormatInt(sa.ClientId, 10), + "exp": exp.Unix(), + "iat": iat.Unix(), + "signedHeaders": getSignedHeaders1(headers), + "requestSignature": requestSign, + }, + } + rsaPrikey, err := getRSAPrivateKey(sa.PrivateKey) + if err != nil { + return "", err + } + signToken, err := token.Sign(rsaPrikey) + if err != nil { + return "", err + } + return "Bearer " + signToken, nil +} + +func getRequestSignature(headers map[string]string, payload string) (string, error) { + canonicalHeaders := "" + if headers != nil && len(headers) != 0 { + var keys []string + for k := range headers { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + canonicalHeaders += strings.ToLower(k) + canonicalHeaders += ":" + canonicalHeaders += strings.TrimSpace(headers[k]) + canonicalHeaders += "\n" + } + } + + if len(canonicalHeaders) == 0 { + canonicalHeaders += "\n" + } + if len(payload) != 0 { + ch, err := sha256String(payload) + if err != nil { + return "", err + } + canonicalHeaders += hex.EncodeToString(ch) + } + + ch, err := sha256String(canonicalHeaders) + if err != nil { + return "", err + } + + return hex.EncodeToString(ch), nil +} + +func getSignedHeaders1(headMap map[string]string) string { + if headMap != nil && len(headMap) != 0 { + var keyArray []string + for key := range headMap { + keyArray = append(keyArray, key) + } + + sort.Strings(keyArray) + return strings.Join(keyArray, ";") + } + + return "" +} + +func sha256String(input string) ([]byte, error) { + h := sha256.New() + _, err := h.Write([]byte(input)) + if err != nil { + return nil, err + } + + output := h.Sum(nil) + return output, nil +} + +func getRSAPrivateKey(privateKey string) (interface{}, error) { + priKeyByte, err := hex.DecodeString(privateKey) + if err != nil { + return nil, err + } + private := []byte(fmt.Sprintf("-----BEGIN PRIVATE KEY-----\n%s\n-----END PRIVATE KEY-----", + base64.StdEncoding.EncodeToString(priKeyByte))) + pkPem, _ := pem.Decode(private) + + privateRsa, err := x509.ParsePKCS8PrivateKey(pkPem.Bytes) + if err != nil { + return nil, err + } + + return privateRsa, nil +} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/jwtsign_test.go b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/jwtsign_test.go new file mode 100644 index 0000000..4878fed --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/jwtsign_test.go @@ -0,0 +1,19 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package serviceaccount sign http request by jwttoken +package serviceaccount + diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/parse.go b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/parse.go new file mode 100644 index 0000000..8fa4770 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/parse.go @@ -0,0 +1,79 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package serviceaccount sign http request by jwttoken +package serviceaccount + +import ( + "crypto/tls" + "fmt" + + "github.com/json-iterator/go" + + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong/pkg/common/faas_common/wisecloudtool/types" +) + +// ParseServiceAccount - +func ParseServiceAccount(serviceAccountKeyStr string) (*types.ServiceAccount, error) { + if len(serviceAccountKeyStr) <= 0 { + return nil, fmt.Errorf("serviceAccountKeyStr is empty") + } + + decryptedByte, err := stsgoapi.DecryptSensitiveConfig(serviceAccountKeyStr) + if err != nil { + return nil, fmt.Errorf("decrypt service account key failed") + } + serviceAccount := &types.ServiceAccount{} + err = jsoniter.Unmarshal(decryptedByte, &serviceAccount) + if err != nil { + return nil, fmt.Errorf("unmarshal service account key failed, err: %s", err.Error()) + } + return serviceAccount, nil +} + +// ParseTlsCipherSuites - +func ParseTlsCipherSuites(tlsCipherSuitesStrs []string) ([]uint16, error) { + if len(tlsCipherSuitesStrs) <= 0 { + return nil, fmt.Errorf("tlsCipherSuitesStr is empty") + } + + return cipherSuitesID(cipherSuitesFromName(tlsCipherSuitesStrs)), nil +} + +func cipherSuitesFromName(names []string) []*tls.CipherSuite { + m := make(map[string]*tls.CipherSuite, len(tls.CipherSuites())) + for _, cipher := range tls.CipherSuites() { + m[cipher.Name] = cipher + } + + r := make([]*tls.CipherSuite, 0) + for _, n := range names { + if _, ok := m[n]; ok { + r = append(r, m[n]) + } + } + return r +} + +func cipherSuitesID(cs []*tls.CipherSuite) []uint16 { + ids := make([]uint16, 0) + for _, value := range cs { + ids = append(ids, value.ID) + } + return ids +} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/parse_test.go b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/parse_test.go new file mode 100644 index 0000000..f535ec5 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/parse_test.go @@ -0,0 +1,75 @@ +package serviceaccount + +import ( + "fmt" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/json-iterator/go" + "github.com/smartystreets/goconvey/convey" + + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong/pkg/common/faas_common/wisecloudtool/types" +) + +func TestCipherSuitesFromName(t *testing.T) { + convey.Convey("Test cipherSuitesFromName", t, func() { + convey.Convey("success", func() { + cipherSuitesArr := []string{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"} + tlsSuite := cipherSuitesID(cipherSuitesFromName(cipherSuitesArr)) + convey.So(len(tlsSuite), convey.ShouldEqual, 2) + }) + }) +} + +func TestParseServiceAccount(t *testing.T) { + convey.Convey("Test ParseServiceAccount", t, func() { + // Setup test cases + validServiceAccount := &types.ServiceAccount{ + PrivateKey: "test-PrivateKey", + ClientId: 111, + } + validJSON, _ := jsoniter.Marshal(validServiceAccount) + + convey.Convey("when input is empty", func() { + _, err := ParseServiceAccount("") + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "serviceAccountKeyStr is empty") + }) + + convey.Convey("when decryption fails", func() { + patches := gomonkey.ApplyFunc(stsgoapi.DecryptSensitiveConfig, func(string) ([]byte, error) { + return nil, fmt.Errorf("decryption error") + }) + defer patches.Reset() + + _, err := ParseServiceAccount("encrypted-string") + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "decrypt service account key failed") + }) + + convey.Convey("when decryption succeeds but unmarshal fails", func() { + patches := gomonkey.ApplyFunc(stsgoapi.DecryptSensitiveConfig, func(string) ([]byte, error) { + return []byte("invalid-json"), nil + }) + defer patches.Reset() + + _, err := ParseServiceAccount("encrypted-string") + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "unmarshal service account key failed") + }) + + convey.Convey("when everything works correctly", func() { + patches := gomonkey.ApplyFunc(stsgoapi.DecryptSensitiveConfig, func(string) ([]byte, error) { + return validJSON, nil + }) + defer patches.Reset() + + result, err := ParseServiceAccount("encrypted-string") + convey.So(err, convey.ShouldBeNil) + convey.So(result.PrivateKey, convey.ShouldEqual, validServiceAccount.PrivateKey) + convey.So(result.ClientId, convey.ShouldEqual, validServiceAccount.ClientId) + }) + }) +} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/token.go b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/token.go new file mode 100644 index 0000000..5992f7b --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount/token.go @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package serviceaccount sign http request by jwttoken +package serviceaccount + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "strings" +) + +var methodHash = crypto.SHA256 + +// Token - +type Token struct { + Header map[string]interface{} + Claims map[string]interface{} +} + +// Sign sign jwt token string +func (t *Token) Sign(key interface{}) (string, error) { + jsonHeader, err := json.Marshal(t.Header) + if err != nil { + return "", err + } + header := base64.RawURLEncoding.EncodeToString(jsonHeader) + + jsonClaims, err := json.Marshal(t.Claims) + if err != nil { + return "", err + } + claim := base64.RawURLEncoding.EncodeToString(jsonClaims) + + stringToBeSign := strings.Join([]string{header, claim}, ".") + + sig, err := t.getSig(stringToBeSign, key) + if err != nil { + return "", err + } + return strings.Join([]string{stringToBeSign, sig}, "."), nil +} + +func (t *Token) getSig(signingString string, key interface{}) (string, error) { + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return "", errors.New("key is invalid") + } + if !methodHash.Available() { + return "", errors.New("the requested hash function is unavailable") + } + hasher := methodHash.New() + _, err := hasher.Write([]byte(signingString)) + if err != nil { + return "", errors.New("hash write failed") + } + + sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, methodHash, hasher.Sum(nil)) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(sigBytes), nil +} diff --git a/yuanrong/pkg/common/faas_common/wisecloudtool/types/types.go b/yuanrong/pkg/common/faas_common/wisecloudtool/types/types.go new file mode 100644 index 0000000..06286a5 --- /dev/null +++ b/yuanrong/pkg/common/faas_common/wisecloudtool/types/types.go @@ -0,0 +1,74 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +// ServiceAccountJwt service account config +type ServiceAccountJwt struct { + NuwaRuntimeAddr string `json:"nuwaRuntimeAddr,omitempty"` + NuwaGatewayAddr string `json:"nuwaGatewayAddr,omitempty"` + OauthTokenUrl string `json:"oauthTokenUrl"` + ServiceAccountKeyStr string `json:"serviceAccountKey"` + *ServiceAccount `json:"-"` + TlsConfig *TLSConfig `json:"tlsConfig"` +} + +// TLSConfig tls config +type TLSConfig struct { + HttpsInsecureSkipVerify bool `json:"httpsInsecureSkipVerify"` + TlsCipherSuitesStr []string `json:"tlsCipherSuites"` + TlsCipherSuites []uint16 `json:"-"` +} + +// ServiceAccount service account config +type ServiceAccount struct { + PrivateKey string `json:"privateKey"` + ClientId int64 `json:"clientId"` + KeyId string `json:"keyId"` + PublicKey string `json:"publicKey"` + UserId int64 `json:"userId"` + Version int32 `json:"version"` +} + +// NuwaRuntimeInfo contains ers workload info for function +type NuwaRuntimeInfo struct { + WisecloudRuntimeId string `json:"wisecloudRuntimeId"` + WisecloudSite string `json:"wisecloudSite"` + WisecloudTenantId string `json:"wisecloudTenantId"` + WisecloudApplicationId string `json:"wisecloudApplicationId"` + WisecloudServiceId string `json:"wisecloudServiceId"` + WisecloudEnvironmentId string `json:"wisecloudEnvironmentId"` + EnvLabel string `json:"envLabel"` +} + +// NuwaColdCreateInstanceReq request to nuwa +type NuwaColdCreateInstanceReq struct { + RuntimeId string `json:"runtimeId"` + RuntimeType string `json:"type"` // function/microservice + PoolType string `json:"poolType"` // java1.8/nodejs/python3 + EnvLabel string `json:"envLabel"` + Memory int64 `json:"memory"` + CPU int64 `json:"cpu"` +} + +// NuwaDestroyInstanceReq request to nuwa +type NuwaDestroyInstanceReq struct { + RuntimeId string `json:"runtimeId"` + RuntimeType string `json:"type"` + InstanceId string `json:"instanceId"` // podNamespace:podName + WorkLoadName string `json:"workLoadName"` +} diff --git a/yuanrong/pkg/common/go.mod b/yuanrong/pkg/common/go.mod new file mode 100644 index 0000000..b04e2c0 --- /dev/null +++ b/yuanrong/pkg/common/go.mod @@ -0,0 +1,156 @@ +module yuanrong/pkg/common + +go 1.24.1 + +require ( + github.com/agiledragon/gomonkey v2.0.1+incompatible + github.com/agiledragon/gomonkey/v2 v2.11.0 + github.com/asaskevich/govalidator/v11 v11.0.1-0.20250122183457-e11347878e23 + github.com/fsnotify/fsnotify v1.7.0 + github.com/gin-gonic/gin v1.10.0 + github.com/huaweicloud/huaweicloud-sdk-go-obs v3.23.12+incompatible + github.com/json-iterator/go v1.1.12 + github.com/magiconair/properties v1.8.7 + github.com/panjf2000/ants/v2 v2.10.0 + github.com/pborman/uuid v1.2.1 + github.com/prometheus/client_golang v1.16.0 + github.com/redis/go-redis/v9 v9.0.5 + github.com/smartystreets/goconvey v1.8.1 + github.com/stretchr/testify v1.9.0 + github.com/valyala/fasthttp v1.58.0 + go.etcd.io/etcd/api/v3 v3.5.11 + go.etcd.io/etcd/client/v3 v3.5.11 + go.opentelemetry.io/otel v1.24.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0 + go.opentelemetry.io/otel/sdk v1.24.0 + go.opentelemetry.io/otel/trace v1.24.0 + go.uber.org/zap v1.27.0 + golang.org/x/crypto v0.29.0 + golang.org/x/net v0.31.0 + golang.org/x/time v0.10.0 + google.golang.org/grpc v1.67.0 + google.golang.org/protobuf v1.36.6 + gopkg.in/yaml.v3 v3.0.1 + gotest.tools v2.3.0+incompatible + huawei.com/wisesecurity/sts-sdk v1.0.1-20250319171100-c6b279f3bac + yuanrong.org/kernel/runtime v1.0.0 + huaweicloud.com/containers/security/cbb_adapt v1.0.7 + k8s.io/api v0.31.2 + k8s.io/apimachinery v0.31.2 + k8s.io/client-go v0.31.2 +) + +require ( + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v4 v4.2.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/coreos/go-semver v0.3.0 // indirect + github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/emicklei/go-restful/v3 v3.11.0 // indirect + github.com/fxamacker/cbor/v2 v2.7.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-openapi/jsonpointer v0.19.6 // indirect + github.com/go-openapi/jsonreference v0.20.2 // indirect + github.com/go-openapi/swag v0.22.4 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.20.0 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/gnostic-models v0.6.8 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/google/gofuzz v1.2.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/compress v1.17.11 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/prometheus/client_model v0.6.0 // indirect + github.com/prometheus/common v0.42.0 // indirect + github.com/prometheus/procfs v0.10.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect + github.com/tjfoc/gmsm v1.4.1 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/x448/float16 v0.8.4 // indirect + go.etcd.io/etcd/client/pkg/v3 v3.5.11 // indirect + go.opentelemetry.io/otel/metric v1.24.0 // indirect + go.opentelemetry.io/proto/otlp v1.1.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/oauth2 v0.22.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/term v0.21.0 // indirect + golang.org/x/text v0.20.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect + gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect + gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + k8s.io/klog/v2 v2.130.1 // indirect + k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 // indirect + k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 // indirect + sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect + sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect + sigs.k8s.io/yaml v1.4.0 // indirect +) + +replace ( + github.com/agiledragon/gomonkey => github.com/agiledragon/gomonkey v2.0.1+incompatible + github.com/asaskevich/govalidator/v11 => github.com/asaskevich/govalidator/v11 v11.0.1-0.20250122183457-e11347878e23 + github.com/fsnotify/fsnotify => github.com/fsnotify/fsnotify v1.7.0 + // for test or internal use + github.com/gin-gonic/gin => github.com/gin-gonic/gin v1.10.0 + github.com/olekukonko/tablewriter => github.com/olekukonko/tablewriter v0.0.5 + github.com/operator-framework/operator-lib => github.com/operator-framework/operator-lib v0.4.0 + github.com/prashantv/gostub => github.com/prashantv/gostub v1.0.0 + github.com/robfig/cron/v3 => github.com/robfig/cron/v3 v3.0.1 + github.com/smartystreets/goconvey => github.com/smartystreets/goconvey v1.6.4 + github.com/spf13/cobra => github.com/spf13/cobra v1.8.1 + github.com/stretchr/testify => github.com/stretchr/testify v1.5.1 + github.com/valyala/fasthttp => github.com/valyala/fasthttp v1.58.0 + go.etcd.io/etcd/api/v3 => go.etcd.io/etcd/api/v3 v3.5.11 + go.etcd.io/etcd/client/v3 => go.etcd.io/etcd/client/v3 v3.5.11 + go.opentelemetry.io/otel => go.opentelemetry.io/otel v1.24.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace => go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc => go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0 + go.opentelemetry.io/otel/sdk => go.opentelemetry.io/otel/sdk v1.24.0 + go.opentelemetry.io/otel/trace => go.opentelemetry.io/otel/trace v1.24.0 + go.uber.org/zap => go.uber.org/zap v1.27.0 + golang.org/x/crypto => golang.org/x/crypto v0.24.0 + // affects VPC plugin building, will cause error if not pinned + golang.org/x/net => golang.org/x/net v0.26.0 + golang.org/x/sync => golang.org/x/sync v0.0.0-20190423024810-112230192c58 + golang.org/x/sys => golang.org/x/sys v0.21.0 + golang.org/x/text => golang.org/x/text v0.16.0 + golang.org/x/time => golang.org/x/time v0.10.0 + google.golang.org/genproto => google.golang.org/genproto v0.0.0-20230526203410-71b5a4ffd15e + google.golang.org/genproto/googleapis/rpc => google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d + google.golang.org/grpc => google.golang.org/grpc v1.67.0 + google.golang.org/protobuf => google.golang.org/protobuf v1.36.6 + gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1 + yuanrong.org/kernel/runtime => ../../../api/go + k8s.io/api => k8s.io/api v0.31.2 + k8s.io/apimachinery => k8s.io/apimachinery v0.31.2 + k8s.io/client-go => k8s.io/client-go v0.31.2 +) diff --git a/yuanrong/pkg/common/httputil/config/adminconfig.go b/yuanrong/pkg/common/httputil/config/adminconfig.go new file mode 100644 index 0000000..5f6d837 --- /dev/null +++ b/yuanrong/pkg/common/httputil/config/adminconfig.go @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package admin config +package config + +var AdminConf *CliConfig + +// CliConfig config parameter structure +type CliConfig struct { + AdminHost string `json:"adminHost"` +} + +func InitAdminConf(adminHost string) error { + AdminConf = &CliConfig{ + AdminHost: adminHost, + } + return nil +} diff --git a/yuanrong/pkg/common/httputil/http/client/client.go b/yuanrong/pkg/common/httputil/http/client/client.go new file mode 100644 index 0000000..5dcac65 --- /dev/null +++ b/yuanrong/pkg/common/httputil/http/client/client.go @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package client is define interface of client +package client + +import ( + "crypto/tls" + "net" + "net/http" + "sync" + "time" + + fhttp "github.com/valyala/fasthttp" + + "yuanrong/pkg/common/faas_common/logger/log" + shttp "yuanrong/pkg/common/httputil/http" + "yuanrong/pkg/common/httputil/http/client/fast" +) + +const ( + // 默认最大重试次数 + defaultMaxRetryTimes = 3 + + // MaxClientConcurrency is the max concurrency of fast http client + MaxClientConcurrency = 1000 + + // DialTimeOut - + DialTimeOut = 10 + + // TCPKeepAlivePeriod - + TCPKeepAlivePeriod = 10 +) + +var tcpDialer = fhttp.TCPDialer{Concurrency: MaxClientConcurrency} + +var globalTLSConf *tls.Config + +// Client 客户端接口 +type Client interface { + PostMultipart(url string, params map[string]string, + headers map[string]string, filePath string) (*shttp.SuccessResponse, error) + Get(url string, headers map[string]string) (*shttp.SuccessResponse, error) + PutMultipart(url string, params map[string]string, + headers map[string]string, filePath string) (*shttp.SuccessResponse, error) +} + +func adminDial(addr string) (net.Conn, error) { + conn, err := tcpDialer.DialTimeout(addr, DialTimeOut*time.Second) + if err != nil { + log.GetLogger().Errorf("failed to dial %s, error: %s ", addr, err.Error()) + return nil, err + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + log.GetLogger().Errorf("failed to dial %s", addr) + return nil, nil + } + err = tcpConn.SetKeepAlive(true) + if err != nil { + log.GetLogger().Errorf("failed to set connection keepalive %s, error: %s", addr, err.Error()) + return nil, err + } + + err = tcpConn.SetKeepAlivePeriod(TCPKeepAlivePeriod * time.Second) + if err != nil { + log.GetLogger().Errorf("failed to set connection keepalive period %s, error: %s", + addr, err.Error()) + return nil, err + } + + return tcpConn, nil +} + +// newClient 创建client +func newClient(tlsConf *tls.Config) Client { + cli := &fast.FastClient{ + Client: &fhttp.Client{ + TLSConfig: tlsConf, + MaxIdemponentCallAttempts: defaultMaxRetryTimes, + ReadBufferSize: http.DefaultMaxHeaderBytes, + Dial: adminDial, + }} + return cli +} + +var once sync.Once +var client Client + +// GetInstance get client instance +func GetInstance() Client { + once.Do(func() { + client = newClient(globalTLSConf) + }) + + return client +} + +// InitTlsConf init tls conf +func InitTlsConf(tlsConf *tls.Config) { + globalTLSConf = tlsConf +} diff --git a/yuanrong/pkg/common/httputil/http/client/fast/client.go b/yuanrong/pkg/common/httputil/http/client/fast/client.go new file mode 100644 index 0000000..2482003 --- /dev/null +++ b/yuanrong/pkg/common/httputil/http/client/fast/client.go @@ -0,0 +1,188 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package fast is fasthttp implementation of client +package fast + +import ( + "bytes" + "errors" + "io" + "mime/multipart" + "os" + "path" + "strconv" + "time" + + fhttp "github.com/valyala/fasthttp" + + "yuanrong/pkg/common/constants" + "yuanrong/pkg/common/httputil/http" + "yuanrong/pkg/common/httputil/utils" + "yuanrong/pkg/common/snerror" + "yuanrong/pkg/common/uuid" +) + +// FastClient fasthttp implement +type FastClient struct { + Client *fhttp.Client +} + +const ( + DefaultResponseHeadersSize = 16 + DeployTimeout = 90 + Base = 10 +) + +func setRequestHeaders(request *fhttp.Request, headers map[string]string) { + request.Header.Set(constants.HeaderTraceID, uuid.New().String()) + for key, value := range headers { + request.Header.Set(key, value) + } +} + +// ParseFastResponse parse fhttp Response +func ParseFastResponse(response *fhttp.Response) (*http.SuccessResponse, error) { + if response.StatusCode() == fhttp.StatusInternalServerError { + // The call fails and the returned status code is 500, and the body contains the returned error message + return nil, snerror.ConvertBadResponse(response.Body()) + } + if response.StatusCode() == fhttp.StatusOK { + // The call is successful and the returned status code is 200 The body contains the returned information + successResponse := &http.SuccessResponse{ + Body: response.Body(), + Headers: getResponseHeaders(response), + } + return successResponse, nil + } + // Other error codes return error information + return nil, errors.New(fhttp.StatusMessage(response.StatusCode())) +} + +func getResponseHeaders(response *fhttp.Response) map[string]string { + headers := make(map[string]string, DefaultResponseHeadersSize) + response.Header.VisitAll(func(key, value []byte) { + headers[string(key)] = string(value) + }) + return headers +} + +// ProcessMultipartRequestParams process multipart request params into fhttp request +func ProcessMultipartRequestParams(request *fhttp.Request, params map[string]string, + bodyWriter *multipart.Writer, bodyBuffer *bytes.Buffer) (*fhttp.Request, error) { + for key, val := range params { + if err := bodyWriter.WriteField(key, val); err != nil { + return nil, err + } + } + if err := bodyWriter.Close(); err != nil { + return nil, err + } + contentType := bodyWriter.FormDataContentType() + request.Header.SetContentType(contentType) + request.SetBody(bodyBuffer.Bytes()) + return request, nil +} + +func (fast *FastClient) processMultipartRequest(request *fhttp.Request, params map[string]string, + filePath string) (*fhttp.Request, error) { + fileSize := utils.GetFileSize(filePath) + request.Header.Set(http.HeaderContentType, http.Multipart) + request.Header.Set(http.HeaderFileDigest, strconv.FormatInt(fileSize, Base)) + request.SetBodyString(strconv.FormatInt(fileSize, Base)) + + bodyBuffer := &bytes.Buffer{} + bodyWriter := multipart.NewWriter(bodyBuffer) + if err := writeFile(bodyWriter, filePath); err != nil { + return nil, err + } + return ProcessMultipartRequestParams(request, params, bodyWriter, bodyBuffer) +} + +func writeFile(bodyWriter *multipart.Writer, filePath string) error { + var ( + fileWriter io.Writer + err error + ) + + fileWriter, err = bodyWriter.CreateFormFile("file", path.Base(filePath)) + if err != nil { + return err + } + file, err := os.Open(filePath) + if err != nil { + return err + } + defer file.Close() + _, err = io.Copy(fileWriter, file) + if err != nil { + return err + } + return nil +} + +// PostMultipart PostMultipart request +func (fast *FastClient) PostMultipart(url string, params map[string]string, + headers map[string]string, filePath string) (*http.SuccessResponse, error) { + request := fhttp.AcquireRequest() + response := fhttp.AcquireResponse() + setRequestHeaders(request, headers) + request.SetRequestURI(url) + request.Header.SetMethod(fhttp.MethodPost) + request, err := fast.processMultipartRequest(request, params, filePath) + if err != nil { + return nil, err + } + fast.Client.ReadTimeout = DeployTimeout * time.Second + if err := fast.Client.DoTimeout(request, response, DeployTimeout*time.Second); err != nil { + return nil, err + } + return ParseFastResponse(response) +} + +// Get Get request +func (fast *FastClient) Get(url string, headers map[string]string) (*http.SuccessResponse, error) { + request := fhttp.AcquireRequest() + response := fhttp.AcquireResponse() + setRequestHeaders(request, headers) + request.Header.Set(http.HeaderContentType, http.ApplicationJSONUTF8) + request.Header.SetMethod(fhttp.MethodGet) + request.SetRequestURI(url) + + if err := fast.Client.Do(request, response); err != nil { + return nil, err + } + return ParseFastResponse(response) +} + +// PutMultipart PutMultipart request +func (fast *FastClient) PutMultipart(url string, params map[string]string, headers map[string]string, + filePath string) (*http.SuccessResponse, error) { + request := fhttp.AcquireRequest() + response := fhttp.AcquireResponse() + setRequestHeaders(request, headers) + request.SetRequestURI(url) + request.Header.SetMethod(fhttp.MethodPut) + request, err := fast.processMultipartRequest(request, params, filePath) + if err != nil { + return nil, err + } + fast.Client.ReadTimeout = DeployTimeout * time.Second + if err := fast.Client.DoTimeout(request, response, DeployTimeout*time.Second); err != nil { + return nil, err + } + return ParseFastResponse(response) +} diff --git a/yuanrong/pkg/common/httputil/http/client/fast/client_test.go b/yuanrong/pkg/common/httputil/http/client/fast/client_test.go new file mode 100644 index 0000000..f9977bd --- /dev/null +++ b/yuanrong/pkg/common/httputil/http/client/fast/client_test.go @@ -0,0 +1,56 @@ +package fast + +import ( + "encoding/json" + "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" + "testing" + + "yuanrong/pkg/common/httputil/http" + "yuanrong/pkg/common/snerror" +) + +func Test_parseFastResponse(t *testing.T) { + response1 := &fasthttp.Response{} + badResponse := snerror.BadResponse{ + Code: 0, + Message: "500 error", + } + bytes, _ := json.Marshal(badResponse) + response1.SetStatusCode(fasthttp.StatusInternalServerError) + response1.SetBody(bytes) + + response2 := &fasthttp.Response{} + + response3 := &fasthttp.Response{} + response3.SetStatusCode(fasthttp.StatusBadRequest) + + tests := []struct { + name string + response *fasthttp.Response + want *http.SuccessResponse + wantErr bool + }{ + { + name: "test 500", + response: response1, + wantErr: true, + }, + { + name: "test 200", + response: response2, + wantErr: false, + }, + { + name: "test 400", + response: response3, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseFastResponse(tt.response) + assert.Equal(t, err != nil, tt.wantErr) + }) + } +} diff --git a/yuanrong/pkg/common/httputil/http/const.go b/yuanrong/pkg/common/httputil/http/const.go new file mode 100644 index 0000000..8764a87 --- /dev/null +++ b/yuanrong/pkg/common/httputil/http/const.go @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http is http api +package http + +// content +const ( + Multipart = "multipart/form-data" + ApplicationJSONUTF8 = "application/json;charset=UTF-8" +) + +// Header +const ( + HeaderFileDigest = "x-file-digest" + HeaderStorageType = "x-storage-type" + HeaderAuthorization = "authorization" + HeaderContentType = "Content-Type" +) diff --git a/yuanrong/pkg/common/httputil/http/type.go b/yuanrong/pkg/common/httputil/http/type.go new file mode 100644 index 0000000..e2563a2 --- /dev/null +++ b/yuanrong/pkg/common/httputil/http/type.go @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package http is http +package http + +// SuccessResponse httpCode 200 成功响应结构体 +type SuccessResponse struct { + Body []byte + Headers map[string]string +} + +// Req 登陆请求 +type Req struct { + UserName string `json:"username"` + Password string `json:"password"` +} + +// Response login响应结构 +type Response struct { + Token string `json:"token"` +} diff --git a/yuanrong/pkg/common/httputil/utils/file.go b/yuanrong/pkg/common/httputil/utils/file.go new file mode 100644 index 0000000..95f8dca --- /dev/null +++ b/yuanrong/pkg/common/httputil/utils/file.go @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "os" + "path/filepath" + + "yuanrong/pkg/common/reader" +) + +// Exists exists Whether the path exists +func Exists(path string) bool { + if _, err := filepath.Abs(path); err != nil { + return false + } + + if _, err := reader.ReadFileInfoWithTimeout(path); err != nil { + if os.IsExist(err) { + return true + } + return false + } + + return true +} + +// GetFileSize 获取文件大小 +func GetFileSize(path string) int64 { + if !Exists(path) { + return 0 + } + fileInfo, err := reader.ReadFileInfoWithTimeout(path) + if err != nil { + return 0 + } + return fileInfo.Size() +} diff --git a/yuanrong/pkg/common/httputil/utils/file_test.go b/yuanrong/pkg/common/httputil/utils/file_test.go new file mode 100644 index 0000000..bc4712a --- /dev/null +++ b/yuanrong/pkg/common/httputil/utils/file_test.go @@ -0,0 +1,44 @@ +package utils + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExists(t *testing.T) { + var cases = []struct { + in string // input + expected bool // expected result + }{ + {"", false}, + {"./file.go", true}, + {"./notexists", false}, + {"/%$&*", false}, + } + for _, c := range cases { + actual := Exists(c.in) + if actual != c.expected { + t.Errorf("Exists(%s) = %v; expected %v", c.in, actual, c.expected) + } + } +} + +func TestGetFileSize(t *testing.T) { + ioutil.WriteFile("./test.txt", []byte("test"), 0666) + var cases = []struct { + in string // input + expectSize int64 // expected result + }{ + {"./test.txt", 4}, + {"./test1.txt", 0}, + } + for _, c := range cases { + + size := GetFileSize(c.in) + assert.Equal(t, size, c.expectSize) + } + os.Remove("./test.txt") +} diff --git a/yuanrong/pkg/common/httputil/utils/utils.go b/yuanrong/pkg/common/httputil/utils/utils.go new file mode 100644 index 0000000..2dfca80 --- /dev/null +++ b/yuanrong/pkg/common/httputil/utils/utils.go @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import "github.com/gin-gonic/gin" + +// ParseHeader 解析请求头 +func ParseHeader(ctx *gin.Context) map[string]string { + if ctx == nil || ctx.Request == nil || len(ctx.Request.Header) == 0 { + return map[string]string{} + } + headers := make(map[string]string) + for key, values := range ctx.Request.Header { + if len(values) > 0 { + headers[key] = values[0] + } + } + return headers +} diff --git a/yuanrong/pkg/common/httputil/utils/utils_test.go b/yuanrong/pkg/common/httputil/utils/utils_test.go new file mode 100644 index 0000000..1d2b988 --- /dev/null +++ b/yuanrong/pkg/common/httputil/utils/utils_test.go @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "net/http" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" +) + +func TestParseHeader(t *testing.T) { + convey.Convey("test ParseHeader", t, func() { + convey.Convey("when header is empty", func() { + var ctx *gin.Context + result := ParseHeader(ctx) + convey.So(result, convey.ShouldBeEmpty) + }) + + convey.Convey("when header is not empty", func() { + ctx := &gin.Context{ + Request: &http.Request{ + Header: map[string][]string{ + "aa": {"bb", "cc"}, + }, + }, + } + result := ParseHeader(ctx) + convey.So(result, convey.ShouldNotBeEmpty) + convey.So(len(result), convey.ShouldEqual, 1) + }) + }) +} diff --git a/yuanrong/pkg/common/job/config.go b/yuanrong/pkg/common/job/config.go new file mode 100644 index 0000000..16be891 --- /dev/null +++ b/yuanrong/pkg/common/job/config.go @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package job - +package job + +import "yuanrong/pkg/common/constants" + +// 处理job的对外接口 +const ( + PathParamSubmissionId = "submissionId" + PathGroupJobs = "/api/jobs" + PathGetJobs = constants.DynamicRouterParamPrefix + PathParamSubmissionId + PathDeleteJobs = constants.DynamicRouterParamPrefix + PathParamSubmissionId + PathStopJobs = constants.DynamicRouterParamPrefix + PathParamSubmissionId + "/stop" +) + +const ( + submissionIdPattern = "^[a-z0-9-]{1,64}$" + jobIDPrefix = "app-" + tenantIdKey = "tenantId" +) + +// Response - +type Response struct { + Code int `form:"code" json:"code"` + Message string `form:"message" json:"message"` + Data []byte `form:"data" json:"data"` +} + +// SubmitRequest is SubmitRequest struct +type SubmitRequest struct { + Entrypoint string `form:"entrypoint" json:"entrypoint"` + SubmissionId string `form:"submission_id" json:"submission_id"` + RuntimeEnv *RuntimeEnv `form:"runtime_env" json:"runtime_env" valid:"optional"` + Metadata map[string]string `form:"metadata" json:"metadata" valid:"optional"` + Labels string `form:"labels" json:"labels" valid:"optional"` + CreateOptions map[string]string `form:"createOptions" json:"createOptions" valid:"optional"` + EntrypointResources map[string]float64 `form:"entrypoint_resources" json:"entrypoint_resources" valid:"optional"` + EntrypointNumCpus float64 `form:"entrypoint_num_cpus" json:"entrypoint_num_cpus" valid:"optional"` + EntrypointNumGpus float64 `form:"entrypoint_num_gpus" json:"entrypoint_num_gpus" valid:"optional"` + EntrypointMemory int `form:"entrypoint_memory" json:"entrypoint_memory" valid:"optional"` +} + +// RuntimeEnv args of invoking create_app +type RuntimeEnv struct { + WorkingDir string `form:"working_dir" json:"working_dir" valid:"optional"` + Pip []string `form:"pip" json:"pip" valid:"optional" ` + EnvVars map[string]string `form:"env_vars" json:"env_vars" valid:"optional"` +} diff --git a/yuanrong/pkg/common/job/handler.go b/yuanrong/pkg/common/job/handler.go new file mode 100644 index 0000000..a519a81 --- /dev/null +++ b/yuanrong/pkg/common/job/handler.go @@ -0,0 +1,297 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package job - +package job + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "net/http" + "regexp" + "strings" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + + "yuanrong/pkg/common/constants" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/httputil/utils" + "yuanrong/pkg/common/uuid" +) + +// SubmitJobHandleReq - +func SubmitJobHandleReq(ctx *gin.Context) *SubmitRequest { + traceID := ctx.Request.Header.Get(constant.HeaderTraceID) + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + var req SubmitRequest + if err := ctx.ShouldBind(&req); err != nil { + logger.Errorf("shouldBind SubmitJob request failed, err: %s", err) + ctx.JSON(http.StatusBadRequest, fmt.Sprintf("shouldBind SubmitJob request failed, err: %v", err)) + return nil + } + err := req.CheckField() + if err != nil { + ctx.JSON(http.StatusBadRequest, err.Error()) + return nil + } + req.EntrypointNumCpus = math.Ceil(req.EntrypointNumCpus * constants.CpuUnitConvert) + req.EntrypointMemory = + int(math.Ceil(float64(req.EntrypointMemory) / constants.MemoryUnitConvert / constants.MemoryUnitConvert)) + reqHeader := utils.ParseHeader(ctx) + if tenantId, ok := reqHeader[constants.HeaderTenantId]; ok { + req.AddCreateOptions(tenantIdKey, tenantId) + } + if labels, ok := reqHeader[constants.HeaderPoolLabel]; ok { + req.Labels = labels + } + logger.Debugf("SubmitJob createApp start, req:%#v", req) + return &req +} + +// SubmitJobHandleRes - +// SubmitJob godoc +// @Summary submit job +// @Description submit a new job +// @Accept json +// @Produce json +// @Router /api/jobs [POST] +// @Param SubmitRequest body SubmitRequest true "提交job时定义的job信息。" +// @Success 200 {object} map[string]string "提交job成功,返回该job的submission_id" +// @Failure 400 {string} string "用户请求错误,包含错误信息" +// @Failure 404 {string} string "该job已经存在" +// @Failure 500 {string} string "服务器处理错误,包含错误信息" +func SubmitJobHandleRes(ctx *gin.Context, resp Response) { + if resp.Code != http.StatusOK || resp.Message != "" { + ctx.JSON(resp.Code, resp.Message) + return + } + var result map[string]string + err := json.Unmarshal(resp.Data, &result) + if err != nil { + ctx.JSON(http.StatusBadRequest, + fmt.Sprintf("unmarshal response data failed, data: %v", resp.Data)) + return + } + ctx.JSON(http.StatusOK, result) + log.GetLogger().Debugf("SubmitJobHandleRes succeed, submission_id: %s", result) +} + +// ListJobsHandleRes - +// ListJobs godoc +// @Summary List Jobs +// @Description list jobs with jobInfo +// @Accept json +// @Produce json +// @Router /api/jobs [GET] +// @Success 200 {array} constant.AppInfo "返回所有jobs的信息" +// @Failure 500 {string} string "服务器处理错误,包含错误信息" +func ListJobsHandleRes(ctx *gin.Context, resp Response) { + traceID := ctx.Request.Header.Get(constant.HeaderTraceID) + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + if resp.Code != http.StatusOK || resp.Message != "" { + ctx.JSON(resp.Code, resp.Message) + return + } + var result []*constant.AppInfo + err := json.Unmarshal(resp.Data, &result) + if err != nil { + ctx.JSON(http.StatusBadRequest, + fmt.Sprintf("unmarshal response data failed, data: %v", resp.Data)) + return + } + ctx.JSON(http.StatusOK, result) + logger.Debugf("ListJobsHandleRes succeed") +} + +// GetJobInfoHandleRes - +// GetJobInfo godoc +// @Summary Get JobInfo +// @Description get jobInfo by submission_id +// @Accept json +// @Produce json +// @Router /api/jobs/{submissionId} [GET] +// @Param submissionId path string true "job的submission_id,以'app-'开头" +// @Success 200 {object} constant.AppInfo "返回submission_id对应的job信息" +// @Failure 404 {string} string "该job不存在" +// @Failure 500 {string} string "服务器处理错误,包含错误信息" +func GetJobInfoHandleRes(ctx *gin.Context, resp Response) { + submissionId := ctx.Param(PathParamSubmissionId) + logger := log.GetLogger().With(zap.Any("SubmissionId", submissionId)) + if resp.Code != http.StatusOK || resp.Message != "" { + ctx.JSON(resp.Code, resp.Message) + return + } + var result *constant.AppInfo + err := json.Unmarshal(resp.Data, &result) + if err != nil { + ctx.JSON(http.StatusBadRequest, + fmt.Sprintf("unmarshal response data failed, data: %v", resp.Data)) + return + } + ctx.JSON(http.StatusOK, result) + logger.Debugf("GetJobInfoHandleRes succeed") +} + +// DeleteJobHandleRes - +// DeleteJob godoc +// @Summary Delete Job +// @Description delete job by submission_id +// @Accept json +// @Produce json +// @Router /api/jobs/{submissionId} [DELETE] +// @Param submissionId path string true "job的submission_id,以'app-'开头" +// @Success 200 {boolean} bool "返回true则说明可以删除对应的job,返回false则说明无法删除job" +// @Failure 403 {string} string "禁止删除job,包含错误信息和job运行状态" +// @Failure 404 {string} string "该job不存在" +// @Failure 500 {string} string "服务器处理错误,包含错误信息" +func DeleteJobHandleRes(ctx *gin.Context, resp Response) { + submissionId := ctx.Param(PathParamSubmissionId) + logger := log.GetLogger().With(zap.Any("SubmissionId", submissionId)) + if resp.Code == http.StatusForbidden { + log.GetLogger().Errorf("forbidden to delete, status: %s", resp.Data) + ctx.JSON(http.StatusOK, false) + return + } + if resp.Code != http.StatusOK || resp.Message != "" { + ctx.JSON(resp.Code, resp.Message) + return + } + ctx.JSON(http.StatusOK, true) + logger.Debugf("DeleteJobHandleRes succeed") +} + +// StopJobHandleRes - +// StopJob godoc +// @Summary Stop Job +// @Description stop job by submission_id +// @Accept json +// @Produce json +// @Router /api/jobs/{submissionId}/stop [POST] +// @Param submissionId path string true "job的submission_id,以'app-'开头" +// @Success 200 {boolean} bool "返回true表示可以停止运行对应的job,返回false表示job当前状态不能被停止" +// @Failure 403 {string} string "禁止删除job,包含错误信息和job运行状态" +// @Failure 404 {string} string "该job不存在" +// @Failure 500 {string} string "服务器处理错误,包含错误信息" +func StopJobHandleRes(ctx *gin.Context, resp Response) { + submissionId := ctx.Param(PathParamSubmissionId) + logger := log.GetLogger().With(zap.Any("SubmissionId", submissionId)) + if resp.Code == http.StatusForbidden { + log.GetLogger().Errorf("forbidden to stop job, status: %s", resp.Data) + ctx.JSON(http.StatusOK, false) + return + } + if resp.Code != http.StatusOK || resp.Message != "" { + ctx.JSON(resp.Code, resp.Message) + return + } + ctx.JSON(http.StatusOK, true) + logger.Debugf("StopJobHandleRes succeed") +} + +// CheckField - +func (req *SubmitRequest) CheckField() error { + if req.Entrypoint == "" { + log.GetLogger().Errorf("entrypoint should not be empty") + return fmt.Errorf("entrypoint should not be empty") + } + if req.RuntimeEnv == nil || req.RuntimeEnv.WorkingDir == "" { + log.GetLogger().Errorf("runtime_env.working_dir should not be empty") + return fmt.Errorf("runtime_env.working_dir should not be empty") + } + if err := req.ValidateResources(); err != nil { + log.GetLogger().Errorf("validateResources error: %s", err.Error()) + return err + } + if err := req.CheckSubmissionId(); err != nil { + log.GetLogger().Errorf("chechk submission_id: %s, error: %s", req.SubmissionId, err.Error()) + return err + } + return nil +} + +// ValidateResources - +func (req *SubmitRequest) ValidateResources() error { + if req.EntrypointNumCpus < 0 { + return errors.New("entrypoint_num_cpus should not be less than 0") + } + if req.EntrypointNumGpus < 0 { + return errors.New("entrypoint_num_gpus should not be less than 0") + } + if req.EntrypointMemory < 0 { + return errors.New("entrypoint_memory should not be less than 0") + } + return nil +} + +// CheckSubmissionId - +func (req *SubmitRequest) CheckSubmissionId() error { + if req.SubmissionId == "" { + return nil + } + if strings.Contains(req.SubmissionId, "driver") { + return errors.New("submission_id should not contain 'driver'") + } + if !strings.HasPrefix(req.SubmissionId, jobIDPrefix) { + req.SubmissionId = jobIDPrefix + req.SubmissionId + } + isMatch, err := regexp.MatchString(submissionIdPattern, req.SubmissionId) + if err != nil || !isMatch { + return fmt.Errorf("regular expression validation error, submissionId: %s, pattern: %s, err: %v", + req.SubmissionId, submissionIdPattern, err) + } + return nil +} + +// NewSubmissionID - +func (req *SubmitRequest) NewSubmissionID() { + if req.SubmissionId == "" { + req.SubmissionId = jobIDPrefix + uuid.New().String() + } +} + +// AddCreateOptions - +func (req *SubmitRequest) AddCreateOptions(key, value string) { + if req.CreateOptions == nil { + req.CreateOptions = map[string]string{} + } + if key != "" { + req.CreateOptions[key] = value + } +} + +// BuildJobResponse - +func BuildJobResponse(data any, code int, err error) Response { + dataBytes, jsonErr := json.Marshal(data) + if jsonErr != nil { + return Response{ + Code: http.StatusInternalServerError, + Message: fmt.Sprintf("marshal job response failed, err: %v", jsonErr), + } + } + var resp Response + resp.Code = code + if data != nil { + resp.Data = dataBytes + } + if err != nil { + resp.Message = err.Error() + } + return resp +} diff --git a/yuanrong/pkg/common/job/handler_test.go b/yuanrong/pkg/common/job/handler_test.go new file mode 100644 index 0000000..eb5e637 --- /dev/null +++ b/yuanrong/pkg/common/job/handler_test.go @@ -0,0 +1,587 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package job + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/constants" + "yuanrong/pkg/common/faas_common/constant" +) + +func TestSubmitJobHandleReq(t *testing.T) { + convey.Convey("test DeleteJobHandleRes", t, func() { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + bodyBytes, _ := json.Marshal(SubmitRequest{ + Entrypoint: "", + SubmissionId: "", + RuntimeEnv: &RuntimeEnv{ + WorkingDir: "", + Pip: []string{""}, + EnvVars: map[string]string{}, + }, + Metadata: map[string]string{}, + EntrypointResources: map[string]float64{}, + EntrypointNumCpus: 0, + EntrypointNumGpus: 0, + EntrypointMemory: 0, + }) + reader := bytes.NewBuffer(bodyBytes) + c.Request = &http.Request{ + Method: "POST", + URL: &url.URL{Path: PathGroupJobs}, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + constants.HeaderTenantID: []string{"123456"}, + constants.HeaderPoolLabel: []string{"abc"}, + }, + Body: io.NopCloser(reader), // 使用 io.NopCloser 包装 reader,使其满足 io.ReadCloser 接口 + } + convey.Convey("when process success", func() { + defer gomonkey.ApplyMethodFunc(&SubmitRequest{}, "CheckField", func() error { + return nil + }).Reset() + expectedResult := &SubmitRequest{ + Entrypoint: "", + SubmissionId: "", + RuntimeEnv: &RuntimeEnv{ + WorkingDir: "", + Pip: []string{""}, + EnvVars: map[string]string{}, + }, + Metadata: map[string]string{}, + Labels: "abc", + CreateOptions: map[string]string{ + "tenantId": "123456", + }, + EntrypointResources: map[string]float64{}, + EntrypointNumCpus: 0, + EntrypointNumGpus: 0, + EntrypointMemory: 0, + } + result := SubmitJobHandleReq(c) + convey.So(result, convey.ShouldResemble, expectedResult) + }) + convey.Convey("when CheckField failed", func() { + defer gomonkey.ApplyMethodFunc(&SubmitRequest{}, "CheckField", func() error { + return errors.New("failed CheckField") + }).Reset() + result := SubmitJobHandleReq(c) + convey.So(result, convey.ShouldBeNil) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldEqual, "\"failed CheckField\"") + }) + }) +} + +func TestSubmitJobHandleRes(t *testing.T) { + convey.Convey("test SubmitJobHandleRes", t, func() { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + resp := Response{ + Code: http.StatusOK, + Message: "", + Data: []byte("app-123"), + } + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusNotFound), func() { + resp.Code = http.StatusNotFound + resp.Message = fmt.Sprintf("not found job") + SubmitJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusNotFound) + convey.So(w.Body.String(), convey.ShouldEqual, "\"not found job\"") + }) + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusInternalServerError), func() { + resp.Code = http.StatusInternalServerError + resp.Message = fmt.Sprintf("failed get job") + SubmitJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(w.Body.String(), convey.ShouldEqual, "\"failed get job\"") + }) + convey.Convey("when response data is nil", func() { + resp.Data = nil + SubmitJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldStartWith, "\"unmarshal response data failed, data:") + }) + convey.Convey("when process success", func() { + marshal, err := json.Marshal(map[string]string{ + "submission_id": "app-123", + }) + resp.Data = marshal + convey.So(err, convey.ShouldBeNil) + SubmitJobHandleRes(c, resp) + expectedResult, err := json.Marshal(map[string]string{ + "submission_id": "app-123", + }) + if err != nil { + t.Errorf("marshal expected result failed, err: %v", err) + } + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldResemble, string(expectedResult)) + }) + }) +} + +func TestListJobsHandleRes(t *testing.T) { + convey.Convey("test ListJobsHandleRes", t, func() { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = &http.Request{ + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + } + dataBytes, err := json.Marshal([]*constant.AppInfo{ + { + Type: "SUBMISSION", + Entrypoint: "python script.py", + SubmissionID: "app-123", + }, + }) + if err != nil { + t.Errorf("marshal expected result failed, err: %v", err) + } + resp := Response{ + Code: http.StatusOK, + Message: "", + Data: dataBytes, + } + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusInternalServerError), func() { + resp.Code = http.StatusInternalServerError + resp.Message = fmt.Sprintf("failed get job") + ListJobsHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(w.Body.String(), convey.ShouldEqual, "\"failed get job\"") + }) + convey.Convey("when unmarshal response data failed", func() { + resp.Data = []byte(",aa,") + ListJobsHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldStartWith, "\"unmarshal response data failed") + }) + convey.Convey("when response data is nil", func() { + resp.Data = []byte("[]") + ListJobsHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldEqual, "[]") + }) + convey.Convey("when process success", func() { + ListJobsHandleRes(c, resp) + expectedResult, err := json.Marshal([]*constant.AppInfo{ + { + Type: "SUBMISSION", + Entrypoint: "python script.py", + SubmissionID: "app-123", + }, + }) + if err != nil { + t.Errorf("marshal expected result failed, err: %v", err) + } + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldResemble, string(expectedResult)) + }) + }) +} + +func TestGetJobInfoHandleRes(t *testing.T) { + convey.Convey("test GetJobInfoHandleRes", t, func() { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + dataBytes, err := json.Marshal(&constant.AppInfo{ + Type: "SUBMISSION", + Entrypoint: "python script.py", + SubmissionID: "app-123", + }) + if err != nil { + t.Errorf("marshal expected result failed, err: %v", err) + } + resp := Response{ + Code: http.StatusOK, + Message: "", + Data: dataBytes, + } + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusNotFound), func() { + resp.Code = http.StatusNotFound + resp.Message = fmt.Sprintf("not found job") + GetJobInfoHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusNotFound) + convey.So(w.Body.String(), convey.ShouldEqual, "\"not found job\"") + }) + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusInternalServerError), func() { + resp.Code = http.StatusInternalServerError + resp.Message = fmt.Sprintf("failed get job") + GetJobInfoHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(w.Body.String(), convey.ShouldEqual, "\"failed get job\"") + }) + convey.Convey("when unmarshal response data failed", func() { + resp.Data = []byte(",aa,") + GetJobInfoHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldStartWith, "\"unmarshal response data failed") + }) + convey.Convey("when response data is nil", func() { + resp.Data = nil + GetJobInfoHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldStartWith, "\"unmarshal response data failed") + }) + convey.Convey("when process success", func() { + GetJobInfoHandleRes(c, resp) + expectedResult, err := json.Marshal(&constant.AppInfo{ + Type: "SUBMISSION", + Entrypoint: "python script.py", + SubmissionID: "app-123", + }) + if err != nil { + t.Errorf("marshal expected result failed, err: %v", err) + } + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldResemble, string(expectedResult)) + }) + }) +} + +func TestDeleteJobHandleRes(t *testing.T) { + convey.Convey("test DeleteJobHandleRes", t, func() { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + resp := Response{ + Code: http.StatusOK, + Message: "", + Data: []byte("SUCCEEDED"), + } + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusForbidden), func() { + resp.Code = http.StatusForbidden + DeleteJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldEqual, "false") + }) + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusBadRequest), func() { + resp.Code = http.StatusBadRequest + resp.Message = fmt.Sprintf("failed delete job") + DeleteJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldEqual, "\"failed delete job\"") + }) + convey.Convey("when response data is nil", func() { + resp.Data = nil + DeleteJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldEqual, "true") + }) + convey.Convey("when process success", func() { + DeleteJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldEqual, "true") + }) + }) +} + +func TestStopJobHandleRes(t *testing.T) { + convey.Convey("test StopJobHandleRes", t, func() { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + resp := Response{ + Code: http.StatusOK, + Message: "", + Data: []byte(`SUCCEEDED`), + } + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusForbidden), func() { + resp.Code = http.StatusForbidden + StopJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldEqual, "false") + }) + convey.Convey("when statusCode is "+strconv.Itoa(http.StatusBadRequest), func() { + resp.Code = http.StatusBadRequest + resp.Message = fmt.Sprintf("failed stop job") + StopJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldEqual, "\"failed stop job\"") + }) + convey.Convey("when response data is nil", func() { + resp.Data = nil + StopJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldEqual, "true") + }) + convey.Convey("when process success", func() { + StopJobHandleRes(c, resp) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(w.Body.String(), convey.ShouldEqual, "true") + }) + }) +} + +func TestSubmitRequest_CheckField(t *testing.T) { + convey.Convey("test (req *SubmitRequest) CheckField", t, func() { + req := &SubmitRequest{ + Entrypoint: "python script.py", + SubmissionId: "", + RuntimeEnv: &RuntimeEnv{ + WorkingDir: "file:///home/disk/tk/file.zip", + Pip: []string{"numpy==1.24", "scipy==1.11.0"}, + EnvVars: map[string]string{ + "SOURCE_REGION": "suzhou_std", + }, + }, + Metadata: map[string]string{ + "autoscenes_ids": "auto_1-test", + "task_type": "task_1", + "ttl": "1250", + }, + EntrypointResources: map[string]float64{ + "NPU": 0, + }, + EntrypointNumCpus: 0, + EntrypointNumGpus: 0, + EntrypointMemory: 0, + } + convey.Convey("when req.Entrypoint is empty", func() { + req.Entrypoint = "" + err := req.CheckField() + convey.So(err, convey.ShouldBeError, errors.New("entrypoint should not be empty")) + }) + convey.Convey("when req.RuntimeEnv is empty", func() { + req.RuntimeEnv = nil + err := req.CheckField() + convey.So(err, convey.ShouldBeError, errors.New("runtime_env.working_dir should not be empty")) + }) + convey.Convey("when req.RuntimeEnv.WorkingDir is empty", func() { + req.RuntimeEnv.WorkingDir = "" + err := req.CheckField() + convey.So(err, convey.ShouldBeError, errors.New("runtime_env.working_dir should not be empty")) + }) + convey.Convey("when ValidateResources failed", func() { + defer gomonkey.ApplyMethodFunc(&SubmitRequest{}, "ValidateResources", func() error { + return errors.New("failed ValidateResources") + }).Reset() + defer gomonkey.ApplyMethodFunc(&SubmitRequest{}, "CheckSubmissionId", func() error { + return nil + }).Reset() + err := req.CheckField() + convey.So(err, convey.ShouldBeError, errors.New("failed ValidateResources")) + }) + convey.Convey("when CheckSubmissionId failed", func() { + defer gomonkey.ApplyMethodFunc(&SubmitRequest{}, "ValidateResources", func() error { + return nil + }).Reset() + defer gomonkey.ApplyMethodFunc(&SubmitRequest{}, "CheckSubmissionId", func() error { + return errors.New("failed CheckSubmissionId") + }).Reset() + err := req.CheckField() + convey.So(err, convey.ShouldBeError, errors.New("failed CheckSubmissionId")) + }) + convey.Convey("when process success", func() { + defer gomonkey.ApplyMethodFunc(&SubmitRequest{}, "ValidateResources", func() error { + return nil + }).Reset() + defer gomonkey.ApplyMethodFunc(&SubmitRequest{}, "CheckSubmissionId", func() error { + return nil + }).Reset() + err := req.CheckField() + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestSubmitRequest_ValidateResources(t *testing.T) { + convey.Convey("test (req *SubmitRequest) ValidateResources()", t, func() { + req := &SubmitRequest{ + Entrypoint: "python script.py", + SubmissionId: "", + EntrypointResources: map[string]float64{ + "NPU": 0, + }, + EntrypointNumCpus: 0, + EntrypointNumGpus: 0, + EntrypointMemory: 0, + } + convey.Convey("when req.EntrypointNumCpus < 0", func() { + req.EntrypointNumCpus = -0.1 + err := req.ValidateResources() + convey.So(err.Error(), convey.ShouldEqual, "entrypoint_num_cpus should not be less than 0") + }) + convey.Convey("when req.EntrypointNumGpus < 0", func() { + req.EntrypointNumGpus = -0.1 + err := req.ValidateResources() + convey.So(err.Error(), convey.ShouldEqual, "entrypoint_num_gpus should not be less than 0") + }) + convey.Convey("when req.EntrypointMemory < 0", func() { + req.EntrypointMemory = -1 + err := req.ValidateResources() + convey.So(err.Error(), convey.ShouldEqual, "entrypoint_memory should not be less than 0") + }) + convey.Convey("when process success", func() { + err := req.ValidateResources() + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestSubmitRequest_CheckSubmissionId(t *testing.T) { + convey.Convey("test (req *SubmitRequest) CheckSubmissionId()", t, func() { + req := &SubmitRequest{ + Entrypoint: "python script.py", + SubmissionId: "123", + } + convey.Convey("when req.SubmissionId is empty", func() { + req.SubmissionId = "" + err := req.CheckSubmissionId() + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("when req.SubmissionId start with driver", func() { + req.SubmissionId = "driver-123" + err := req.CheckSubmissionId() + convey.So(err.Error(), convey.ShouldEqual, "submission_id should not contain 'driver'") + }) + convey.Convey("when req.SubmissionId doesn't start with 'app-'", func() { + err := req.CheckSubmissionId() + convey.So(err, convey.ShouldBeNil) + convey.So(req.SubmissionId, convey.ShouldStartWith, jobIDPrefix) + }) + convey.Convey("when req.SubmissionId length is 60 without 'app-'", func() { + req.SubmissionId = "023456781234567822345678323456784234567852345678623456787234" + err := req.CheckSubmissionId() + convey.So(err, convey.ShouldBeNil) + convey.So(req.SubmissionId, convey.ShouldStartWith, jobIDPrefix) + }) + convey.Convey("when req.SubmissionId length is more than 60 without 'app-'", func() { + req.SubmissionId = "0234567812345678223456783234567842345678523456786234567872345" + err := req.CheckSubmissionId() + convey.So(err.Error(), convey.ShouldStartWith, "regular expression validation error,") + convey.So(req.SubmissionId, convey.ShouldStartWith, jobIDPrefix) + }) + convey.Convey("when req.SubmissionId length is 64 with 'app-'", func() { + req.SubmissionId = "app-023456781234567822345678323456784234567852345678623456787234" + err := req.CheckSubmissionId() + convey.So(err, convey.ShouldBeNil) + convey.So(req.SubmissionId, convey.ShouldStartWith, jobIDPrefix) + }) + convey.Convey("when req.SubmissionId length is more than 64 with 'app-'", func() { + req.SubmissionId = "app-0234567812345678223456783234567842345678523456786234567872345" + err := req.CheckSubmissionId() + convey.So(err.Error(), convey.ShouldStartWith, "regular expression validation error,") + convey.So(req.SubmissionId, convey.ShouldStartWith, jobIDPrefix) + }) + convey.Convey("when process success", func() { + err := req.CheckSubmissionId() + convey.So(err, convey.ShouldBeNil) + convey.So(req.SubmissionId, convey.ShouldStartWith, jobIDPrefix) + }) + }) +} + +func TestSubmitRequest_NewSubmissionID(t *testing.T) { + convey.Convey("test (req *SubmitRequest) NewSubmissionID()", t, func() { + req := &SubmitRequest{ + Entrypoint: "python script.py", + SubmissionId: "", + } + convey.Convey("when req.SubmissionId is empty", func() { + req.NewSubmissionID() + convey.So(req.SubmissionId, convey.ShouldNotBeEmpty) + convey.So(req.SubmissionId, convey.ShouldStartWith, jobIDPrefix) + }) + }) +} + +func TestSubmitRequest_AddCreateOptions(t *testing.T) { + convey.Convey("test (req *SubmitRequest) AddCreateOptions()", t, func() { + req := &SubmitRequest{ + Entrypoint: "python script.py", + SubmissionId: "123", + } + convey.Convey("when req.CreateOptions is empty", func() { + req.AddCreateOptions("key", "value") + convey.So(len(req.CreateOptions), convey.ShouldEqual, 1) + }) + convey.Convey("when key is empty", func() { + req.AddCreateOptions("", "value") + convey.So(len(req.CreateOptions), convey.ShouldEqual, 0) + }) + convey.Convey("when key is not empty", func() { + req.AddCreateOptions("key", "value") + convey.So(len(req.CreateOptions), convey.ShouldEqual, 1) + }) + }) +} + +func TestBuildJobResponse(t *testing.T) { + convey.Convey("test BuildJobResponse", t, func() { + convey.Convey("when process success", func() { + expectedResult := Response{ + Code: 0, + Message: "", + Data: []byte("test"), + } + result := BuildJobResponse("test", 0, nil) + convey.So(result.Code, convey.ShouldEqual, expectedResult.Code) + convey.So(result.Message, convey.ShouldEqual, expectedResult.Message) + convey.So(string(result.Data), convey.ShouldEqual, "\""+string(expectedResult.Data)+"\"") + }) + convey.Convey("when data is nil", func() { + expectedResult := Response{ + Code: http.StatusOK, + Message: "", + Data: nil, + } + result := BuildJobResponse(nil, http.StatusOK, nil) + convey.So(result, convey.ShouldResemble, expectedResult) + }) + convey.Convey("when response status is "+strconv.Itoa(http.StatusBadRequest), func() { + expectedResult := Response{ + Code: http.StatusBadRequest, + Message: "error request", + Data: nil, + } + result := BuildJobResponse(nil, http.StatusBadRequest, errors.New("error request")) + convey.So(result, convey.ShouldResemble, expectedResult) + }) + convey.Convey("when data marshal failed", func() { + expectedResult := Response{ + Code: http.StatusInternalServerError, + Message: "marshal job response failed, err:", + } + result := BuildJobResponse(func() {}, http.StatusOK, nil) + convey.So(result.Code, convey.ShouldEqual, expectedResult.Code) + convey.So(result.Message, convey.ShouldStartWith, expectedResult.Message) + convey.So(result.Data, convey.ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/common/protobuf/adaptor.proto b/yuanrong/pkg/common/protobuf/adaptor.proto new file mode 100644 index 0000000..d70b2ab --- /dev/null +++ b/yuanrong/pkg/common/protobuf/adaptor.proto @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +import "invoke.proto"; +import "get.proto"; +import "filter.proto"; +import "deadlock.proto"; +import "readstate.proto"; +import "savestate.proto"; +import "settimeout.proto"; +import "terminate.proto"; +import "wait.proto"; + +package adaptor; + +option go_package = "./adaptor"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "CallRequest"; + +message CallRequestPb { + int32 method = 1; + int32 source = 2; + repeated string reqIDList = 3; + oneof data { + invoke.RequestPb invokeReq = 4; + invoke.TaskRequestPb invokeTaskReq = 5; + invoke.TaskCrateStateQueueRequestPb invokeTaskCreateStateReq = 6; + + invoke.OnCompleteRequestPb onCompleteReq = 7; + + filter.RequestPb filterReq = 8; + + getfuture.RequestPb getReq = 9; + getfuture.TaskRequestPb getTaskReq = 10; + + readstate.RequestPb readReq = 11; + + savestate.RequestPb saveReq = 12; + + settimeout.RequestPb setTimeoutReq = 13; + + terminate.RequestPb terminateReq = 14; + + waitfuture.RequestPb waitReq = 16; + + deadlock.StartRequestPb deadlockStartReq = 17; + deadlock.EndRequestPb deadlockEndReq = 18; + } +} diff --git a/yuanrong/pkg/common/protobuf/bus.proto b/yuanrong/pkg/common/protobuf/bus.proto new file mode 100644 index 0000000..a7ba6fd --- /dev/null +++ b/yuanrong/pkg/common/protobuf/bus.proto @@ -0,0 +1,132 @@ +syntax = "proto3"; + +option go_package = "./;pb"; + +message FunctionProperty { + string TenantID = 1; + string FunctionName = 2; + string FunctionVersion = 3; + string InstanceID = 4; + string StateID = 5; + string futureID = 7; +} + +message CallRequest { + // log type + string LogType = 1; + // request data + bytes RawData = 2; + // request ID + uint64 RequestID = 3; + // userData + map UserData = 4; + // Priority + int32 Priority = 5; + // cpu size + string CPU = 6; + // memory size + string Memory = 7; + // affinity info + AffinityInfoPb AffinityInfoPb = 8; + // group info + GroupInfo groupInfo = 9; + // resource MetaData + map ResourceMetaData = 10; +} + +message CallResponse { + // Success: 0, Failed for others + uint32 ErrorCode = 1; + // Message for error + string ErrorMessage = 2; + // response data + bytes RawData = 3; + // logs + string Logs = 4; + // request ID + uint64 RequestID = 5; + // summary + string Summary = 6; +} + +message Message { + enum MessageType { + REGISTER = 0; + CALL_REQUEST = 1; + CALL_RESPONSE = 2; + CANCEL_REQUEST = 3; + CANCEL_RESPONSE = 4; + GROUP_REQUEST = 5; + GROUP_RESPONSE = 6; + CHECKPOINT_REQUEST = 7; + CHECKPOINT_RESPONSE = 8; + } + message CallMessage { + string traceID = 1; + FunctionProperty dst = 2; + CallRequest req = 3; + } + message CancelMessage { + string traceID = 1; + uint64 RequestID = 2; + string futureID = 3; + bool force = 4; + bool recursive = 5; + string instanceID = 6; + } + message CancelResponse { + // Success: 0, Failed for others + uint32 ErrorCode = 1; + // Message for error + string ErrorMessage = 2; + } + message GroupRequest { + string traceID = 1; + string businessID = 2; + string tenantID = 3; + string groupID = 4; + string operation = 5; + bytes data = 6; + } + message GroupResponse { + string traceID = 1; + uint32 errorCode = 2; + string errorMessage = 3; + bytes result = 4; + } + message CheckpointRequest { + string traceID = 1; + string stateID = 2; + string instanceID = 3; + } + message CheckpointResponse { + uint32 errorCode = 1; + string errorMessage = 2; + bytes rawData = 3; + } + MessageType type = 1; + FunctionProperty registerMessage = 2; + CallMessage callMessage = 3; + CallResponse callResponse = 4; + CancelMessage cancelMessage = 5; + CancelResponse cancelResponse = 6; + GroupRequest groupRequest = 7; + GroupResponse groupResponse = 8; + CheckpointRequest checkpointRequest = 9; + CheckpointResponse checkpointResponse = 10; +} + +message AffinityInfoPb { + AffinityRequestPb AffinityRequest = 1; + string AffinityNode = 2; +} + +message AffinityRequestPb { + repeated string ObjectIDs = 1; + string Strategy = 2; +} + +message GroupInfo { + string groupID = 1; + string stackID = 2; +} diff --git a/yuanrong/pkg/common/protobuf/callMessage.proto b/yuanrong/pkg/common/protobuf/callMessage.proto new file mode 100644 index 0000000..2dcdb63 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/callMessage.proto @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package common; + +option go_package = "./common"; + +message CallRequest { + // The native event from user + bytes event = 1; + // "state" can be null + State state = 2; + string property = 4; +} + +message CallResult { + bytes result = 1; + // "state" can be null + State state = 2; + string property = 4; +} + +message State { + // State id + string id = 1; + // State content + string content = 2; +} diff --git a/yuanrong/pkg/common/protobuf/deadlock.proto b/yuanrong/pkg/common/protobuf/deadlock.proto new file mode 100644 index 0000000..0b30df9 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/deadlock.proto @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package deadlock; + +option go_package = "api/deadlock"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "Deadlock"; + +message ReqIDListPb { + repeated string reqIDs = 1; +} + +message StartRequestPb { + string stateID = 1; + string state = 2; + repeated ReqIDListPb futureReqIDLists = 3; + string tenantID = 4; + string serviceID = 5; +} + +message StartResponsePb { + int32 code = 1; + string message = 2; + string deadlockID = 3; +} + +message EndRequestPb { + string stateID = 1; + string deadlockID = 2; + string tenantID = 4; + string serviceID = 5; +} + +message EndResponsePb { + int32 code = 1; + string message = 2; + string state = 3; +} diff --git a/yuanrong/pkg/common/protobuf/error.proto b/yuanrong/pkg/common/protobuf/error.proto new file mode 100644 index 0000000..5888a6d --- /dev/null +++ b/yuanrong/pkg/common/protobuf/error.proto @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package common; + +option go_package = "./common"; + +message ErrorResponsePb { + int32 code = 1; + string message = 2; +} diff --git a/yuanrong/pkg/common/protobuf/filter.proto b/yuanrong/pkg/common/protobuf/filter.proto new file mode 100644 index 0000000..88ef39c --- /dev/null +++ b/yuanrong/pkg/common/protobuf/filter.proto @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package filter; + +option go_package = "api/filter"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "Filter"; + +message RequestPb { + string key = 1; +} + +message ResponsePb { + int32 code = 1; + string message = 2; + string stateID = 3; +} diff --git a/yuanrong/pkg/common/protobuf/get.proto b/yuanrong/pkg/common/protobuf/get.proto new file mode 100644 index 0000000..9bd8ff1 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/get.proto @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package getfuture; + +option go_package = "api/getfuture"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "Get"; + +message RequestPb { + string futureID = 1; + string stateID = 2; + string state = 3; +} + +message ResponsePb { + int32 code = 1; + string message = 2; + string result = 3; + string state = 4; + bool updateState = 5; +} + +message TaskRequestPb { + string futureID = 1; + string stateID = 2; +} + +message TaskResponsePb { + int32 code = 1; + string message = 2; + string result = 3; + string log = 4; +} diff --git a/yuanrong/pkg/common/protobuf/health/health_service.proto b/yuanrong/pkg/common/protobuf/health/health_service.proto new file mode 100644 index 0000000..e9dfce4 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/health/health_service.proto @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package health_service; + +option go_package = "grpc/pb/health;health"; + +// common service provides APIs +service HealthService { + // health check + rpc Readiness (ReadinessRequest) returns (ReadinessResponse) {} +} + +message ReadinessRequest { +} + +message ReadinessResponse { + string nodeID=1; +} diff --git a/yuanrong/pkg/common/protobuf/invoke.proto b/yuanrong/pkg/common/protobuf/invoke.proto new file mode 100644 index 0000000..cea1d14 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/invoke.proto @@ -0,0 +1,84 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package invoke; + +option go_package = "api/invoke"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "Invoke"; + +message RequestPb { + string funcName = 1; + string version = 2; + string payload = 3; + bool byKey = 4; + string key = 5; + string stateID = 6; + int32 countOfFuncName = 7; + map requestAttributes = 9; + map resourceMetaData = 10; + AffinityRequestPb affinityRequestPb = 11; + GroupInfo groupInfo = 12; + string invokeType = 13; +} + +message ResponsePb { + int32 code = 1; + string message = 2; + string stateID = 3; + string futureID = 4; +} + +message TaskRequestPb { + string funcName = 1; + string version = 2; + string payload = 3; + string stateID = 4; + string futureID = 5; + bool createStateQueue = 6; + string stateKey = 7; + string invokeType = 8; +} + +message TaskResponsePb { + int32 code = 1; + string message = 2; +} + +message TaskCrateStateQueueRequestPb { + string stateID = 1; +} + +message OnCompleteRequestPb { + string funcName = 1; + string version = 2; + string stateID = 3; + string futureID = 4; + int32 countOfFuncName = 5; +} + +message AffinityRequestPb { + repeated string objectIDs = 1; + string strategy = 2; +} + +message GroupInfo { + string groupID = 1; + string stackID = 2; +} diff --git a/yuanrong/pkg/common/protobuf/readstate.proto b/yuanrong/pkg/common/protobuf/readstate.proto new file mode 100644 index 0000000..94b7073 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/readstate.proto @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package readstate; + +option go_package = "api/readstate"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "ReadState"; + +message RequestPb { + bool byKey = 1; + string key = 2; + string stateID = 3; +} + +message ResponsePb { + int32 code = 1; + string message = 2; + string state = 3; +} diff --git a/yuanrong/pkg/common/protobuf/rpc/bus_service.proto b/yuanrong/pkg/common/protobuf/rpc/bus_service.proto new file mode 100644 index 0000000..2d1a37e --- /dev/null +++ b/yuanrong/pkg/common/protobuf/rpc/bus_service.proto @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package bus_service; + +import "common.proto"; + +option go_package = "grpc/pb/bus;bus"; + +// bus service provides APIs to runtime, +service BusService { + // notify bus to connect frontend + rpc DiscoverFrontend (DiscoverFrontendRequest) returns (DiscoverFrontendResponse) {} + // query instance info from frontend + rpc QueryInstance (QueryInstanceRequest) returns (QueryInstanceResponse) {} + // notify bus to connect driver + rpc DiscoverDriver (DiscoverDriverRequest) returns (DiscoverDriverResponse) {} +} + +message DiscoverDriverRequest { + string driverIP = 1; + string driverPort = 2; + string jobID = 3; +} + +message DiscoverDriverResponse {} + +message DiscoverFrontendRequest { + string frontendIP = 1; + string frontendPort = 2; +} + +message DiscoverFrontendResponse {} + +message QueryInstanceRequest { + string instanceID = 1; +} + +message QueryInstanceResponse { + common.ErrorCode code = 1; + string message = 2; + string status = 3; +} diff --git a/yuanrong/pkg/common/protobuf/rpc/common.proto b/yuanrong/pkg/common/protobuf/rpc/common.proto new file mode 100644 index 0000000..e72d4f5 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/rpc/common.proto @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package common; + +option go_package = "grpc/pb/common;common"; + +message Arg { + enum ArgType { + VALUE = 0; + OBJECT_REF = 1; + } + ArgType type = 1; + bytes value = 2; + repeated string nested_refs = 3; +} + +enum ErrorCode { + ERR_NONE = 0; + ERR_PARAM_INVALID = 1001; + ERR_RESOURCE_NOT_ENOUGH = 1002; + ERR_INSTANCE_NOT_FOUND = 1003; + ERR_INSTANCE_DUPLICATED = 1004; + ERR_INVOKE_RATE_LIMITED = 1005; + ERR_RESOURCE_CONFIG_ERROR = 1006; + ERR_INSTANCE_EXITED = 1007; + ERR_EXTENSION_META_ERROR = 1008; + ERR_USER_CODE_LOAD = 2001; + ERR_USER_FUNCTION_EXCEPTION = 2002; + ERR_REQUEST_BETWEEN_RUNTIME_BUS = 3001; + ERR_INNER_COMMUNICATION = 3002; + ERR_INNER_SYSTEM_ERROR = 3003; + ERR_DISCONNECT_FRONTEND_BUS = 3004; + ERR_ETCD_OPERATION_ERROR = 3005; + ERR_BUS_DISCONNECTION = 3006; + ERR_REDIS_OPERATION_ERROR = 3007; + ERR_NPU_FAULT_ERROR = 3016; +} diff --git a/yuanrong/pkg/common/protobuf/rpc/core_service.proto b/yuanrong/pkg/common/protobuf/rpc/core_service.proto new file mode 100644 index 0000000..3d872d9 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/rpc/core_service.proto @@ -0,0 +1,158 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package core_service; + +import "common.proto"; + +option go_package = "grpc/pb/core;core"; + +// Core service provides APIs to runtime, +service CoreService { + // Create an instance for specify function + rpc Create (CreateRequest) returns (CreateResponse) {} + // invoke the created instance + rpc Invoke (InvokeRequest) returns (InvokeResponse) {} + // terminate the created instance + rpc Terminate (TerminateRequest) returns (TerminateResponse) {} + // exit the created instance + rpc Exit (ExitRequest) returns (ExitResponse) {} + // save state of the created instance + rpc SaveState (StateSaveRequest) returns (StateSaveResponse) {} + // load state of the created instance + rpc LoadState (StateLoadRequest) returns (StateLoadResponse) {} + // Kill the signal to instance + rpc Kill (KillRequest) returns (KillResponse) {} +} + +enum AffinityType { + PreferredAffinity = 0; + PreferredAntiAffinity = 1; + RequiredAffinity = 2; + RequiredAntiAffinity = 3; +} + +message SchedulingOptions { + int32 priority = 1; + map resources = 2; + map extension = 3; + map affinity = 4; +} + +message CreateRequest { + string function = 1; + repeated common.Arg args = 2; + SchedulingOptions schedulingOps = 3; + string requestID = 4; + string traceID = 5; + repeated string labels = 6; + // optional. if designated instanceID is not empty, the created instance id will be assigned designatedInstanceID + string designatedInstanceID = 7; + map createOptions = 8; +} + +message CreateResponse { + common.ErrorCode code = 1; + string message = 2; + string instanceID = 3; +} + +message InvokeRequest { + string function = 1; + repeated common.Arg args = 2; + string instanceID = 3; + string requestID = 4; + string traceID = 5; + repeated string returnObjectIDs = 6; +} + +message InvokeResponse { + common.ErrorCode code = 1; + string message = 2; + string returnObjectID = 3; +} + +message CallResult { + common.ErrorCode code = 1; + string message = 2; + string instanceID = 3; + string requestID = 4; + HttpTriggerResponse triggerResponse = 5; +} + +message CallResultAck { + common.ErrorCode code = 1; + string message = 2; +} + +message TerminateRequest { + string instanceID = 1; +} + +message TerminateResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ExitRequest { + common.ErrorCode code = 1; + string message = 2; +} + +message ExitResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message StateSaveRequest { + bytes state = 1; +} + +message StateSaveResponse { + common.ErrorCode code = 1; + string message = 2; + string checkpointID = 3; +} + +message StateLoadRequest { + string checkpointID = 1; +} + +message StateLoadResponse { + common.ErrorCode code = 1; + string message = 2; + bytes state = 3; +} + +message KillRequest { + string instanceID = 1; + int32 signal = 2; + bytes payload = 3; +} + +message KillResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message HttpTriggerResponse { + int32 statusCode = 1; + string contentType = 2; + bytes body = 3; + map headers = 4; +} \ No newline at end of file diff --git a/yuanrong/pkg/common/protobuf/rpc/inner_service.proto b/yuanrong/pkg/common/protobuf/rpc/inner_service.proto new file mode 100644 index 0000000..cdadadb --- /dev/null +++ b/yuanrong/pkg/common/protobuf/rpc/inner_service.proto @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package inner_service; + +import "common.proto"; +import "core_service.proto"; +import "bus_service.proto"; +import "runtime_service.proto"; + +option go_package = "grpc/pb/inner;inner"; + +// Inner service provides APIs for bus to bus interaction +service InnerService { + // forward a recovery request to rebuild instance owned by other bus-proxy + rpc ForwardRecover (ForwardRecoverRequest) returns (ForwardRecoverResponse) {} + // notify the result of forward by other proxy request + rpc NotifyResult (NotifyRequest) returns (NotifyResponse) {} + // forward a killing request to signal instance owned by other bus-proxy + rpc ForwardKill (ForwardKillRequest) returns (ForwardKillResponse) {} + // forward a calling result request to other bus-proxy + rpc ForwardCallResult (ForwardCallResultRequest) returns (ForwardCallResultResponse) {} + // forward a invoke request to other bus-proxy + rpc ForwardInvoke (ForwardInvokeRequest) returns (ForwardInvokeResponse) {} + // forward a queryInstance request to other bus-proxy + rpc QueryInstance (bus_service.QueryInstanceRequest) returns (bus_service.QueryInstanceResponse) {} + // to check bus-proxy liveliness + rpc LiveProbe (Probe) returns (Probe) {} + // forward a initCall request to other bus-proxy + rpc ForwardInitCall (ForwardInitCallRequest) returns (ForwardInitCallResponse) {} +} + +message Probe {} + +message NotifyRequest { + string requestID = 1; + common.ErrorCode code = 2; + string message = 3; +} + +message NotifyResponse {} + +message ForwardRecoverRequest { + string instanceID = 1; + string runtimeIP = 2; + string runtimePort = 3; + string runtimeID = 4; + string function = 5; +} + +message ForwardRecoverResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ForwardKillRequest { + core_service.KillRequest req = 1; + string instanceID = 2; + string runtimeIP = 3; + string runtimePort = 4; + string runtimeID = 5; +} + +message ForwardKillResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ForwardCallResultRequest { + core_service.CallResult req = 1; + string instanceID = 2; + string runtimeID = 3; +} + +message ForwardCallResultResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ForwardInvokeRequest { + core_service.InvokeRequest req = 1; + string srcInstanceID = 2; + string srcIP = 3; + string srcNode = 4; +} + +message ForwardInvokeResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ForwardInitCallRequest { + runtime_service.CallRequest req = 1; + string runtimeID = 2; + string instanceID = 3; + string runtimeIP = 4; + string runtimePort = 5; + string tenantID = 6; +} + +message ForwardInitCallResponse { + common.ErrorCode code = 1; + string message = 2; +} \ No newline at end of file diff --git a/yuanrong/pkg/common/protobuf/rpc/runtime_rpc.proto b/yuanrong/pkg/common/protobuf/rpc/runtime_rpc.proto new file mode 100644 index 0000000..9abba35 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/rpc/runtime_rpc.proto @@ -0,0 +1,112 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package runtime_rpc; + +import "core_service.proto"; +import "runtime_service.proto"; + +option go_package = "grpc/pb;api"; + +// RuntimeRPC provide bidirectional streaming RPC interface +service RuntimeRPC { + // build bidirection grpc communication channel, different message body type specify different api handler + rpc MessageStream (stream StreamingMessage) returns (stream StreamingMessage) {} +} + +message StreamingMessage { + string messageID = 1; + oneof body { + + // Create an instance for specify function + // handle by core + core_service.CreateRequest createReq = 2; + core_service.CreateResponse createRsp = 3; + + // invoke the created instance + // handle by core + core_service.InvokeRequest invokeReq = 4; + core_service.InvokeResponse invokeRsp = 5; + + // exit the created instance + // only support to be called by instance itself + // handle by core + core_service.ExitRequest exitReq = 6; + core_service.ExitResponse exitRsp = 7; + + // save state of the created instance + // handle by core + core_service.StateSaveRequest saveReq = 8; + core_service.StateSaveResponse saveRsp = 9; + + // load state of the created instance + // handle by core + core_service.StateLoadRequest loadReq = 10; + core_service.StateLoadResponse loadRsp = 11; + + // send the signal to instance or core + // 1 ~ 63: core defined signal + // 64 ~ 1024: custom runtime defined signal + // handle by core + core_service.KillRequest killReq = 12; + core_service.KillResponse killRsp = 13; + + // send call request result to sender + // handle by core + core_service.CallResult callResultReq = 14; + core_service.CallResultAck callResultAck = 15; + + // Call a method or init state of instance + // handle by runtime + runtime_service.CallRequest callReq = 16; + runtime_service.CallResponse callRsp = 17; + + // NotifyResult is applied to async notify result of create or invoke request invoked by runtime + // handle by runtime + runtime_service.NotifyRequest notifyReq = 18; + runtime_service.NotifyResponse notifyRsp = 19; + + // Checkpoint request a state to save for failure recovery and state migration + // handle by runtime + runtime_service.CheckpointRequest checkpointReq = 20; + runtime_service.CheckpointResponse checkpointRsp = 21; + + // Recover state + // handle by runtime + runtime_service.RecoverRequest recoverReq = 22; + runtime_service.RecoverResponse recoverRsp = 23; + + // request an instance to shutdown + // handle by runtime + runtime_service.ShutdownRequest shutdownReq = 24; + runtime_service.ShutdownResponse shutdownRsp = 25; + + // receive the signal send by other runtime or driver + // handle by runtime + runtime_service.SignalRequest signalReq = 26; + runtime_service.SignalResponse signalRsp = 27; + + // check whether the runtime is alive + // handle by runtime + runtime_service.HeartbeatRequest heartbeatReq = 28; + runtime_service.HeartbeatResponse heartbeatRsp = 29; + } + + // message is sent from functiontask(0), or frontend, this is only used in runtime call request for now + int32 messageFrom = 30; +} \ No newline at end of file diff --git a/yuanrong/pkg/common/protobuf/rpc/runtime_service.proto b/yuanrong/pkg/common/protobuf/rpc/runtime_service.proto new file mode 100644 index 0000000..142316d --- /dev/null +++ b/yuanrong/pkg/common/protobuf/rpc/runtime_service.proto @@ -0,0 +1,134 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package runtime_service; + +import "common.proto"; + +option go_package = "grpc/pb/runtime;runtime"; + +// Runtime service provides APIs to core, +service RuntimeService { + // Call a method or init state of instance + rpc Call (CallRequest) returns (CallResponse) {} + // NotifyResult is applied to async notify result of create or invoke request invoked by runtime + rpc NotifyResult (NotifyRequest) returns (NotifyResponse) {} + // Checkpoint request a state to save for failure recovery and state migration + rpc Checkpoint (CheckpointRequest) returns (CheckpointResponse) {} + // Recover state + rpc Recover (RecoverRequest) returns (RecoverResponse) {} + // GracefulExit request an instance graceful exit + rpc GracefulExit (GracefulExitRequest) returns (GracefulExitResponse) {} + // Shutdown request an instance shutdown + rpc Shutdown (ShutdownRequest) returns (ShutdownResponse) {} + // check whether the runtime is alive + rpc Heartbeat (HeartbeatRequest) returns (HeartbeatResponse) {} + // Signal the signal to instance + rpc Signal (SignalRequest) returns (SignalResponse) {} +} + +message CallRequest { + string function = 1; + repeated common.Arg args = 2; + string traceID = 3; + string returnObjectID = 4; + // isCreate specify the request whether initialization or runtime invoke + bool isCreate = 5; + // senderID specify the caller identity + // while process done, it should be send back to core by CallResult.instanceID + string senderID = 6; + // while process done, it should be send back to core by CallResult.requestID + string requestID = 7; + repeated string returnObjectIDs = 8; + map createOptions = 9; + HttpTriggerEvent triggerRequest = 10; +} + +message CallResponse { + common.ErrorCode code = 1; + string message = 2; + +} + +message CheckpointRequest { + string checkpointID = 1; +} + +message CheckpointResponse { + common.ErrorCode code = 1; + string message = 2; + bytes state = 3; +} + +message RecoverRequest { + bytes state = 1; + map createOptions = 2; +} + +message RecoverResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message GracefulExitRequest { + uint64 gracePeriodSecond = 1; +} + +message GracefulExitResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ShutdownRequest { + uint64 gracePeriodSecond = 1; +} + +message ShutdownResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message NotifyRequest { + string requestID = 1; + common.ErrorCode code = 2; + string message = 3; +} + +message NotifyResponse {} + +message HeartbeatRequest {} + +message HeartbeatResponse {} + +message SignalRequest { + int32 signal = 1; + bytes payload = 2; +} + +message SignalResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message HttpTriggerEvent { + string path = 1; + string queryStringParameters = 2; + string httpMethod = 3; + bytes body = 4; + map headers = 5; +} \ No newline at end of file diff --git a/yuanrong/pkg/common/protobuf/savestate.proto b/yuanrong/pkg/common/protobuf/savestate.proto new file mode 100644 index 0000000..8a0b59d --- /dev/null +++ b/yuanrong/pkg/common/protobuf/savestate.proto @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package savestate; + +option go_package = "api/savestate"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "SaveState"; + +message RequestPb { + string stateID = 1; + string state = 2; +} + +message ResponsePb { + int32 code = 1; + string message = 2; +} diff --git a/yuanrong/pkg/common/protobuf/scheduler/domainscheduler_service.proto b/yuanrong/pkg/common/protobuf/scheduler/domainscheduler_service.proto new file mode 100644 index 0000000..8bfa6dd --- /dev/null +++ b/yuanrong/pkg/common/protobuf/scheduler/domainscheduler_service.proto @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package scheduler; + +import "scheduler_common.proto"; + +option go_package = "grpc/pb/scheduler;scheduler"; + +service DomainSchedulerService { + // domain scheduler forwards scheduler request to local scheduler + rpc ForwardSchedule (scheduler.ForwardScheduleRequest) returns (scheduler.ForwardScheduleResponse); + // PullMetrics + rpc PullMetrics (scheduler.PullMetricsRequest) returns (scheduler.PullMetricsResponse); + // QuerySchedulersMetrics + rpc QuerySchedulersMetrics (QuerySchedulersMetricsRequest) returns (QuerySchedulersMetricsResponse); + // NotifySchedulerLineage + rpc NotifySchedulerLineage (NotifySchedulerLineageRequest) returns (NotifySchedulerLineageResponse); + // GetNodesResource + rpc GetNodesResource (GetNodesResourceRequest) returns (GetNodesResourceResponse); + // ReBalance + rpc ReBalance (ReBalanceRequest) returns (ReBalanceResponse); +} + +message QuerySchedulersMetricsRequest { +} + +message QuerySchedulersMetricsResponse { + scheduler.SchedulerMetrics self = 1; + map subdomainSchedulers = 2; +} + + +message GetNodesResourceRequest { +} + +message GetNodesResourceResponse { +} + +message NotifySchedulerLineageRequest { + enum type { + ADD = 0; + DELETE = 1; + UPDATE = 2; + UPDATE_MEMBER = 3; + } + type operationType = 1; + string Address = 2; + string Name = 3; +} + +message NotifySchedulerLineageResponse { +} + +message ReBalanceRequest { + string dstNode = 1; + map resource = 2; + string requestID = 3; +} + +message ReBalanceResponse { +} + diff --git a/yuanrong/pkg/common/protobuf/scheduler/globalscheduler_service.proto b/yuanrong/pkg/common/protobuf/scheduler/globalscheduler_service.proto new file mode 100644 index 0000000..bc49265 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/scheduler/globalscheduler_service.proto @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package scheduler; + +option go_package = "grpc/pb/scheduler;scheduler"; + +service GlobalSchedulerService { + // local scheduler register to global scheduler. + rpc RegisterLocalScheduler (RegisterLocalSchedulerRequest) returns (RegisterLocalSchedulerResponse); + // domain scheduler register to global scheduler. + rpc RegisterDomainScheduler (RegisterDomainSchedulerRequest) returns (RegisterDomainSchedulerResponse); +} + +message RegisterLocalSchedulerRequest { + string hostname = 1; + string address = 2; +} + +message RegisterLocalSchedulerResponse { + string leaderAddress = 2; +} + +message RegisterDomainSchedulerRequest { + string name = 1; +} + +message RegisterDomainSchedulerResponse { + message member { + string name = 1; + string address = 2; + } + string leaderAddress = 1; + repeated member members = 2; +} \ No newline at end of file diff --git a/yuanrong/pkg/common/protobuf/scheduler/localscheduler_service.proto b/yuanrong/pkg/common/protobuf/scheduler/localscheduler_service.proto new file mode 100644 index 0000000..0e1ee05 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/scheduler/localscheduler_service.proto @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package scheduler; + +import "scheduler_common.proto"; + +option go_package = "grpc/pb/scheduler;scheduler"; + +service LocalSchedulerService { + // domain scheduler forwards scheduler request to local scheduler + rpc ForwardSchedule (scheduler.ForwardScheduleRequest) returns (scheduler.ForwardScheduleResponse); + // PullMetrics + rpc PullMetrics (scheduler.PullMetricsRequest) returns (scheduler.PullMetricsResponse); + // NotifyLocalScheduler notify leader address to local scheduler + rpc NotifyLocalScheduler (NotifyLocalSchedulerRequest) returns (NotifyLocalSchedulerResponse); + // RegisterWorker register worker to local scheduler + rpc RegisterWorker (scheduler.RegisterWorkerRequest) returns (scheduler.RegisterWorkerResponse); + // UnregisterWorker unregister worker from local scheduler + rpc UnregisterWorker (scheduler.UnregisterWorkerRequest) returns (scheduler.UnregisterWorkerResponse); +} + +message NotifyLocalSchedulerRequest { + string Address = 1; +} + +message NotifyLocalSchedulerResponse { +} diff --git a/yuanrong/pkg/common/protobuf/scheduler/scheduler_common.proto b/yuanrong/pkg/common/protobuf/scheduler/scheduler_common.proto new file mode 100644 index 0000000..1e5ddce --- /dev/null +++ b/yuanrong/pkg/common/protobuf/scheduler/scheduler_common.proto @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package scheduler; + +import "common.proto"; +import "core_service.proto"; +import "worker_agent_service.proto"; + +option go_package = "grpc/pb/scheduler;scheduler"; + +message ForwardScheduleRequest { + core_service.CreateRequest CreateReq = 1; + int32 SourceLevel = 2; + string InstanceID = 3; + SchedulerMetrics ConfirmedMetrics = 4; + string SourceName = 5; + bool WakeUp = 6; + int32 OriginalSourceLevel = 7; +} + +message InstanceInfo { + string RuntimeID = 1; + string RuntimeIP = 2; + string RuntimePort = 3; + string DeployNode = 4; +} + +message ForwardScheduleResponse { + common.ErrorCode code = 1; + string message = 2; + string instanceID = 3; + string deployID = 4; + SchedulerMetrics confirmedMetrics = 5; + DeployInstanceResponse response = 6; +} + +message ResourceCapacity { + double Total = 1; + double RequestAvailable = 2; + double ActualUsed = 3; + double PreAllocated = 4; +} + +message PullMetricsRequest { +} + +message SchedulerMetrics { + map maxResources = 1; + map resources = 2; + map labelMap = 3; + int32 managedNodeNumber = 4; + map instanceMap = 5; +} + +message PullMetricsResponse { + SchedulerMetrics metrics = 1; +} + +message RegisterWorkerRequest { + string IP = 1; + string Port = 2; + string NodeIP = 3; + string P2pPort = 4; + string NodeName = 5; + string NodeID = 6; + string WorkerAgentID = 7; + int64 AllocatableCPU = 8; + int64 AllocatableMemory = 9; + map AllocatableCustomResource = 10; +} + +message RegisterWorkerResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message UnregisterWorkerRequest { + string WorkerAgentID = 1; +} + +message UnregisterWorkerResponse { +} + +message ScheduleCacheInstanceReq { + int32 SourceLevel = 1; + string FunctionUrn = 2; + int32 CacheNum = 3; +} + +message ScheduleCacheInstanceResponse { + common.ErrorCode code = 1; + string message = 2; +} diff --git a/yuanrong/pkg/common/protobuf/scheduler/worker_agent_service.proto b/yuanrong/pkg/common/protobuf/scheduler/worker_agent_service.proto new file mode 100644 index 0000000..68bba89 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/scheduler/worker_agent_service.proto @@ -0,0 +1,178 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package scheduler; + +import "core_service.proto"; + +option go_package = "grpc/pb/scheduler;scheduler"; + +service WorkerAgentService { + rpc DeployInstance (DeployInstanceRequest) returns (DeployInstanceResponse); + rpc KillInstance (KillInstanceRequest) returns (KillInstanceResponse); + // GetMetrics local scheduler get metrics from worker + rpc GetMetrics (GetWorkerMetricsRequest) returns (GetWorkerMetricsResponse); + // while bus detected the heartbeat of runtime disconnected, QueryRuntimeException will be invoked + rpc QueryRuntimeException(QueryExceptionRequest) returns (Exception); + // when a user uses a cache instance, the system invokes the + rpc DeployCacheInstance (DeployCacheInstanceRequest) returns (DeployCacheInstanceResponse); +} + +message DeployInstanceRequest { + string LogicInstanceID = 1; + string BusAddress = 2; + string Urn = 3; + string Handler = 4; + string EntryFile = 5; + int64 Timeout = 6; + string EnvKey = 7; + string Environment = 8; + string EncryptedUserData = 9; + string Language = 10; + string CodeSha256 = 11; + string DomainID = 12; + ResourceMetaData ResourceMetadata = 13; + Initializer Initializer = 14; + LogTankService LogTankService = 15; + TraceService TraceService = 16; + FuncMountConfig FuncMountConfig = 17; + FuncDeploySpec FuncDeploySpec = 18; + repeated string labels = 19; + map affinity = 20; + map HookHandler = 21; + string TraceID = 22; +} + +message ResourceMetaData { + int64 CPU = 1; + int64 Memory = 2; + map CustomResource = 3; +} + +message Initializer { + string Handler = 1; + int64 Timeout = 2; +} + +message LogTankService { + string GroupID = 1; + string StreamID = 2; +} + +message TraceService { + string TraceAK = 1; + string TraceSK = 2; + string ProjectName = 3; +} + +message FuncMountConfig { + int32 UserID = 1; + int32 GroupID = 2; +} + +message FuncDeploySpec { + string BucketID = 1; + string ObjectID = 2; + string Layers = 3; + string DeployDir = 4; + string StorageType = 5; +} + +message DeployInstanceResponse { + string TimeInfo = 1; + string RuntimeID = 2; + string RuntimeIP = 3; + string GrpcPort = 4; + string HostIP = 5; + string NodeID = 6; +} + +message KillInstanceRequest { + string TenantID = 1; + string BusinessID = 2; + string FuncName = 3; + string FuncVersion = 4; + string RuntimeID = 5; + ResourceMetaData ResourceMetaData = 6; + repeated string Labels = 7; + string JobID = 8; +} + +message KillInstanceResponse { + bool Success = 1; +} + +message GetWorkerMetricsRequest { +} + +message MetricsData { + double Total = 1; + double Inuse = 2; +} + +message ResourceMetrics { + map Resources = 1; +} + +message FunctionMetrics { + map InstanceResources = 1; +} + +message GetWorkerMetricsResponse { + ResourceMetrics SystemResources = 1; + // Function instances resource on this node. + map FunctionResources = 2; +} + +message QueryExceptionRequest { + string RuntimeID = 1; +} + +message Exception { + string body = 1; + int64 type = 2; +} + +message DeployCacheInstanceRequest { + string BusAddress = 1; + string Urn = 2; + string Handler = 3; + string EntryFile = 4; + int64 Timeout = 5; + string EnvKey = 6; + string Environment = 7; + string EncryptedUserData = 8; + string Language = 9; + string CodeSha256 = 10; + string DomainID = 11; + ResourceMetaData ResourceMetadata = 12; + Initializer Initializer = 13; + LogTankService LogTankService = 14; + TraceService TraceService = 15; + FuncMountConfig FuncMountConfig = 16; + FuncDeploySpec FuncDeploySpec = 17; + repeated string labels = 18; + map affinity = 19; + map HookHandler = 20; + string TraceID = 21; + int32 CacheNum = 22; +} + +message DeployCacheInstanceResponse { + bool Success = 1; +} \ No newline at end of file diff --git a/yuanrong/pkg/common/protobuf/settimeout.proto b/yuanrong/pkg/common/protobuf/settimeout.proto new file mode 100644 index 0000000..2fec51e --- /dev/null +++ b/yuanrong/pkg/common/protobuf/settimeout.proto @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package settimeout; + +option go_package = "api/settimeout"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "SetTimeout"; + +message RequestPb { + string stateID = 1; + string futureID = 2; + int32 timeout = 3; + int32 countOfFuncName = 4; +} + +message ResponsePb { + int32 code = 1; + string message = 2; + string futureID = 3; +} diff --git a/yuanrong/pkg/common/protobuf/specialize.proto b/yuanrong/pkg/common/protobuf/specialize.proto new file mode 100644 index 0000000..80a1a1b --- /dev/null +++ b/yuanrong/pkg/common/protobuf/specialize.proto @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package specialize; + +option java_multiple_files = true; +option java_package = "com.sn.runtime.rpc.generate"; +option go_package = "./;pb"; + +service Specialize { + rpc InitRuntime (InitRequest) returns (InitResponse) {} +} + +message InitRequest { + string logicInstanceID = 1; + string customHandler = 2; + map envs = 3; + repeated string storageInfo = 4; + string busAddress = 5; + map hookHandler = 6; +} + +message InitResponse { + int32 code = 1; + string message = 2; +} \ No newline at end of file diff --git a/yuanrong/pkg/common/protobuf/terminate.proto b/yuanrong/pkg/common/protobuf/terminate.proto new file mode 100644 index 0000000..7ac4d92 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/terminate.proto @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package terminate; + +option go_package = "api/terminate"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "Terminate"; + +message RequestPb { + string stateID = 1; + string futureID = 2; + int32 countOfFuncName = 3; +} + +message ResponsePb { + int32 code = 1; + string message = 2; + string futureID = 3; +} diff --git a/yuanrong/pkg/common/protobuf/wait.proto b/yuanrong/pkg/common/protobuf/wait.proto new file mode 100644 index 0000000..4376e36 --- /dev/null +++ b/yuanrong/pkg/common/protobuf/wait.proto @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package waitfuture; + +option go_package = "api/waitfuture"; + +option java_package = "com.sn.msgstruct"; +option java_outer_classname = "Wait"; + +message RequestPb { + int32 waitNum = 1; + int32 timeout = 2; + string stateID = 3; + string state = 4; + repeated string futureIDs = 5; +} + +message ResponsePb { + int32 code = 1; + string message = 2; + string state = 3; + bool updateState = 4; + repeated string futureIDs = 5; +} diff --git a/yuanrong/pkg/common/reader/reader.go b/yuanrong/pkg/common/reader/reader.go new file mode 100644 index 0000000..c4c74ca --- /dev/null +++ b/yuanrong/pkg/common/reader/reader.go @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package reader provides ReadFile with timeConsumption +package reader + +import ( + "fmt" + "io/ioutil" + "os" + "time" +) + +// MaxReadFileTime elapsed time allowed to read config file from disk +const MaxReadFileTime = 10 + +// ReadFileWithTimeout is to ReadFile and count timeConsumption at same time +func ReadFileWithTimeout(configFile string) ([]byte, error) { + stopCh := make(chan struct{}) + go printTimeOut(stopCh) + data, err := ioutil.ReadFile(configFile) + close(stopCh) + return data, err +} + +// ReadFileInfoWithTimeout is to Read FileInfo and count timeConsumption at same time +func ReadFileInfoWithTimeout(filePath string) (os.FileInfo, error) { + stopCh := make(chan struct{}) + go printTimeOut(stopCh) + fileInfo, err := os.Stat(filePath) + close(stopCh) + return fileInfo, err +} + +// printTimeOut print error info every 10s after timeout +func printTimeOut(stopCh <-chan struct{}) { + if stopCh == nil { + os.Exit(0) + return + } + timer := time.NewTicker(time.Second * MaxReadFileTime) + count := 0 + for { + <-timer.C + select { + case _, ok := <-stopCh: + if !ok { + timer.Stop() + return + } + default: + count += MaxReadFileTime + fmt.Printf("ReadFile Timeout: elapsed time %ds\n", count) + } + } +} diff --git a/yuanrong/pkg/common/reader/reader_test.go b/yuanrong/pkg/common/reader/reader_test.go new file mode 100644 index 0000000..e08a57c --- /dev/null +++ b/yuanrong/pkg/common/reader/reader_test.go @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package reader provides ReadFile with timeConsumption +package reader + +import ( + "io/ioutil" + "os" + "testing" + "time" + + "github.com/agiledragon/gomonkey" + "github.com/stretchr/testify/assert" +) + +func TestReadFileWithTimeout(t *testing.T) { + patch := gomonkey.ApplyFunc(ioutil.ReadFile, func(string) ([]byte, error) { + return nil, nil + }) + data, _ := ReadFileWithTimeout("/sn/home") + assert.Nil(t, data) + patch.Reset() +} + +func TestReadFileInfoWithTimeout(t *testing.T) { + patch := gomonkey.ApplyFunc(os.Stat, func(string) (os.FileInfo, error) { + return nil, nil + }) + fileInfo, _ := ReadFileInfoWithTimeout("/sn/home") + assert.Nil(t, fileInfo) + patch.Reset() +} + +func TestPrintTimeout(t *testing.T) { + stopCh := make(chan struct{}) + go printTimeOut(stopCh) + time.Sleep(time.Second * 15) + close(stopCh) +} + +func TestPrintTimeoutErr(t *testing.T) { + test := 0 + patch := gomonkey.ApplyFunc(os.Exit, func(code int) { + test++ + }) + printTimeOut(nil) + assert.EqualValues(t, test, 1) + patch.Reset() +} diff --git a/yuanrong/pkg/common/tls/https.go b/yuanrong/pkg/common/tls/https.go new file mode 100644 index 0000000..a32c512 --- /dev/null +++ b/yuanrong/pkg/common/tls/https.go @@ -0,0 +1,369 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + + "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/reader" +) + +// HTTPSConfig is for needed HTTPS config +type HTTPSConfig struct { + CipherSuite []uint16 + MinVers uint16 + MaxVers uint16 + CACertFile string + CertFile string + SecretKeyFile string + PwdFilePath string + KeyPassPhase string +} + +// InternalHTTPSConfig is for input config +type InternalHTTPSConfig struct { + HTTPSEnable bool `json:"httpsEnable" yaml:"httpsEnable" valid:"optional"` + TLSProtocol string `json:"tlsProtocol" yaml:"tlsProtocol" valid:"optional"` + TLSCiphers string `json:"tlsCiphers" yaml:"tlsCiphers" valid:"optional"` +} + +var ( + // HTTPSConfigs is a global variable of HTTPS config + HTTPSConfigs = &HTTPSConfig{} + // tlsConfig is a global variable of TLS config + tlsConfig *tls.Config + once sync.Once + + // tlsVersionMap is a set of TLS versions + tlsVersionMap = map[string]uint16{ + "TLSv1.2": tls.VersionTLS12, + } +) + +// tlsCipherSuiteMap is a set of supported TLS algorithms +var tlsCipherSuiteMap = map[string]uint16{ + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, +} + +// GetURLScheme returns "http" or "https" +func GetURLScheme(https bool) string { + if https { + return "https" + } + return "http" +} + +// HTTPListenAndServeTLS listens and serves by TLS in HTTP +func HTTPListenAndServeTLS(addr string, server *http.Server) error { + listener, err := net.Listen("tcp4", addr) + if err != nil { + return err + } + + tlsListener := tls.NewListener(listener, tlsConfig) + + if err = server.Serve(tlsListener); err != nil { + return err + } + return nil +} + +// GetClientTLSConfig returns the config of TLS +func GetClientTLSConfig() *tls.Config { + return tlsConfig +} + +// GetHTTPTransport get http transport +func GetHTTPTransport() *http.Transport { + tr, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return nil + } + tr.TLSClientConfig = GetClientTLSConfig() + + return tr +} + +func loadCerts(path string) string { + env := os.Getenv("SSL_ROOT") + if len(env) == 0 { + log.GetLogger().Errorf("failed to get SSL_ROOT") + return "" + } + certPath, err := filepath.Abs(filepath.Join(env, path)) + if err != nil { + log.GetLogger().Errorf("failed to return an absolute representation of path: %s", path) + return "" + } + ok := utils.FileExists(certPath) + if !ok { + log.GetLogger().Errorf("failed to load the cert file: %s", certPath) + return "" + } + return certPath +} + +func loadTLSConfig() (err error) { + clientAuthMode := tls.NoClientCert + var pool *x509.CertPool + + pool, err = GetX509CACertPool(HTTPSConfigs.CACertFile) + if err != nil { + log.GetLogger().Errorf("failed to GetX509CACertPool: %s", err.Error()) + return err + } + + var certs []tls.Certificate + certs, err = loadServerTLSCertificate() + if err != nil { + log.GetLogger().Errorf("failed to loadServerTLSCertificate: %s", err.Error()) + return err + } + + tlsConfig = &tls.Config{ + ClientCAs: pool, + Certificates: certs, + CipherSuites: HTTPSConfigs.CipherSuite, + PreferServerCipherSuites: true, + ClientAuth: clientAuthMode, + InsecureSkipVerify: true, + MinVersion: HTTPSConfigs.MinVers, + MaxVersion: HTTPSConfigs.MaxVers, + Renegotiation: tls.RenegotiateNever, + } + + return nil +} + +// loadHTTPSConfig loads the protocol and ciphers of TLS +func loadHTTPSConfig(tlsProtocols string, tlsCiphers []byte) error { + HTTPSConfigs = &HTTPSConfig{ + MinVers: tls.VersionTLS12, + MaxVers: tls.VersionTLS12, + CipherSuite: nil, + CACertFile: loadCerts("trust.cer"), + CertFile: loadCerts("server.cer"), + SecretKeyFile: loadCerts("server_key.pem"), + PwdFilePath: loadCerts("cert_pwd"), + KeyPassPhase: "", + } + + minVersion := parseSSLProtocol(tlsProtocols) + if HTTPSConfigs.MinVers == 0 { + return errors.New("invalid TLS protocol") + } + HTTPSConfigs.MinVers = minVersion + cipherSuites := parseSSLCipherSuites(tlsCiphers) + if len(cipherSuites) == 0 { + return errors.New("invalid TLS ciphers") + } + HTTPSConfigs.CipherSuite = cipherSuites + + keyPassPhase, err := reader.ReadFileWithTimeout(HTTPSConfigs.PwdFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read file cert_pwd: %s", err.Error()) + return err + } + HTTPSConfigs.KeyPassPhase = string(keyPassPhase) + + return nil +} + +// InitTLSConfig inits config of HTTPS +func InitTLSConfig(tlsProtocols string, tlsCiphers []byte) (err error) { + once.Do(func() { + err = loadHTTPSConfig(tlsProtocols, tlsCiphers) + if err != nil { + err = errors.New("failed to load HTTPS config") + return + } + err = loadTLSConfig() + if err != nil { + return + } + }) + return err +} + +// GetX509CACertPool get ca cert pool +func GetX509CACertPool(caCertFilePath string) (caCertPool *x509.CertPool, err error) { + pool := x509.NewCertPool() + caCertContent, err := LoadCACertBytes(caCertFilePath) + if err != nil { + return nil, err + } + + pool.AppendCertsFromPEM(caCertContent) + return pool, nil +} + +func loadServerTLSCertificate() (tlsCert []tls.Certificate, err error) { + certContent, keyContent, err := LoadCertAndKeyBytes(HTTPSConfigs.CertFile, HTTPSConfigs.SecretKeyFile, + HTTPSConfigs.KeyPassPhase) + if err != nil { + return nil, err + } + + cert, err := tls.X509KeyPair(certContent, keyContent) + if err != nil { + log.GetLogger().Errorf("failed to load the X509 key pair from cert file %s with key file %s: %s", + HTTPSConfigs.CertFile, HTTPSConfigs.SecretKeyFile, err.Error()) + return nil, err + } + + var certs []tls.Certificate + certs = append(certs, cert) + + return certs, nil +} + +// LoadServerTLSCertificate generates tls certificate by certfile and keyfile +func LoadServerTLSCertificate(cerfile, keyfile string) (tlsCert []tls.Certificate, err error) { + certContent, keyContent, err := LoadCertAndKeyBytes(cerfile, keyfile, "") + + if err != nil { + return nil, err + } + cert, err := tls.X509KeyPair(certContent, keyContent) + if err != nil { + log.GetLogger().Errorf("failed to load the X509 key pair from cert file %s with key file %s: %s", + cerfile, keyfile, err.Error()) + return nil, err + } + var certs []tls.Certificate + certs = append(certs, cert) + return certs, nil +} + +// LoadCertAndKeyBytes load cert and key bytes +func LoadCertAndKeyBytes(certFilePath, keyFilePath, passPhase string) (certPEMBlock, keyPEMBlock []byte, err error) { + certContent, err := reader.ReadFileWithTimeout(certFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read cert file %s", err.Error()) + return nil, nil, err + } + + keyContent, err := reader.ReadFileWithTimeout(keyFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read key file %s", err.Error()) + return nil, nil, err + } + + keyContent, err = crypto.DecryptByte(keyContent, crypto.GetRootKey()) + if err != nil { + log.GetLogger().Errorf("failed to decrypt key content, err: %s", err.Error()) + return nil, nil, err + } + keyBlock, _ := pem.Decode(keyContent) + if keyBlock == nil { + log.GetLogger().Errorf("failed to decode key file") + return nil, nil, errors.New("failed to decode key file") + } + + if crypto.IsEncryptedPEMBlock(keyBlock) { + var plainPassPhase []byte + if len(passPhase) > 0 { + plainPassPhase, err = localauth.Decrypt(passPhase) + if err != nil { + log.GetLogger().Errorf("failed to decrypt the ssl passPhase(%d): %s", len(passPhase), + err.Error()) + return nil, nil, err + } + } + + keyData, err := crypto.DecryptPEMBlock(keyBlock, plainPassPhase) + clearByteMemory(plainPassPhase) + if err != nil { + log.GetLogger().Errorf("failed to decrypt key file, error: %s", err.Error()) + return nil, nil, err + } + + // The decryption is successful, then the file is re-encoded to a PEM file + plainKeyBlock := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: keyData, + } + + keyContent = pem.EncodeToMemory(plainKeyBlock) + } + + return certContent, keyContent, nil +} + +func clearByteMemory(src []byte) { + for idx := 0; idx < len(src)&32; idx++ { + src[idx] = 0 + } +} + +// LoadCACertBytes Load CA Cert Content +func LoadCACertBytes(caCertFilePath string) ([]byte, error) { + caCertContent, err := reader.ReadFileWithTimeout(caCertFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read ca cert file %s: %s", caCertFilePath, err.Error()) + return nil, err + } + + return caCertContent, nil +} + +func parseSSLProtocol(rawProtocol string) uint16 { + if protocol, ok := tlsVersionMap[rawProtocol]; ok { + return protocol + } + log.GetLogger().Errorf("invalid SSL version %s, use the default protocol version", rawProtocol) + return 0 +} + +func parseSSLCipherSuites(ciphers []byte) []uint16 { + cipherSuiteNameList := strings.Split(string(ciphers), ",") + if len(cipherSuiteNameList) == 0 { + log.GetLogger().Errorf("no input cipher suite") + return nil + } + cipherSuiteList := make([]uint16, 0, len(cipherSuiteNameList)) + for _, cipherSuiteItem := range cipherSuiteNameList { + cipherSuiteItem = strings.TrimSpace(cipherSuiteItem) + if len(cipherSuiteItem) == 0 { + continue + } + + if cipherSuite, ok := tlsCipherSuiteMap[cipherSuiteItem]; ok { + cipherSuiteList = append(cipherSuiteList, cipherSuite) + } else { + log.GetLogger().Errorf("cipher %s does not exist", cipherSuiteItem) + } + } + + return cipherSuiteList +} diff --git a/yuanrong/pkg/common/tls/https_test.go b/yuanrong/pkg/common/tls/https_test.go new file mode 100644 index 0000000..c11435e --- /dev/null +++ b/yuanrong/pkg/common/tls/https_test.go @@ -0,0 +1,312 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "testing" + + "github.com/agiledragon/gomonkey" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/localauth" +) + +func TestGetURLScheme(t *testing.T) { + if "https" != GetURLScheme(true) { + t.Error("GetURLScheme failed") + } + if "http" != GetURLScheme(false) { + t.Error("GetURLScheme failed") + } +} + +func TestInitTLSConfig1(t *testing.T) { + os.Setenv("SSL_ROOT", "/home/sn/resource/https") + tlsCiphers := []byte("TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, TLS_TEST") + err := InitTLSConfig("TLSv1.2", tlsCiphers) + assert.NotEqual(t, nil, err) + + patch := gomonkey.ApplyFunc(loadHTTPSConfig, func(string, []byte) error { + return nil + }) + InitTLSConfig("TLSv1.2", tlsCiphers) + patch.Reset() +} + +func TestInitTLSConfig2(t *testing.T) { + os.Setenv("SSL_ROOT", "/home/sn/resource/https") + tlsCiphers := []byte("TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, TLS_TEST") + patch := gomonkey.ApplyFunc(loadHTTPSConfig, func(string, []byte) error { + return nil + }) + InitTLSConfig("TLSv1.2", tlsCiphers) + patch.Reset() +} + +func Test_loadServerTLSCertificate(t *testing.T) { + HTTPSConfigs.CertFile = "/home/snuser" + _, err := loadServerTLSCertificate() + assert.NotNil(t, err) + + convey.Convey("test loadServerTLSCertificate", t, func() { + convey.Convey("LoadCertAndKeyBytes success", func() { + patch1 := gomonkey.ApplyFunc(LoadCertAndKeyBytes, func(certFilePath, keyFilePath, passPhase string) (certPEMBlock, + keyPEMBlock []byte, err error) { + return nil, nil, nil + }) + convey.Convey("X509KeyPair success", func() { + patch2 := gomonkey.ApplyFunc(tls.X509KeyPair, func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { + return tls.Certificate{}, nil + }) + _, err := loadServerTLSCertificate() + convey.So(err, convey.ShouldBeNil) + _, err = LoadServerTLSCertificate("", "") + convey.So(err, convey.ShouldBeNil) + defer patch2.Reset() + }) + convey.Convey("X509KeyPair fail", func() { + patch3 := gomonkey.ApplyFunc(tls.X509KeyPair, func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { + return tls.Certificate{}, errors.New("fail to load X509KeyPair") + }) + _, err := loadServerTLSCertificate() + convey.So(err, convey.ShouldNotBeNil) + _, err = LoadServerTLSCertificate("", "") + convey.So(err, convey.ShouldNotBeNil) + defer patch3.Reset() + }) + defer patch1.Reset() + }) + convey.Convey("LoadCertAndKeyBytes fail", func() { + patch4 := gomonkey.ApplyFunc(LoadCertAndKeyBytes, func(certFilePath, keyFilePath, passPhase string) (certPEMBlock, + keyPEMBlock []byte, err error) { + return nil, nil, errors.New("fail to LoadCertAndKeyBytes") + }) + _, err := loadServerTLSCertificate() + convey.So(err, convey.ShouldNotBeNil) + _, err = LoadServerTLSCertificate("", "") + convey.So(err, convey.ShouldNotBeNil) + defer patch4.Reset() + }) + }) +} + +func TestHTTPListenAndServeTLS(t *testing.T) { + server := &http.Server{} + err := HTTPListenAndServeTLS("127.0.0.1", server) + assert.NotNil(t, err) +} + +func Test_loadTLSConfig(t *testing.T) { + HTTPSConfigs = &HTTPSConfig{} + err := loadTLSConfig() + assert.NotNil(t, err) + + convey.Convey("test loadTLSConfig", t, func() { + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(GetX509CACertPool, func(caCertFilePath string) (caCertPool *x509.CertPool, err error) { + return x509.NewCertPool(), nil + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + err = loadTLSConfig() + assert.NotNil(t, err) + }) +} + +func Test_LoadCertAndKeyBytes(t *testing.T) { + convey.Convey("test LoadCertAndKeyBytes", t, func() { + convey.Convey("ReadFile fail", func() { + _, _, err := LoadCertAndKeyBytes("certPath", "keyPath", "pass") + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("ReadFile success", func() { + patches := [...]*gomonkey.Patches{ + gomonkey.ApplyFunc(ioutil.ReadFile, func(filename string) ([]byte, error) { + return nil, nil + }), + gomonkey.ApplyFunc(crypto.GetRootKey, func() []byte { + return nil + }), + } + defer func() { + for idx := range patches { + patches[idx].Reset() + } + }() + + convey.Convey("DecryptByte fail", func() { + patch := gomonkey.ApplyFunc(crypto.DecryptByte, func(cipherText []byte, secret []byte) ([]byte, error) { + return []byte{}, errors.New("DecryptByte fail") + }) + defer patch.Reset() + + _, _, err := LoadCertAndKeyBytes("certPath", "keyPath", "pass") + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("DecryptByte success", func() { + patch := gomonkey.ApplyFunc(crypto.DecryptByte, func(cipherText []byte, secret []byte) ([]byte, error) { + return []byte{}, nil + }) + defer patch.Reset() + + convey.Convey("Decode fail", func() { + patch := gomonkey.ApplyFunc(pem.Decode, func(data []byte) (p *pem.Block, rest []byte) { + return nil, nil + }) + defer patch.Reset() + + _, _, err := LoadCertAndKeyBytes("certPath", "keyPath", "pass") + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("Decode success", func() { + patch := gomonkey.ApplyFunc(pem.Decode, func(data []byte) (p *pem.Block, rest []byte) { + return &pem.Block{}, nil + }) + defer patch.Reset() + + convey.Convey("crypto.IsEncryptedPEMBlock fail", func() { + patch := gomonkey.ApplyFunc(crypto.IsEncryptedPEMBlock, func(b *pem.Block) bool { + return false + }) + defer patch.Reset() + + _, _, err := LoadCertAndKeyBytes("certPath", "keyPath", "pass") + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("crypto.IsEncryptedPEMBlock success", func() { + patch := gomonkey.ApplyFunc(crypto.IsEncryptedPEMBlock, func(b *pem.Block) bool { + return true + }) + defer patch.Reset() + + convey.Convey("localauth.Decrypt fail", func() { + patch := gomonkey.ApplyFunc(localauth.Decrypt, func(src string) ([]byte, error) { + return nil, errors.New("localauth.Decrypt fail") + }) + defer patch.Reset() + + _, _, err := LoadCertAndKeyBytes("certPath", "keyPath", "pass") + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("localauth.Decrypt success", func() { + patch := gomonkey.ApplyFunc(localauth.Decrypt, func(src string) ([]byte, error) { + return []byte{}, nil + }) + defer patch.Reset() + + convey.Convey("crypto.DecryptPEMBlock fail", func() { + patch := gomonkey.ApplyFunc(crypto.DecryptPEMBlock, + func(b *pem.Block, password []byte) ([]byte, error) { + return nil, errors.New("crypto.DecryptPEMBlock fail") + }) + defer patch.Reset() + + _, _, err := LoadCertAndKeyBytes("certPath", "keyPath", "pass") + convey.So(err, convey.ShouldNotBeNil) + }) + }) + }) + }) + }) + }) + }) +} + +func Test_loadCerts(t *testing.T) { + convey.Convey("env length 0", t, func() { + patch := gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return "" + }) + defer patch.Reset() + result := loadCerts("") + convey.So(result, convey.ShouldEqual, "") + }) +} + +func Test_GetClientTLSConfig(t *testing.T) { + convey.Convey("test GetClientTLSConfig", t, func() { + convey.So(GetClientTLSConfig(), convey.ShouldBeNil) + }) +} + +func Test_GetHTTPTransport(t *testing.T) { + convey.Convey("test GetHTTPTransport", t, func() { + convey.So(GetHTTPTransport(), convey.ShouldNotBeNil) + }) +} + +func Test_BuildClientTLSConfOpts(t *testing.T) { + BuildClientTLSConfOpts(MutualTLSConfig{}) +} + +func Test_BuildServerTLSConfOpts(t *testing.T) { + BuildServerTLSConfOpts(MutualTLSConfig{}) +} + +func Test_ClearByteMemory(t *testing.T) { + convey.Convey("test clearByteMemory", t, func() { + s := make([]byte, 33) + s = append(s, 'A') + clearByteMemory() + convey.So(s[0], convey.ShouldEqual, 0) + }) +} + +func Test_parseSSLProtocol(t *testing.T) { + parseSSLProtocol("") +} + +func Test_HTTPListenAndServeTLS(t *testing.T) { + patch := gomonkey.ApplyFunc(net.Listen, func(string, string) (net.Listener, error) { + return nil, nil + }) + defer patch.Reset() + + patch1 := gomonkey.ApplyFunc((*http.Server).Serve, func(*http.Server, net.Listener) error { + return nil + }) + HTTPListenAndServeTLS("", &http.Server{}) + patch1.Reset() + + patch2 := gomonkey.ApplyFunc((*http.Server).Serve, func(*http.Server, net.Listener) error { + return errors.New("test") + }) + HTTPListenAndServeTLS("", &http.Server{}) + patch2.Reset() +} + +func Test_GetX509CACertPool(t *testing.T) { + patch := gomonkey.ApplyFunc(LoadCACertBytes, func(string) ([]byte, error) { + return []byte{'a'}, nil + }) + GetX509CACertPool("") + patch.Reset() +} diff --git a/yuanrong/pkg/common/tls/option.go b/yuanrong/pkg/common/tls/option.go new file mode 100644 index 0000000..329ebf4 --- /dev/null +++ b/yuanrong/pkg/common/tls/option.go @@ -0,0 +1,238 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/reader" +) + +// NewTLSConfig returns tls.Config with given options +func NewTLSConfig(opts ...Option) *tls.Config { + config := &tls.Config{ + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + // for TLS1.2 + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + // for TLS1.3 + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + PreferServerCipherSuites: true, + Renegotiation: tls.RenegotiateNever, + } + for _, opt := range opts { + opt.apply(config) + } + return config +} + +// Option is optional argument for tls.Config +type Option interface { + apply(*tls.Config) +} + +type rootCAOption struct { + cas *x509.CertPool +} + +func (r *rootCAOption) apply(config *tls.Config) { + config.RootCAs = r.cas +} + +// WithRootCAs returns Option that applies root CAs to tls.Config +func WithRootCAs(caFiles ...string) Option { + rootCAs, err := LoadRootCAs(caFiles...) + if err != nil { + log.GetLogger().Warnf("failed to load root ca, err: %s", err.Error()) + rootCAs = nil + } + return &rootCAOption{ + cas: rootCAs, + } +} + +type certsOption struct { + certs []tls.Certificate +} + +func (c *certsOption) apply(config *tls.Config) { + config.Certificates = c.certs +} + +// WithCerts returns Option that applies cert file and key file to tls.Config +func WithCerts(certFile, keyFile string) Option { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.GetLogger().Warnf("load cert.pem and key.pem error: %s", err) + cert = tls.Certificate{} + } + return &certsOption{ + certs: []tls.Certificate{cert}, + } +} + +// WithCertsByEncryptedKey returns Option that applies cert file and encrypted key file to tls.Config +func WithCertsByEncryptedKey(certFile, keyFile, passPhase string) Option { + cert := tls.Certificate{} + certPEM, err := reader.ReadFileWithTimeout(certFile) + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + keyPEMBlock, err := getKeyContent(keyFile) + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + keyBlock, _ := pem.Decode(keyPEMBlock) + if keyBlock == nil { + log.GetLogger().Errorf("failed to decode key file ") + return &certsOption{certs: []tls.Certificate{cert}} + } + + if crypto.IsEncryptedPEMBlock(keyBlock) { + var plainPassPhase []byte + var err error + var decrypted string + if len(passPhase) > 0 { + decrypted, err = crypto.Decrypt(keyPEMBlock, crypto.GetRootKey()) + plainPassPhase = []byte(decrypted) + if err != nil { + log.GetLogger().Errorf("failed to decrypt the ssl passPhase(%d): %s", + len(passPhase), err.Error()) + return &certsOption{certs: []tls.Certificate{cert}} + } + keyData, err := crypto.DecryptPEMBlock(keyBlock, plainPassPhase) + clearByteMemory(plainPassPhase) + utils.ClearStringMemory(decrypted) + + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + // The decryption is successful, then the file is re-encoded to a PEM file + plainKeyBlock := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyData} + keyPEMBlock = pem.EncodeToMemory(plainKeyBlock) + } + } + + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + + cert, err = tls.X509KeyPair(certPEM, keyPEMBlock) + if err != nil { + cert = tls.Certificate{} + } + utils.ClearByteMemory(keyPEMBlock) + return &certsOption{certs: []tls.Certificate{cert}} +} + +func getKeyContent(keyFile string) ([]byte, error) { + var err error + keyPEMBlock, err := reader.ReadFileWithTimeout(keyFile) + + if err != nil { + log.GetLogger().Errorf("fialed to SCCDecrypt, err: %s", err.Error()) + return nil, err + } + return keyPEMBlock, nil +} + +// WithCertsContent returns Option that applies cert content and key content to tls.Config +func WithCertsContent(certContent, keyContent []byte) Option { + cert, err := tls.X509KeyPair(certContent, keyContent) + utils.ClearByteMemory(keyContent) + if err != nil { + log.GetLogger().Warnf("load cert.pem and key.pem error: %s", err) + cert = tls.Certificate{} + } + return &certsOption{ + certs: []tls.Certificate{cert}, + } +} + +type skipVerifyOption struct { +} + +func (s *skipVerifyOption) apply(config *tls.Config) { + config.InsecureSkipVerify = true +} + +// WithSkipVerify returns Option that skips to verify certificates +func WithSkipVerify() Option { + return &skipVerifyOption{} +} + +type clientAuthOption struct { + clientAuthType tls.ClientAuthType +} + +func (a *clientAuthOption) apply(config *tls.Config) { + config.ClientAuth = a.clientAuthType +} + +// WithClientAuthType returns Option with client auth strategy +func WithClientAuthType(t tls.ClientAuthType) Option { + return &clientAuthOption{ + clientAuthType: t, + } +} + +type clientCAOption struct { + clientCAs *x509.CertPool +} + +func (r *clientCAOption) apply(config *tls.Config) { + config.ClientCAs = r.clientCAs +} + +// WithClientCAs returns Option that applies client CAs to tls.Config +func WithClientCAs(caFiles ...string) Option { + clientCAs, err := LoadRootCAs(caFiles...) + if err != nil { + log.GetLogger().Warnf("failed to load client ca, err: %s", err.Error()) + clientCAs = nil + } + return &clientCAOption{ + clientCAs: clientCAs, + } +} + +type serverNameOption struct { + serverName string +} + +func (sn *serverNameOption) apply(config *tls.Config) { + config.ServerName = sn.serverName +} + +// WithServerName returns Option that applies server name to tls.Config +func WithServerName(name string) Option { + return &serverNameOption{ + serverName: name, + } +} diff --git a/yuanrong/pkg/common/tls/option_scc.go b/yuanrong/pkg/common/tls/option_scc.go new file mode 100644 index 0000000..39c7f12 --- /dev/null +++ b/yuanrong/pkg/common/tls/option_scc.go @@ -0,0 +1,93 @@ +//go:build cryptoapi +// +build cryptoapi + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tls + +import ( + "crypto/tls" + "encoding/pem" + + "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/reader" +) + +const ( + rsaPrivateKey = "RSA PRIVATE KEY" +) + +// WithCertsByEncryptedKeyScc returns Option that applies cert file and encrypted key file to tls.Config +func WithCertsByEncryptedKeyScc(certFile, keyFile, passPhaseStr string) Option { + cert := tls.Certificate{} + certPEM, err := reader.ReadFileWithTimeout(certFile) + + if err != nil { + log.GetLogger().Errorf("failed to read file with timeout, err: %s", err.Error()) + return &certsOption{certs: []tls.Certificate{cert}} + } + keyPEMBlock, err := reader.ReadFileWithTimeout(keyFile) + + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + keyBlock, _ := pem.Decode(keyPEMBlock) + if keyBlock == nil { + log.GetLogger().Errorf("failed to decode key file.") + return &certsOption{certs: []tls.Certificate{cert}} + } + + if crypto.IsEncryptedPEMBlock(keyBlock) { + var plainPassPhase []byte + var err error + var decrypted string + if len(passPhaseStr) > 0 { + decrypted, err = crypto.SCCDecrypt([]byte(passPhaseStr)) + plainPassPhase = []byte(decrypted) + if err != nil { + log.GetLogger().Errorf("failed to decrypt the ssl passPhase(%d), err: %s", + len(passPhaseStr), err.Error()) + return &certsOption{certs: []tls.Certificate{cert}} + } + keyData, err := crypto.DecryptPEMBlock(keyBlock, plainPassPhase) + clearByteMemory(plainPassPhase) + utils.ClearStringMemory(decrypted) + + if err != nil { + log.GetLogger().Errorf("failed to decrypt PEM Block, err: %s", err.Error()) + return &certsOption{certs: []tls.Certificate{cert}} + } + // The decryption is successful, then the file is re-encoded to a PEM file + plainKeyBlock := &pem.Block{Type: rsaPrivateKey, Bytes: keyData} + keyPEMBlock = pem.EncodeToMemory(plainKeyBlock) + } + } + + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + + cert, err = tls.X509KeyPair(certPEM, keyPEMBlock) + if err != nil { + log.GetLogger().Errorf("failed to build X509KeyPair, err: %s", err.Error()) + cert = tls.Certificate{} + } + utils.ClearByteMemory(keyPEMBlock) + return &certsOption{certs: []tls.Certificate{cert}} +} diff --git a/yuanrong/pkg/common/tls/option_scc_fake.go b/yuanrong/pkg/common/tls/option_scc_fake.go new file mode 100644 index 0000000..cf6f321 --- /dev/null +++ b/yuanrong/pkg/common/tls/option_scc_fake.go @@ -0,0 +1,86 @@ +//go:build !cryptoapi +// +build !cryptoapi + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tls + +import ( + "crypto/tls" + "encoding/pem" + + "yuanrong/pkg/common/crypto" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/reader" +) + +// WithCertsByEncryptedKeyScc returns Option that applies cert file and encrypted key file to tls.Config +func WithCertsByEncryptedKeyScc(certFile, keyFile, passPhase string) Option { + cert := tls.Certificate{} + fakeCertPEM, err := reader.ReadFileWithTimeout(certFile) + + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + fakeKeyPEMBlock, err := reader.ReadFileWithTimeout(keyFile) + + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + fakeKeyBlock, _ := pem.Decode(fakeKeyPEMBlock) + if fakeKeyBlock == nil { + log.GetLogger().Errorf("failed to decode key file ") + return &certsOption{certs: []tls.Certificate{cert}} + } + + if crypto.IsEncryptedPEMBlock(fakeKeyBlock) { + var plainPassPhase []byte + var err error + var decrypted string + if len(passPhase) > 0 { + decrypted, err = crypto.SCCDecrypt([]byte(passPhase)) + plainPassPhase = []byte(decrypted) + if err != nil { + log.GetLogger().Errorf("failed to decrypt the ssl passPhase(%d): %s", + len(passPhase), err.Error()) + return &certsOption{certs: []tls.Certificate{cert}} + } + keyData, err := crypto.DecryptPEMBlock(fakeKeyBlock, plainPassPhase) + clearByteMemory(plainPassPhase) + utils.ClearStringMemory(decrypted) + + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + // The decryption is successful, then the file is re-encoded to a PEM file + plainKeyBlock := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyData} + fakeKeyPEMBlock = pem.EncodeToMemory(plainKeyBlock) + } + } + + if err != nil { + return &certsOption{certs: []tls.Certificate{cert}} + } + + cert, err = tls.X509KeyPair(fakeCertPEM, fakeKeyPEMBlock) + if err != nil { + cert = tls.Certificate{} + } + utils.ClearByteMemory(fakeKeyPEMBlock) + return &certsOption{certs: []tls.Certificate{cert}} +} diff --git a/yuanrong/pkg/common/tls/option_test.go b/yuanrong/pkg/common/tls/option_test.go new file mode 100644 index 0000000..9eb6803 --- /dev/null +++ b/yuanrong/pkg/common/tls/option_test.go @@ -0,0 +1,157 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package tls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "testing" + + "github.com/agiledragon/gomonkey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type TestSuite struct { + suite.Suite + server http.Server + rootKEY string + rootPEM string + rootSRL string + serverKEY string + serverPEM string + serverCSR string +} + +func (s *TestSuite) SetupSuite() { + certificatePath, err := os.Getwd() + if err != nil { + s.T().Errorf("failed to get current working dictionary: %s", err.Error()) + return + } + + certificatePath += "/../../../test/" + s.rootKEY = certificatePath + "ca.key" + s.rootPEM = certificatePath + "ca.crt" + s.rootSRL = certificatePath + "ca.srl" + s.serverKEY = certificatePath + "server.key" + s.serverPEM = certificatePath + "server.crt" + s.serverCSR = certificatePath + "server.csr" + + body := "Hello" + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, body) + }) + + s.server = http.Server{ + Addr: "127.0.0.1:6061", + Handler: handler, + } + + go func() { + if err := s.server.ListenAndServeTLS(s.serverPEM, s.serverKEY); err != nil { + s.T().Logf("failed to start server: %s", err.Error()) + } + }() +} + +func (s *TestSuite) TearDownSuite() { + s.server.Shutdown(context.Background()) + + os.Remove(s.serverKEY) + os.Remove(s.serverPEM) + os.Remove(s.serverCSR) + os.Remove(s.rootKEY) + os.Remove(s.rootPEM) + os.Remove(s.rootSRL) +} + +// This is test for no verify client +func (s *TestSuite) TestNewTLSConfig() { + // no verify client + _, err := http.Get("https://127.0.0.1:6061") + assert.NotNil(s.T(), err) + // client skip server certificate verify + tr := &http.Transport{ + TLSClientConfig: NewTLSConfig(WithSkipVerify()), + } + client := &http.Client{Transport: tr} + resp, err := client.Get("https://127.0.0.1:6061") + assert.Nil(s.T(), err) + defer resp.Body.Close() + res, err := ioutil.ReadAll(resp.Body) + assert.Equal(s.T(), string(res), "Hello") +} + +// This is test for verify client +func (s *TestSuite) TestNewTLSConfig2() { + tr := &http.Transport{ + TLSClientConfig: NewTLSConfig(WithRootCAs(s.rootPEM), + WithCertsByEncryptedKey(s.serverPEM, s.serverKEY, ""), WithSkipVerify()), + } + client := &http.Client{Transport: tr} + resp, _ := client.Get("https://127.0.0.1:6061") + defer resp.Body.Close() + res, _ := ioutil.ReadAll(resp.Body) + assert.Equal(s.T(), string(res), "Hello") +} + +func TestOptionTestSuite(t *testing.T) { + suite.Run(t, new(TestSuite)) +} + +func TestWithFunc(t *testing.T) { + WithCertsContent(nil, nil) + WithCerts("", "") + WithClientCAs() + patch := gomonkey.ApplyFunc(LoadRootCAs, func(caFiles ...string) (*x509.CertPool, error) { + return nil, errors.New("LoadRootCAs failed") + }) + WithClientCAs("") + patch.Reset() +} + +func TestVerifyCert(t *testing.T) { + var raw [][]byte + tlsConfig = &tls.Config{} + tlsConfig.ClientCAs = x509.NewCertPool() + err := VerifyCert(raw, nil) + assert.NotNil(t, err) + + raw = [][]byte{ + []byte("0"), + []byte("1"), + } + err = VerifyCert(raw, nil) + assert.NotNil(t, err) + + patch1 := gomonkey.ApplyFunc(x509.ParseCertificate, func([]byte) (*x509.Certificate, error) { + return &x509.Certificate{}, nil + }) + VerifyCert(raw, nil) + patch1.Reset() +} + +func TestApply(t *testing.T) { + cli := clientCAOption{} + cli.apply(&tls.Config{}) +} diff --git a/yuanrong/pkg/common/tls/tls.go b/yuanrong/pkg/common/tls/tls.go new file mode 100644 index 0000000..b7e5276 --- /dev/null +++ b/yuanrong/pkg/common/tls/tls.go @@ -0,0 +1,130 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tls provides tls utils +package tls + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "time" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/reader" +) + +// MutualTLSConfig indicates tls config +type MutualTLSConfig struct { + TLSEnable bool `json:"tlsEnable" yaml:"tlsEnable" valid:"optional"` + RootCAFile string `json:"rootCAFile" yaml:"rootCAFile" valid:"optional"` + ModuleCertFile string `json:"moduleCertFile" yaml:"moduleCertFile" valid:"optional"` + ModuleKeyFile string `json:"moduleKeyFile" yaml:"moduleKeyFile" valid:"optional"` + ServerName string `json:"serverName" yaml:"serverName" valid:"optional"` + SecretName string `json:"secretName" yaml:"secretName" valid:"optional"` + PwdFile string `json:"pwdFile" yaml:"pwdFile" valid:"optional"` + DecryptTool string `json:"sslDecryptTool" yaml:"sslDecryptTool" valid:"optional"` +} + +// MutualSSLConfig indicates ssl config +type MutualSSLConfig struct { + SSLEnable bool `json:"sslEnable" yaml:"sslEnable" valid:"optional"` + RootCAFile string `json:"rootCAFile" yaml:"rootCAFile" valid:"optional"` + ModuleCertFile string `json:"moduleCertFile" yaml:"moduleCertFile" valid:"optional"` + ModuleKeyFile string `json:"moduleKeyFile" yaml:"moduleKeyFile" valid:"optional"` + ServerName string `json:"serverName" yaml:"serverName" valid:"optional"` + PwdFile string `json:"pwdFile" yaml:"pwdFile" valid:"optional"` + DecryptTool string `json:"sslDecryptTool" yaml:"sslDecryptTool" valid:"optional"` +} + +// BuildClientTLSConfOpts is to build an option array for mostly used client tlsConf +func BuildClientTLSConfOpts(mutualConf MutualTLSConfig) []Option { + var opts []Option + passPhase, err := reader.ReadFileWithTimeout(mutualConf.PwdFile) + if err != nil { + log.GetLogger().Errorf("failed to read file PwdFile: %s", err.Error()) + return opts + } + opts = append(opts, WithRootCAs(mutualConf.RootCAFile), + WithCertsByEncryptedKey(mutualConf.ModuleCertFile, mutualConf.ModuleKeyFile, + string(passPhase)), + WithServerName(mutualConf.ServerName)) + return opts +} + +// BuildServerTLSConfOpts is to build an option array for mostly used server tlsConf +func BuildServerTLSConfOpts(mutualConf MutualTLSConfig) []Option { + var opts []Option + var passPhase []byte + var err error + if mutualConf.PwdFile != "" { + passPhase, err = reader.ReadFileWithTimeout(mutualConf.PwdFile) + if err != nil { + log.GetLogger().Errorf("failed to read file PwdFile: %s", err.Error()) + return opts + } + } + opts = append(opts, WithRootCAs(mutualConf.RootCAFile), + WithCertsByEncryptedKey(mutualConf.ModuleCertFile, mutualConf.ModuleKeyFile, + string(passPhase)), + WithClientCAs(mutualConf.RootCAFile), + WithClientAuthType(tls.RequireAndVerifyClientCert)) + return opts +} + +// LoadRootCAs returns system cert pool with caFiles added +func LoadRootCAs(caFiles ...string) (*x509.CertPool, error) { + rootCAs := x509.NewCertPool() + for _, file := range caFiles { + cert, err := reader.ReadFileWithTimeout(file) + if err != nil { + return nil, err + } + if !rootCAs.AppendCertsFromPEM(cert) { + return nil, err + } + } + return rootCAs, nil +} + +// VerifyCert Used to verity the server certificate +func VerifyCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(rawCerts)) + if len(certs) == 0 { + log.GetLogger().Errorf("cert number is 0") + return errors.New("cert number is 0") + } + opts := x509.VerifyOptions{ + Roots: tlsConfig.ClientCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + log.GetLogger().Errorf("failed to parse certificate from server: %s", err.Error()) + return err + } + certs[i] = cert + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err +} diff --git a/yuanrong/pkg/common/uuid/uuid.go b/yuanrong/pkg/common/uuid/uuid.go new file mode 100644 index 0000000..2f04e2f --- /dev/null +++ b/yuanrong/pkg/common/uuid/uuid.go @@ -0,0 +1,153 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package uuid for common functions +package uuid + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/hex" + "errors" + "hash" + "io" + "strings" +) + +const ( + defaultByteNum = 16 + indexFour = 4 + indexSix = 6 + indexEight = 8 + indexNine = 9 + indexTen = 10 + indexThirteen = 13 + indexFourteen = 14 + indexEighteen = 18 + indexNineteen = 19 + indexTwentyThree = 23 + indexTwentyFour = 24 + indexThirtySix = 36 + defaultSHA1BaseVersion = 5 +) + +// RandomUUID - +type RandomUUID [defaultByteNum]byte + +var ( + rander = rand.Reader // random function + // NameSpaceURL Well known namespace IDs and UUIDs + NameSpaceURL, _ = parseUUID("6ba7b811-9dad-11d1-80b4-00c04fd430c8") +) + +// New - +func New() RandomUUID { + return mustUUID(newRandom()) +} + +func mustUUID(uuid RandomUUID, err error) RandomUUID { + if err != nil { + return RandomUUID{} + } + return uuid +} + +func newRandom() (RandomUUID, error) { + return newRandomFromReader(rander) +} + +func newRandomFromReader(r io.Reader) (RandomUUID, error) { + var randomUUID RandomUUID + _, err := io.ReadFull(r, randomUUID[:]) + if err != nil { + return RandomUUID{}, err + } + randomUUID[indexSix] = (randomUUID[indexSix] & 0x0f) | 0x40 // Version 4 + randomUUID[indexEight] = (randomUUID[indexEight] & 0x3f) | 0x80 // Variant is 10 + return randomUUID, nil +} + +// String- +func (uuid RandomUUID) String() string { + var buf [indexThirtySix]byte + encodeHex(buf[:], uuid) + return string(buf[:]) +} + +func encodeHex(dstBuf []byte, uuid RandomUUID) { + hex.Encode(dstBuf, uuid[:indexFour]) + dstBuf[indexEight] = '-' + hex.Encode(dstBuf[indexNine:indexThirteen], uuid[indexFour:indexSix]) + dstBuf[indexThirteen] = '-' + hex.Encode(dstBuf[indexFourteen:indexEighteen], uuid[indexSix:indexEight]) + dstBuf[indexEighteen] = '-' + hex.Encode(dstBuf[indexNineteen:indexTwentyThree], uuid[indexEight:indexTen]) + dstBuf[indexTwentyThree] = '-' + hex.Encode(dstBuf[indexTwentyFour:], uuid[indexTen:]) +} + +// NewSHA1 - +func NewSHA1(space RandomUUID, data []byte) RandomUUID { + return NewHash(sha1.New(), space, data, defaultSHA1BaseVersion) +} + +// NewHash returns a new RandomUUID derived from the hash of space concatenated with +// data generated by h. The hash should be at least 16 byte in length. The +// first 16 bytes of the hash are used to form the RandomUUID. +func NewHash(sha1Hash hash.Hash, space RandomUUID, data []byte, version int) RandomUUID { + sha1Hash.Reset() + if _, err := sha1Hash.Write(space[:]); err != nil { + return RandomUUID{} + } + if _, err := sha1Hash.Write(data); err != nil { + return RandomUUID{} + } + s := sha1Hash.Sum(nil) + var uuid RandomUUID + copy(uuid[:], s) + // Set the version bits in the RandomUUID. + uuid[6] = (uuid[6] & 0x0f) | uint8((version&0xf)<<4) // The version bits are located at positions 13-15. + // Set the variant bits in the RandomUUID. + uuid[8] = (uuid[8] & 0x3f) | 0x80 // The variant bits are located at positions 8-11 (counting from 0). + return uuid +} + +func parseUUID(uuidStr string) (RandomUUID, error) { + const separator = "-" + + uuidStr = strings.ReplaceAll(uuidStr, separator, "") + + if len(uuidStr) != 32 { // Check if the length of the RandomUUID string is exactly 32 characters (16 bytes). + return RandomUUID{}, errors.New("invalid RandomUUID length") + } + + part1, part2 := uuidStr[:16], uuidStr[16:] // Split the RandomUUID string into two parts, each representing 8 bytes. + + b1, err := hex.DecodeString(part1) + if err != nil { + return RandomUUID{}, err + } + b2, err := hex.DecodeString(part2) + if err != nil { + return RandomUUID{}, err + } + + var uuid RandomUUID + copy(uuid[:8], b1) // Copy the first 8 bytes into the RandomUUID variable. + copy(uuid[8:], b2) // Copy the remaining 8 bytes into the RandomUUID variable. + + return uuid, nil +} diff --git a/yuanrong/pkg/common/uuid/uuid_test.go b/yuanrong/pkg/common/uuid/uuid_test.go new file mode 100644 index 0000000..1b0951d --- /dev/null +++ b/yuanrong/pkg/common/uuid/uuid_test.go @@ -0,0 +1,132 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package uuid for common functions +package uuid + +import ( + "testing" +) + +func TestNew(t *testing.T) { + m := make(map[RandomUUID]bool) + for x := 1; x < 32; x++ { + s := New() + if m[s] { + t.Errorf("New returned duplicated RandomUUID %s", s) + } + m[s] = true + } +} + +func TestSHA1(t *testing.T) { + uuid := NewSHA1(NameSpaceURL, []byte("python.org")).String() + want := "7af94e2b-4dd9-50f0-9c9a-8a48519bdef0" + if uuid != want { + t.Errorf("SHA1: got %q expected %q", uuid, want) + } +} + +func Test_parseRandomUUID(t *testing.T) { + type args struct { + uuidStr string + } + validRandomUUID := "6ba7b811-9dad-11d1-80b4-00c04fd430c8" + invalidFormatRandomUUID := "6ba7b8119dad11d180b400c04fd430c8" + illegalCharRandomUUID := "6ba7b811-9dad-11d1-80b4-00c04fd430cG" + shortRandomUUID := "6ba7b811-9dad-11d1-80b4" + longRandomUUID := "6ba7b811-9dad-11d1-80b4-00c04fd430c8-extra" + emptyRandomUUID := "" + tests := []struct { + name string + args args + want bool + wantErr bool + }{ + { + name: "Test_parseRandomUUID_with_validRandomUUID", + args: args{ + uuidStr: validRandomUUID, + }, + want: true, + wantErr: false, + }, + { + name: "Test_parseRandomUUID_with_invalidFormatRandomUUID", + args: args{ + uuidStr: invalidFormatRandomUUID, + }, + want: false, + wantErr: false, + }, + { + name: "Test_parseRandomUUID_with_illegalCharRandomUUID", + args: args{ + uuidStr: illegalCharRandomUUID, + }, + want: false, + wantErr: true, + }, + { + name: "Test_parseRandomUUID_with_shortRandomUUID", + args: args{ + uuidStr: shortRandomUUID, + }, + want: false, + wantErr: true, + }, + { + name: "Test_parseRandomUUID_with_longRandomUUID", + args: args{ + uuidStr: longRandomUUID, + }, + want: false, + wantErr: true, + }, + { + name: "Test_parseRandomUUID_with_emptyRandomUUID", + args: args{ + uuidStr: emptyRandomUUID, + }, + want: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseUUID(tt.args.uuidStr) + if (err != nil) != tt.wantErr { + t.Errorf("parseRandomUUID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if (got.String() != tt.args.uuidStr) == tt.want { + t.Errorf("parseRandomUUID() got = %v, want %v", got.String(), tt.args.uuidStr) + } + }) + } +} + +func BenchmarkParseRandomUUID(b *testing.B) { + uuidStr := "6ba7b811-9dad-11d1-80b4-00c04fd430c8" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := parseUUID(uuidStr) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/yuanrong/pkg/dashboard/client/index.html b/yuanrong/pkg/dashboard/client/index.html new file mode 100644 index 0000000..79f38c2 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/index.html @@ -0,0 +1,29 @@ + + + + + + + + + YuanRong + + +
+ + + diff --git a/yuanrong/pkg/dashboard/client/package.json b/yuanrong/pkg/dashboard/client/package.json new file mode 100644 index 0000000..dd8f3dc --- /dev/null +++ b/yuanrong/pkg/dashboard/client/package.json @@ -0,0 +1,37 @@ +{ + "name": "client", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview", + "test": "vitest", + "coverage": "vitest run --coverage" + }, + "dependencies": { + "@opentiny/vue": "3.23.0", + "@opentiny/vue-directive": "3.23.0", + "@opentiny/vue-huicharts": "3.23.0", + "@opentiny/vue-vite-import": "1.2.0", + "axios": "1.12.0", + "dayjs": "1.11.13", + "vue": "3.5.13", + "vue-i18n": "10.0.8", + "vue-request": "2.0.4", + "vue-router": "4.5.0", + "vue-virtual-scroller": "2.0.0-beta.8" + }, + "devDependencies": { + "@vitejs/plugin-vue": "5.1.2", + "@vitest/coverage-v8": "2.0.5", + "@vue/test-utils": "2.4.5", + "jsdom": "24.0.0", + "vite": "5.4.19", + "vitest": "2.0.5" + }, + "optionalDependencies": { + "rollup": "4.50.0" + } +} diff --git a/yuanrong/pkg/dashboard/client/public/logo.png b/yuanrong/pkg/dashboard/client/public/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..ccc4ebd398f2fb8aaaa36ee97144bf3d5992ada4 GIT binary patch literal 1691 zcmV;M24wk(P)WpDJT*@D=69m3If^&{wP!t426PJpoyUc zAyG&Jf(QbED<4KMl@GBtQpyFY5dzW_NZY&jcAVB;dmr!KeULcG=Jw6b&i;00c4qew z?vgSW>hBZc8%R`f`?@H$m18r6NeE$5=0c}cl32^#lny4L9Y_{_U#ylv7!)#?r{_CH zBB+HxG{HpaZy-eW-c?uyAyCj@GUnasEuhOMKrkcNSx8VW(Xn*@ny&Vs-`SMC6Iai83SfYkWLU`Dwxw{NmOuf zy;|;1Ko~J6zN2MDHOtc^>x z6e}a;a#k=6LLtPnLdWbMHsM)!_M{xL=$gKS1+oLG5U@sT7Db1*;GZ~ty1tS@%p!np z5GD~psX$S9a1;LJ&QG5~qc|CE^fGF1~x>_PgXNEzOI$USpFSvYFJb8GtI#ueFNfK-P z$#aX51jV8+)r@_`nc=?Rp)zHIL@^39+SDO0p@G|qCRN0dhL<#=BRKYDX zZ5K9pk62&TXM@Ci-en&!BciWJ5sPfwu-=`WHjM@eA*Vo-_yW5?8vE#DXfG)^<|SV+ zNOHPuhyQ$#l||7vBjr+-oCAK+Lz^CiJM8Br^|xSM?LKVj*w9qRFhX?;cZayk8@fr? z!IXBf#VWTebJ2lupwo80uGZjQcJIoEx>5Cr$<)&|pQO5#$A@^!8-i&|Qb7>A(Oy}h zOhYG*V^DSjT^i`P2f#7u)K0w7W|b_ma6Aq+TG!Yz4hV@VcjJ+xyOdbYIkS2GmHz@1 z50R!%;C!>~XhoL`E?KtvWS`CJXh)rLj4cx;Rqny{q6JgAPCd&TDw0xJ22y;T`b_6_ zgRPR~z7}e<<~|LE9^Ipqs#5b+iUeU3XRQZi3w_8*%*sc%HIHzYHu6tg+C++VYLw=& zR)i?rw-*~*`zjb%LWTimq_xQcA~H`(%JCj!bpukz(Hw_UMxyX2hsTZ|z?-fQRT_W^ z_4+r$s^{?$Ud)#0?r$GNR7gT_JatM4#4p`tRHGP?p~j6aawJf9GZBW+gH7JUV;aJ- zfM7z!5?;{b!JmJ`OYR)Q0EQ4_#_$jb{ERK}7^_M*n#0-oTW_ABWkLmG=xP-<2q#7n z&}sfuyvOwj*z3&`ZO6~61)G68z!>Rh0;sCOUpe_L^B&`vJbC6Peq;H7VH0~d8Da)D z(gE%7N+Do1e_!($)8Ubn6i3I^_R(a@9F|Nc4BbSi42_ScdunStE3GLQlk}k9fgU7$ zUKf!9)J4gO&SFI!Sy2?GX@8;}NP;JLEu0eyENFuQh6!3yG`6|&H%rFsDYC60Xp~4{uv;ihX zf%x-R@M2OgN~-=FtG&nl7U-nj)Q^L&kh?*hs6fOUJN5fjG<`M)ye%a5A`55a)m$Ma zLpVgxz%DCka>6QmgYVNPBrq;Yh@tU{gtZJRDHwT{YvFrTv1oyn)s6U1Fmpow8!e$H zy~!po0FxMDAZ1+1KfTz9wGaY@^#|j?2dF#Om<8OO&kMU>s%-4i{0ohE68>*Wj3!G$ l6qUzE(ecD6%7DA1e*v=3?*4*tcyj;%002ovPDHLkV1ib{5kLR{ literal 0 HcmV?d00001 diff --git a/yuanrong/pkg/dashboard/client/src/api/api.ts b/yuanrong/pkg/dashboard/client/src/api/api.ts new file mode 100644 index 0000000..d465a9a --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/api/api.ts @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import request from './index'; + +const apiV1Path = '/api/v1'; + +export const GetRsrcSummAPI = (): Promise => + request.get(apiV1Path + '/logical-resources/summary'); + +export const GetCompAPI = (): Promise => request.get(apiV1Path + '/components'); + +export const GetInstAPI = (): Promise => request.get(apiV1Path + '/instances'); + +export const GetInstSummAPI = (): Promise => request.get(apiV1Path + '/instances/summary'); + +export const GetInstInstIDAPI = (instanceID: string): Promise => + request.get(apiV1Path + '/instances/' + instanceID); + +export const GetInstParentIDAPI = (parentID: string): Promise => + request.get(apiV1Path + '/instances?parent_id=' + parentID); + +const jobPath = '/api/jobs'; + +export const ListJobsAPI = (): Promise => request.get(jobPath); + +export const GetJobInfoAPI = (submissionID: string): Promise => + request.get(jobPath + '/' + submissionID); + +export const ListLogsAPI = (instanceID: string): Promise => + request.get('/api/logs/list?instance_id=' + instanceID); + +export const GetLogByFilenameAPI = (filename: string, start: number, end: number): Promise => + request.get('/api/logs?filename=' + filename + '&start_line=' + start + '&end_line=' + end); + +export const GetPromQueryAPI = (query: string): Promise => + request.get('/api/v1/prometheus/query?query=' + query); \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/api/index.ts b/yuanrong/pkg/dashboard/client/src/api/index.ts new file mode 100644 index 0000000..d331919 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/api/index.ts @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import axios from 'axios'; + +const instance = axios.create({ + baseURL: '/', + timeout: 3000, +}) + +instance.interceptors.request.use(config => { + return config +}, err => { + return Promise.reject(err) +}) + +instance.interceptors.response.use(res => { + return res.data +}) + +export default instance \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/components/breadcrumb-component.vue b/yuanrong/pkg/dashboard/client/src/components/breadcrumb-component.vue new file mode 100644 index 0000000..952bf7a --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/components/breadcrumb-component.vue @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/components/chart-config.ts b/yuanrong/pkg/dashboard/client/src/components/chart-config.ts new file mode 100644 index 0000000..10d3e8b --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/components/chart-config.ts @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { ref } from 'vue'; + +export const pagerConfig = ref({ + attrs: { + currentPage: 1, + pageSize: 10, + pageSizes: [10, 20, 50], + total: 0, + align: 'right', + layout: 'total, sizes, prev, pager, next, jumper', + } +}); + +export const statusFilter = ref({ + layout: 'simple', + multi: true, + simpleFilter: { + selectAll: true + }, +}); \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/components/common-card.vue b/yuanrong/pkg/dashboard/client/src/components/common-card.vue new file mode 100644 index 0000000..3c0ac17 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/components/common-card.vue @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/components/log-content-template.vue b/yuanrong/pkg/dashboard/client/src/components/log-content-template.vue new file mode 100644 index 0000000..7a6350d --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/components/log-content-template.vue @@ -0,0 +1,134 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/components/progress-bar-template.ts b/yuanrong/pkg/dashboard/client/src/components/progress-bar-template.ts new file mode 100644 index 0000000..ed9d5fe --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/components/progress-bar-template.ts @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Float2 } from '../utils/handleNum'; + +const percentageBarStyle = { + background: "#ffffff", + border: "1px solid #e6f4ff", + minHeight: "14px", + lineHeight: "14px", + position: "relative", + boxSizing: "border-box", + borderRadius: "2px", +} + +const progressBarStyle = (percentage: number) => ({ + background: "#e6f4ff", + position: "absolute", + left: 0, + minHeight: "14px", + transition: "0.5s width", + boxSizing: "border-box", + width: `${Math.min(Math.max(0, percentage), 100)}%`, +}) + +const textStyle = { + fontSize: 14, + position: "relative", + width: "100%", + textAlign: "center", + whiteSpace: "nowrap", +} + +export const ProgressBar = (h:any, used: number, total: number) => { + const percent = total ? Float2(used / total * 100) : 0; + return h('div', {style: percentageBarStyle}, [ + h('div', {style: progressBarStyle(percent)}), + h('div', {style: textStyle}, used + '/' + total + '(' + percent + '%)'), + ]); +} + +export const SimpleProgressBar = (h:any, percent: number) => { + percent = Float2(percent); + return h('div', {style: percentageBarStyle}, [ + h('div', {style: progressBarStyle(percent)}), + h('div', {style: textStyle}, percent + '%'), + ]); +} + +export const MultiProgressBar = (h:any, data: object) => { + let children = []; + Object.keys(data).forEach(key => { + const percent = data[key].capacity ? Float2(data[key].used / data[key].capacity * 100) : 0; + const child = h('div', [ + h('span', ["[" + key + "]:"]), + h('div', + {style: {...percentageBarStyle, marginLeft: '4px', marginBottom: '4px', + display: 'inline-block', width: '80%',}}, + [ + h('div', {style: progressBarStyle(percent)}), + h('div', {style: textStyle}, percent + '%'), + ] + ), + ]); + children.push(child); + }) + return h('div', children); +} \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/components/warning-notify.ts b/yuanrong/pkg/dashboard/client/src/components/warning-notify.ts new file mode 100644 index 0000000..ae25dff --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/components/warning-notify.ts @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { TinyNotify } from '@opentiny/vue'; + +export const WarningNotify = (tip: string, err: any) => { + let message = ''; + if (typeof err === 'string' && err != '') { + message = err; + }else { + err = 'backend error'; + message = 'Unable to access backend service.'; + } + console.error(tip, 'error:', err); + TinyNotify({ + type: 'warning', + message: message, + position: 'top-right', + duration: 5000, + }); +}; \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/i18n/index.ts b/yuanrong/pkg/dashboard/client/src/i18n/index.ts new file mode 100644 index 0000000..a8ee768 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/i18n/index.ts @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { createI18n } from 'vue-i18n'; +import locale from '@opentiny/vue-locale'; + +const initI18n = (i18n) => + locale.initI18n({ + i18n, + createI18n, + messages: { + zhCN: { + test: '中文', + }, + enUS: { + test: 'English', + }, + }, + }); + +export const i18n = initI18n({ locale: 'enUS' }); \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/index.css b/yuanrong/pkg/dashboard/client/src/index.css new file mode 100644 index 0000000..2c06229 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/index.css @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +:root { + --card--space: 8px; +} + +.content .tiny-col { + padding: 0; +} + +.slot-logo { + padding: 0 !important; +} + +.tiny-nav-menu .menu>li>a.selected, .tiny-nav-menu .menu>li>a.selected:hover{ + color: var(--tv-color-text-active); +} + +.font-size24 { + font-size: 24px !important; +} + +.font-size20 { + font-size: 18px; +} + +.font-size18 { + font-size: 18px; +} + +.font-size16 { + font-size: 16px; +} + +.font-size14 { + font-size: 14px; +} + +.margin-top16 { + margin-top: 16px; +} + +.margin-top10 { + margin-top: 10px; +} + +.margin-left20 { + margin-left: 20px; +} + +.color-blue { + color: #1476ff; +} \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/main.ts b/yuanrong/pkg/dashboard/client/src/main.ts new file mode 100644 index 0000000..a059427 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/main.ts @@ -0,0 +1,82 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { createApp } from 'vue'; +import { createRouter, createWebHashHistory } from 'vue-router'; +import VueVirtualScroller from 'vue-virtual-scroller'; +import 'vue-virtual-scroller/dist/vue-virtual-scroller.css'; +import '@opentiny/vue-theme/index.css'; +import { i18n } from '@/i18n'; +import '@/index.css'; +import App from '@/pages/layout.vue'; + +const routes = [ + { + path:'/overview', + component: () => import(/* webpackChunkName: "overview" */ '@/pages/overview/overview-layout.vue'), + }, + { + path:'/cluster', + component : () => import(/* webpackChunkName: "cluster" */ '@/pages/cluster/cluster-layout.vue'), + }, + { + path:'/jobs', + component : () => import(/* webpackChunkName: "jobs" */ '@/pages/jobs/jobs-chart.vue'), + }, + { + path:'/jobs/:jobID', + component : () => import(/* webpackChunkName: "jobDetails" */ '@/pages/job-details/job-details-layout.vue'), + }, + { + path:'/instances', + component : () => import(/* webpackChunkName: "instances" */ '@/pages/instances/instances-chart.vue'), + }, + { + path:'/instances/:instanceID', + component : () => + import(/* webpackChunkName: "instanceDetails" */ '@/pages/instance-details/instance-details-layout.vue'), + }, + { + path:'/logs', + component : () => + import(/* webpackChunkName: "logs" */ '@/pages/log-pages/logs-nodes.vue'), + }, + { + path:'/logs/:nodeID', + component : () => + import(/* webpackChunkName: "logsNodeID" */ '@/pages/log-pages/logs-files.vue'), + }, + { + path:'/logs/:nodeID/:filename', + component : () => + import(/* webpackChunkName: "logsNodeIDFilename" */ '@/pages/log-pages/logs-content.vue'), + }, +]; + +const router = createRouter({ + history: createWebHashHistory(), + routes: [ + {path: '/', redirect: '/overview'}, + ...routes, + ], +}); + +const app = createApp(App); +app.use(i18n); +app.use(router); +app.use(VueVirtualScroller); + +app.mount('#app'); \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/cluster/cluster-chart.vue b/yuanrong/pkg/dashboard/client/src/pages/cluster/cluster-chart.vue new file mode 100644 index 0000000..4716095 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/cluster/cluster-chart.vue @@ -0,0 +1,304 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/cluster/cluster-layout.vue b/yuanrong/pkg/dashboard/client/src/pages/cluster/cluster-layout.vue new file mode 100644 index 0000000..11e9d22 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/cluster/cluster-layout.vue @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/pages/instance-details/components/empty-log-card.vue b/yuanrong/pkg/dashboard/client/src/pages/instance-details/components/empty-log-card.vue new file mode 100644 index 0000000..5cd9b1d --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/instance-details/components/empty-log-card.vue @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/instance-details/components/instance-info.vue b/yuanrong/pkg/dashboard/client/src/pages/instance-details/components/instance-info.vue new file mode 100644 index 0000000..e7e250d --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/instance-details/components/instance-info.vue @@ -0,0 +1,185 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/pages/instance-details/instance-details-layout.vue b/yuanrong/pkg/dashboard/client/src/pages/instance-details/instance-details-layout.vue new file mode 100644 index 0000000..c23b4e6 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/instance-details/instance-details-layout.vue @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/pages/instances/instances-chart.vue b/yuanrong/pkg/dashboard/client/src/pages/instances/instances-chart.vue new file mode 100644 index 0000000..40d2b71 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/instances/instances-chart.vue @@ -0,0 +1,189 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/job-details/components/job-info.vue b/yuanrong/pkg/dashboard/client/src/pages/job-details/components/job-info.vue new file mode 100644 index 0000000..e630e4e --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/job-details/components/job-info.vue @@ -0,0 +1,165 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/pages/job-details/job-details-layout.vue b/yuanrong/pkg/dashboard/client/src/pages/job-details/job-details-layout.vue new file mode 100644 index 0000000..3120a84 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/job-details/job-details-layout.vue @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/pages/jobs/jobs-chart.vue b/yuanrong/pkg/dashboard/client/src/pages/jobs/jobs-chart.vue new file mode 100644 index 0000000..0dc22e6 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/jobs/jobs-chart.vue @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/layout.vue b/yuanrong/pkg/dashboard/client/src/pages/layout.vue new file mode 100644 index 0000000..6dd5313 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/layout.vue @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-content.vue b/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-content.vue new file mode 100644 index 0000000..72d10f3 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-content.vue @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-files.vue b/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-files.vue new file mode 100644 index 0000000..e8668da --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-files.vue @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-nodes.vue b/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-nodes.vue new file mode 100644 index 0000000..42d6b30 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/log-pages/logs-nodes.vue @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/overview/components/cluster-card.vue b/yuanrong/pkg/dashboard/client/src/pages/overview/components/cluster-card.vue new file mode 100644 index 0000000..62c3d86 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/overview/components/cluster-card.vue @@ -0,0 +1,70 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + + + \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/pages/overview/components/instances-card.vue b/yuanrong/pkg/dashboard/client/src/pages/overview/components/instances-card.vue new file mode 100644 index 0000000..d64384d --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/overview/components/instances-card.vue @@ -0,0 +1,74 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/pages/overview/components/resources-card.vue b/yuanrong/pkg/dashboard/client/src/pages/overview/components/resources-card.vue new file mode 100644 index 0000000..dde8dee --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/overview/components/resources-card.vue @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/pages/overview/overview-layout.vue b/yuanrong/pkg/dashboard/client/src/pages/overview/overview-layout.vue new file mode 100644 index 0000000..7e80cc9 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/pages/overview/overview-layout.vue @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + + + + diff --git a/yuanrong/pkg/dashboard/client/src/types/api.d.ts b/yuanrong/pkg/dashboard/client/src/types/api.d.ts new file mode 100644 index 0000000..d5b3019 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/types/api.d.ts @@ -0,0 +1,150 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +interface RsrcSummData { + cap_cpu: number; + cap_mem: number; + alloc_cpu: number; + alloc_mem: number; + proxy_num: number; +} + +interface GetRsrcSummAPIRes { + code: number; + data: RsrcSummData; + msg: string; +} + +interface UsageInfo { + cap_cpu: number; + cap_mem: number; + alloc_cpu: number; + alloc_mem: number; + alloc_npu: number; +} + +interface CompInfo extends UsageInfo { + hostname: string; + status: string; + address: string; +} + +interface CompsObj { + [key: string]: CompInfo[]; +} + +interface CompsData extends UsageInfo { + nodes: CompInfo[]; + components: CompsObj; +} + +interface GetCompAPIRes { + code: number; + data: CompsData; + msg: string; +} + +interface InstInfo { + id: string; + status: string; + create_time: string; + job_id: string; + pid: string; + ip: string; + node_id: string; + agent_id: string; + parent_id: string; + required_cpu: number; + required_mem: number; + required_gpu: number; + required_npu: number; + restarted: number; + exit_detail: string; +} + +interface GetInstAPIRes { + code: number; + data: InstInfo[]; + msg: string; +} + +interface InstSummData { + total: number; + running: number; + exited: number; + fatal: number; +} + +interface GetInstSummAPIRes { + code: number; + data: InstSummData; + msg: string; +} + +interface GetInstInstIDAPIRes { + code: number; + data: string; + msg: string; +} + +interface DriverDetail { + id: string; + node_ip_address: string; + pid: string; +} + +interface StringObj { + [key: string]: string; +} + +interface UnknownObj { + [key: string]: unknown; +} + +interface JobInfo { + type: string; + entrypoint: string; + submission_id: string; + driver_info: DriverDetail; + status: string; + message: string; + error_type: string; + start_time: string; + end_time: string; + metadata: object; + runtime_env: UnknownObj; + driver_agent_http_address: string; + driver_node_id: string; + driver_exit_code: number; +} + +type GetListJobsAPIRes = JobInfo[] + +interface ListLogsObj { + [key: string]: string[]; +} + +interface ListLogsRes { + message: string; + data: ListLogsObj; +} + +interface MetricData { + metric: StringObj; + value: [number, string]; +} + +type PromData = MetricData[] \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/utils/dayFormat.ts b/yuanrong/pkg/dashboard/client/src/utils/dayFormat.ts new file mode 100644 index 0000000..e6829ed --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/utils/dayFormat.ts @@ -0,0 +1,22 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import dayjs from 'dayjs'; + +// input param time is timestamp, unit is seconds +export const DayFormat = (time: string) => { + return time == '' ? time : dayjs(parseInt(time + '000')).format('YYYY-MM-DD HH:mm:ss'); +} \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/utils/handleNum.ts b/yuanrong/pkg/dashboard/client/src/utils/handleNum.ts new file mode 100644 index 0000000..f38fcee --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/utils/handleNum.ts @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export const Float2 = (num: number) => { + return Math.round(num * 100) / 100; +} + +export const MBToGB = (num: number) => { + return Float2(num / 1024); +} + +export const CPUConvert = (num: number) => { + return num / 1000; +} \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/utils/sort.ts b/yuanrong/pkg/dashboard/client/src/utils/sort.ts new file mode 100644 index 0000000..29e6290 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/utils/sort.ts @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export const ChartSort = (data: any, primaryField: string, secondaryField: string) => { + if (!data) { + return + } + data.sort((a, b) => { + return b[primaryField].localeCompare(a[primaryField]) === 0 ? + a[secondaryField].localeCompare(b[secondaryField]) : b[primaryField].localeCompare(a[primaryField]); + }) +} \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/src/utils/swr.ts b/yuanrong/pkg/dashboard/client/src/utils/swr.ts new file mode 100644 index 0000000..ac2ec67 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/src/utils/swr.ts @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { useRequest } from 'vue-request'; + +const POLLING_INTERVAL = 5000; +const LOG_POLLING_INTERVAL = 60000; + +const defaultSWR = (func: any, interval: number) => { + useRequest(func, { + pollingInterval: interval, + }); +}; + +export const SWR = (func: any) => { + defaultSWR(func, POLLING_INTERVAL); +}; + +export const LogSWR = (func: any) => { + defaultSWR(func, LOG_POLLING_INTERVAL); +}; \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/components/progress-bar-template.spec.ts b/yuanrong/pkg/dashboard/client/tests/components/progress-bar-template.spec.ts new file mode 100644 index 0000000..05a0816 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/components/progress-bar-template.spec.ts @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { h } from 'vue'; +import { describe, it, expect } from 'vitest'; +import { ProgressBar, SimpleProgressBar, MultiProgressBar } from '@/components/progress-bar-template'; + +describe('ProgressBarTemplate', () => { + it('renders ProgressBar correctly', () => { + const progressBar = ProgressBar(h, 2.93, 38); + expect(JSON.stringify(progressBar)).toContain('2.93/38(7.71%)'); + }); + it('renders SimpleProgressBar correctly', () => { + const simpleProgressBar = SimpleProgressBar(h, 10.68); + expect(JSON.stringify(simpleProgressBar)).toContain('10.68%'); + }); + it('renders MultiProgressBar correctly', () => { + const multiProgressBar = MultiProgressBar(h, {1: {used : 11.8, capacity: 100}}); + expect(JSON.stringify(multiProgressBar)).toContain('11.8%'); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/main.spec.ts b/yuanrong/pkg/dashboard/client/tests/main.spec.ts new file mode 100644 index 0000000..dab4401 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/main.spec.ts @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, it, expect } from 'vitest'; +import MainApp from '@/main'; + +describe('MainApp', () => { + it('renders the main structure correctly', () => { + expect(MainApp).toBeUndefined(); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/cluster/cluster-chart.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/cluster/cluster-chart.spec.ts new file mode 100644 index 0000000..d8ed953 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/cluster/cluster-chart.spec.ts @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import CommonCard from '@/components/common-card.vue'; +import ClusterChart from '@/pages/cluster/cluster-chart.vue'; + +describe('ClusterChart', () => { + const wrapper = mount(ClusterChart, { + global: { + stubs: { + CommonCard: true, + TinyGrid: true, + TinyGridColumn: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('ClusterChart.initChartWithProm', () => { + vi.mock('@/api/api', () => ({ + GetCompAPI: vi.fn().mockResolvedValue({ + data: { + 'nodes': [ + { + 'hostname': 'dggphis232340-2846342', + 'status': 'healthy', + 'address': '7.185.104.157', + 'cap_cpu': 10000, + 'cap_mem': 38912, + 'alloc_cpu': 10000, + 'alloc_mem': 38912, + 'alloc_npu': 0 + } + ], + 'components': { + 'dggphis232340-2846342': [ + { + 'hostname': 'function-agent-7.185.104.157-58866', + 'status': 'alive', + 'address': '7.185.104.157:58866', + 'cap_cpu': 10000, + 'cap_mem': 38912, + 'alloc_cpu': 10000, + 'alloc_mem': 38912, + 'alloc_npu': 0 + } + ] + } + }, + }), + GetInstAPI: vi.fn().mockResolvedValue({ + data: [ + { + 'id': 'app-6e334abe-3554-454f-94d0-493f24118292', + 'status': 'fatal', + 'create_time': '1762428525', + 'job_id': 'job-his232340-2846342', + 'pid': '2855872', + 'ip': '7.185.104.157:22773', + 'node_id': 'dggphis232340-2846342', + 'agent_id': 'function-agent-7.185.104.157-58866', + 'parent_id': 'driver-faas-frontend-dggphis232340-2846342', + 'required_cpu': 500, + 'required_mem': 500, + 'required_gpu': 0, + 'required_npu': 0, + 'restarted': 0, + 'exit_detail': 'Instance(app-6e334abe-3554-454f-94d0-493f24118292) exitStatus:1' + }]}), + GetPromQueryAPI: vi.fn().mockResolvedValue([]), + })); + const wrapper = mount(ClusterChart); + + it('renders initChartWithProm correctly', async () => { + wrapper.vm.initChartWithProm(); + await flushPromises(); + expect(wrapper.text()).toContain('Components HostnameStatusIP/Address' + + 'CPUMemory(GB)NPUDisk(GB)Logical Resources'); + expect(wrapper.text()).toContain('dggphis232340-2846342healthy'); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/cluster/cluster-layout.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/cluster/cluster-layout.spec.ts new file mode 100644 index 0000000..80d54bb --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/cluster/cluster-layout.spec.ts @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import ClusterChart from '@/pages/cluster/cluster-chart.vue'; +import ClusterLayout from '@/pages/cluster/cluster-layout.vue'; + +describe('ClusterLayout', () => { + vi.mock('@opentiny/vue-huicharts', () => ({ + default: { + name: 'TinyHuichartsGauge', + template: '
Canvas Mock
', + }, + })); + + const wrapper = mount(ClusterLayout, { + global: { + stubs: { + ResourcesCard: true, + ClusterChart: true, + }, + }, + }); + + it('renders the main structure correctly', async () => { + expect(wrapper.findComponent(ClusterChart).exists()).toBe(true); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/instance-details/components/empty-log-card.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/instance-details/components/empty-log-card.spec.ts new file mode 100644 index 0000000..2663471 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/instance-details/components/empty-log-card.spec.ts @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount } from '@vue/test-utils'; +import { describe, it, expect } from 'vitest'; +import CommonCard from '@/components/common-card.vue'; +import EmptyLogCard from '@/pages/instance-details/components/empty-log-card.vue'; + +describe('EmptyLogCard', () => { + const wrapper = mount(EmptyLogCard, { + global: { + stubs: { + CommonCard: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/instance-details/components/instance-info.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/instance-details/components/instance-info.spec.ts new file mode 100644 index 0000000..362dce1 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/instance-details/components/instance-info.spec.ts @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import { GetInstInstIDAPI } from '@/api/api'; +import CommonCard from '@/components/common-card.vue'; +import InstanceInfo from '@/pages/instance-details/components/instance-info.vue'; + +describe('InstanceInfo', () => { + const wrapper = mount(InstanceInfo, { + global: { + stubs: { + CommonCard: true, + TinyCol: true, + TinyLayout: true, + TinyRow: true, + AutoTip: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('InstanceInfo.initInstanceInfo', () => { + vi.mock('@/api/api', () => ({ + GetInstInstIDAPI: vi.fn(), + })); + + const wrapper = mount(InstanceInfo); + const vm = wrapper.vm; + it('renders when initInstanceInfo error or empty', async () => { + GetInstInstIDAPI.mockRejectedValue('error'); + vm.initInstanceInfo(); + await flushPromises(); + expect(wrapper.text()).toBe('InstanceInfos IDStatusJobIDPIDIPNodeIDParentIDCreateTime' + + 'RequiredCPURequiredMemory(MB)RequiredGPURequiredNPURestartedExitDetail'); + }); + it('renders when initInstanceInfo correctly', async () => { + GetInstInstIDAPI.mockResolvedValue({ + 'id': '10050000-0000-4000-b00f-8374b8dd2508', + 'status': 'fatal', + 'create_time': '1762222864', + 'job_id': 'job-febb4a18', + 'pid': '249465', + 'ip': '7.185.105.138:22773', + 'node_id': 'dggphis232339-189755', + 'agent_id': 'function-agent-7.185.105.138-58866', + 'parent_id': 'app-ab00977c-682e-4b5e-9cb3-f928c55a7d27', + 'required_cpu': 3000, + 'required_mem': 500, + 'required_gpu': 0, + 'required_npu': 0, + 'restarted': 0, + 'exit_detail': 'ancestor instance(app-ab00977c-682e-4b5e-9cb3-f928c55a7d27) is abnormal' + }); + vm.initInstanceInfo(); + await flushPromises(); + expect(wrapper.text()).toContain('InstanceInfos IDStatusJobID10050000-0000-4000-b00f-8374b8dd2508fatal'); + expect(wrapper.text()).toContain('2025-11-04 10:21:04'); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/instance-details/instance-details-layout.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/instance-details/instance-details-layout.spec.ts new file mode 100644 index 0000000..f3efe3e --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/instance-details/instance-details-layout.spec.ts @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import BreadcrumbComponent from '@/components/breadcrumb-component.vue'; +import EmptyLogCard from '@/pages/instance-details/components/empty-log-card.vue'; +import InstanceInfo from '@/pages/instance-details/components/instance-info.vue'; +import InstanceDetailsLayout from '@/pages/instance-details/instance-details-layout.vue'; + +describe('InstanceDetailsLayout', () => { + const wrapper = mount(InstanceDetailsLayout, { + global: { + stubs: { + BreadcrumbComponent: true, + LogContentTemplate: true, + EmptyLogCard: true, + InstanceInfo: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(BreadcrumbComponent).exists()).toBe(true); + expect(wrapper.findComponent(EmptyLogCard).exists()).toBe(true); + expect(wrapper.findComponent(InstanceInfo).exists()).toBe(true); + }); +}) + +describe('InstanceDetailsLayout.ListLogsAPI', () => { + vi.mock('@/api/api', () => ({ + ListLogsAPI: vi.fn().mockResolvedValue({ + 'message': '', + 'data': { + 'dggphis232340-2846342': [] + } + }), + })); + + const wrapper = mount(InstanceDetailsLayout); + it('renders when initScrollerItem error or empty', async () => { + wrapper.vm.initScrollerItem(); + await flushPromises(); + expect(wrapper.text()).toContain('InstanceInfos IDStatusJobIDPIDIPNodeIDParentID' + + 'CreateTimeRequiredCPURequiredMemory(MB)RequiredGPURequiredNPURestartedExitDetailLog'); + expect(wrapper.text()).toContain('Driver logs are only available when submitting jobs ' + + 'via the Job Submission API, SDK.'); + }); +}) + diff --git a/yuanrong/pkg/dashboard/client/tests/pages/instances/instances-chart.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/instances/instances-chart.spec.ts new file mode 100644 index 0000000..4026c8d --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/instances/instances-chart.spec.ts @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect, vi, afterAll } from 'vitest'; +import { GetInstAPI, GetInstParentIDAPI } from '@/api/api'; +import CommonCard from '@/components/common-card.vue'; +import InstancesChart from '@/pages/instances/instances-chart.vue'; + +describe('InstancesChart', () => { + const wrapper = mount(InstancesChart, { + global: { + stubs: { + CommonCard: true, + TinyGrid: true, + TinyGridColumn: true, + TinyLink: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('InstancesChart.initChart', () => { + vi.mock('@/api/api', () => ({ + GetInstAPI: vi.fn(), + GetInstParentIDAPI: vi.fn(), + })); + + const wrapper = mount(InstancesChart); + const vm = wrapper.vm; + const data = [ + { + 'id': '1025e641-f911-4500-8000-000000f8dbc2', + 'status': 'fatal', + 'create_time': '1762222864', + 'job_id': 'job-febb4a18', + 'pid': '249466', + 'ip': '7.185.105.138:22773', + 'node_id': 'dggphis232339-189755', + 'agent_id': 'function-agent-7.185.105.138-58866', + 'parent_id': 'app-ab00977c-682e-4b5e-9cb3-f928c55a7d27', + 'required_cpu': 3000, + 'required_mem': 500, + 'required_gpu': 0, + 'required_npu': 0, + 'restarted': 0, + 'exit_detail': 'ancestor instance(app-ab00977c-682e-4b5e-9cb3-f928c55a7d27) is abnormal' + }, + { + 'id': 'app-ab00977c-682e-4b5e-9cb3-f928c55a7d27', + 'status': 'fatal', + 'create_time': '1762222864', + 'job_id': 'job-his232339-189755', + 'pid': '249351', + 'ip': '7.185.105.138:22773', + 'node_id': 'dggphis232339-189755', + 'agent_id': 'function-agent-7.185.105.138-58866', + 'parent_id': 'driver-faas-frontend-dggphis232339-189755', + 'required_cpu': 500, + 'required_mem': 500, + 'required_gpu': 0, + 'required_npu': 0, + 'restarted': 0, + 'exit_detail': 'Instance(app-ab00977c-682e-4b5e-9cb3-f928c55a7d27) exitStatus:0' + }, + ]; + + it('renders when initChart.GetInstAPI correctly', async () => { + GetInstAPI.mockResolvedValue({data}); + + vm.initChart(); + await flushPromises(); + expect(wrapper.text()).toContain('IDStatusJobIDPIDIPNodeIDParentIDCreateTimeRequired ' + + 'CPURequired Memory(MB)Required GPURequired NPURestartedExitDetailLog'); + expect(wrapper.text()).toContain('app-ab00977c-682e-4b5e-9cb3-f928c55a7d27fatal'); + expect(wrapper.text()).toContain('1025e641-f911-4500-8000-000000f8dbc2fatal'); + }); + + const originalHash = window.location.hash; + afterAll(()=>{ window.location.hash = originalHash }); + + it('renders when initChart.GetInstParentIDAPI correctly', async () => { + GetInstParentIDAPI.mockResolvedValue(data); + window.location.hash = '#/jobs/123'; + + vm.initChart(); + await flushPromises(); + expect(wrapper.text()).toContain('app-ab00977c-682e-4b5e-9cb3-f928c55a7d27fatal'); + expect(wrapper.text()).toContain('1025e641-f911-4500-8000-000000f8dbc2fatal'); + }); +}) diff --git a/yuanrong/pkg/dashboard/client/tests/pages/job-details/components/job-info.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/job-details/components/job-info.spec.ts new file mode 100644 index 0000000..04582ab --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/job-details/components/job-info.spec.ts @@ -0,0 +1,84 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import { GetJobInfoAPI } from '@/api/api'; +import CommonCard from '@/components/common-card.vue'; +import JobInfo from '@/pages/job-details/components/job-info.vue'; + +describe('JobInfo', () => { + const wrapper = mount(JobInfo, { + global: { + stubs: { + CommonCard: true, + TinyCol: true, + TinyLayout: true, + TinyRow: true, + AutoTip: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('JobInfo.initJobInfo', () => { + vi.mock('@/api/api', () => ({ + GetJobInfoAPI: vi.fn(), + })); + + const wrapper = mount(JobInfo); + const vm = wrapper.vm; + it('renders when initJobInfo empty', async () => { + expect(wrapper.text()).toBe('JobInfos EntrypointRuntimeEnvSubmissionIDStatusStartTime' + + 'EndTimeMessageErrorTypeDriverNodeIDDriverNodeIPDriverPID'); + }) + + it('renders when initJobInfo correctly', async () => { + GetJobInfoAPI.mockResolvedValue({ + 'key': '/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function' + + '/0-system-faasExecutorPosixCustom/version/$latest/defaultaz/0ebe2d84ad28eed000/' + + 'app-ab00977c-682e-4b5e-9cb3-f928c55a7d27', + 'type': 'SUBMISSION', + 'entrypoint': 'python3 ajobsample/three_actor.py', + 'submission_id': 'app-ab00977c-682e-4b5e-9cb3-f928c55a7d27', + 'driver_info': { + 'id': 'app-ab00977c-682e-4b5e-9cb3-f928c55a7d27', + 'node_ip_address': '7.185.105.138', + 'pid': '249351', + }, + 'status': 'FAILED', + 'start_time': '1762222864', + 'end_time': '1762222866', + 'metadata': null, + 'runtime_env': { + 'envVars': '{\'ENABLE_SERVER_MODE\':\'true\'}', + 'pip': '', + 'working_dir': 'file:///root/wxq/ajobsample/ajobsample.zip', + }, + 'driver_agent_http_address': '', + 'driver_node_id': 'dggphis232339-189755', + 'driver_exit_code': 139, + 'error_type': 'Instance(app-ab00977c-682e-4b5e-9cb3-f928c55a7d27) exitStatus:0', + }); + vm.initJobInfo(); + await flushPromises(); + expect(wrapper.text()).toContain('app-ab00977c-682e-4b5e-9cb3-f928c55a7d27FAILED'); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/job-details/job-details-layout.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/job-details/job-details-layout.spec.ts new file mode 100644 index 0000000..2424f68 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/job-details/job-details-layout.spec.ts @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import BreadcrumbComponent from '@/components/breadcrumb-component.vue'; +import InstancesChart from '@/pages/instances/instances-chart.vue'; +import JobInfo from '@/pages/job-details/components/job-info.vue'; +import JobDetailsLayout from '@/pages/job-details/job-details-layout.vue'; + +describe('JobDetailsLayout', () => { + const wrapper = mount(JobDetailsLayout, { + global: { + stubs: { + BreadcrumbComponent: true, + LogContentTemplate: true, + InstancesChart: true, + JobInfo: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(BreadcrumbComponent).exists()).toBe(true); + expect(wrapper.findComponent(InstancesChart).exists()).toBe(true); + expect(wrapper.findComponent(JobInfo).exists()).toBe(true); + }); +}) + +describe('JobDetailsLayout.initScrollerItem', () => { + vi.mock('@/api/api', () => ({ + ListLogsAPI: vi.fn().mockResolvedValue({ + 'message': '', + 'data': { + 'dggphis232340-2846342': [] + } + }), + })); + const wrapper = mount(JobDetailsLayout); + const vm = wrapper.vm; + + it('renders when initScrollerItem correctly', async () => { + vm.initScrollerItem(); + await flushPromises(); + expect(wrapper.text()).toContain('JobInfos EntrypointRuntimeEnvSubmissionIDStatusStartTime' + + 'EndTimeMessageErrorTypeDriverNodeIDDriverNodeIPDriverPID'); + }); +}) diff --git a/yuanrong/pkg/dashboard/client/tests/pages/jobs/jobs-chart.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/jobs/jobs-chart.spec.ts new file mode 100644 index 0000000..ade904f --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/jobs/jobs-chart.spec.ts @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect } from 'vitest'; +import CommonCard from '@/components/common-card.vue'; +import JobsChart from '@/pages/jobs/jobs-chart.vue'; + +describe('JobsChart', () => { + const wrapper = mount(JobsChart, { + global: { + stubs: { + CommonCard: true, + TinyGrid: true, + TinyGridColumn: true, + TinyLink: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('JobsChart.initChart', () => { + const wrapper = mount(JobsChart); + const vm = wrapper.vm; + + it('renders when initChart correctly', async () => { + vm.initChart(); + await flushPromises(); + expect(wrapper.text()).toContain('SubmissionIDEntrypointStatusMessageStartTimeEndTimeDriverPIDLog'); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/layout.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/layout.spec.ts new file mode 100644 index 0000000..9d147bd --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/layout.spec.ts @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount } from '@vue/test-utils'; +import { describe, it, expect } from 'vitest'; +import { TinyNavMenu } from '@opentiny/vue'; +import Layout from '@/pages/layout.vue'; + +describe('Layout', () => { + const wrapper = mount(Layout, { + global: { + stubs: { + TinyNavMenu: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(TinyNavMenu).exists()).toBe(true); + }); +}) + +describe('Layout.text', () => { + const wrapper = mount(Layout); + + it('renders text correctly', () => { + expect(wrapper.text()).toContain('OverviewClusterJobsInstancesLogs'); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-content.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-content.spec.ts new file mode 100644 index 0000000..2f61dfe --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-content.spec.ts @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount } from '@vue/test-utils'; +import { describe, it, expect } from 'vitest'; +import LogContentTemplate from '@/components/log-content-template.vue'; +import LogsContent from '@/pages/log-pages/logs-content.vue'; + +describe('LogsContent', () => { + const wrapper = mount(LogsContent, { + global: { + stubs: { + BreadcrumbComponent: true, + LogContentTemplate: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(LogContentTemplate).exists()).toBe(true); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-files.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-files.spec.ts new file mode 100644 index 0000000..ea4b070 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-files.spec.ts @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import {describe, it, expect, vi } from 'vitest'; +import CommonCard from '@/components/common-card.vue'; +import LogsFiles from '@/pages/log-pages/logs-files.vue'; + +describe('LogsFiles', () => { + const wrapper = mount(LogsFiles, { + global: { + stubs: { + CommonCard: true, + TinyLink: true, + TinySearch: true, + BreadcrumbComponent: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('LogsFiles.initScrollerItem', () => { + vi.mock('@/api/api', () => { + return { + ListLogsAPI: vi.fn().mockResolvedValue({ + 'message': '', + 'data': { + 'dggphis232339-189755': [ + 'runtime-10050000-0000-4000-b00f-8374b8dd2508-00000000009d.out', + 'runtime-10050000-0000-4000-b00f-8374b8dd2508-00000000009d.err', + ]}}), + } + }) + + const wrapper = mount(LogsFiles); + const vm = wrapper.vm; + it('renders when initScrollerItem correctly', async () => { + vm.initScrollerItem(); + await flushPromises(); + expect(wrapper.text()).toContain('Log Files'); + }); +}) diff --git a/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-nodes.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-nodes.spec.ts new file mode 100644 index 0000000..45b6bce --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/log-pages/logs-nodes.spec.ts @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import CommonCard from '@/components/common-card.vue'; +import LogsNodes from '@/pages/log-pages/logs-nodes.vue'; + +describe('LogsNodes', () => { + const wrapper = mount(LogsNodes, { + global: { + stubs: { + CommonCard: true, + TinyLink: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('LogsNodes.initScrollerItem', () => { + vi.mock('@/api/api', () => { + return { + GetCompAPI: vi.fn().mockResolvedValue({ + 'data': { + 'nodes': [{ + 'hostname': 'dggphis232339-189755', + 'status': 'healthy', + 'address': '7.185.105.138', + 'cap_cpu': 10000, + 'cap_mem': 38912, + 'alloc_cpu': 10000, + 'alloc_mem': 38912, + 'alloc_npu': 0 + }]}}), + ListLogsAPI: vi.fn().mockResolvedValue({ + 'data': { + 'dggphis232339-189755': [ + 'runtime-10050000-0000-4000-b00f-8374b8dd2508-00000000009d.out', + 'runtime-10050000-0000-4000-b00f-8374b8dd2508-00000000009d.err', + ]}}), + } + }) + + const wrapper = mount(LogsNodes); + const vm = wrapper.vm; + it('renders when initScrollerItem correctly', async () => { + await vm.initNodesObj(); + await vm.initScrollerItem(); + expect(wrapper.text()).toBe('Logs Select a node to view logs:' + + 'dggphis232339-189755(IP:7.185.105.138)dggphis232339-189755(IP:7.185.105.138)'); + }); +}) diff --git a/yuanrong/pkg/dashboard/client/tests/pages/overview/components/cluster-card.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/overview/components/cluster-card.spec.ts new file mode 100644 index 0000000..173bbbc --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/overview/components/cluster-card.spec.ts @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import { TinyLink } from '@opentiny/vue'; +import { GetRsrcSummAPI } from '@/api/api'; +import CommonCard from '@/components/common-card.vue'; +import ClusterCard from '@/pages/overview/components/cluster-card.vue'; + +describe('ClusterCard', () => { + const wrapper = mount(ClusterCard, { + global: { + stubs: { + CommonCard: true, + TinyLink: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('ClusterCard.TinyLink', async () => { + const wrapper = mount(TinyLink, { + props: { href: '#/cluster' }, + slots: { + default: 'View All Cluster Status >', + }, + }); + await wrapper.trigger('click'); + it('renders TinyLink correctly', () => { + expect(wrapper.emitted('click')).toHaveLength(1); + expect(wrapper.text()).toBe('View All Cluster Status >'); + }); +}) + +describe('ClusterCard.initStatus', () => { + vi.mock('@/api/api', () => ({ + GetRsrcSummAPI: vi.fn(), + })); + + const wrapper = mount(ClusterCard); + const vm = wrapper.vm; + it('renders when initStatus error', async () => { + GetRsrcSummAPI.mockRejectedValue('error'); + vm.initStatus(); + await flushPromises(); + expect(wrapper.text()).toBe('Cluster Status Total: 0 nodeAlive: 0 nodeView All Cluster Status >'); + }); + it('renders when initStatus return empty', async () => { + GetRsrcSummAPI.mockResolvedValue(''); + vm.initStatus(); + await flushPromises(); + expect(wrapper.text()).toBe('Cluster Status Total: 0 nodeAlive: 0 nodeView All Cluster Status >'); + }); + it('renders when initStatus correctly', async () => { + GetRsrcSummAPI.mockResolvedValue({data: {proxy_num: 3}}); + vm.initStatus(); + await flushPromises(); + expect(wrapper.text()).toBe('Cluster Status Total: 3 nodeAlive: 3 nodeView All Cluster Status >'); + }); +}) diff --git a/yuanrong/pkg/dashboard/client/tests/pages/overview/components/instances-card.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/overview/components/instances-card.spec.ts new file mode 100644 index 0000000..d990915 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/overview/components/instances-card.spec.ts @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { mount, flushPromises } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import { TinyLink } from '@opentiny/vue'; +import { GetInstSummAPI } from '@/api/api'; +import CommonCard from '@/components/common-card.vue'; +import InstancesCard from '@/pages/overview/components/instances-card.vue'; + +describe('InstancesCard', () => { + const wrapper = mount(InstancesCard, { + global: { + stubs: { + CommonCard: true, + TinyLink: true, + }, + }, + }); + + it('renders the main structure correctly', () => { + expect(wrapper.findComponent(CommonCard).exists()).toBe(true); + }); +}) + +describe('InstancesCard.TinyLink', async () => { + const wrapper = mount(TinyLink, { + props: { href: '#/instances' }, + slots: { + default: 'View All Instances >', + }, + }); + await wrapper.trigger('click'); + it('renders TinyLink correctly', () => { + expect(wrapper.emitted('click')).toHaveLength(1); + expect(wrapper.text()).toBe('View All Instances >'); + }); +}) + +describe('InstancesCard.initStatus', () => { + vi.mock('@/api/api', () => ({ + GetInstSummAPI: vi.fn(), + })); + + const wrapper = mount(InstancesCard); + const vm = wrapper.vm; + it('renders when initStatus error', async () => { + GetInstSummAPI.mockResolvedValue(''); + vm.initStatus(); + await flushPromises(); + expect(wrapper.text()).toBe('Instances Total: 0Running: 0Exited: 0Fatal: 0Others: 0View All Instances >'); + }); + it('renders when initStatus correctly', async () => { + GetInstSummAPI.mockResolvedValue({ + data: { + total: 5, + running: 2, + exited: 1, + fatal: 0, + }}); + vm.initStatus(); + await flushPromises(); + expect(wrapper.text()).toBe('Instances Total: 5Running: 2Exited: 1Fatal: 0Others: 2View All Instances >'); + }); +}) diff --git a/yuanrong/pkg/dashboard/client/tests/pages/overview/components/resources-card.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/overview/components/resources-card.spec.ts new file mode 100644 index 0000000..ee02514 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/overview/components/resources-card.spec.ts @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { shallowMount } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import ResourcesCard from '@/pages/overview/components/resources-card.vue'; + +describe('ResourcesCard', () => { + const wrapper = shallowMount(ResourcesCard); + vi.mock('@opentiny/vue-huicharts', () => ({ + default: { + name: 'TinyHuichartsGauge', + template: '
Canvas Mock
', + }, + })); + + vi.mock('@/api/api', () => ({ + GetRsrcSummAPI: vi.fn().mockResolvedValue({ + 'data': { + 'cap_cpu': 10000, + 'cap_mem': 38912, + 'alloc_cpu': 10000, + 'alloc_mem': 38912, + }}), + })); + + it('renders the main structure correctly', async () => { + expect(wrapper.text()).toBe(''); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/tests/pages/overview/overview-layout.spec.ts b/yuanrong/pkg/dashboard/client/tests/pages/overview/overview-layout.spec.ts new file mode 100644 index 0000000..5a1ba59 --- /dev/null +++ b/yuanrong/pkg/dashboard/client/tests/pages/overview/overview-layout.spec.ts @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { shallowMount } from '@vue/test-utils'; +import { describe, it, expect, vi } from 'vitest'; +import OverviewLayout from '@/pages/overview/overview-layout.vue'; + +describe('OverviewLayout', () => { + const wrapper = shallowMount(OverviewLayout); + vi.mock('@opentiny/vue-huicharts', () => ({ + default: { + name: 'TinyHuichartsGauge', + template: '
Canvas Mock
', + }, + })); + + it('renders the main structure correctly', async () => { + expect(wrapper.text()).toBe(''); + }); +}) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/client/vite.config.js b/yuanrong/pkg/dashboard/client/vite.config.js new file mode 100644 index 0000000..76571fa --- /dev/null +++ b/yuanrong/pkg/dashboard/client/vite.config.js @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { fileURLToPath } from 'node:url'; +import path from 'path'; +import vue from '@vitejs/plugin-vue'; +import { mergeConfig } from 'vite'; +import { configDefaults, defineConfig } from 'vitest/config'; +import viteConfig from './vite.config'; +import importPlugin from '@opentiny/vue-vite-import'; + +export default mergeConfig( + viteConfig, + defineConfig({ + plugins: [ + vue(), + importPlugin( + [ + { + libraryName: '@opentiny/vue' + }, + { + libraryName: `@opentiny/vue-icon`, + customName: (name) => { + return `@opentiny/vue-icon/lib/${name.replace(/^icon-/, '')}.js` + } + } + ], + 'pc' + ) + ], + resolve: { + alias: { + '@': path.resolve(__dirname, './src'), + }, + }, + test: { + environment: 'jsdom', + exclude: [...configDefaults.exclude, 'e2e/*'], + root: fileURLToPath(new URL('./', import.meta.url)), + transformMode: { + web: [/\.[jt]sx$/] + }, + coverage: { + provider: 'v8', + reporter: ['text', 'json', 'html'], + }, + deps: { + inline: [ + '@opentiny/vue', + '@opentiny/vue-renderless', + '@opentiny/vue-common', + '@opentiny/vue-icon', + '@opentiny/vue-theme', + ] + }, + threads: false, + environmentOptions: { + jsdom: { + resources: 'usable' + } + } + } + }) +) \ No newline at end of file diff --git a/yuanrong/pkg/dashboard/etcdcache/instance_cache.go b/yuanrong/pkg/dashboard/etcdcache/instance_cache.go new file mode 100644 index 0000000..beb2588 --- /dev/null +++ b/yuanrong/pkg/dashboard/etcdcache/instance_cache.go @@ -0,0 +1,211 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package etcdcache caches all etcd events listened from etcd +package etcdcache + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/etcdkey" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" +) + +var ( + // InstanceCache is the cache struct instance for function system instances + InstanceCache instanceCache +) + +func init() { + InstanceCache = instanceCache{ + id2Instance: map[string]*types.InstanceSpecification{}, + jobID2Instances: map[string]map[string]*types.InstanceSpecification{}, + runtimeID2Instance: map[string]*types.InstanceSpecification{}, + } +} + +type instanceCache struct { + // instanceID => instanceInfo + id2Instance map[string]*types.InstanceSpecification + jobID2Instances map[string]map[string]*types.InstanceSpecification + runtimeID2Instance map[string]*types.InstanceSpecification + iMtx sync.RWMutex + + instanceExitHandler func(instance *types.InstanceSpecification) + instanceStartHandler func(instance *types.InstanceSpecification) +} + +// Put an instance +func (c *instanceCache) Put(instance *types.InstanceSpecification) { + log.GetLogger().Infof("Put instance: %s %#v with runtime id %s", instance.InstanceID, + instance.InstanceStatus.Code, instance.RuntimeID) + c.iMtx.Lock() + defer c.iMtx.Unlock() + c.id2Instance[instance.InstanceID] = instance + if _, ok := c.jobID2Instances[instance.JobID]; ok { + c.jobID2Instances[instance.JobID][instance.InstanceID] = instance + } else { + c.jobID2Instances[instance.JobID] = map[string]*types.InstanceSpecification{instance.InstanceID: instance} + } + c.runtimeID2Instance[instance.RuntimeID] = instance + + if c.instanceStartHandler != nil && instance.InstanceStatus.Code == int32(constant.KernelInstanceStatusRunning) { + log.GetLogger().Infof("instance %s started", instance.InstanceID) + go c.instanceStartHandler(instance) + } +} + +// Remove an instance +func (c *instanceCache) Remove(instanceID string) { + c.iMtx.Lock() + defer c.iMtx.Unlock() + inst, ok := c.id2Instance[instanceID] + if ok { + delete(c.jobID2Instances, inst.JobID) + delete(c.runtimeID2Instance, inst.RuntimeID) // Assume no restart... or it will be left here + } + delete(c.id2Instance, instanceID) + if c.instanceExitHandler != nil && inst != nil { + go c.instanceExitHandler(inst) + } +} + +// RegisterInstanceExitHandler an instance +func (c *instanceCache) RegisterInstanceExitHandler(handler func(instance *types.InstanceSpecification)) { + c.iMtx.Lock() + defer c.iMtx.Unlock() + if c.instanceExitHandler == nil { // in case someone re-register the handler + log.GetLogger().Infof("instance exit handler registered") + c.instanceExitHandler = handler + } +} + +// RegisterInstanceStartHandler an instance +func (c *instanceCache) RegisterInstanceStartHandler(handler func(instance *types.InstanceSpecification)) { + c.iMtx.Lock() + defer c.iMtx.Unlock() + if c.instanceStartHandler == nil { // in case someone re-register the handler + log.GetLogger().Infof("instance start handler registered") + c.instanceStartHandler = handler + } +} + +// Get an instance by instanceID +func (c *instanceCache) Get(instanceID string) *types.InstanceSpecification { + c.iMtx.RLock() + defer c.iMtx.RUnlock() + log.GetLogger().Infof("Get an instance by id: %s", instanceID) + return c.id2Instance[instanceID] +} + +// GetByJobID a map of instance +func (c *instanceCache) GetByJobID(jobID string) map[string]*types.InstanceSpecification { + c.iMtx.RLock() + defer c.iMtx.RUnlock() + return c.jobID2Instances[jobID] +} + +// GetByParentID a map of instance +func (c *instanceCache) GetByParentID(ParentInstanceID string) map[string]*types.InstanceSpecification { + c.iMtx.RLock() + defer c.iMtx.RUnlock() + for _, inst := range c.id2Instance { + if inst.ParentID == ParentInstanceID { + return c.jobID2Instances[inst.JobID] + } + } + return map[string]*types.InstanceSpecification{} +} + +// GetByRuntimeID a instance +func (c *instanceCache) GetByRuntimeID(runtimeID string) *types.InstanceSpecification { + c.iMtx.RLock() + defer c.iMtx.RUnlock() + return c.runtimeID2Instance[runtimeID] +} + +// String - +func (c *instanceCache) String() string { + c.iMtx.RLock() + defer c.iMtx.RUnlock() + return fmt.Sprintf("cache:{%#v}", c.id2Instance) +} + +// StartWatchInstance to watch the instance faas schedulers by the etcd +func StartWatchInstance(stopCh <-chan struct{}) { + etcdClient := etcd3.GetRouterEtcdClient() + // no filter, always return false + watcher := etcd3.NewEtcdWatcher(constant.InstancePathPrefix, func(_ *etcd3.Event) bool { return false }, + instanceHandler, stopCh, etcdClient) + watcher.StartWatch() +} + +// SyncAllInstances will get all instances from etcd +func SyncAllInstances() error { + etcdClient := etcd3.GetRouterEtcdClient() + log.GetLogger().Infof("etcdclient: %v", etcdClient) + getResponse, err := etcdClient.Get( + etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), etcd3.DurationContextTimeout), + constant.InstancePathPrefix, + clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("failed to get all instance info, %s", err.Error()) + return err + } + for _, kv := range getResponse.Kvs { + instance := &types.InstanceSpecification{} + err := json.Unmarshal(kv.Value, instance) + if err != nil { + log.GetLogger().Warnf("failed to unmarshal synced kv's value (key: %s)", kv.Key) + return err + } + InstanceCache.Put(instance) + } + return nil +} + +func instanceHandler(event *etcd3.Event) { + log.GetLogger().Debugf("handling instance event type %d, key:%s", event.Type, event.Key) + + switch event.Type { + case etcd3.PUT: + instance := &types.InstanceSpecification{} + err := json.Unmarshal(event.Value, instance) + if err != nil { + log.GetLogger().Warnf("failed to unmarshal watched event's value (key: %s)", event.Key) + return + } + InstanceCache.Put(instance) + case etcd3.DELETE: + etcdInstanceKey := etcdkey.FunctionInstanceKey{} + err := etcdInstanceKey.ParseFrom(event.Key) + if err != nil { + log.GetLogger().Warnf("failed to unmarshal watched event's key: %s, err: %s", event.Key, err) + return + } + InstanceCache.Remove(etcdInstanceKey.InstanceID) + default: + log.GetLogger().Debugf("unsupported event: %#v", event) + } +} diff --git a/yuanrong/pkg/dashboard/etcdcache/instance_cache_test.go b/yuanrong/pkg/dashboard/etcdcache/instance_cache_test.go new file mode 100644 index 0000000..ea7ea14 --- /dev/null +++ b/yuanrong/pkg/dashboard/etcdcache/instance_cache_test.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package etcdcache + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/types" +) + +func TestGetJobIDByParentID(t *testing.T) { + convey.Convey("Test GetJobIDByParentID:", t, func() { + convey.Convey("Test InstancesByInstanceID when instances is empty", func() { + instancesMap := InstanceCache.GetByParentID("app-123") + convey.So(len(instancesMap), convey.ShouldEqual, 0) + }) + convey.Convey("Test InstancesByInstanceID success", func() { + instance := &types.InstanceSpecification{InstanceID: "id-123", ParentID: "app-123", JobID: "job-123"} + InstanceCache.id2Instance[instance.InstanceID] = instance + InstanceCache.jobID2Instances[instance.JobID] = map[string]*types.InstanceSpecification{ + instance.InstanceID: instance} + instancesMap := InstanceCache.GetByParentID("app-123") + convey.So(instancesMap[instance.InstanceID].InstanceID, convey.ShouldEqual, instance.InstanceID) + instancesMap = InstanceCache.GetByJobID(instance.JobID) + convey.So(instancesMap[instance.InstanceID].InstanceID, convey.ShouldEqual, instance.InstanceID) + newInstance := InstanceCache.Get(instance.InstanceID) + convey.So(newInstance.InstanceID, convey.ShouldEqual, instance.InstanceID) + str := InstanceCache.String() + convey.So(str, convey.ShouldContainSubstring, instance.InstanceID) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/flags/flags.go b/yuanrong/pkg/dashboard/flags/flags.go new file mode 100644 index 0000000..11b59c2 --- /dev/null +++ b/yuanrong/pkg/dashboard/flags/flags.go @@ -0,0 +1,229 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package flags for obtain command line params +package flags + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/asaskevich/govalidator/v11" + "github.com/spf13/cobra" + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/reader" +) + +const ( + dashboardLogFileName = "dashboard" + defaultDashboardPort = 9080 + + defaultEtcdOpTimeout = 60 * time.Second + dashboardRegisterKeyKeepAliveInterval = 10 * time.Second + dashboardRegisterKeyTTL = 60 + dashboardRegisterKey = "/yr/dashboard" +) + +type dashboardConfig struct { + StaticPath string `json:"staticPath"` + Ip string `json:"ip" valid:"ip"` + GrpcIP string `json:"grpcIP" valid:"ip"` + Port int `json:"port" valid:"port"` + GrpcPort int `json:"grpcPort" valid:"port"` + FunctionMasterAddr string `json:"functionMasterAddr" valid:"url"` + FrontendAddr string `json:"frontendAddr" valid:"url"` + PrometheusAddr string `json:"prometheusAddr" valid:"url"` + RouterEtcdConfig etcd3.EtcdConfig `json:"routerEtcdConfig"` + MetaEtcdConfig etcd3.EtcdConfig `json:"metaEtcdConfig"` + + ServerAddr string +} + +const ( + // frontend jobs instance url + defaultListenIP = "0.0.0.0" + appBasePath = "/app/v1" +) + +var ( + // DashboardConfig is the global config struct + DashboardConfig dashboardConfig + + dashboardConfigPath string + dashboardLogConfigPath string +) + +func addHTTPPrefix(url string) string { + if !strings.Contains(url, "://") { + url = "http://" + url + } + return url +} + +func initLog() error { + if err := log.InitRunLog(dashboardLogFileName, true); err != nil { + log.GetLogger().Errorf("failed to init dashboard log, err: %s", err.Error()) + } + return nil +} + +func initConfig(configFilePath string) error { + // InitConfig get config info from configPath + data, err := reader.ReadFileWithTimeout(configFilePath) + if err != nil { + log.GetLogger().Errorf("failed to read config, filename: %s, error: %s", configFilePath, err.Error()) + return err + } + + err = json.Unmarshal(data, &DashboardConfig) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal config, configPath: %s, error: %s", configFilePath, err.Error()) + return err + } + + _, err = govalidator.ValidateStruct(DashboardConfig) + if err != nil { + log.GetLogger().Errorf("failed to validate config, err: %s", err.Error()) + return err + } + return nil +} + +func init() { + cCmd := cobraCmd() + initWithConfigFiles(cCmd) + initWithParams(cCmd) + + log.GetLogger().Infof("init done") +} + +// InitEtcdClient - +func InitEtcdClient() error { + stopCh := make(chan struct{}, 1) + log.GetLogger().Infof("init etcd client...") + return etcd3.InitParam(). + WithRouteEtcdConfig(DashboardConfig.RouterEtcdConfig). + WithMetaEtcdConfig(DashboardConfig.MetaEtcdConfig). + WithStopCh(stopCh). + InitClient() +} + +// RegisterSelfToEtcd - +func RegisterSelfToEtcd(stopCh <-chan struct{}) error { + log.GetLogger().Infof("registering self to etcd") + if stopCh == nil { + return errors.New("etcd stopCh should not be nil") + } + leaseID, err := etcd3.GetRouterEtcdClient().Grant(etcd3.CreateEtcdCtxInfoWithTimeout(context.TODO(), + defaultEtcdOpTimeout), dashboardRegisterKeyTTL) + if err != nil { + log.GetLogger().Infof("lease: %d", leaseID) + return err + } + go func() { + for { + select { + case <-stopCh: + log.GetLogger().Warnf("stopCh signal recevied, exit the keep alive channel") + return + default: + err := etcd3.GetRouterEtcdClient().KeepAliveOnce(etcd3.CreateEtcdCtxInfoWithTimeout(context.TODO(), + dashboardRegisterKeyKeepAliveInterval), leaseID) + if err != nil { + log.GetLogger().Errorf("failed to revoke self registry in etcd: %s", err) + } + time.Sleep(dashboardRegisterKeyKeepAliveInterval) + } + } + }() + log.GetLogger().Infof("registering self to etcd with key and v: %s with lease: %d", DashboardConfig.ServerAddr, + leaseID) + err = etcd3.GetRouterEtcdClient().Put(etcd3.CreateEtcdCtxInfoWithTimeout(context.TODO(), defaultEtcdOpTimeout), + dashboardRegisterKey, fmt.Sprintf("%s:%d", DashboardConfig.GrpcIP, DashboardConfig.GrpcPort), + clientv3.WithLease(leaseID)) + if err != nil { + return err + } + return nil +} + +func initWithConfigFiles(cCmd *cobra.Command) { + cmdErr(cCmd.Execute()) + if err := initLog(); err != nil { + log.GetLogger().Errorf("failed to init dashboard log, err: %s", err.Error()) + } + if err := initConfig(dashboardConfigPath); err != nil { + log.GetLogger().Errorf("failed to init dashboard config: %s", err) + } +} + +func initWithParams(cCmd *cobra.Command) { + cmdErr(cCmd.Execute()) + DashboardConfig.FunctionMasterAddr = addHTTPPrefix(DashboardConfig.FunctionMasterAddr) + DashboardConfig.FrontendAddr = addHTTPPrefix(DashboardConfig.FrontendAddr) + appBasePath + DashboardConfig.PrometheusAddr = addHTTPPrefix(DashboardConfig.PrometheusAddr) + ip := DashboardConfig.Ip + if len(ip) == 0 { + ip = defaultListenIP + } + DashboardConfig.ServerAddr = fmt.Sprintf("%s:%d", ip, DashboardConfig.Port) +} + +func cobraCmd() *cobra.Command { + cCmd := &cobra.Command{ + Long: "This params for starting dashboard server.", + Run: func(cmd *cobra.Command, args []string) {}, + } + + // if config path, use config path + cCmd.Flags().StringVarP(&dashboardConfigPath, "config_path", "", "", "config file path") + cCmd.Flags().StringVarP(&dashboardLogConfigPath, "log_config_path", "", "", "log config file path") + + // if command line, use command line + cCmd.Flags().StringVarP(&DashboardConfig.FunctionMasterAddr, "function_master_addr", "", "", + "FunctionMasterURL format is :") + cCmd.Flags().StringVarP(&DashboardConfig.FrontendAddr, "frontend_addr", "", "", + "FrontendURL format is :") + cCmd.Flags().StringVarP(&DashboardConfig.PrometheusAddr, "prometheus_addr", "", "", + "PrometheusURL format is :") + cCmd.Flags().StringVarP(&DashboardConfig.Ip, "ip", "i", "0.0.0.0", "this service listening with this ip") + cCmd.Flags().IntVarP(&DashboardConfig.Port, "port", "p", defaultDashboardPort, "this service listening with this port") + cCmd.Flags().StringVarP(&DashboardConfig.StaticPath, "static_path", "s", "./client/dist", + "this is client static resources path") + return cCmd +} + +func cmdErr(err error) { + if err != nil { + log.GetLogger().Fatal(err.Error()) + } +} + +func loadConfig(configPath string, loadFunc func(configPath string) error) { + if loadFunc == nil { + return + } + if err := loadFunc(configPath); err != nil { + log.GetLogger().Fatal(err.Error()) + } +} diff --git a/yuanrong/pkg/dashboard/flags/flags_test.go b/yuanrong/pkg/dashboard/flags/flags_test.go new file mode 100644 index 0000000..fa2c5ec --- /dev/null +++ b/yuanrong/pkg/dashboard/flags/flags_test.go @@ -0,0 +1,114 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package flags for obtain command line params +package flags + +import ( + "os" + "testing" + + "github.com/agiledragon/gomonkey" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/reader" +) + +func TestInitConfigNoFile(t *testing.T) { + convey.Convey("Test initConfig:", t, func() { + dashboardConfigPath = "config.json" + convey.Convey("init config when no file", func() { + err := initConfig(dashboardConfigPath) + convey.So(os.IsNotExist(err), convey.ShouldBeTrue) + }) + convey.Convey("init config when format error", func() { + cofPatches := gomonkey.ApplyFunc(reader.ReadFileWithTimeout, func(configFile string) ([]byte, error) { + configStr := `{ + "ip": "0.0.0.0", + "port": "9080" + }` + return []byte(configStr), nil + }) + defer cofPatches.Reset() + err := initConfig(dashboardConfigPath) + convey.So(err.Error(), convey.ShouldContainSubstring, "cannot unmarshal") + }) + convey.Convey("init config when validate error", func() { + cofPatches := gomonkey.ApplyFunc(reader.ReadFileWithTimeout, func(configFile string) ([]byte, error) { + configStr := `{ + "ip": "0.0.0.0", + "port": 9080, + "functionMasterAddr": "0.0.0.0:1234" + }` + return []byte(configStr), nil + }) + defer cofPatches.Reset() + err := initConfig(dashboardConfigPath) + convey.So(err.Error(), convey.ShouldContainSubstring, "0.0.0.0:1234 does not validate as url") + }) + }) +} + +func TestLoadConfig(t *testing.T) { + convey.Convey("Test loadConfig:", t, func() { + convey.Convey("load config success", func() { + cofPatches := gomonkey.ApplyFunc(reader.ReadFileWithTimeout, func(configFile string) ([]byte, error) { + configStr := `{ + "ip": "0.0.0.0", + "port": 9080, + "functionMasterAddr": "127.0.0.1:1234", + "frontendAddr": "127.0.0.1:8888", + "prometheusAddr": "127.0.0.1:9090", + "routerEtcdConfig": { + "servers": ["127.0.0.1:5678"], + "sslEnable": false, + "authType": "NOAUTH" + }, + "metaEtcdConfig": { + "servers": ["127.0.0.1:5678"], + "sslEnable": false, + "authType": "NOAUTH" + } + }` + return []byte(configStr), nil + }) + defer cofPatches.Reset() + convey.So(func() { + loadConfig(dashboardConfigPath, initConfig) + }, convey.ShouldNotPanic) + }) + convey.Convey("load config when loadFunc is nil", func() { + convey.So(func() { + loadConfig(dashboardConfigPath, nil) + }, convey.ShouldNotPanic) + }) + }) +} + +func TestLoadLogConfig(t *testing.T) { + convey.Convey("Test loadLogConfig:", t, func() { + dashboardLogConfigPath = "log.json" + logPatches := gomonkey.ApplyFunc(log.InitRunLog, func(fileName string, async bool) error { + return nil + }) + defer logPatches.Reset() + + convey.Convey("load log config success", func() { + convey.So(initLog(), convey.ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/getinfo/client_pool.go b/yuanrong/pkg/dashboard/getinfo/client_pool.go new file mode 100644 index 0000000..f7c51a9 --- /dev/null +++ b/yuanrong/pkg/dashboard/getinfo/client_pool.go @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package getinfo for get function-master info +package getinfo + +import ( + "io" + "net/http" + "time" + + "github.com/prometheus/client_golang/api" + prometheusv1 "github.com/prometheus/client_golang/api/prometheus/v1" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/flags" +) + +// HttpClient is client pool +var HttpClient *http.Client + +// PromClient is prometheus client pool +var PromClient prometheusv1.API + +const reqType = "protobuf" + +func init() { + HttpClient = &http.Client{ + Timeout: 30 * time.Minute, // 连接超时时间 + Transport: &http.Transport{ + MaxIdleConns: 10, // 最大空闲连接数 + MaxIdleConnsPerHost: 5, // 每个主机最大空闲连接数 + MaxConnsPerHost: 10, // 每个主机最大连接数 + IdleConnTimeout: 30 * time.Second, // 空闲连接的超时时间 + TLSHandshakeTimeout: 30 * time.Minute, // 限制TLS握手的时间 + }, + } + InitPromClient() +} + +// InitPromClient - +func InitPromClient() { + apiClient, promClientErr := api.NewClient(api.Config{ + Address: flags.DashboardConfig.PrometheusAddr, + Client: HttpClient, + }) + if promClientErr != nil { + log.GetLogger().Errorf("failed to connect prometheus, error: %s", promClientErr.Error()) + } else { + PromClient = prometheusv1.NewAPI(apiClient) + } +} + +func requestFunctionMaster(path string) ([]byte, error) { + req, err := http.NewRequest("GET", flags.DashboardConfig.FunctionMasterAddr+path, nil) + if err != nil { + return nil, err + } + req.Header.Set("Type", reqType) + return handleRes(req) +} + +func requestFrontend(method string, path string, reqBody io.Reader) ([]byte, error) { + req, err := http.NewRequest(method, flags.DashboardConfig.FrontendAddr+path, reqBody) + if err != nil { + return nil, err + } + return handleRes(req) +} + +func handleRes(req *http.Request) ([]byte, error) { + resp, err := HttpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return body, nil +} diff --git a/yuanrong/pkg/dashboard/getinfo/frontend_app.go b/yuanrong/pkg/dashboard/getinfo/frontend_app.go new file mode 100644 index 0000000..7574701 --- /dev/null +++ b/yuanrong/pkg/dashboard/getinfo/frontend_app.go @@ -0,0 +1,79 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package getinfo for get/post frontend +package getinfo + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/job" +) + +const ( + creatAppPath = "/posix/instance/create" + deleteAppPath = "/delete/" + getAppInfoPath = "/getappinfo/" + listAppsPath = "/list" + stopAppPath = "/posix/kill/" +) + +// CreateApp function for create app +func CreateApp(reqBytes []byte) job.Response { + return reqForJob("POST", creatAppPath, bytes.NewBuffer(reqBytes)) +} + +// DeleteApp function for delete app +func DeleteApp(submissionID string) job.Response { + return reqForJob("DELETE", deleteAppPath+submissionID, nil) +} + +// GetAppInfo function for get app info by submissionID +func GetAppInfo(submissionID string) job.Response { + return reqForJob("GET", getAppInfoPath+submissionID, nil) +} + +// ListApps function for list apps info +func ListApps() job.Response { + return reqForJob("GET", listAppsPath, nil) +} + +// StopApp function for stop app +func StopApp(submissionID string) job.Response { + return reqForJob("POST", stopAppPath+submissionID, nil) +} + +func reqForJob(method string, path string, reqBody io.Reader) job.Response { + var result job.Response + respBytes, err := requestFrontend(method, path, reqBody) + if err != nil { + log.GetLogger().Errorf("request to fronted failed, err: %v", err) + return job.BuildJobResponse(nil, http.StatusBadRequest, + fmt.Errorf("request to fronted failed, err: %v", err)) + } + err = json.Unmarshal(respBytes, &result) + if err != nil { + log.GetLogger().Errorf("unmarshal response from fronted failed, err: %v", err) + return job.BuildJobResponse(nil, http.StatusBadRequest, + fmt.Errorf("unmarshal response from fronted failed, err: %v", err)) + } + return result +} diff --git a/yuanrong/pkg/dashboard/getinfo/get_instances.go b/yuanrong/pkg/dashboard/getinfo/get_instances.go new file mode 100644 index 0000000..3d46606 --- /dev/null +++ b/yuanrong/pkg/dashboard/getinfo/get_instances.go @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package getinfo for get function-master info +package getinfo + +import ( + "google.golang.org/protobuf/proto" + + "yuanrong/pkg/common/faas_common/grpc/pb/message" + "yuanrong/pkg/common/faas_common/grpc/pb/resource" +) + +// GetInstances function for get instances +func GetInstances(instPath string) ([]*resource.InstanceInfo, error) { + body, err := requestFunctionMaster(instPath) + if err != nil { + return nil, err + } + var inst message.QueryInstancesInfoResponse + err = proto.Unmarshal(body, &inst) + if err != nil { + return nil, err + } + return inst.InstanceInfos, nil +} diff --git a/yuanrong/pkg/dashboard/getinfo/get_resources.go b/yuanrong/pkg/dashboard/getinfo/get_resources.go new file mode 100644 index 0000000..aa66afa --- /dev/null +++ b/yuanrong/pkg/dashboard/getinfo/get_resources.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package getinfo for get function-master info +package getinfo + +import ( + "google.golang.org/protobuf/proto" + + "yuanrong/pkg/common/faas_common/grpc/pb/message" + "yuanrong/pkg/common/faas_common/grpc/pb/resource" +) + +const rsrcPath = "/global-scheduler/resources" + +// GetResources function for get resources +func GetResources() (*resource.ResourceUnit, error) { + body, err := requestFunctionMaster(rsrcPath) + if err != nil { + return nil, err + } + var resources message.ResourceInfo + err = proto.Unmarshal(body, &resources) + if err != nil { + return nil, err + } + return resources.GetResource(), nil +} diff --git a/yuanrong/pkg/dashboard/handlers/cluster_status_handler.go b/yuanrong/pkg/dashboard/handlers/cluster_status_handler.go new file mode 100644 index 0000000..a7a74ba --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/cluster_status_handler.go @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "fmt" + "math" + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/constants" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/getinfo" +) + +const pendingInstPath = "/global-scheduler/scheduling_queue" + +var resourcesName = []string{"CPU", "NPU", "GPU", "Memory"} // enum + +var decimal = 100.0 + +// ClusterStatusHandler function for /api/cluster_status route +func ClusterStatusHandler(ctx *gin.Context) { + format := ctx.Query("format") + if format != "1" { + ctx.JSON(http.StatusBadRequest, gin.H{ + "result": false, + "msg": "Please check request params.", + "data": "Resources:\nDemands:", + }) + log.GetLogger().Errorf("/api/cluster_status failed, format=%s", format) + return + } + pbInfos, err := getinfo.GetInstances(pendingInstPath) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "result": false, + "msg": "Failed to get formatted cluster status. Error: " + err.Error(), + "data": "Resources:\nDemands:", + }) + log.GetLogger().Errorf("/api/cluster_status failed, %d %v", errGetInstances, err) + return + } + + resourcesMap := initResourcesMap() + for _, pbInfo := range pbInfos { + if pbInfo.Resources == nil { + continue + } + for name, rsrc := range pbInfo.Resources.Resources { + if rsrc.Scalar.Value == 0 || !validateName(name) { + continue + } + resourcesMap[name][rsrc.Scalar.Value]++ + } + } + resStr := formatResourcesMap(resourcesMap) + + ctx.JSON(http.StatusOK, gin.H{ + "result": true, + "msg": "Got formatted cluster status.", + "data": map[string]string{ + "clusterStatus": "Resources:\nDemands:" + resStr, + }, + }) + log.GetLogger().Debugf("/api/cluster_status succeed") +} + +func initResourcesMap() map[string]map[float64]int { + resourcesMap := make(map[string]map[float64]int) + for _, name := range resourcesName { + resourcesMap[name] = make(map[float64]int) + } + return resourcesMap +} + +func validateName(newName string) bool { + for _, name := range resourcesName { + if newName == name { + return true + } + } + return false +} + +func formatResourcesMap(resourcesMap map[string]map[float64]int) string { + var resStr string + for name, m := range resourcesMap { + if name == "Memory" { + name = "memory" + } + for value, count := range m { + if name == "memory" { + value = value * constants.MemoryUnitConvert * constants.MemoryUnitConvert + } + if name == "CPU" { + value /= constants.CpuUnitConvert + } + + value = math.Ceil(value*decimal) / decimal + resStr += fmt.Sprintf("\n{'%s':%.2f}: %d+ pending tasks/actors", name, value, count) + } + } + return resStr +} diff --git a/yuanrong/pkg/dashboard/handlers/cluster_status_handler_test.go b/yuanrong/pkg/dashboard/handlers/cluster_status_handler_test.go new file mode 100644 index 0000000..56c3f0a --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/cluster_status_handler_test.go @@ -0,0 +1,165 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + "google.golang.org/protobuf/proto" + + "yuanrong/pkg/common/faas_common/grpc/pb/message" + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/dashboard/flags" +) + +func TestClusterStatusHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + convey.Convey("Test ClusterStatusHandler:", t, func() { + resources1 := resource.Resources{ + Resources: map[string]*resource.Resource{ + "CPU": { + Name: "CPU", + Scalar: &resource.Value_Scalar{ + Value: 1.57, + }, + }, + "Memory": { + Name: "Memory", + Scalar: &resource.Value_Scalar{ + Value: 300, + }, + }, + "GPU": { + Name: "GPU", + Scalar: &resource.Value_Scalar{ + Value: 0, + }, + }, + "NPU": { + Name: "NPU", + Scalar: &resource.Value_Scalar{ + Value: 0, + }, + }, + }, + } + resources2 := resource.Resources{ + Resources: map[string]*resource.Resource{ + "CPU": { + Name: "CPU", + Scalar: &resource.Value_Scalar{ + Value: 0.23, + }, + }, + "Memory": { + Name: "Memory", + Scalar: &resource.Value_Scalar{ + Value: 300, + }, + }, + "memory": { + Name: "memory", + Scalar: &resource.Value_Scalar{ + Value: 300, + }, + }, + }, + } + instances := []*resource.InstanceInfo{ + &resource.InstanceInfo{Resources: &resources1}, + &resource.InstanceInfo{Resources: &resources2}, + &resource.InstanceInfo{}, + } + + instancesInfo := message.QueryInstancesInfoResponse{ + InstanceInfos: instances, + } + //message.QueryInstancesInfoResponse + instancesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + instancesInfoBytes, _ := proto.Marshal(&instancesInfo) + w.WriteHeader(http.StatusOK) + w.Write(instancesInfoBytes) + })) + defer instancesServer.Close() + flags.DashboardConfig.FunctionMasterAddr = instancesServer.URL + + r := gin.Default() + r.GET("/api/cluster_status", ClusterStatusHandler) + req, err := http.NewRequest("GET", "/api/cluster_status?format=1", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + newResources := map[string]float64{ + "CPU": 0.0, + "GPU": 0.0, + "memory": 0.0, + "NPU": 0.0, + } + + convey.Convey("It should return status 200 and appropriate message", func() { + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + var res map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &res) + convey.So(err, convey.ShouldBeNil) + clusterStatus := res["data"].(map[string]interface{})["clusterStatus"].(string) + demands := strings.Split(clusterStatus, "Resources")[1] + infos := strings.Split(strings.Split(demands, "Demands")[1], "\n") + for _, info := range infos { + if !strings.Contains(info, "pending tasks/actors") { + continue + } + n := strings.LastIndexByte(info, ':') + var subInfo map[string]float64 + newInfo := strings.ReplaceAll(info, `'`, `"`) + err = json.Unmarshal([]byte(newInfo[:n]), &subInfo) + convey.So(err, convey.ShouldBeNil) + replicas, err := strconv.Atoi(strings.Split(info[n+2:], "+")[0]) + convey.So(err, convey.ShouldBeNil) + for k, v := range subInfo { + newResources[k] += float64(replicas) * v + } + } + convey.So(newResources["CPU"], convey.ShouldNotEqual, 0) + convey.So(newResources["GPU"], convey.ShouldEqual, 0) + }) + }) +} + +func TestClusterStatusHandlerError(t *testing.T) { + convey.Convey("Test ClusterStatusHandler error:", t, func() { + r := gin.Default() + r.GET("/api/cluster_status", ClusterStatusHandler) + req, err := http.NewRequest("GET", "/api/cluster_status?format=1", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.Convey("test ClusterStatusHandler when function master error", func() { + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(w.Body.String(), convey.ShouldContainSubstring, "Failed to get formatted cluster status.") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/components_componentid_handler.go b/yuanrong/pkg/dashboard/handlers/components_componentid_handler.go new file mode 100644 index 0000000..fee3a5b --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/components_componentid_handler.go @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// ComponentsByComponentIDHandler function for /components/:component-id route +func ComponentsByComponentIDHandler(ctx *gin.Context) { + componentId := ctx.Param("component-id") + ctx.JSON(http.StatusOK, gin.H{ + "code": 0, + "msg": "succeed", + "data": "Components by ComponentID:" + componentId + " json", + }) + log.GetLogger().Debugf("/components/%s succeed", componentId) +} diff --git a/yuanrong/pkg/dashboard/handlers/components_componentid_handler_test.go b/yuanrong/pkg/dashboard/handlers/components_componentid_handler_test.go new file mode 100644 index 0000000..36173df --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/components_componentid_handler_test.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" +) + +func TestComponentsByComponentIDHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + convey.Convey("Test ComponentsByComponentIDHandler:", t, func() { + r := gin.Default() + r.GET("/components/:component-id", ComponentsByComponentIDHandler) + req, err := http.NewRequest("GET", "/components/id-123", nil) + convey.So(err, convey.ShouldBeNil) + w := httptest.NewRecorder() + + convey.Convey("Test ComponentsByComponentID success", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "id-123") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/components_handler.go b/yuanrong/pkg/dashboard/handlers/components_handler.go new file mode 100644 index 0000000..f17034f --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/components_handler.go @@ -0,0 +1,132 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/getinfo" +) + +// ComponentsAddition for response +type ComponentsAddition struct { + Usage + Nodes []Component `form:"nodes" json:"nodes"` + Components `form:"components" json:"components"` +} + +// Components defines all Component +type Components map[string][]Component + +// Component ComponentInfo +type Component struct { + Hostname string `form:"hostname" json:"hostname"` + Status string `form:"status" json:"status"` + Address string `form:"address" json:"address"` + Usage +} + +// ComponentsHandler function for /components route +func ComponentsHandler(ctx *gin.Context) { + resource, err := getinfo.GetResources() + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "code": errGetResources, + "msg": "fail", + "data": "", + }) + log.GetLogger().Errorf("/components GetResources, %d, %s", errGetResources, err) + return + } + compAdd, err := PBToCompAdd(resource) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "code": errPBToCompAdd, + "msg": "fail", + "data": "", + }) + log.GetLogger().Errorf("/components PBToComponents %d %s", errPBToCompAdd, err) + return + } + ctx.JSON(http.StatusOK, gin.H{ + "code": 0, + "msg": "succeed", + "data": compAdd, + }) + log.GetLogger().Debugf("/components succeed, data: %#v", compAdd) +} + +// PBToCompAdd pb data switch to ComponentsAddition struct +func PBToCompAdd(resource *resource.ResourceUnit) (ComponentsAddition, error) { + var compAdd ComponentsAddition + var err error + compAdd.Usage, err = PBToUsage(resource) + if err != nil { + return ComponentsAddition{}, err + } + comps := Components{} + for _, f := range resource.Fragment { + comp := Component{ + Hostname: f.Id, + Status: "alive", + Address: parseAddress(f.Id), + } + comp.Usage, err = PBToUsage(f) + if err != nil { + return ComponentsAddition{}, err + } + comps[f.OwnerId] = append(comps[f.OwnerId], comp) + } + compAdd.Components = comps + var nodes []Component + for k, compUnit := range comps { + node := Component{ + Hostname: k, + Status: "healthy", + Address: parseIP(compUnit[0].Address), + } + node.AllocNPU = compUnit[0].AllocNPU + for _, component := range compUnit { + node.CapCPU += component.CapCPU + node.CapMem += component.CapMem + node.AllocCPU += component.AllocCPU + node.AllocMem += component.AllocMem + } + nodes = append(nodes, node) + } + compAdd.Nodes = nodes + return compAdd, nil +} + +func parseAddress(hostname string) string { + pathArr := strings.Split(hostname, "-") + length := len(pathArr) + if length < 2 { // pathArr at least has two member: ip and port + return "" + } + return pathArr[length-2] + ":" + pathArr[length-1] // pathArr[length-2] is ip, pathArr[length-1] is port +} + +func parseIP(address string) string { + return strings.Split(address, ":")[0] +} diff --git a/yuanrong/pkg/dashboard/handlers/components_handler_test.go b/yuanrong/pkg/dashboard/handlers/components_handler_test.go new file mode 100644 index 0000000..c9d99ad --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/components_handler_test.go @@ -0,0 +1,112 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/agiledragon/gomonkey" + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + "google.golang.org/protobuf/proto" + + "yuanrong/pkg/common/faas_common/grpc/pb/message" + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/dashboard/flags" +) + +func TestComponentsHandler(t *testing.T) { + convey.Convey("Test ComponentsHandler:", t, func() { + r := gin.Default() + r.GET("/components", ComponentsHandler) + req, err := http.NewRequest("GET", "/components", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + + convey.Convey("Test Components when function master error", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "2001") + }) + + resources := &resource.Resources{ + Resources: map[string]*resource.Resource{ + "Memory": { + Name: "Memory", + Scalar: &resource.Value_Scalar{Value: 38912}, + }, + "CPU": { + Name: "CPU", + Scalar: &resource.Value_Scalar{Value: 10000}, + }, + }, + } + nodeLabels := map[string]*resource.Value_Counter{ + "NODE_ID": { + Items: map[string]uint64{"dggphis232340-1936114": 1}, + }, + } + fragment1 := &resource.ResourceUnit{ + Id: "function-agent-7.185.104.157-31630", + Capacity: resources, + Allocatable: resources, + NodeLabels: nodeLabels, + OwnerId: "dggphis232340-1936114", + } + resourceInfo := message.ResourceInfo{ + RequestID: "145aa8bc-d616-4000-8000-000000734df7", + Resource: &resource.ResourceUnit{ + Id: "InnerDomainScheduler", + Capacity: resources, + Allocatable: resources, + Fragment: map[string]*resource.ResourceUnit{ + "function-agent-7.185.104.157-31630": fragment1, + }, + NodeLabels: nodeLabels, + Revision: 19, + ViewInitTime: "542b3f0f-0000-4000-8000-0089a640b071", + }, + } + resourcesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resourceInfoBytes, _ := proto.Marshal(&resourceInfo) + w.WriteHeader(http.StatusOK) + w.Write(resourceInfoBytes) + })) + defer resourcesServer.Close() + flags.DashboardConfig.FunctionMasterAddr = resourcesServer.URL + + convey.Convey("Test Components success", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, `"msg":"succeed"`) + }) + convey.Convey("Test Components when PBToUsage error", func() { + patches := gomonkey.ApplyFunc(PBToUsage, func(resource *resource.ResourceUnit) (Usage, error) { + return Usage{}, errors.New("PBToUsage error") + }) + defer patches.Reset() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "1001") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/err_code.go b/yuanrong/pkg/dashboard/handlers/err_code.go new file mode 100644 index 0000000..626d894 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/err_code.go @@ -0,0 +1,31 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +const ( + // PB format convert error + errPBToCompAdd = 1001 + errPBToRsrcSummary = 1002 + + // function-master request error + errGetResources = 2001 + errGetInstances = 2002 + + // frontend request error + errFrontend = 3000 +) diff --git a/yuanrong/pkg/dashboard/handlers/instances_handler.go b/yuanrong/pkg/dashboard/handlers/instances_handler.go new file mode 100644 index 0000000..110ebe9 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/instances_handler.go @@ -0,0 +1,146 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/getinfo" +) + +// InstanceInfo is instance base info +type InstanceInfo struct { + ID string `form:"id" json:"id"` + Status string `form:"status" json:"status"` + CreateTime string `form:"create_time" json:"create_time"` + JobId string `form:"job_id" json:"job_id"` + PID string `form:"pid" json:"pid"` + IP string `form:"ip" json:"ip"` + NodeID string `form:"node_id" json:"node_id"` + AgentID string `form:"agent_id" json:"agent_id"` + ParentID string `form:"parent_id" json:"parent_id"` + RequiredCPU float64 `form:"required_cpu" json:"required_cpu"` + RequiredMem float64 `form:"required_mem" json:"required_mem"` + RequiredGPU float64 `form:"required_gpu" json:"required_gpu"` + RequiredNPU float64 `form:"required_npu" json:"required_npu"` + Restarted int32 `form:"restarted" json:"restarted"` + ExitDetail string `form:"exit_detail" json:"exit_detail"` +} + +// InstanceInfos is all instance info +type InstanceInfos []InstanceInfo + +const listInstPath = "/instance-manager/queryinstances" + +var instanceStatusMap = map[constant.InstanceStatus]string{ + constant.KernelInstanceStatusExited: "exited", + constant.KernelInstanceStatusNew: "new", + constant.KernelInstanceStatusScheduling: "scheduling", + constant.KernelInstanceStatusCreating: "creating", + constant.KernelInstanceStatusRunning: "running", + constant.KernelInstanceStatusFailed: "failed", + constant.KernelInstanceStatusExiting: "exiting", + constant.KernelInstanceStatusFatal: "fatal", + constant.KernelInstanceStatusScheduleFailed: "schedule_failed", + constant.KernelInstanceStatusEvicting: "evicting", + constant.KernelInstanceStatusEvicted: "evicted", + constant.KernelInstanceStatusSubHealth: "sub_health", +} + +// InstancesHandler function for /instances route +func InstancesHandler(ctx *gin.Context) { + if ctx.Query("parent_id") != "" { + InstancesByParentIDHandler(ctx) + return + } + pbInfos, err := getinfo.GetInstances(listInstPath) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "code": errGetInstances, + "msg": "fail", + "data": "", + }) + log.GetLogger().Errorf("/instances GetInst, %d %v", errGetInstances, err) + return + } + + infos := PBToInstanceInfos(pbInfos) + ctx.JSON(http.StatusOK, gin.H{ + "code": 0, + "msg": "succeed", + "data": infos, + }) + log.GetLogger().Debugf("/instances succeed, infos:%#v", infos) +} + +// PBToInstanceInfos function switch pb struct to InstanceInfos struct +func PBToInstanceInfos(pbInfos []*resource.InstanceInfo) InstanceInfos { + infos := make(InstanceInfos, 0) + for _, pbInfo := range pbInfos { + info := InstanceInfo{ + ID: pbInfo.InstanceID, + Status: getInstanceStatus(pbInfo.InstanceStatus.Code), + JobId: pbInfo.JobID, + IP: pbInfo.RuntimeAddress, + NodeID: pbInfo.FunctionProxyID, + AgentID: pbInfo.FunctionAgentID, + ParentID: pbInfo.ParentID, + ExitDetail: pbInfo.InstanceStatus.Msg, + Restarted: pbInfo.DeployTimes - 1, + } + info.setFields(pbInfo) + infos = append(infos, info) + } + return infos +} + +func getInstanceStatus(code int32) string { + if status, ok := instanceStatusMap[constant.InstanceStatus(code)]; ok { + return status + } + return "" +} + +func (info *InstanceInfo) setFields(pbInfo *resource.InstanceInfo) { + if createTime, ok := pbInfo.Extensions["createTimestamp"]; ok { + info.CreateTime = createTime + } + if pid, ok := pbInfo.Extensions["pid"]; ok { + info.PID = pid + } + if pbInfo.Resources == nil || pbInfo.Resources.Resources == nil { + return + } + resourcesMap := pbInfo.Resources.Resources + info.RequiredCPU = getResourceValue("CPU", resourcesMap) + info.RequiredMem = getResourceValue("Memory", resourcesMap) + info.RequiredGPU = getResourceValue("GPU", resourcesMap) + info.RequiredNPU = getResourceValue("NPU/.+/count", resourcesMap) +} + +func getResourceValue(resourceName string, resourcesMap map[string]*resource.Resource) float64 { + if _, ok := resourcesMap[resourceName]; ok { + return resourcesMap[resourceName].Scalar.Value + } + return 0 +} diff --git a/yuanrong/pkg/dashboard/handlers/instances_handler_test.go b/yuanrong/pkg/dashboard/handlers/instances_handler_test.go new file mode 100644 index 0000000..fab565e --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/instances_handler_test.go @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + "google.golang.org/protobuf/proto" + + "yuanrong/pkg/common/faas_common/grpc/pb/message" + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/dashboard/flags" +) + +func TestInstancesHandler(t *testing.T) { + convey.Convey("Test InstancesHandler:", t, func() { + r := gin.Default() + r.GET("/instances", InstancesHandler) + req, err := http.NewRequest("GET", "/instances", nil) + convey.So(err, convey.ShouldBeNil) + w := httptest.NewRecorder() + + convey.Convey("Test InstancesHandler when function master error", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "2002") + }) + + resources := resource.Resources{ + Resources: map[string]*resource.Resource{ + "CPU": { + Name: "CPU", + Scalar: &resource.Value_Scalar{ + Value: 500, + }, + }, + "Memory": { + Name: "Memory", + Scalar: &resource.Value_Scalar{ + Value: 300, + }, + }, + }, + } + + instance := &resource.InstanceInfo{ + InstanceID: "app-dfb5ed67-9342-4b50-82ff-8fa8f055a9f4", + InstanceStatus: &resource.InstanceStatus{Code: 3}, + FunctionAgentID: "function-agent-7.185.104.157-31630", + FunctionProxyID: "phish232340-1936114", + ParentID: "driver-faas-frontend-dggphis232340-1936114", + Resources: &resources, + } + instances := []*resource.InstanceInfo{instance} + + instancesInfo := message.QueryInstancesInfoResponse{ + InstanceInfos: instances, + } + //message.QueryInstancesInfoResponse + instancesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + instancesInfoBytes, _ := proto.Marshal(&instancesInfo) + w.WriteHeader(http.StatusOK) + w.Write(instancesInfoBytes) + })) + defer instancesServer.Close() + flags.DashboardConfig.FunctionMasterAddr = instancesServer.URL + + convey.Convey("Test Instances success", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, `"msg":"succeed"`) + }) + }) +} + +func TestGetInstanceStatus(t *testing.T) { + convey.Convey("Test getInstanceStatus:", t, func() { + convey.Convey("Test getInstanceStatus when code not in range", func() { + convey.So(getInstanceStatus(-2), convey.ShouldEqual, "") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/instances_instanceid_handler.go b/yuanrong/pkg/dashboard/handlers/instances_instanceid_handler.go new file mode 100644 index 0000000..cd928a3 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/instances_instanceid_handler.go @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/etcdcache" +) + +// InstancesByInstanceIDHandler function for /api/v1/instances/:instance-id route +func InstancesByInstanceIDHandler(ctx *gin.Context) { + instanceId := ctx.Param("instance-id") + instanceSpec := etcdcache.InstanceCache.Get(instanceId) + instance := instanceSpec2Instance(instanceSpec) + ctx.JSON(http.StatusOK, instance) + log.GetLogger().Debugf("/instances/%s succeed, instanceInfo: %#v", instanceId, instance) +} diff --git a/yuanrong/pkg/dashboard/handlers/instances_instanceid_handler_test.go b/yuanrong/pkg/dashboard/handlers/instances_instanceid_handler_test.go new file mode 100644 index 0000000..a5a6711 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/instances_instanceid_handler_test.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/dashboard/etcdcache" +) + +func TestInstancesByInstanceIDHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + convey.Convey("Test InstancesByInstanceIDHandler:", t, func() { + patches := gomonkey.ApplyMethodReturn(&etcdcache.InstanceCache, "Get", &types.InstanceSpecification{ + InstanceID: "id-123", + InstanceStatus: types.InstanceStatus{}, + Extensions: types.Extensions{}, + Resources: types.Resources{Resources: map[string]types.Resource{}}, + }) + defer patches.Reset() + + r := gin.Default() + r.GET("/instances/:instance-id", InstancesByInstanceIDHandler) + req, err := http.NewRequest("GET", "/instances/id-123", nil) + convey.So(err, convey.ShouldBeNil) + w := httptest.NewRecorder() + + convey.Convey("Test InstancesByInstanceID success", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "id-123") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/instances_parentid_handler.go b/yuanrong/pkg/dashboard/handlers/instances_parentid_handler.go new file mode 100644 index 0000000..7bce19b --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/instances_parentid_handler.go @@ -0,0 +1,79 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/dashboard/etcdcache" +) + +// InstancesByParentIDHandler function for /api/v1/instances?parent_id= route +func InstancesByParentIDHandler(ctx *gin.Context) { + parentId := ctx.Query("parent_id") + instancesMap := etcdcache.InstanceCache.GetByParentID(parentId) + instances := make([]*InstanceInfo, 0, len(instancesMap)) + for _, instanceSpec := range instancesMap { + instances = append(instances, instanceSpec2Instance(instanceSpec)) + } + ctx.JSON(http.StatusOK, instances) + log.GetLogger().Debugf("/instances?parent_id=%s succeed, instances:%s", parentId, instances) +} + +func instanceSpec2Instance(instanceSpec *types.InstanceSpecification) *InstanceInfo { + if instanceSpec == nil { + return &InstanceInfo{} + } + instance := &InstanceInfo{ + ID: instanceSpec.InstanceID, + Status: getInstanceStatus(instanceSpec.InstanceStatus.Code), + CreateTime: instanceSpec.Extensions.CreateTimestamp, + JobId: instanceSpec.JobID, + PID: instanceSpec.Extensions.PID, + IP: instanceSpec.RuntimeAddress, + NodeID: instanceSpec.FunctionProxyID, + AgentID: instanceSpec.FunctionAgentID, + ParentID: instanceSpec.ParentID, + RequiredCPU: 0, + RequiredMem: 0, + RequiredGPU: 0, + RequiredNPU: 0, + Restarted: instanceSpec.DeployTimes - 1, + ExitDetail: instanceSpec.InstanceStatus.Msg, + } + instance.setResources(instanceSpec.Resources.Resources) + return instance +} + +func (info *InstanceInfo) setResources(resourcesMap map[string]types.Resource) { + info.RequiredCPU = getTypesResourceValue("CPU", resourcesMap) + info.RequiredMem = getTypesResourceValue("Memory", resourcesMap) + info.RequiredGPU = getTypesResourceValue("GPU", resourcesMap) + info.RequiredNPU = getTypesResourceValue("NPU/.+/count", resourcesMap) +} + +func getTypesResourceValue(resourceName string, resourcesMap map[string]types.Resource) float64 { + if _, ok := resourcesMap[resourceName]; ok { + return resourcesMap[resourceName].Scalar.Value + } + return 0 +} diff --git a/yuanrong/pkg/dashboard/handlers/instances_parentid_handler_test.go b/yuanrong/pkg/dashboard/handlers/instances_parentid_handler_test.go new file mode 100644 index 0000000..448b1b0 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/instances_parentid_handler_test.go @@ -0,0 +1,73 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/dashboard/etcdcache" +) + +func TestInstancesByParentIDHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + convey.Convey("Test InstancesByParentIDHandler:", t, func() { + instance := &types.InstanceSpecification{ + InstanceID: "id-123", + InstanceStatus: types.InstanceStatus{}, + Extensions: types.Extensions{}, + Resources: types.Resources{Resources: map[string]types.Resource{ + "CPU": types.Resource{Scalar: types.ValueScalar{}}, + "Memory": types.Resource{Scalar: types.ValueScalar{}}, + "GPU": types.Resource{Scalar: types.ValueScalar{}}, + "NPU": types.Resource{Scalar: types.ValueScalar{}}, + }}, + } + + r := gin.Default() + r.GET("/api/v1/instances", InstancesHandler) + req, err := http.NewRequest("GET", "/api/v1/instances?parent_id="+instance.InstanceID, nil) + convey.So(err, convey.ShouldBeNil) + w := httptest.NewRecorder() + + convey.Convey("Test InstancesByParentID success", func() { + patches := gomonkey.ApplyMethodReturn(&etcdcache.InstanceCache, "GetByParentID", + map[string]*types.InstanceSpecification{instance.InstanceID: instance}) + defer patches.Reset() + + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, instance.InstanceID) + }) + }) +} + +func TestInstanceSpec2Instance(t *testing.T) { + convey.Convey("Test instanceSpec2Instance:", t, func() { + convey.Convey("Test instanceSpec2Instance success", func() { + instance := instanceSpec2Instance(nil) + convey.So(instance.ID, convey.ShouldBeEmpty) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/instances_summary_handler.go b/yuanrong/pkg/dashboard/handlers/instances_summary_handler.go new file mode 100644 index 0000000..5dafc0f --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/instances_summary_handler.go @@ -0,0 +1,85 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/getinfo" +) + +// InstSummary for overviews page show instances summary +type InstSummary struct { + Total int `form:"total" json:"total"` + Running int `form:"running" json:"running"` + Exited int `form:"exited" json:"exited"` + Fatal int `form:"fatal" json:"fatal"` +} + +const ( + // RUNNING code + RUNNING = 3 + // EXITED code + EXITED = 8 + // FATAL code + FATAL = 6 +) + +// InstancesSummaryHandler function for /instances/summary route +func InstancesSummaryHandler(ctx *gin.Context) { + pbInfos, err := getinfo.GetInstances(listInstPath) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "code": errGetInstances, + "msg": "fail", + "data": "", + }) + log.GetLogger().Errorf("/instances/summary GetInst %d %v", errGetInstances, err) + return + } + + instSumm := PBToInstSummary(pbInfos) + ctx.JSON(http.StatusOK, gin.H{ + "code": 0, + "msg": "succeed", + "data": instSumm, + }) + log.GetLogger().Debugf("/instances/summary succeed, data:%#v", instSumm) +} + +// PBToInstSummary function for switch pb struct to InstanceInfos struct +func PBToInstSummary(infos []*resource.InstanceInfo) InstSummary { + var instSumm InstSummary + instSumm.Total = len(infos) + for _, info := range infos { + switch info.InstanceStatus.Code { + case RUNNING: + instSumm.Running++ + case EXITED: + instSumm.Exited++ + case FATAL: + instSumm.Fatal++ + default: + } + } + return instSumm +} diff --git a/yuanrong/pkg/dashboard/handlers/instances_summary_handler_test.go b/yuanrong/pkg/dashboard/handlers/instances_summary_handler_test.go new file mode 100644 index 0000000..0ed2cbc --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/instances_summary_handler_test.go @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + "google.golang.org/protobuf/proto" + + "yuanrong/pkg/common/faas_common/grpc/pb/message" + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/dashboard/flags" +) + +func TestInstancesSummaryHandler(t *testing.T) { + convey.Convey("Test InstancesSummaryHandler:", t, func() { + r := gin.Default() + r.GET("/instances/summary", InstancesSummaryHandler) + req, err := http.NewRequest("GET", "/instances/summary", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + + convey.Convey("Test InstancesSummary when function master error", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "2002") + }) + + resources := resource.Resources{ + Resources: map[string]*resource.Resource{ + "CPU": { + Name: "CPU", + Scalar: &resource.Value_Scalar{ + Value: 500, + }, + }, + "Memory": { + Name: "Memory", + Scalar: &resource.Value_Scalar{ + Value: 300, + }, + }, + }, + } + + instance1 := resource.InstanceInfo{ + InstanceID: "app-dfb5ed67-9342-4b50-82ff-8fa8f055a9f4", + InstanceStatus: &resource.InstanceStatus{Code: 3}, + FunctionAgentID: "function-agent-7.185.104.157-31630", + FunctionProxyID: "phish232340-1936114", + ParentID: "driver-faas-frontend-dggphis232340-1936114", + Resources: &resources, + } + instance2 := instance1 + instance2.InstanceStatus = &resource.InstanceStatus{Code: 8} + instance3 := instance1 + instance3.InstanceStatus = &resource.InstanceStatus{Code: 6} + instances := []*resource.InstanceInfo{&instance1, &instance2, &instance3} + + instancesInfo := message.QueryInstancesInfoResponse{ + InstanceInfos: instances, + } + //message.QueryInstancesInfoResponse + instancesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + instancesInfoBytes, _ := proto.Marshal(&instancesInfo) + w.WriteHeader(http.StatusOK) + w.Write(instancesInfoBytes) + })) + defer instancesServer.Close() + flags.DashboardConfig.FunctionMasterAddr = instancesServer.URL + + convey.Convey("Test InstancesSummary success", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, `"msg":"succeed"`) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/job_handler.go b/yuanrong/pkg/dashboard/handlers/job_handler.go new file mode 100644 index 0000000..7eb2be1 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/job_handler.go @@ -0,0 +1,99 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/job" + "yuanrong/pkg/dashboard/getinfo" +) + +// SubmitJobHandler - +func SubmitJobHandler(ctx *gin.Context) { + traceID := ctx.Request.Header.Get(constant.HeaderTraceID) + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + req := job.SubmitJobHandleReq(ctx) + if req == nil { + return + } + if req.SubmissionId == "" { + req.NewSubmissionID() + } else { + resp := getinfo.GetAppInfo(req.SubmissionId) + if resp.Message != "" { + if resp.Code != http.StatusNotFound { + logger.Errorf("failed GetAppInfo, submissionId: %s, code: %d, err: %s", + req.SubmissionId, errFrontend, resp.Message) + ctx.JSON(http.StatusInternalServerError, + fmt.Sprintf("failed GetAppInfo, submissionId: %s, code: %d, err: %s", + req.SubmissionId, errFrontend, resp.Message)) + return + } + } + if resp.Code == http.StatusOK { + logger.Errorf("submit job has already exist, submissionId: %s, code: %d", + req.SubmissionId, errFrontend) + ctx.JSON(http.StatusBadRequest, + fmt.Sprintf("submit job has already exist, submissionId: %s, code: %d", + req.SubmissionId, errFrontend)) + return + } + } + logger.Debugf("start to CreateApp, req:%#v", req) + reqBytes, err := json.Marshal(req) + if err != nil { + logger.Errorf("marshal CreateApp request failed, err:%s", err.Error()) + ctx.JSON(http.StatusBadRequest, fmt.Sprintf("marshal CreateApp request failed, err:%v", err)) + return + } + job.SubmitJobHandleRes(ctx, getinfo.CreateApp(reqBytes)) +} + +// ListJobsHandler - +func ListJobsHandler(ctx *gin.Context) { + job.ListJobsHandleRes(ctx, getinfo.ListApps()) +} + +// GetJobInfoHandler - +func GetJobInfoHandler(ctx *gin.Context) { + submissionId := ctx.Param(job.PathParamSubmissionId) + log.GetLogger().Debugf("start to GetJobInfoHandler, submissionId: %s", submissionId) + job.GetJobInfoHandleRes(ctx, getinfo.GetAppInfo(submissionId)) +} + +// DeleteJobHandler - +func DeleteJobHandler(ctx *gin.Context) { + submissionId := ctx.Param(job.PathParamSubmissionId) + log.GetLogger().Debugf("start to DeleteJobHandler, submissionId: %s", submissionId) + job.DeleteJobHandleRes(ctx, getinfo.DeleteApp(submissionId)) +} + +// StopJobHandler - +func StopJobHandler(ctx *gin.Context) { + submissionId := ctx.Param(job.PathParamSubmissionId) + log.GetLogger().Debugf("start to StopJobHandler, submissionId: %s", submissionId) + job.StopJobHandleRes(ctx, getinfo.StopApp(submissionId)) +} diff --git a/yuanrong/pkg/dashboard/handlers/job_handler_test.go b/yuanrong/pkg/dashboard/handlers/job_handler_test.go new file mode 100644 index 0000000..44bbc2b --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/job_handler_test.go @@ -0,0 +1,363 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/constants" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/job" + "yuanrong/pkg/dashboard/flags" + "yuanrong/pkg/dashboard/getinfo" +) + +type TC struct { + FrontendResStatus int + DashboardResStatus int + FrontendResBody string + DashboardResBody string +} + +func addApps(submissionId string) []*constant.AppInfo { + var result []*constant.AppInfo + return append(result, buildAppInfo(submissionId)) +} + +func buildAppInfo(submissionId string) *constant.AppInfo { + return &constant.AppInfo{ + Key: submissionId, + Type: "SUBMISSION", + SubmissionID: submissionId, + RuntimeEnv: map[string]interface{}{ + "working_dir": "", + "pip": "", + "envVars": "", + }, + DriverInfo: constant.DriverInfo{ + ID: submissionId, + }, + Status: "RUNNING", + } +} + +func TestSubmitJobHandler(t *testing.T) { + convey.Convey("test SubmitJobHandler", t, func() { + submissionId := "app-frontend-job-submit1" + gin.SetMode(gin.TestMode) + rw := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rw) + bodyBytes, _ := json.Marshal(job.SubmitRequest{ + Entrypoint: "sleep 200", + SubmissionId: submissionId, + RuntimeEnv: &job.RuntimeEnv{ + WorkingDir: "file:///usr1/deploy/file.zip", + Pip: []string{"numpy==1.24", "scipy==1.25"}, + EnvVars: map[string]string{ + "SOURCE_REGION": "suzhou_std", + "DEPLOY_REGION": "suzhou_std", + }, + }, + Metadata: map[string]string{ + "autoscenes_ids": "auto_1-test", + "task_type": "task_1", + "ttl": "1250", + }, + EntrypointResources: map[string]float64{ + "NPU": 0, + }, + EntrypointNumCpus: 300, + EntrypointNumGpus: 0, + EntrypointMemory: 0, + }) + reader := bytes.NewBuffer(bodyBytes) + c.Request = &http.Request{ + Method: http.MethodPost, + URL: &url.URL{Path: job.PathGroupJobs}, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + constants.HeaderTenantID: []string{"123456"}, + constants.HeaderPoolLabel: []string{"abc"}, + }, + Body: io.NopCloser(reader), // 使用 io.NopCloser 包装 reader,使其满足 io.ReadCloser 接口 + } + convey.Convey("when job is exist", func() { + defer gomonkey.ApplyFunc(getinfo.GetAppInfo, func(submissionID string) job.Response { + return job.Response{ + Data: nil, + Code: http.StatusOK, + Message: "", + } + }).Reset() + SubmitJobHandler(c) + convey.So(rw.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(rw.Body.String(), convey.ShouldStartWith, "\"submit job has already exist, submissionId") + }) + convey.Convey("when get job failed", func() { + defer gomonkey.ApplyFunc(getinfo.GetAppInfo, func(string) job.Response { + return job.Response{ + Data: nil, + Code: http.StatusInternalServerError, + Message: "get job failed", + } + }).Reset() + SubmitJobHandler(c) + convey.So(rw.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(rw.Body.String(), convey.ShouldStartWith, "\"failed GetAppInfo, submissionId:") + }) + convey.Convey("when not found job", func() { + defer gomonkey.ApplyFunc(getinfo.GetAppInfo, func(string) job.Response { + return job.Response{ + Data: nil, + Code: http.StatusNotFound, + Message: "not found job", + } + }).Reset() + SubmitJobHandler(c) + convey.So(rw.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(rw.Body.String(), convey.ShouldStartWith, "\"request to fronted failed, err:") + }) + + convey.Convey("when create app response status: "+strconv.Itoa(http.StatusOK), func() { + defer gomonkey.ApplyFunc(getinfo.GetAppInfo, func(string) job.Response { + return job.Response{ + Data: nil, + Code: http.StatusNotFound, + Message: "not found job", + } + }).Reset() + defer gomonkey.ApplyFunc(getinfo.CreateApp, func([]byte) job.Response { + return job.Response{ + Code: http.StatusOK, + Message: "", + Data: []byte(`{"submission_id":"app-123"}`), + } + }).Reset() + SubmitJobHandler(c) + convey.So(rw.Code, convey.ShouldEqual, http.StatusOK) + convey.So(rw.Body.String(), convey.ShouldStartWith, `{"submission_id":"app-123"}`) + }) + convey.Convey("when create app response status: "+strconv.Itoa(http.StatusInternalServerError), func() { + defer gomonkey.ApplyFunc(getinfo.GetAppInfo, func(string) job.Response { + return job.Response{ + Data: nil, + Code: http.StatusNotFound, + Message: "not found job", + } + }).Reset() + defer gomonkey.ApplyFunc(getinfo.CreateApp, func([]byte) job.Response { + return job.Response{ + Code: http.StatusInternalServerError, + Message: "failed submit job", + } + }).Reset() + SubmitJobHandler(c) + convey.So(rw.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(rw.Body.String(), convey.ShouldStartWith, "\"failed submit job") + }) + }) +} + +func TestListJobsHandler(t *testing.T) { + convey.Convey("Test ListJobsHandler:", t, func() { + dataBytes, err := json.Marshal(addApps("app-123")) + convey.So(err, convey.ShouldBeNil) + appInfoByte, err := json.Marshal(job.Response{ + Code: http.StatusOK, + Data: dataBytes, + }) + tt := []TC{ + {http.StatusOK, http.StatusOK, string(appInfoByte), string(dataBytes)}, + {http.StatusInternalServerError, http.StatusInternalServerError, `{"code": 500, "message": "failed get job"}`, `"failed get job"`}, + } + for _, tc := range tt { + convey.Convey("when frontend return status: "+strconv.Itoa(tc.FrontendResStatus), func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.FrontendResStatus) + w.Write([]byte(tc.FrontendResBody)) + })) + defer server.Close() + flags.DashboardConfig.FrontendAddr = server.URL + "/app/v1" + + r := gin.Default() + r.GET("/jobs", ListJobsHandler) + req, err := http.NewRequest("GET", "/jobs", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, tc.DashboardResStatus) + convey.So(string(w.Body.Bytes()), convey.ShouldEqual, tc.DashboardResBody) + }) + } + convey.Convey("when request frontend failed "+strconv.Itoa(http.StatusInternalServerError), func() { + r := gin.Default() + r.GET("/jobs", ListJobsHandler) + req, err := http.NewRequest("GET", "/jobs", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldStartWith, "\"request to fronted failed") + }) + }) +} + +func TestGetJobInfoHandler(t *testing.T) { + convey.Convey("Test GetJobInfoHandler:", t, func() { + dataBytes, err := json.Marshal(buildAppInfo("app-123")) + convey.So(err, convey.ShouldBeNil) + appInfoByte, err := json.Marshal(job.Response{ + Code: http.StatusOK, + Data: dataBytes, + }) + convey.So(err, convey.ShouldBeNil) + tt := []TC{ + {http.StatusOK, http.StatusOK, string(appInfoByte), string(dataBytes)}, + {http.StatusNotFound, http.StatusNotFound, `{"code": 404, "message": "the job does not exist"}`, `"the job does not exist"`}, + {http.StatusInternalServerError, http.StatusInternalServerError, `{"code": 500, "message": "failed get job"}`, `"failed get job"`}, + } + for _, tc := range tt { + convey.Convey("when frontend return status: "+strconv.Itoa(tc.FrontendResStatus), func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.FrontendResStatus) + w.Write([]byte(tc.FrontendResBody)) + })) + defer server.Close() + flags.DashboardConfig.FrontendAddr = server.URL + "/app/v1" + + r := gin.Default() + r.GET("/jobs/:submission_id", GetJobInfoHandler) + req, err := http.NewRequest("GET", "/jobs/123", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, tc.DashboardResStatus) + convey.So(w.Body.String(), convey.ShouldEqual, tc.DashboardResBody) + }) + } + convey.Convey("when request frontend failed ", func() { + r := gin.Default() + r.GET("/jobs/:submission_id", GetJobInfoHandler) + req, err := http.NewRequest("GET", "/jobs/123", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldStartWith, "\"request to fronted failed") + }) + }) +} + +func TestDeleteJobHandler(t *testing.T) { + convey.Convey("Test DeleteJobHandler:", t, func() { + tt := []TC{ + {http.StatusOK, http.StatusOK, `{"code": 200}`, "true"}, + {http.StatusForbidden, http.StatusOK, `{"code": 403}`, "false"}, + {http.StatusNotFound, http.StatusNotFound, `{"code": 404, "message": "the job does not exist"}`, `"the job does not exist"`}, + {http.StatusInternalServerError, http.StatusInternalServerError, `{"code": 500, "message": "failed delete job"}`, `"failed delete job"`}, + } + for _, tc := range tt { + convey.Convey("when frontend return status: "+strconv.Itoa(tc.FrontendResStatus), func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.FrontendResStatus) + w.Write([]byte(tc.FrontendResBody)) + })) + defer server.Close() + flags.DashboardConfig.FrontendAddr = server.URL + "/app/v1" + + r := gin.Default() + r.DELETE("/jobs/:submission_id", DeleteJobHandler) + req, err := http.NewRequest("DELETE", "/jobs/123", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, tc.DashboardResStatus) + convey.So(string(w.Body.Bytes()), convey.ShouldEqual, tc.DashboardResBody) + }) + } + convey.Convey("when request frontend failed", func() { + r := gin.Default() + r.DELETE("/jobs/:submission_id", DeleteJobHandler) + req, err := http.NewRequest("DELETE", "/jobs/123", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldStartWith, "\"request to fronted failed") + }) + }) +} + +func TestStopJobHandler(t *testing.T) { + convey.Convey("Test StopJobHandler", t, func() { + tt := []TC{ + {http.StatusOK, http.StatusOK, `{"code": 200}`, "true"}, + {http.StatusForbidden, http.StatusOK, `{"code": 403}`, "false"}, + {http.StatusNotFound, http.StatusNotFound, `{"code": 404, "message": "the job does not exist"}`, `"the job does not exist"`}, + {http.StatusInternalServerError, http.StatusInternalServerError, `{"code": 500, "message": "failed stop job"}`, `"failed stop job"`}, + } + for _, tc := range tt { + convey.Convey("when frontend return status: "+strconv.Itoa(tc.FrontendResStatus), func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.FrontendResStatus) + w.Write([]byte(tc.FrontendResBody)) + })) + defer server.Close() + flags.DashboardConfig.FrontendAddr = server.URL + "/app/v1" + + r := gin.Default() + r.POST("/jobs/:submission_id/stop", StopJobHandler) + req, err := http.NewRequest("POST", "/jobs/123/stop", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, tc.DashboardResStatus) + convey.So(string(w.Body.Bytes()), convey.ShouldEqual, tc.DashboardResBody) + }) + } + convey.Convey("when request frontend failed", func() { + r := gin.Default() + r.POST("/jobs/:submission_id/stop", StopJobHandler) + req, err := http.NewRequest("POST", "/jobs/123/stop", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(w.Body.String(), convey.ShouldStartWith, "\"request to fronted failed") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/prometheus_handler.go b/yuanrong/pkg/dashboard/handlers/prometheus_handler.go new file mode 100644 index 0000000..2763496 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/prometheus_handler.go @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/flags" + "yuanrong/pkg/dashboard/getinfo" +) + +// PrometheusHandler - +func PrometheusHandler(ctx *gin.Context) { + // reconnect + if getinfo.PromClient == nil { + getinfo.InitPromClient() + } + + if getinfo.PromClient == nil { + log.GetLogger().Errorf("failed to connect prometheus, prometheus address: %v", flags.DashboardConfig.PrometheusAddr) + ctx.JSON(http.StatusBadRequest, + fmt.Sprintf("failed to connect prometheus, prometheus address: %v", flags.DashboardConfig.PrometheusAddr)) + return + } + + query := ctx.Query("query") + res, warnings, err := getinfo.PromClient.Query(ctx, query, time.Time{}) + if err != nil { + log.GetLogger().Errorf("prometheus query failed, error: %v", err.Error()) + ctx.JSON(http.StatusBadRequest, fmt.Sprintf("prometheus query failed, error: %v", err.Error())) + return + } + + if len(warnings) > 0 { + log.GetLogger().Errorf("prometheus query return warnings: %v", warnings) + ctx.JSON(http.StatusBadRequest, fmt.Sprintf("prometheus query return warnings: %v", warnings)) + return + } + + ctx.JSON(http.StatusOK, res) + log.GetLogger().Debugf("/prometheus/query succeed, infos:%#v", res) +} diff --git a/yuanrong/pkg/dashboard/handlers/prometheus_handler_test.go b/yuanrong/pkg/dashboard/handlers/prometheus_handler_test.go new file mode 100644 index 0000000..4391eaf --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/prometheus_handler_test.go @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/dashboard/getinfo" +) + +type PromTC struct { + ResStatus int + ResBody string +} + +func TestPrometheusHandler(t *testing.T) { + convey.Convey("Test PrometheusHandler", t, func() { + r := gin.Default() + r.GET("/api/v1/prometheus/query", PrometheusHandler) + convey.Convey("when prometheus query failed", func() { + getinfo.PromClient = nil + req, err := http.NewRequest("GET", "/api/v1/prometheus/query?query=up", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "prometheus query failed") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/resources_handler.go b/yuanrong/pkg/dashboard/handlers/resources_handler.go new file mode 100644 index 0000000..9fc30b9 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/resources_handler.go @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// ResourcesHandler function for /logical-resources route +func ResourcesHandler(ctx *gin.Context) { + ctx.JSON(http.StatusOK, gin.H{ + "code": 0, + "msg": "succeed", + "data": "Resources json", + }) + log.GetLogger().Debugf("/logical-resources succeed") +} diff --git a/yuanrong/pkg/dashboard/handlers/resources_handler_test.go b/yuanrong/pkg/dashboard/handlers/resources_handler_test.go new file mode 100644 index 0000000..643a505 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/resources_handler_test.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" +) + +func TestResourcesHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + convey.Convey("Test ResourcesHandler:", t, func() { + r := gin.Default() + r.GET("/logical-resources", ResourcesHandler) + req, err := http.NewRequest("GET", "/logical-resources", nil) + convey.So(err, convey.ShouldBeNil) + w := httptest.NewRecorder() + + convey.Convey("Test Resources success", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "Resources json") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/resources_summary_handler.go b/yuanrong/pkg/dashboard/handlers/resources_summary_handler.go new file mode 100644 index 0000000..73e9c98 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/resources_summary_handler.go @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "errors" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/getinfo" +) + +// Usage for check cpu and mem usage +type Usage struct { + CapCPU float64 `form:"cap_cpu" json:"cap_cpu"` + CapMem float64 `form:"cap_mem" json:"cap_mem"` + AllocCPU float64 `form:"alloc_cpu" json:"alloc_cpu"` + AllocMem float64 `form:"alloc_mem" json:"alloc_mem"` + AllocNPU float64 `form:"alloc_npu" json:"alloc_npu"` +} + +// RsrcSummary for check summary +type RsrcSummary struct { + Usage + ProxyNum int `form:"proxy_num" json:"proxy_num"` +} + +// ResourcesSummaryHandler function for /logical-resources/summary route +func ResourcesSummaryHandler(ctx *gin.Context) { + resource, err := getinfo.GetResources() + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "code": errGetResources, + "msg": "fail", + "data": "", + }) + log.GetLogger().Errorf("/logical-resources/summary GetResources %d %v", errGetResources, err) + return + } + + rsrcSummary, err := PBToRsrcSummary(resource) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "code": errPBToRsrcSummary, + "msg": "fail", + "data": "", + }) + log.GetLogger().Warnf("/logical-resources/summary PBToSummary %d %v", errPBToRsrcSummary, err) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "code": 0, + "msg": "succeed", + "data": rsrcSummary, + }) + log.GetLogger().Debugf("/logical-resources/summary succeed, data:%#v", rsrcSummary) +} + +// PBToUsage function for pb data switch to Usage struct +func PBToUsage(resource *resource.ResourceUnit) (Usage, error) { + var usage Usage + capCPU, ok := resource.Capacity.Resources["CPU"] + if !ok { + return Usage{}, errors.New(`no resource.Capacity.Resources["CPU"]`) + } + if capCPU != nil && capCPU.Scalar != nil { + usage.CapCPU = capCPU.Scalar.Value + } + capMem, ok := resource.Capacity.Resources["Memory"] + if !ok { + return Usage{}, errors.New(`no resource.Capacity.Resources["Memory"]`) + } + if capMem != nil && capMem.Scalar != nil { + usage.CapMem = capMem.Scalar.Value + } + allocResources := resource.Allocatable.Resources + allocCPU, ok := allocResources["CPU"] + if !ok { + return Usage{}, errors.New(`no resource.Allocatable.Resources["CPU"]`) + } + if allocCPU != nil && allocCPU.Scalar != nil { + usage.AllocCPU = allocCPU.Scalar.Value + } + allocMem, ok := allocResources["Memory"] + if !ok { + return Usage{}, errors.New(`no resource.Allocatable.Resources["Memory"]`) + } + if allocMem != nil && allocMem.Scalar != nil { + usage.AllocMem = allocMem.Scalar.Value + } + for resourceName, resourceData := range allocResources { + if strings.Contains(resourceName, "NPU") { + usage.AllocNPU = resourceData.Scalar.Value + } + } + return usage, nil +} + +// PBToRsrcSummary function for pb data switch to Summary struct +func PBToRsrcSummary(resource *resource.ResourceUnit) (RsrcSummary, error) { + var summary RsrcSummary + var err error + summary.Usage, err = PBToUsage(resource) + if err != nil { + return RsrcSummary{}, err + } + summary.ProxyNum = len(resource.NodeLabels["NODE_ID"].Items) + return summary, nil +} diff --git a/yuanrong/pkg/dashboard/handlers/resources_summary_handler_test.go b/yuanrong/pkg/dashboard/handlers/resources_summary_handler_test.go new file mode 100644 index 0000000..b0fc5c9 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/resources_summary_handler_test.go @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/agiledragon/gomonkey" + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + "google.golang.org/protobuf/proto" + + "yuanrong/pkg/common/faas_common/grpc/pb/message" + "yuanrong/pkg/common/faas_common/grpc/pb/resource" + "yuanrong/pkg/dashboard/flags" +) + +func TestResourcesSummaryHandler(t *testing.T) { + convey.Convey("Test ResourcesSummaryHandler:", t, func() { + r := gin.Default() + r.GET("/logical-resources/summary", ResourcesSummaryHandler) + req, err := http.NewRequest("GET", "/logical-resources/summary", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + + convey.Convey("Test ResourcesSummary when function master error", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "2001") + }) + + resources := &resource.Resources{ + Resources: map[string]*resource.Resource{ + "Memory": { + Name: "Memory", + Scalar: &resource.Value_Scalar{Value: 38912}, + }, + "CPU": { + Name: "CPU", + Scalar: &resource.Value_Scalar{Value: 10000}, + }, + }, + } + nodeLabels := map[string]*resource.Value_Counter{ + "NODE_ID": { + Items: map[string]uint64{"dggphis232340-1936114": 1}, + }, + } + fragment1 := &resource.ResourceUnit{ + Id: "function-agent-7.185.104.157-31630", + Capacity: resources, + Allocatable: resources, + NodeLabels: nodeLabels, + OwnerId: "dggphis232340-1936114", + } + resourceInfo := message.ResourceInfo{ + RequestID: "145aa8bc-d616-4000-8000-000000734df7", + Resource: &resource.ResourceUnit{ + Id: "InnerDomainScheduler", + Capacity: resources, + Allocatable: resources, + Fragment: map[string]*resource.ResourceUnit{ + "function-agent-7.185.104.157-31630": fragment1, + }, + NodeLabels: nodeLabels, + Revision: 19, + ViewInitTime: "542b3f0f-0000-4000-8000-0089a640b071", + }, + } + resourcesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resourceInfoBytes, _ := proto.Marshal(&resourceInfo) + w.WriteHeader(http.StatusOK) + w.Write(resourceInfoBytes) + })) + defer resourcesServer.Close() + flags.DashboardConfig.FunctionMasterAddr = resourcesServer.URL + + convey.Convey("Test ResourcesSummary success", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, `"msg":"succeed"`) + }) + convey.Convey("Test ResourcesSummary when PBToUsage error", func() { + patches := gomonkey.ApplyFunc(PBToUsage, func(resource *resource.ResourceUnit) (Usage, error) { + return Usage{}, errors.New("PBToUsage error") + }) + defer patches.Reset() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "1002") + }) + }) +} + +func TestPBToUsage(t *testing.T) { + convey.Convey("Test PBToUsage:", t, func() { + + resources := &resource.Resources{Resources: map[string]*resource.Resource{}} + resources1 := &resource.Resources{ + Resources: map[string]*resource.Resource{ + "CPU": { + Name: "CPU", + Scalar: &resource.Value_Scalar{Value: 10000}, + }, + }, + } + resources2 := &resource.Resources{ + Resources: map[string]*resource.Resource{ + "Memory": { + Name: "Memory", + Scalar: &resource.Value_Scalar{Value: 38912}, + }, + "CPU": { + Name: "CPU", + Scalar: &resource.Value_Scalar{Value: 10000}, + }, + }, + } + tt := []struct { + Label string + CapRsrc *resource.Resources + AllocRsrc *resource.Resources + Res string + }{ + {"when no cap cpu", resources, resources, `no resource.Capacity.Resources["CPU"]`}, + {"when no cap mem", resources1, resources, `no resource.Capacity.Resources["Memory"]`}, + {"when no alloc cpu", resources2, resources, `no resource.Allocatable.Resources["CPU"]`}, + {"when no alloc mem", resources2, resources1, `no resource.Allocatable.Resources["Memory"]`}, + } + for _, tc := range tt { + convey.Convey("Test PBToUsage "+tc.Label, func() { + resourceUnit := &resource.ResourceUnit{ + Capacity: tc.CapRsrc, + Allocatable: tc.AllocRsrc, + } + usage, err := PBToUsage(resourceUnit) + convey.So(usage.CapCPU, convey.ShouldEqual, 0) + convey.So(err.Error(), convey.ShouldEqual, tc.Res) + }) + } + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/resources_unitid_handler.go b/yuanrong/pkg/dashboard/handlers/resources_unitid_handler.go new file mode 100644 index 0000000..d1b1661 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/resources_unitid_handler.go @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// ResourcesByUnitIDHandler function for /logical-resources/:unit-id route +func ResourcesByUnitIDHandler(ctx *gin.Context) { + unitId := ctx.Param("unit-id") + ctx.JSON(http.StatusOK, gin.H{ + "code": 0, + "msg": "succeed", + "data": "ResourcesUnit by UnitID:" + unitId + " json", + }) + log.GetLogger().Debugf("/logical-resources/%s succeed", unitId) +} diff --git a/yuanrong/pkg/dashboard/handlers/resources_unitid_handler_test.go b/yuanrong/pkg/dashboard/handlers/resources_unitid_handler_test.go new file mode 100644 index 0000000..5891088 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/resources_unitid_handler_test.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" +) + +func TestResourcesByUnitIDHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + convey.Convey("Test ResourcesByUnitIDHandler:", t, func() { + r := gin.Default() + r.GET("/logical-resources/:unit-id", ResourcesByUnitIDHandler) + req, err := http.NewRequest("GET", "/logical-resources/id-123", nil) + convey.So(err, convey.ShouldBeNil) + w := httptest.NewRecorder() + + convey.Convey("Test ResourcesByUnitID success", func() { + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldContainSubstring, "id-123") + }) + }) +} diff --git a/yuanrong/pkg/dashboard/handlers/serve_handler.go b/yuanrong/pkg/dashboard/handlers/serve_handler.go new file mode 100644 index 0000000..1f69767 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/serve_handler.go @@ -0,0 +1,328 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/dashboard/models" +) + +const ( + // from python/ray/serve/_private/common.py DeploymentStatus + serveDeploymentStatusHealthy = "HEALTHY" // for normal + serveDeploymentStatusUnhealthy = "UNHEALTHY" // for abnormal cases + serveDeploymentStatusUpScaling = "UPSCALING" // for scaling + + // from python/ray/serve/_private/common.py ApplicationStatus + serveAppStatusDeploying = "DEPLOYING" // for scaling + serveAppStatusRunning = "RUNNING" // for normal + + defaultEtcdRequestTimeout = 30 * time.Second +) + +// ServeApplicationStatus - +type ServeApplicationStatus struct { + Deployments map[string]ServeDeploymentStatus `json:"deployments"` + Name string `json:"name,omitempty"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +// ServeDeploymentStatus - +type ServeDeploymentStatus struct { + Name string `json:"name,omitempty"` + Status string `json:"status,omitempty"` + Message string `json:"message,omitempty"` +} + +func setServeGetErrorRsp(ctx *gin.Context, err error, msgPrefix string) { + if err != nil { + ctx.JSON(http.StatusServiceUnavailable, gin.H{ + "message": msgPrefix + err.Error(), + }) + } +} + +// ServeGetHandler - +// RayService use this to check if the service is ready or not (some level a healthy probe) +// RayService needs ( len(Apps) > 0 and all apps.STATUS == "RUNNING" ) +// Ray returns each app's status by +// ==> python/ray/serve/_private/controller.py:ServeController.get_serve_instance_details +// ==> python/ray/serve/_private/application_state.py:ApplicationStateManager.list_app_statuses +// ==> python/ray/serve/_private/application_state.py:ApplicationStateManager._determine_app_status +// ==> python/ray/serve/_private/deployment_state.py:DeploymentState.check_curr_status +// it is determined by every deployment, it any deployment is either not running or not have enough replicas, the app is +// not running. +// So, in yuanrong, we have some conditions +// 1. get all serve func by fetch /sn/functions, and fill the responses' `Applications` part +// 2. get all serve instances by fetch /sn/instances, and check each instances' status ( RUNNING and others ) +func ServeGetHandler(ctx *gin.Context) { + // get without request options + serveFuncMetaInfos, err := getAllServeFunctions() + if err != nil { + setServeGetErrorRsp(ctx, err, "") + return + } + serveInstances, err := getAllServeRunningInstances() + if err != nil { + setServeGetErrorRsp(ctx, err, "") + return + } + serveApps, err := convertServeDeploymentFaasFunctionsToServeDetails(serveFuncMetaInfos, serveInstances) + if err != nil { + setServeGetErrorRsp(ctx, err, "") + return + } + ctx.JSON(http.StatusOK, serveApps) + return +} + +func getAllServeRunningInstances() ([]*types.InstanceSpecification, error) { + resp, err := etcd3.GetRouterEtcdClient().Get( + etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), defaultEtcdRequestTimeout), + "/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-serveExecutor", + clientv3.WithPrefix()) + if err != nil { + return nil, fmt.Errorf("invalid serve params: %s", err.Error()) + } + var serveInstances []*types.InstanceSpecification + for _, kv := range resp.Kvs { + instSpec := &types.InstanceSpecification{} + err := json.Unmarshal(kv.Value, instSpec) + if err != nil { + log.GetLogger().Warnf("failed to marshal instance spec: %s", err.Error()) + } + if instSpec.InstanceStatus.Code == int32(constant.KernelInstanceStatusRunning) { + serveInstances = append(serveInstances, instSpec) + } + } + return serveInstances, nil +} + +func getAllServeFunctions() ([]*types.FunctionMetaInfo, error) { + // Step 1. get all serve functions + resp, err := etcd3.GetMetaEtcdClient().Get( + etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), defaultEtcdRequestTimeout), + "/sn/functions", clientv3.WithPrefix()) + if err != nil { + return nil, fmt.Errorf("invalid serve params: %s", err.Error()) + } + var allServeFunc []*types.FunctionMetaInfo + for _, kv := range resp.Kvs { + functionInfo := types.FunctionMetaInfo{} + err := json.Unmarshal(kv.Value, &functionInfo) + if err != nil { + return nil, fmt.Errorf("invalid serve params: %s", err.Error()) + } + // yes, this is a serve func + if len(functionInfo.ExtendedMetaData.ServeDeploySchema.Applications) > 0 { + allServeFunc = append(allServeFunc, &functionInfo) + } + } + return allServeFunc, nil +} + +func getInstancesOfServeFunc(allRunningInstances []*types.InstanceSpecification, + serveFunc *types.FunctionMetaInfo) []*types.InstanceSpecification { + var matchInstances []*types.InstanceSpecification + var faasKey string + var keyNoteExists bool + for _, inst := range allRunningInstances { + if faasKey, keyNoteExists = inst.CreateOptions["FUNCTION_KEY_NOTE"]; !keyNoteExists { + continue + } + meta := types.ServeFunctionKey{} + _ = meta.FromFaasFunctionKey(faasKey) + if meta.ToFaasFunctionVersionUrn() == serveFunc.FuncMetaData.FunctionVersionURN { + matchInstances = append(matchInstances, inst) + } + } + return matchInstances +} + +func getServeDeploymentStatus(serveInstances []*types.InstanceSpecification, + serveFunc *types.FunctionMetaInfo) (string, string) { + if len(serveFunc.ExtendedMetaData.ServeDeploySchema.Applications) == 0 { + log.GetLogger().Errorf("there is no application in %v", serveFunc.ExtendedMetaData.ServeDeploySchema) + return serveDeploymentStatusUnhealthy, "no application info found in func meta info" + } + if len(serveFunc.ExtendedMetaData.ServeDeploySchema.Applications[0].Deployments) == 0 { + log.GetLogger().Errorf("there is no deployments in %v", serveFunc.ExtendedMetaData.ServeDeploySchema) + return serveDeploymentStatusUnhealthy, "no deployment info found in func meta info" + } + expectedReplicas := serveFunc.ExtendedMetaData.ServeDeploySchema.Applications[0].Deployments[0].NumReplicas + if expectedReplicas == int64(len(serveInstances)) { + return serveDeploymentStatusHealthy, "healthy" + } + return serveDeploymentStatusUpScaling, fmt.Sprintf("now: %d expect: %d", len(serveInstances), expectedReplicas) +} + +// DeploymentStatus: UPDATING, HEALTHY, UNHEALTHY, UPSCALING, DOWNSCALING +// ApplicationStatus: NOT_STARTED, DEPLOYING, DEPLOY_FAILED, RUNNING, UNHEALTHY, DELETING +func convertServeDeploymentFaasFunctionsToServeDetails(allServeFunc []*types.FunctionMetaInfo, + allRunningInstances []*types.InstanceSpecification) (models.ServeDetails, + error) { + serveDetails := models.ServeDetails{ + Applications: make(map[string]*models.ServeApplicationDetails), + } + for _, serveFunc := range allServeFunc { + theCorrespondingServeDeploySchema := serveFunc.ExtendedMetaData.ServeDeploySchema + if len(theCorrespondingServeDeploySchema.Applications) < 1 { + log.GetLogger().Errorf("failed to validate app info from serve deploy schema, contains no app") + continue + } + theCorrespondingServeAppSchema := theCorrespondingServeDeploySchema.Applications[0] + if len(theCorrespondingServeAppSchema.Deployments) < 1 { + log.GetLogger().Errorf("failed to validate app info from serve deploy schema, contains no deployment") + continue + } + theCorrespondingServeDeploymentSchema := theCorrespondingServeAppSchema.Deployments[0] + serveInstances := getInstancesOfServeFunc(allRunningInstances, serveFunc) + + status, statusMsg := getServeDeploymentStatus(serveInstances, serveFunc) + dpDetails := models.ServeDeploymentDetails{ + ServeDeploymentStatus: models.ServeDeploymentStatus{ + Name: theCorrespondingServeDeploymentSchema.Name, + Status: status, + Message: statusMsg, + }, + RoutePrefix: theCorrespondingServeAppSchema.RoutePrefix, + } + + // no app, add one + if _, ok := serveDetails.Applications[theCorrespondingServeAppSchema.Name]; !ok { + serveDetails.Applications[theCorrespondingServeAppSchema.Name] = &models.ServeApplicationDetails{ + // `deployments` seems not really useful, just make it always ready and healthy + Deployments: map[string]models.ServeDeploymentDetails{ + theCorrespondingServeDeploymentSchema.Name: dpDetails, + }, + RoutePrefix: theCorrespondingServeDeploySchema.Applications[0].RoutePrefix, + ServeApplicationStatus: models.ServeApplicationStatus{ + Name: theCorrespondingServeAppSchema.Name, + Status: serveAppStatusRunning, + }, + } + } else { + // there is an existing app, add deployment into that one + serveDetails.Applications[theCorrespondingServeAppSchema.Name]. + Deployments[theCorrespondingServeDeploymentSchema.Name] = dpDetails + } + } + + modifyAppStatusByDeploymentStatus(&serveDetails) + return serveDetails, nil +} + +func modifyAppStatusByDeploymentStatus(serveDetails *models.ServeDetails) { + for i, app := range serveDetails.Applications { + isRunning := true + for _, dp := range app.Deployments { + if dp.Status != serveDeploymentStatusHealthy { + isRunning = false + } + } + if !isRunning { + serveDetails.Applications[i].Status = serveAppStatusDeploying + } + } +} + +// ServeDelHandler - +func ServeDelHandler(ctx *gin.Context) { + // get without request options + ctx.JSON(http.StatusOK, gin.H{ + "message": "ok, but yuanrong doesn't support this right now", + }) +} + +// ServePutHandler function for /serve routes +func ServePutHandler(ctx *gin.Context) { + var serveDeploySchema types.ServeDeploySchema + err := ctx.ShouldBindJSON(&serveDeploySchema) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Invalid serve params: %s", err.Error()), + }) + return + } + + log.GetLogger().Infof("allIncomingDeploySchema: %#v\n", serveDeploySchema) + if err = serveDeploySchema.Validate(); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Invalid serve params: %s", err.Error()), + }) + return + } + allServeFuncMetas := serveDeploySchema.ToFaaSFuncMetas() + err = putServeAsFunctionMetaInfoIntoEtcd(allServeFuncMetas) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to publish serve functions: %s", err.Error()), + }) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "message": "succeed", + }) + return +} + +type instanceConfiguration struct { + InstanceMetaData types.InstanceMetaData `json:"instanceMetaData" valid:",optional"` +} + +// use transaction to avoid partially success +func putServeAsFunctionMetaInfoIntoEtcd(allFuncMetas []*types.ServeFuncWithKeysAndFunctionMetaInfo) error { + txn := etcd3.GetMetaEtcdClient().Client.Txn(context.Background()) + var ops []clientv3.Op + for _, value := range allFuncMetas { + funcMetaValue, err := json.Marshal(value.FuncMetaInfo) + if err != nil { + return err + } + instanceMetaValue, err := json.Marshal(instanceConfiguration{ + InstanceMetaData: value.FuncMetaInfo.InstanceMetaData, + }) + if err != nil { + return err + } + ops = append(ops, clientv3.OpPut(value.FuncMetaKey, string(funcMetaValue))) + ops = append(ops, clientv3.OpPut(value.InstanceMetaKey, string(instanceMetaValue))) + } + commit, err := txn.Then(ops...).Commit() + if err != nil { + return err + } + if !commit.Succeeded { + return fmt.Errorf("failed to put function meta into etcd") + } + return nil +} diff --git a/yuanrong/pkg/dashboard/handlers/serve_handler_test.go b/yuanrong/pkg/dashboard/handlers/serve_handler_test.go new file mode 100644 index 0000000..00e34e0 --- /dev/null +++ b/yuanrong/pkg/dashboard/handlers/serve_handler_test.go @@ -0,0 +1,163 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package handlers for handle request +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/types" +) + +func TestServeDelHandler(t *testing.T) { + convey.Convey("Given a ServeDelHandler", t, func() { + // Setup gin router and the mock context + r := gin.Default() + r.DELETE("/serve", ServeDelHandler) + + req, err := http.NewRequest("DELETE", "/serve", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + convey.Convey("It should return status 200 and appropriate message", func() { + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + + var response map[string]string + err := json.Unmarshal(w.Body.Bytes(), &response) + convey.So(err, convey.ShouldBeNil) + + convey.So(response["message"], convey.ShouldEqual, "ok, but yuanrong doesn't support this right now") + }) + }) +} + +func TestServePutHandler(t *testing.T) { + convey.Convey("Given a ServePutHandler", t, func() { + // Setup gin router and the mock context + r := gin.Default() + r.PUT("/serve", ServePutHandler) + patches := gomonkey.ApplyFunc(putServeAsFunctionMetaInfoIntoEtcd, + func(allFuncMetas []*types.ServeFuncWithKeysAndFunctionMetaInfo) error { + return nil + }) + defer patches.Reset() + // Create a mock request payload + serveDeploySchema := types.ServeDeploySchema{ + Applications: []types.ServeApplicationSchema{ + { + Name: "testApp", + RoutePrefix: "/testRoute", + ImportPath: "testImportPath", + RuntimeEnv: types.ServeRuntimeEnvSchema{ + Pip: []string{"pip1", "pip2"}, + WorkingDir: "/test/dir", + EnvVars: map[string]any{"env1": "value1"}, + }, + Deployments: []types.ServeDeploymentSchema{ + { + Name: "testDeployment", + NumReplicas: 3, + HealthCheckPeriodS: 10, + HealthCheckTimeoutS: 5, + }, + }, + }, + }, + } + payload, err := json.Marshal(serveDeploySchema) + convey.So(err, convey.ShouldBeNil) + + req, err := http.NewRequest("PUT", "/serve", bytes.NewReader(payload)) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + convey.Convey("It should return status 200 and message 'succeed' when no errors occur", func() { + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + + var response map[string]string + err := json.Unmarshal(w.Body.Bytes(), &response) + convey.So(err, convey.ShouldBeNil) + + convey.So(response["message"], convey.ShouldEqual, "succeed") + }) + + convey.Convey("It should return a 400 status and an error message when invalid JSON is provided", func() { + // Test invalid JSON + req, err := http.NewRequest("PUT", "/serve", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + convey.So(w.Code, convey.ShouldEqual, http.StatusBadRequest) + + var response map[string]string + err = json.Unmarshal(w.Body.Bytes(), &response) + convey.So(err, convey.ShouldBeNil) + + convey.So(response["error"], convey.ShouldContainSubstring, "Invalid serve params") + }) + }) +} + +func TestPutServeAsFunctionMetaInfoIntoEtcd(t *testing.T) { + convey.Convey("Given a putServeAsFunctionMetaInfoIntoEtcd function", t, func() { + // Mock the GetMetaEtcdClient and the client methods + patches := gomonkey.ApplyFunc(putServeAsFunctionMetaInfoIntoEtcd, + func(allFuncMetas []*types.ServeFuncWithKeysAndFunctionMetaInfo) error { + return nil + }) + defer patches.Reset() + // Prepare mock data + funcMeta := &types.ServeFuncWithKeysAndFunctionMetaInfo{ + FuncMetaKey: "testKey", + InstanceMetaKey: "testInstanceKey", + FuncMetaInfo: &types.FunctionMetaInfo{ + InstanceMetaData: types.InstanceMetaData{ + MaxInstance: 5, + MinInstance: 5, + }, + }, + } + funcMetas := []*types.ServeFuncWithKeysAndFunctionMetaInfo{funcMeta} + + // Mock the response of Etcd operations + patches.ApplyMethod(reflect.TypeOf(&clientv3.Client{}), "Txn", func(client *clientv3.Client, ctx context.Context) clientv3.Txn { + return client.Txn(ctx) + }) + + // Test the function + convey.Convey("It should successfully put data into etcd", func() { + err := putServeAsFunctionMetaInfoIntoEtcd(funcMetas) + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/logmanager/collector_client.go b/yuanrong/pkg/dashboard/logmanager/collector_client.go new file mode 100644 index 0000000..2b2d09e --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/collector_client.go @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "context" + "io" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +type collectorClientInfo struct { + ID string + Address string +} + +type collectorClient struct { + collectorClientInfo + grpcConn *grpc.ClientConn + logClient logservice.LogCollectorServiceClient +} + +// Connect will connect the collector and store the connection +func (c *collectorClient) Connect() error { + // connect it + conn, err := grpc.Dial(c.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return err + } + c.grpcConn = conn + c.logClient = logservice.NewLogCollectorServiceClient(c.grpcConn) + return nil +} + +// Healthcheck should be used in a goroutine +func (c *collectorClient) Healthcheck(shutdownCallback func()) { + if c.grpcConn == nil { + return + } + for { + state := c.grpcConn.GetState() + // TransientFailure means recoverable failure (maybe), but Shutdown means the connection has been closed + if state == connectivity.Shutdown { + log.GetLogger().Errorf("connection to collector %s at %s lost or shutting down...", c.ID, c.Address) + shutdownCallback() + return + } + time.Sleep(time.Second) // Wait before checking again + } +} + +// CollectLog will collect the +func (c *collectorClient) CollectLog(ctx context.Context, readLogReq *logservice.ReadLogRequest, + outLog chan<- *logservice.ReadLogResponse) error { + stream, err := c.logClient.ReadLog(ctx, readLogReq) + if err != nil { + log.GetLogger().Warnf("failed to read log %s from %s, err: %s", readLogReq.Item.Filename, + readLogReq.Item.CollectorID, err.Error()) + close(outLog) + return err + } + return redirectLog(readLogReq, stream, outLog) +} + +func redirectLog(readLogReq *logservice.ReadLogRequest, stream logservice.LogCollectorService_ReadLogClient, + outLog chan<- *logservice.ReadLogResponse) error { + for { + response, err := stream.Recv() + if err == nil { + outLog <- response + continue + } + if err == io.EOF { + log.GetLogger().Infof("read log stream stopped for %s from %s", readLogReq.Item.Filename, + readLogReq.Item.CollectorID) + close(outLog) + return nil + } + log.GetLogger().Warnf("failed to receive log %s from %s, err: %s", readLogReq.Item.Filename, + readLogReq.Item.CollectorID, err.Error()) + close(outLog) + return err + } +} diff --git a/yuanrong/pkg/dashboard/logmanager/collector_client_test.go b/yuanrong/pkg/dashboard/logmanager/collector_client_test.go new file mode 100644 index 0000000..1a96f34 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/collector_client_test.go @@ -0,0 +1,194 @@ +package logmanager + +import ( + "context" + "errors" + "io" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +func TestMockDial(t *testing.T) { + Convey("Given a mock dial function", t, func() { + patches := gomonkey.ApplyFunc(grpc.Dial, func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return nil, errors.New("dial failed") + }) + defer patches.Reset() + + _, err := grpc.Dial("127.0.0.1:50051", grpc.WithTransportCredentials(insecure.NewCredentials())) + So(err, ShouldNotBeNil) + }) +} + +func TestCollectorClient_Connect(t *testing.T) { + Convey("Given a collectorClient instance", t, func() { + client := &collectorClient{ + collectorClientInfo: collectorClientInfo{ + ID: "collector-1", + Address: "127.0.0.1:50051", + }, + } + + Convey("When connecting to the collector successfully", func() { + // 模拟 grpc.Dial 成功 + patches := gomonkey.ApplyFunc(grpc.Dial, func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return &grpc.ClientConn{}, nil + }) + defer patches.Reset() + + err := client.Connect() + + Convey("Then the connection should be established", func() { + So(err, ShouldBeNil) + So(client.grpcConn, ShouldNotBeNil) + So(client.logClient, ShouldNotBeNil) + }) + }) + + Convey("When connecting to the collector fails", func() { + // 模拟 grpc.Dial 失败 + patches := gomonkey.ApplyFunc(grpc.Dial, func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return nil, errors.New("dial failed") + }) + defer patches.Reset() + err := client.Connect() + + Convey("Then the connection should fail", func() { + So(err, ShouldNotBeNil) + So(client.grpcConn, ShouldBeNil) + So(client.logClient, ShouldBeNil) + }) + }) + }) +} + +func TestCollectorClient_Healthcheck(t *testing.T) { + Convey("Given a collectorClient instance", t, func() { + client := &collectorClient{ + collectorClientInfo: collectorClientInfo{ + ID: "collector-1", + Address: "127.0.0.1:50051", + }, + grpcConn: &grpc.ClientConn{}, + } + + Convey("When the connection state is Shutdown", func() { + // 模拟 grpcConn.GetState 返回 Shutdown + patches := gomonkey.ApplyMethodFunc(reflect.TypeOf(&grpc.ClientConn{}), "GetState", + func() connectivity.State { + log.GetLogger().Infof("calling get state...") + return connectivity.Shutdown + }) + defer patches.Reset() + + // 用于记录 shutdownCallback 是否被调用 + shutdownCalled := make(chan string, 1) + shutdownCallback := func() { + shutdownCalled <- "closed" + } + + // 启动 Healthcheck + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + client.Healthcheck(shutdownCallback) + }() + // 等待 Healthcheck 检测到 Shutdown + select { + case req := <-shutdownCalled: + So(req, ShouldEqual, "closed") + case <-time.After(15 * time.Second): + t.Fatalf("not receive any start log stream request in 15 seconds") + } + + // 停止 Healthcheck + wg.Wait() + }) + }) +} + +func TestCollectorClient_CollectLog(t *testing.T) { + Convey("Given a collectorClient instance", t, func() { + client := &collectorClient{ + collectorClientInfo: collectorClientInfo{ + ID: "collector-1", + Address: "127.0.0.1:50051", + }, + grpcConn: &grpc.ClientConn{}, + logClient: &mockLogCollectorClient{}, + } + + Convey("When collecting logs successfully", func() { + // 模拟 logClient.ReadLog 成功 + patches := gomonkey.ApplyMethodFunc(&mockLogCollectorClient{}, "ReadLog", func(ctx context.Context, req *logservice.ReadLogRequest, opts ...grpc.CallOption) (logservice.LogCollectorService_ReadLogClient, error) { + return &fakeReadLogClient{}, nil + }) + defer patches.Reset() + + outLog := make(chan *logservice.ReadLogResponse, 1) + err := client.CollectLog(context.Background(), &logservice.ReadLogRequest{ + Item: &logservice.LogItem{ + Filename: "target-1", + CollectorID: "collector-1", + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-1", + }, + StartLine: 0, + EndLine: 0, + }, outLog) + + Convey("Then the log collection should start successfully", func() { + So(err, ShouldBeNil) + }) + }) + + Convey("When collecting logs fails", func() { + // 模拟 logClient.ReadLog 失败 + patches := gomonkey.ApplyMethodFunc(&mockLogCollectorClient{}, "ReadLog", func(ctx context.Context, req *logservice.ReadLogRequest, opts ...grpc.CallOption) (logservice.LogCollectorService_ReadLogClient, error) { + return nil, errors.New("read log failed") + }) + defer patches.Reset() + + outLog := make(chan *logservice.ReadLogResponse, 1) + err := client.CollectLog(context.Background(), &logservice.ReadLogRequest{ + Item: &logservice.LogItem{ + Filename: "target-1", + CollectorID: "collector-1", + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-1", + }, + StartLine: 0, + EndLine: 0, + }, outLog) + + Convey("Then the log collection should fail", func() { + So(err, ShouldNotBeNil) + }) + }) + }) +} + +// // 用于模拟 logservice.LogCollectorService_ReadLogClient +type fakeReadLogClient struct { + grpc.ClientStream +} + +func (c *fakeReadLogClient) Recv() (*logservice.ReadLogResponse, error) { + return &logservice.ReadLogResponse{Content: []byte("log content")}, io.EOF +} + +func (c *fakeReadLogClient) CloseSend() error { + return nil +} diff --git a/yuanrong/pkg/dashboard/logmanager/http_handlers.go b/yuanrong/pkg/dashboard/logmanager/http_handlers.go new file mode 100644 index 0000000..65148a4 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/http_handlers.go @@ -0,0 +1,188 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/dashboard/models" +) + +type listLogQuery struct { + SubmissionID string `form:"submission_id" json:"submission_id"` + InstanceID string `form:"instance_id" json:"instance_id"` +} + +type readLogQuery struct { + Filename string `form:"filename" json:"filename" validate:"required" example:"runtime-01234.log"` + StartLine uint32 `form:"start_line" json:"start_line" example:"0"` + EndLine uint32 `form:"end_line" json:"end_line" example:"5000"` +} + +type listLogsResponse = map[string][]string + +// ListLogsHandler function for get /logs, will list all logs by filename +// +// @Summary list all log files (INTERNAL ONLY) +// @Description return a map use collectorID (like nodeID) as key, filename(s) as value +// @Produce json +// @Tags logs,internal +// @Param params query listLogQuery false "params" +// @Success 200 {object} models.DashboardCommonResponse{data=listLogsResponse} "success" +// @Failure 400 {object} models.DashboardErrorResponse "invalid parameter" +// @Failure 500 {object} models.DashboardErrorResponse "some internal error, need to check log" +// @Router /api/logs/list [get] +func ListLogsHandler(ctx *gin.Context) { + if managerSingleton == nil { + log.GetLogger().Errorf("manager is nullptr!!") + ctx.JSON(http.StatusInternalServerError, gin.H{ + "message": "unexpected error", + }) + return + } + query := listLogQuery{} + err := ctx.BindQuery(&query) + if err != nil { + log.GetLogger().Errorf("bind query failed: %s", err.Error()) + ctx.JSON(http.StatusBadRequest, models.DashboardCommonResponse{ + Message: fmt.Sprintf("query binding failed: %s", err.Error()), + Data: nil, + }) + return + } + logEntries := managerSingleton.LogDB.Query(logDBQuery{ + SubmissionID: query.SubmissionID, + InstanceID: query.InstanceID, + }) + res := listLogsResponse{} + logEntries.Range(func(e *LogEntry) { + res[e.CollectorID] = append(res[e.CollectorID], e.Filename) + }) + ctx.JSON(http.StatusOK, models.SuccessDashboardCommonResponse(res)) +} + +// ReadLogHandler function for read logs +// +// @Summary read log file content (INTERNAL ONLY) +// @Description return log file (use a stream) +// @Produce octet-stream +// @Tags logs,internal +// @Param request query readLogQuery false "query params" +// @Success 200 {object} string "success" +// @Failure 400 {object} models.DashboardErrorResponse "invalid parameter" +// @Failure 500 {object} models.DashboardErrorResponse "some internal error, need to check log" +// @Router /api/logs [get] +func ReadLogHandler(ctx *gin.Context) { + log.GetLogger().Infof("receive read log request %#v", ctx.Request) + if managerSingleton == nil { + log.GetLogger().Errorf("!! manager is nil !! this should never happened !!") + ctx.JSON(http.StatusInternalServerError, models.DashboardCommonResponse{ + Message: "unexpected error", + Data: nil, + }) + return + } + query := readLogQuery{} + err := ctx.BindQuery(&query) + if err != nil { + log.GetLogger().Errorf("bind query failed: %s", err.Error()) + ctx.JSON(http.StatusInternalServerError, models.DashboardCommonResponse{ + Message: fmt.Sprintf("query binding failed: %s", err.Error()), + Data: nil, + }) + return + } + log.GetLogger().Infof("readlog handler get log query: %#v", query.Filename) + log.GetLogger().Infof("DB right now: %#v", managerSingleton.GetLogEntries()) + logEntries := managerSingleton.LogDB.Query(logDBQuery{ + Filename: query.Filename, + }) + log.GetLogger().Infof("found Entries: %#v", logEntries) + if logEntries.Len() != 1 { + msg := fmt.Sprintf("can not find log file or duplicate log files found, get %d files", logEntries.Len()) + log.GetLogger().Warn(msg) + ctx.JSON(http.StatusInternalServerError, models.DashboardCommonResponse{ + Message: msg, + Data: nil, + }) + return + } + + outLog := make(chan *logservice.ReadLogResponse, 100) + entry := logEntries.FindFirst() + + wait := make(chan struct{}) + go streamResponse(ctx, outLog, wait) + publishCollectRequest(ctx, entry, outLog, query) + <-wait +} + +func publishCollectRequest(ctx *gin.Context, entry *LogEntry, outLog chan<- *logservice.ReadLogResponse, + query readLogQuery) { + req := &logservice.ReadLogRequest{ + Item: entry.LogItem, + StartLine: query.StartLine, + EndLine: query.EndLine, + } + c := managerSingleton.GetCollector(entry.CollectorID) + if c == nil { + msg := fmt.Sprintf("can not find collector %s for file %s", entry.CollectorID, entry.Filename) + log.GetLogger().Warn(msg) + ctx.JSON(http.StatusInternalServerError, models.DashboardCommonResponse{ + Message: msg, + Data: nil, + }) + close(outLog) + return + } + err := c.CollectLog(ctx, req, outLog) + if err != nil { + log.GetLogger().Warnf("failed to collect log, err: %s", err.Error()) + ctx.JSON(http.StatusInternalServerError, models.DashboardCommonResponse{ + Message: err.Error(), + Data: nil, + }) + return + } +} + +func streamResponse(ctx *gin.Context, outLog <-chan *logservice.ReadLogResponse, wait chan struct{}) { + defer close(wait) + for { + select { + case resp, ok := <-outLog: + if !ok { + log.GetLogger().Info("end to read log.") + return + } + if resp.Code != 0 { + log.GetLogger().Warnf("failed to read log, errCode: %d, errMsg: %s", resp.Code, resp.Message) + return + } + _, err := ctx.Writer.Write(resp.Content) + if err != nil { + log.GetLogger().Warnf("failed to write resp, err: %s", err.Error()) + return + } + } + } +} diff --git a/yuanrong/pkg/dashboard/logmanager/http_handlers_test.go b/yuanrong/pkg/dashboard/logmanager/http_handlers_test.go new file mode 100644 index 0000000..6756a37 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/http_handlers_test.go @@ -0,0 +1,254 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/dashboard/etcdcache" +) + +var logManagerAddress = "127.0.0.1:55555" +var logCollectAddress = "127.0.0.1:55556" +var logManagerServer *grpc.Server +var logCollectorServer *grpc.Server + +// MockLogCollectorServiceServer - +type MockLogCollectorServiceServer struct { + logservice.UnimplementedLogCollectorServiceServer +} + +func (MockLogCollectorServiceServer) ReadLog(req *logservice.ReadLogRequest, s logservice.LogCollectorService_ReadLogServer) error { + if req.Item.Filename == "123" { + resp := &logservice.ReadLogResponse{Content: []byte("123")} + err := s.Send(resp) + if err != nil { + return err + } + return nil + } + return status.Error(codes.Unknown, "file not found") +} + +func startLogManager(t *testing.T) { + lis, err := net.Listen("tcp", logManagerAddress) + if err != nil { + fmt.Printf("%s\n", err.Error()) + } + assert.Nil(t, err) + logManagerServer = grpc.NewServer() + logservice.RegisterLogManagerServiceServer(logManagerServer, &Server{}) + + err = logManagerServer.Serve(lis) + if err != nil { + fmt.Printf("%s\n", err.Error()) + } + assert.Nil(t, err) +} + +func startLogCollector(t *testing.T) { + lis, err := net.Listen("tcp", logCollectAddress) + if err != nil { + fmt.Printf("%s\n", err.Error()) + } + assert.Nil(t, err) + logCollectorServer = grpc.NewServer() + logservice.RegisterLogCollectorServiceServer(logCollectorServer, &MockLogCollectorServiceServer{}) + + err = logCollectorServer.Serve(lis) + if err != nil { + fmt.Printf("%s\n", err.Error()) + } + assert.Nil(t, err) +} + +func register(t *testing.T, collectorID string) { + conn, err := grpc.Dial(logManagerAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + fmt.Printf("%s\n", err.Error()) + } + assert.Nil(t, err) + defer func(conn *grpc.ClientConn) { + _ = conn.Close() + }(conn) + logClient := logservice.NewLogManagerServiceClient(conn) + req := logservice.RegisterRequest{CollectorID: collectorID, Address: logCollectAddress} + _, err = logClient.Register(context.Background(), &req) + if err != nil { + fmt.Printf("%s\n", err.Error()) + } + assert.Nil(t, err) +} + +func reportLog(t *testing.T, fileName, runtimeID, collectorID string) *logservice.ReportLogResponse { + conn, err := grpc.Dial(logManagerAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + fmt.Printf("%s\n", err.Error()) + } + assert.Nil(t, err) + defer func(conn *grpc.ClientConn) { + _ = conn.Close() + }(conn) + logClient := logservice.NewLogManagerServiceClient(conn) + logItem := &logservice.LogItem{ + Filename: fileName, + CollectorID: collectorID, + RuntimeID: runtimeID, + Target: logservice.LogTarget_USER_STD, + } + req := logservice.ReportLogRequest{ + Items: make([]*logservice.LogItem, 0), + } + req.Items = append(req.Items, logItem) + resp, err := logClient.ReportLog(context.Background(), &req) + if err != nil { + fmt.Printf("%s\n", err.Error()) + } + assert.Nil(t, err) + return resp +} + +func closeServer() { + if logManagerServer != nil { + logManagerServer.Stop() + } + if logCollectorServer != nil { + logCollectorServer.Stop() + } + +} + +func TestReadLogHandler(t *testing.T) { + + convey.Convey("ReadLog succeed", t, func() { + go startLogManager(t) + go startLogCollector(t) + defer closeServer() + time.Sleep(100 * time.Millisecond) + register(t, "123") + etcdcache.InstanceCache.Put(&types.InstanceSpecification{ + InstanceID: "123", + JobID: "123", + RuntimeID: "123"}) + defer etcdcache.InstanceCache.Remove("123") + reportLog(t, "123", "123", "123") + defer managerSingleton.LogDB.Remove(&LogEntry{LogItem: &logservice.LogItem{ + Filename: "123", + RuntimeID: "123", + CollectorID: "123", + }}) + + rw := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rw) + ctx.Request, _ = http.NewRequest("GET", "/?filename=123", nil) + ReadLogHandler(ctx) + convey.So(rw.Body.String(), convey.ShouldEqual, "123") + }) + + convey.Convey("ReadLog failed with duplicate log file", t, func() { + go startLogManager(t) + go startLogCollector(t) + defer closeServer() + time.Sleep(100 * time.Millisecond) + register(t, "123") + + reportLog(t, "123", "123", "123") + defer managerSingleton.LogDB.Remove(&LogEntry{LogItem: &logservice.LogItem{ + Filename: "123", + RuntimeID: "123", + CollectorID: "123", + }}) + + reportLog(t, "123", "234", "123") + defer managerSingleton.LogDB.Remove(&LogEntry{LogItem: &logservice.LogItem{ + Filename: "123", + RuntimeID: "234", + CollectorID: "123", + }}) + + rw := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rw) + ctx.Request, _ = http.NewRequest("GET", "/?filename=123", nil) + ReadLogHandler(ctx) + convey.So(rw.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(rw.Body.String(), convey.ShouldContainSubstring, "duplicate") + }) + + convey.Convey("ReadLog failed with no collector", t, func() { + go startLogManager(t) + defer closeServer() + time.Sleep(100 * time.Millisecond) + reportLog(t, "123", "123", "234") + defer managerSingleton.LogDB.Remove(&LogEntry{LogItem: &logservice.LogItem{ + Filename: "123", + RuntimeID: "123", + CollectorID: "234", + }}) + rw := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rw) + ctx.Request, _ = http.NewRequest("GET", "/?filename=123", nil) + ReadLogHandler(ctx) + convey.So(rw.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(rw.Body.String(), convey.ShouldContainSubstring, "can not find collector") + }) + + convey.Convey("ReadLog failed with connect collector failed", t, func() { + go startLogManager(t) + defer closeServer() + time.Sleep(100 * time.Millisecond) + register(t, "123") + reportLog(t, "123", "123", "123") + defer managerSingleton.LogDB.Remove(&LogEntry{LogItem: &logservice.LogItem{ + Filename: "123", + RuntimeID: "123", + CollectorID: "123", + }}) + rw := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rw) + ctx.Request, _ = http.NewRequest("GET", "/?filename=123", nil) + ReadLogHandler(ctx) + convey.So(rw.Code, convey.ShouldEqual, http.StatusInternalServerError) + convey.So(rw.Body.String(), convey.ShouldContainSubstring, "connection error") + }) + +} + +func TestListLogsHandler(t *testing.T) { + convey.Convey("ReadLog succeed with empty logs", t, func() { + rw := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rw) + ctx.Request, _ = http.NewRequest("GET", "/?instance_id=123", nil) + ListLogsHandler(ctx) + convey.So(rw.Code, convey.ShouldEqual, http.StatusOK) + }) +} diff --git a/yuanrong/pkg/dashboard/logmanager/log_db.go b/yuanrong/pkg/dashboard/logmanager/log_db.go new file mode 100644 index 0000000..aef182a --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/log_db.go @@ -0,0 +1,135 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package logmanager - this file is about log db +package logmanager + +import ( + "encoding/json" + "sync" +) + +type logDBQuery struct { + Filename string `form:"filename" json:"filename"` + SubmissionID string `form:"submission_id" json:"submission_id"` + InstanceID string `form:"instance_id" json:"instance_id"` + RuntimeID string + CollectorID string +} + +type putOptionIfExists int32 + +const ( + putOptionIfExistsReplace putOptionIfExists = iota + putOptionIfExistsNoop +) + +// LogDB interface +type LogDB interface { + Put(item *LogEntry, replaceIfExists putOptionIfExists) + Remove(item *LogEntry) + GetLogEntries() *LogEntries + Serialize() ([]byte, error) + Deserialize([]byte) error + + // Query by union mode by default, will join all Entries satisfy at least one condition, treat empty query condition + // as NO SELECTION : i.e. return all results + Query(query logDBQuery) *LogEntries +} + +type generalLogDBImpl struct { + Entries *LogEntries `json:"items"` + allIndex []LogIndex + + mtx sync.RWMutex // To protect Entries +} + +func newGeneralLogDBImpl() *generalLogDBImpl { + return &generalLogDBImpl{ + Entries: NewLogEntries(), + allIndex: []LogIndex{ + NewLogIndex(func(item *LogEntry) string { return item.InstanceID }, + func(query logDBQuery) string { return query.InstanceID }), + NewLogIndex(func(item *LogEntry) string { return item.Filename }, + func(query logDBQuery) string { return query.Filename }), + NewLogIndex(func(item *LogEntry) string { return item.JobID }, + func(query logDBQuery) string { return query.SubmissionID }), + NewLogIndex(func(item *LogEntry) string { return item.RuntimeID }, + func(query logDBQuery) string { return query.RuntimeID }), + }, + } +} + +// Put an item into the db, should contain corresponding jobs/instance-id +func (g *generalLogDBImpl) Put(item *LogEntry, replaceIfExists putOptionIfExists) { + if item == nil { + return + } + g.mtx.Lock() + defer g.mtx.Unlock() + g.Entries.Put(item, replaceIfExists) + for k, _ := range g.allIndex { + g.allIndex[k].Put(item) + } +} + +// Remove an entry +func (g *generalLogDBImpl) Remove(item *LogEntry) { + g.mtx.Lock() + g.Entries.Delete(item.ID()) + g.mtx.Unlock() + + for k, _ := range g.allIndex { + g.allIndex[k].Remove(item) + } +} + +// GetLogEntries - +func (g *generalLogDBImpl) GetLogEntries() *LogEntries { + return g.Entries +} + +// Serialize to byte +func (g *generalLogDBImpl) Serialize() ([]byte, error) { + g.mtx.RLock() + defer g.mtx.RUnlock() + return json.Marshal(g.Entries) +} + +// Deserialize from binary +func (g *generalLogDBImpl) Deserialize(data []byte) error { + entries := NewLogEntries() + if err := json.Unmarshal(data, &entries); err != nil { + return err + } + // re-put to reconstruct the index + entries.Range(func(entry *LogEntry) { + g.Put(entry, putOptionIfExistsReplace) + }) + return nil +} + +// Query make sure no nil will be returns, only empty possible +func (g *generalLogDBImpl) Query(query logDBQuery) *LogEntries { + allMatches := []*LogEntries{g.Entries} + for k, _ := range g.allIndex { + if queryString := g.allIndex[k].indexKeyByQuery(query); queryString != "" { + allMatches = append(allMatches, g.allIndex[k].Query(queryString)) + } + } + allSatisfiedEntries := logEntriesIntersection(allMatches...) + return allSatisfiedEntries +} diff --git a/yuanrong/pkg/dashboard/logmanager/log_db_test.go b/yuanrong/pkg/dashboard/logmanager/log_db_test.go new file mode 100644 index 0000000..129a0e0 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/log_db_test.go @@ -0,0 +1,163 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "sync" + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +func TestGeneralLogDBImpl(t *testing.T) { + Convey("Given a generalLogDBImpl instance", t, func() { + db := newGeneralLogDBImpl() + + // 测试数据 + logItem := &logservice.LogItem{ + Filename: "test.log", + CollectorID: "collector-1", + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-1", + } + logEntry := &LogEntry{ + LogItem: logItem, + JobID: "job-1", + InstanceID: "instance-1", + } + + Convey("When putting a LogEntry into the DB", func() { + db.Put(logEntry, putOptionIfExistsReplace) + Convey("Then the LogEntry should be added to the Entries and indexes", func() { + + So(db.Entries.Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + So(db.Query(logDBQuery{InstanceID: "instance-1"}).Get(logEntry.ID()), ShouldEqual, logEntry) + So(db.Query(logDBQuery{SubmissionID: "job-1"}).Get(logEntry.ID()), ShouldEqual, logEntry) + So(db.Query(logDBQuery{Filename: "test.log"}).Get(logEntry.ID()), ShouldEqual, logEntry) + }) + }) + + Convey("When removing a LogEntry from the DB", func() { + db.Put(logEntry, putOptionIfExistsReplace) // 先添加 + db.Remove(logEntry) + + Convey("Then the LogEntry should be removed from the Entries and indexes", func() { + So(db.Entries.Get(logEntry.ID()), ShouldBeNil) + So(db.Query(logDBQuery{InstanceID: "instance-1"}).Get(logEntry.ID()), ShouldNotEqual, logEntry) + So(db.Query(logDBQuery{SubmissionID: "job-1"}).Get(logEntry.ID()), ShouldNotEqual, logEntry) + So(db.Query(logDBQuery{Filename: "test.log"}).Get(logEntry.ID()), ShouldNotEqual, logEntry) + }) + }) + + Convey("When getting all LogEntries", func() { + db.Put(logEntry, putOptionIfExistsReplace) + entries := db.GetLogEntries() + + Convey("Then the returned Entries should match the DB's Entries", func() { + So(entries.Get(logEntry.ID()), ShouldEqual, logEntry) + }) + }) + + Convey("When serializing && deserializing the DB", func() { + db.Put(logEntry, putOptionIfExistsReplace) + log.GetLogger().Infof("Entries: %#v", db.Entries) + + data, err := db.Serialize() + log.GetLogger().Infof("Entries: %s", data) + So(err, ShouldBeNil) + + newDB := newGeneralLogDBImpl() + err = newDB.Deserialize(data) + So(err, ShouldBeNil) + + So(newDB.Query(logDBQuery{InstanceID: "instance-1"}).Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + So(newDB.Query(logDBQuery{SubmissionID: "job-1"}).Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + So(newDB.Query(logDBQuery{Filename: "test.log"}).Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + }) + + Convey("When querying the DB with a valid query", func() { + db.Put(logEntry, putOptionIfExistsReplace) + query := logDBQuery{ + Filename: "test.log", + InstanceID: "instance-1", + SubmissionID: "job-1", + } + result := db.Query(query) + Convey("Then the result should contain the matching LogEntry", func() { + So(result.Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + }) + }) + + Convey("When querying the DB with an invalid query", func() { + db.Put(logEntry, putOptionIfExistsReplace) + query := logDBQuery{ + Filename: "non-existent.log", + InstanceID: "non-existent-instance", + SubmissionID: "non-existent-job", + } + result := db.Query(query) + + log.GetLogger().Infof("result: %v", result) + Convey("Then the result should be an empty LogEntries map", func() { + So(result.Len(), ShouldEqual, 0) + }) + }) + + Convey("When putting a LogEntry concurrently", func() { + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + db.Put(logEntry, putOptionIfExistsReplace) + }() + } + wg.Wait() + + Convey("Then the LogEntry should be added correctly without race conditions", func() { + So(db.Entries.Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + + So(db.Query(logDBQuery{InstanceID: "instance-1"}).Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + So(db.Query(logDBQuery{SubmissionID: "job-1"}).Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + So(db.Query(logDBQuery{Filename: "test.log"}).Get(logEntry.ID()).Equal(logEntry), ShouldBeTrue) + }) + }) + + Convey("When removing a LogEntry concurrently", func() { + db.Put(logEntry, putOptionIfExistsReplace) // 先添加 + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + db.Remove(logEntry) + }() + } + wg.Wait() + + Convey("Then the LogEntry should be removed correctly without race conditions", func() { + So(db.Entries.Get(logEntry.ID()), ShouldBeNil) + So(db.Query(logDBQuery{InstanceID: "instance-1"}).Get(logEntry.ID()), ShouldNotEqual, logEntry) + So(db.Query(logDBQuery{SubmissionID: "job-1"}).Get(logEntry.ID()), ShouldNotEqual, logEntry) + So(db.Query(logDBQuery{Filename: "test.log"}).Get(logEntry.ID()), ShouldNotEqual, logEntry) + }) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/logmanager/log_entry.go b/yuanrong/pkg/dashboard/logmanager/log_entry.go new file mode 100644 index 0000000..a5c5dc8 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/log_entry.go @@ -0,0 +1,166 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "fmt" + "sync" + + "google.golang.org/protobuf/proto" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" +) + +// LogEntry represents a log file, of one log exporter +type LogEntry struct { + // reported by collectorClient + *logservice.LogItem + + // matched by manager + JobID string + InstanceID string +} + +// Equal - +func (li *LogEntry) Equal(that *LogEntry) bool { + if li == nil || that == nil { + return li == that + } + return proto.Equal(li.LogItem, that.LogItem) && li.JobID == that.JobID && li.InstanceID == li.InstanceID +} + +// String - +func (li *LogEntry) String() string { + return li.LogItem.String() + fmt.Sprintf(` instanceID:"%s" jobID:"%s"`, li.InstanceID, li.JobID) +} + +// NewLogEntry returns a log item pointer +func NewLogEntry(item *logservice.LogItem) *LogEntry { + if item == nil { + return nil + } + return &LogEntry{ + LogItem: item, + } +} + +// ID returns the key +func (li *LogEntry) ID() string { + return fmt.Sprintf("%s//%s//%s", li.Filename, li.CollectorID, li.RuntimeID) +} + +// LogEntries is a key-value map, key is item.ID(). This structure is to avoid complexity when delete some item from +// memory storage. +type LogEntries struct { + Entries map[string]*LogEntry `json:"Entries"` + mtx sync.RWMutex +} + +// NewLogEntries - +func NewLogEntries() *LogEntries { + return &LogEntries{Entries: map[string]*LogEntry{}} +} + +// FindFirst entry +func (l *LogEntries) FindFirst() *LogEntry { + l.mtx.Lock() + defer l.mtx.Unlock() + for _, v := range l.Entries { + return v + } + return nil +} + +// Get an entry by key +func (l *LogEntries) Get(k string) *LogEntry { + l.mtx.RLock() + defer l.mtx.RUnlock() + if e, ok := l.Entries[k]; ok { + return e + } + return nil +} + +// Put an entry +func (l *LogEntries) Put(item *LogEntry, replaceIfExists putOptionIfExists) { + l.mtx.Lock() + defer l.mtx.Unlock() + if _, ok := l.Entries[item.ID()]; ok && replaceIfExists == putOptionIfExistsNoop { + return + } + l.Entries[item.ID()] = item +} + +// Delete an entry +func (l *LogEntries) Delete(key string) { + l.mtx.Lock() + defer l.mtx.Unlock() + delete(l.Entries, key) +} + +// Range over all Entries, true to continue +func (l *LogEntries) Range(processor func(entry *LogEntry)) { + l.mtx.Lock() + defer l.mtx.Unlock() + for _, e := range l.Entries { + processor(e) + } +} + +// Len return length +func (l *LogEntries) Len() int { + l.mtx.RLock() + defer l.mtx.RUnlock() + return len(l.Entries) +} + +// String return string +func (l *LogEntries) String() string { + l.mtx.RLock() + defer l.mtx.RUnlock() + var result string + for k, e := range l.Entries { + result += fmt.Sprintf(";k(%s),v(%#v)", k, e.String()) + } + return result +} + +func logEntriesIntersection(allEntries ...*LogEntries) *LogEntries { + le := NewLogEntries() + if len(allEntries) == 0 { + return le + } + + allEntries[0].mtx.RLock() + defer allEntries[0].mtx.RUnlock() + + for k, e := range allEntries[0].Entries { // for each key + matchedCnt := 0 + for _, each := range allEntries[1:] { // try match for each entry + each.mtx.RLock() + if _, ok := each.Entries[k]; ok { + matchedCnt = matchedCnt + 1 + } + each.mtx.RUnlock() + } + if matchedCnt == len(allEntries)-1 { // must match at least len(allEntries) - 1 + le.Entries[k] = e + } + } + + return le +} diff --git a/yuanrong/pkg/dashboard/logmanager/log_entry_test.go b/yuanrong/pkg/dashboard/logmanager/log_entry_test.go new file mode 100644 index 0000000..2ad58c6 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/log_entry_test.go @@ -0,0 +1,150 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" +) + +func TestLogEntry(t *testing.T) { + Convey("Given a LogEntry instance", t, func() { + // 正常场景 + Convey("When creating a new LogEntry with valid LogItem", func() { + logItem := &logservice.LogItem{ + Filename: "test.log", + CollectorID: "collector-1", + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-1", + } + logEntry := NewLogEntry(logItem) + + Convey("Then the LogEntry should be initialized correctly", func() { + So(logEntry.LogItem, ShouldEqual, logItem) + So(logEntry.JobID, ShouldBeEmpty) + So(logEntry.InstanceID, ShouldBeEmpty) + }) + + Convey("Then the ID method should return the correct key", func() { + expectedID := "test.log//collector-1//runtime-1" + So(logEntry.ID(), ShouldEqual, expectedID) + }) + }) + }) +} + +func TestLogEntriesIntersection(t *testing.T) { + Convey("Given multiple LogEntries maps", t, func() { + // 正常场景 + Convey("When computing the intersection of multiple LogEntries", func() { + entry1 := &LogEntry{ + LogItem: &logservice.LogItem{ + Filename: "file1.log", + CollectorID: "collector-1", + RuntimeID: "runtime-1", + }, + } + entry2 := &LogEntry{ + LogItem: &logservice.LogItem{ + Filename: "file2.log", + CollectorID: "collector-2", + RuntimeID: "runtime-2", + }, + } + entry3 := &LogEntry{ + LogItem: &logservice.LogItem{ + Filename: "file1.log", + CollectorID: "collector-1", + RuntimeID: "runtime-1", + }, + } + + entries1 := NewLogEntries() + entries1.Put(entry1, putOptionIfExistsReplace) + entries1.Put(entry2, putOptionIfExistsReplace) + entries2 := NewLogEntries() + entries2.Put(entry1, putOptionIfExistsReplace) + entries2.Put(entry3, putOptionIfExistsReplace) + + intersection := logEntriesIntersection(entries1, entries2) + + Convey("Then the intersection should contain only common Entries", func() { + So(intersection.Len(), ShouldEqual, 1) + So(intersection.Get(entry1.ID()), ShouldEqual, entry1) + }) + }) + + // 异常场景 + Convey("When computing the intersection with no input maps", func() { + intersection := logEntriesIntersection() + + Convey("Then the intersection should be an empty map", func() { + So(intersection.Len(), ShouldEqual, 0) + }) + }) + + Convey("When computing the intersection with one input map", func() { + entry := &LogEntry{ + LogItem: &logservice.LogItem{ + Filename: "file1.log", + CollectorID: "collector-1", + RuntimeID: "runtime-1", + }, + } + entries := NewLogEntries() + entries.Put(entry, putOptionIfExistsReplace) + + intersection := logEntriesIntersection(entries) + + Convey("Then the intersection should be the same as the input map", func() { + So(intersection.Len(), ShouldEqual, 1) + So(intersection.Get(entry.ID()), ShouldEqual, entry) + }) + }) + + Convey("When computing the intersection with no common Entries", func() { + entry1 := &LogEntry{ + LogItem: &logservice.LogItem{ + Filename: "file1.log", + CollectorID: "collector-1", + RuntimeID: "runtime-1", + }, + } + entry2 := &LogEntry{ + LogItem: &logservice.LogItem{ + Filename: "file2.log", + CollectorID: "collector-2", + RuntimeID: "runtime-2", + }, + } + + entries1 := NewLogEntries() + entries1.Put(entry1, putOptionIfExistsReplace) + entries2 := NewLogEntries() + entries2.Put(entry2, putOptionIfExistsReplace) + + intersection := logEntriesIntersection(entries1, entries2) + + Convey("Then the intersection should be an empty map", func() { + So(intersection.Len(), ShouldEqual, 0) + }) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/logmanager/log_index.go b/yuanrong/pkg/dashboard/logmanager/log_index.go new file mode 100644 index 0000000..b8a28b0 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/log_index.go @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "sync" +) + +// LogIndex is an index to speed up query (by some key) +type LogIndex struct { + indexKeyByItem func(entry *LogEntry) string + index map[string]*LogEntries + indexKeyByQuery func(query logDBQuery) string + mtx sync.RWMutex // TO protect index +} + +// NewLogIndex - +func NewLogIndex(indexKeyByItem func(entry *LogEntry) string, indexKeyByQuery func(query logDBQuery) string) LogIndex { + return LogIndex{ + indexKeyByItem: indexKeyByItem, + indexKeyByQuery: indexKeyByQuery, + index: map[string]*LogEntries{}, + } +} + +// Put an entry in the index +func (g *LogIndex) Put(item *LogEntry) { + if item == nil || item.LogItem == nil { + return + } + key := g.indexKeyByItem(item) + g.mtx.Lock() + defer g.mtx.Unlock() + if _, ok := g.index[key]; !ok { + g.index[key] = NewLogEntries() + } + g.index[key].Put(item, putOptionIfExistsReplace) +} + +// Remove an entry from index +func (g *LogIndex) Remove(item *LogEntry) { + if item == nil || item.LogItem == nil { + return + } + key := g.indexKeyByItem(item) + g.mtx.Lock() + defer g.mtx.Unlock() + if _, ok := g.index[key]; ok { + g.index[key].Delete(item.ID()) + } +} + +// Query by key +func (g *LogIndex) Query(queryIndexKey string) *LogEntries { + g.mtx.RLock() + defer g.mtx.RUnlock() + if e, ok := g.index[queryIndexKey]; ok { + return e + } + return NewLogEntries() +} diff --git a/yuanrong/pkg/dashboard/logmanager/log_index_test.go b/yuanrong/pkg/dashboard/logmanager/log_index_test.go new file mode 100644 index 0000000..a162118 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/log_index_test.go @@ -0,0 +1,136 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "sync" + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +func TestLogIndex(t *testing.T) { + Convey("Given a LogIndex instance", t, func() { + // 初始化 LogIndex + logIndex := NewLogIndex(func(entry *LogEntry) string { return entry.CollectorID }, + func(query logDBQuery) string { return query.CollectorID }) + // 测试数据 + logItem := &logservice.LogItem{ + Filename: "test.log", + CollectorID: "collector-1", + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-1", + } + logEntry := NewLogEntry(logItem) + + Convey("When putting a LogEntry into the index", func() { + logIndex.Put(logEntry) + + Convey("Then the LogEntry should be added to the correct index keys", func() { + So(logIndex.index["collector-1"].Get(logEntry.ID()), ShouldEqual, logEntry) + }) + }) + + Convey("When removing a LogEntry from the index", func() { + logIndex.Put(logEntry) // 先添加 + logIndex.Remove(logEntry) + + Convey("Then the LogEntry should be removed from the index keys", func() { + So(logIndex.index["collector-1"].Get(logEntry.ID()), ShouldBeNil) + }) + }) + + Convey("When querying the index with a valid key", func() { + logIndex.Put(logEntry) + result := logIndex.Query("collector-1") + + Convey("Then the result should contain the LogEntry", func() { + So(result.Get(logEntry.ID()), ShouldEqual, logEntry) + }) + }) + + Convey("When querying the index with an invalid key", func() { + result := logIndex.Query("invalid-key") + + Convey("Then the result should be an empty LogEntries map", func() { + So(result.Len(), ShouldBeZeroValue) + }) + }) + + Convey("When putting a LogEntry with nil LogItem", func() { + nilLogEntry := NewLogEntry(nil) + logIndex.Put(nilLogEntry) + + Convey("Then the LogEntry should not be added to the index", func() { + So(logIndex.index[""], ShouldBeNil) + }) + }) + + Convey("When removing a LogEntry that does not exist", func() { + lenBefore := len(logIndex.index) + log.GetLogger().Infof("before, should be %d", lenBefore) + nonExistentEntry := &LogEntry{ + LogItem: &logservice.LogItem{ + Filename: "non-existent.log", + CollectorID: "non-existent-collector", + RuntimeID: "non-existent-runtime", + }, + } + logIndex.Remove(nonExistentEntry) + + Convey("Then the index should remain unchanged", func() { + So(len(logIndex.index), ShouldEqual, lenBefore) + }) + }) + + Convey("When putting a LogEntry concurrently", func() { + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + logIndex.Put(logEntry) + }() + } + wg.Wait() + + Convey("Then the LogEntry should be added correctly without race conditions", func() { + So(logIndex.index["collector-1"].Get(logEntry.ID()), ShouldEqual, logEntry) + }) + }) + + Convey("When removing a LogEntry concurrently", func() { + logIndex.Put(logEntry) // 先添加 + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + logIndex.Remove(logEntry) + }() + } + wg.Wait() + + Convey("Then the LogEntry should be removed correctly without race conditions", func() { + So(logIndex.index["collector-1"].Get(logEntry.ID()), ShouldBeNil) + }) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/logmanager/log_manager.go b/yuanrong/pkg/dashboard/logmanager/log_manager.go new file mode 100644 index 0000000..a3bdc20 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/log_manager.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/dashboard/etcdcache" +) + +const userStdStreamPrefix = "/log/runtime/std/" + +type manager struct { + Collectors map[string]collectorClient + pendingItems chan *logservice.LogItem + LogDB + + mtx sync.RWMutex // TO protect Collectors +} + +var ( + managerSingleton *manager + logManagerOnce sync.Once +) + +func init() { + logManagerOnce.Do(func() { + managerSingleton = &manager{ + Collectors: map[string]collectorClient{}, + LogDB: newGeneralLogDBImpl(), + } + etcdcache.InstanceCache.RegisterInstanceStartHandler(managerSingleton.OnInstanceStart) + etcdcache.InstanceCache.RegisterInstanceExitHandler(managerSingleton.OnInstanceExit) + }) +} + +// RegisterLogCollector - +func (m *manager) RegisterLogCollector(collectorInfo collectorClientInfo) error { + client := collectorClient{ + collectorClientInfo: collectorInfo, + } + err := client.Connect() + if err != nil { + return err + } + go client.Healthcheck(func() { + // unregister self if shutdown + m.UnregisterLogCollector(collectorInfo.ID) + }) + m.mtx.Lock() + defer m.mtx.Unlock() + m.Collectors[collectorInfo.ID] = client + return nil +} + +// UnregisterLogCollector - +func (m *manager) UnregisterLogCollector(id string) { + m.mtx.RLock() + c, ok := m.Collectors[id] + m.mtx.RUnlock() + if !ok { + return + } + + m.mtx.Lock() + delete(m.Collectors, id) + m.mtx.Unlock() + + if err := c.grpcConn.Close(); err != nil { + log.GetLogger().Warnf("failed to close connection to collector %s at %s: %s", id, c.Address, err.Error()) + } +} + +// GetCollector - +func (m *manager) GetCollector(id string) *collectorClient { + m.mtx.RLock() + defer m.mtx.RUnlock() + if c, ok := m.Collectors[id]; ok { + return &c + } + return nil +} + +// OnInstanceStart handles event when instance is running +func (m *manager) OnInstanceStart(instance *types.InstanceSpecification) { + log.GetLogger().Infof("running instance %s started cb", instance.InstanceID) + if isDriverInstance(instance) { + // do nothing to driver + log.GetLogger().Debugf("skip driver instance event of %s", instance.InstanceID) + return + } + // then try to fulfill the log entry + m.fulfillLogEntryInLogDB(instance) +} + +// ReportLogItem handles the report request +func (m *manager) ReportLogItem(item *logservice.LogItem) { + // match and check the component id (runtime id) + m.LogDB.Put(NewLogEntry(item), putOptionIfExistsNoop) + if item.RuntimeID != "" { + // if not empty, this is a runtime, try match the runtime and try to fulfill the log entry + m.fulfillLogEntryInLogDB(etcdcache.InstanceCache.GetByRuntimeID(item.RuntimeID)) + } +} + +func (m *manager) fulfillLogEntryInLogDB(instance *types.InstanceSpecification) { + if instance == nil { + log.GetLogger().Debugf("try fulfill a log entry with nil instance ptr") + return + } + result := m.LogDB.Query(logDBQuery{RuntimeID: instance.RuntimeID}) + result.Range(func(entry *LogEntry) { + if entry.InstanceID == instance.InstanceID { + log.GetLogger().Debugf("entry of instance(%s) with filename(%s) on collector(%s) is already been set, "+ + "no update need", instance.InstanceID, entry.Filename, entry.CollectorID) + return + } + entry.InstanceID = instance.InstanceID + entry.JobID = instance.JobID + // actually performs an in-place modification, but still put it to avoid some unexpected problem + m.LogDB.Put(entry, putOptionIfExistsReplace) + }) +} + +// OnInstanceExit - +func (m *manager) OnInstanceExit(instance *types.InstanceSpecification) { + if !isDriverInstance(instance) { + return + } +} + +func isDriverInstance(instance *types.InstanceSpecification) bool { + return strings.HasPrefix(instance.InstanceID, "driver") && instance.ParentID == "" +} diff --git a/yuanrong/pkg/dashboard/logmanager/log_manager_test.go b/yuanrong/pkg/dashboard/logmanager/log_manager_test.go new file mode 100644 index 0000000..09e3112 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/log_manager_test.go @@ -0,0 +1,135 @@ +package logmanager + +import ( + "context" + "google.golang.org/grpc" + "reflect" + "testing" + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/dashboard/etcdcache" +) + +type mockLogCollectorClient struct { +} + +func newMockLogCollectorClient() *mockLogCollectorClient { + m := &mockLogCollectorClient{} + return m +} + +func (m *mockLogCollectorClient) Reset() { +} + +func (m *mockLogCollectorClient) ReadLog(ctx context.Context, in *logservice.ReadLogRequest, + opts ...grpc.CallOption) (logservice.LogCollectorService_ReadLogClient, error) { + //m.ReadLogRequestCh <- in + return nil, nil +} + +func TestManager(t *testing.T) { + Convey("Given a manager instance", t, func() { + // 初始化 manager + m := managerSingleton + + // 测试数据 + fakeLogCollectorClient := newMockLogCollectorClient() + collectorInfo := collectorClientInfo{ + ID: "collector-1", + Address: "127.0.0.1:50051", + } + logItem := &logservice.LogItem{ + Filename: "test.log", + CollectorID: "collector-1", + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-1", + } + instance := &types.InstanceSpecification{ + InstanceID: "instance-1", + JobID: "job-1", + RuntimeID: "runtime-1", + } + + Convey("When registering a log collector", func() { + // 模拟 collectorClient.Connect 成功 + patches := gomonkey.ApplyMethodFunc(reflect.TypeOf(&grpc.ClientConn{}), "Connect", func() error { + return nil + }) + defer patches.Reset() + + err := m.RegisterLogCollector(collectorInfo) + Convey("Then the collector should be registered successfully", func() { + So(err, ShouldBeNil) + c := m.GetCollector(collectorInfo.ID) + So(c, ShouldNotBeNil) + So(c.ID, ShouldEqual, collectorInfo.ID) + }) + }) + + Convey("When unregistering a log collector", func() { + // 先注册一个 collector + m.Collectors[collectorInfo.ID] = collectorClient{ + collectorClientInfo: collectorInfo, + } + + // 模拟 grpcConn.Close 成功 + patches := gomonkey.ApplyMethodFunc(reflect.TypeOf(&grpc.ClientConn{}), "Close", func() error { + return nil + }) + defer patches.Reset() + + m.UnregisterLogCollector(collectorInfo.ID) + + Convey("Then the collector should be unregistered successfully", func() { + So(m.GetCollector(collectorInfo.ID), ShouldBeNil) + }) + }) + + Convey("When getting a collector by ID", func() { + // 先注册一个 collector + m.Collectors[collectorInfo.ID] = collectorClient{ + collectorClientInfo: collectorInfo, + logClient: fakeLogCollectorClient, + } + + collector := m.GetCollector(collectorInfo.ID) + + Convey("Then the collector should be returned", func() { + So(collector, ShouldNotBeNil) + So(collector.ID, ShouldEqual, collectorInfo.ID) + }) + }) + + Convey("When reporting a log item with empty RuntimeID", func() { + logItem := &logservice.LogItem{ + Filename: "test.log", + CollectorID: "collector-1", + Target: logservice.LogTarget_USER_STD, + RuntimeID: "", // 空 RuntimeID + } + + m.ReportLogItem(logItem) + + Convey("Then the log item should be added to the LogDB", func() { + entry := m.LogDB.GetLogEntries().Get(logItem.Filename + "//" + logItem.CollectorID + "//" + logItem.RuntimeID) + So(entry, ShouldNotBeNil) + So(entry.LogItem, ShouldResemble, logItem) + }) + }) + + Convey("When reporting a log item with non-empty RuntimeID", func() { + // 模拟 etcdcache.InstanceCache.GetByRuntimeID 返回实例 + patches := gomonkey.ApplyMethodReturn(&etcdcache.InstanceCache, "GetByRuntimeID", instance) + defer patches.Reset() + + m.ReportLogItem(logItem) + + Convey("Then the log item should be added to the LogDB with JobID and InstanceID", func() { + entry := m.LogDB.GetLogEntries().Get(logItem.Filename + "//" + logItem.CollectorID + "//" + logItem.RuntimeID) + So(entry, ShouldNotBeNil) + So(entry.JobID, ShouldEqual, instance.JobID) + So(entry.InstanceID, ShouldEqual, instance.InstanceID) + }) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/logmanager/service.go b/yuanrong/pkg/dashboard/logmanager/service.go new file mode 100644 index 0000000..9e26e10 --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/service.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "context" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" + "yuanrong/pkg/common/faas_common/logger/log" +) + +// Server 实现了 LogManagerService 服务 +type Server struct { + logservice.UnimplementedLogManagerServiceServer +} + +// Register 处理 Register RPC 请求 +func (s *Server) Register(_ context.Context, req *logservice.RegisterRequest) (*logservice.RegisterResponse, error) { + log.GetLogger().Infof("receive register request: %v", req) + err := managerSingleton.RegisterLogCollector(collectorClientInfo{ + ID: req.CollectorID, + Address: req.Address, + }) + if err != nil { + return nil, err + } + return &logservice.RegisterResponse{ + Code: 0, + }, nil +} + +// ReportLog 处理 ReportLog RPC 请求 +func (s *Server) ReportLog(_ context.Context, req *logservice.ReportLogRequest) (*logservice.ReportLogResponse, error) { + log.GetLogger().Infof("receive report request: %v", req) + for _, item := range req.GetItems() { + managerSingleton.ReportLogItem(item) + } + return &logservice.ReportLogResponse{ + Code: 0, + Message: "Log reported successfully", + }, nil +} diff --git a/yuanrong/pkg/dashboard/logmanager/service_test.go b/yuanrong/pkg/dashboard/logmanager/service_test.go new file mode 100644 index 0000000..aae563b --- /dev/null +++ b/yuanrong/pkg/dashboard/logmanager/service_test.go @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logmanager + +import ( + "context" + "errors" + "testing" + + "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/grpc/pb/logservice" +) + +func TestServer(t *testing.T) { + Convey("Given a Server instance", t, func() { + server := &Server{} + + // 测试数据 + registerReq := &logservice.RegisterRequest{ + CollectorID: "collector-1", + Address: "127.0.0.1:50051", + } + reportLogReq := &logservice.ReportLogRequest{ + Items: []*logservice.LogItem{ + { + Filename: "test.log", + CollectorID: "collector-1", + Target: logservice.LogTarget_USER_STD, + RuntimeID: "runtime-1", + }, + }, + } + + Convey("When handling a Register RPC request", func() { + // 模拟 managerSingleton.RegisterLogCollector 成功 + patches := gomonkey.ApplyMethodReturn(managerSingleton, "RegisterLogCollector", nil) + defer patches.Reset() + + resp, err := server.Register(context.Background(), registerReq) + + Convey("Then the request should be handled successfully", func() { + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + }) + }) + + Convey("When handling a Register RPC request with an error", func() { + // 模拟 managerSingleton.RegisterLogCollector 返回错误 + patches := gomonkey.ApplyMethodReturn(managerSingleton, "RegisterLogCollector", errors.New("register failed")) + defer patches.Reset() + + resp, err := server.Register(context.Background(), registerReq) + + Convey("Then the request should return an error", func() { + So(err, ShouldNotBeNil) + So(resp, ShouldBeNil) + }) + }) + + Convey("When handling a ReportLog RPC request", func() { + // 模拟 managerSingleton.ReportLogItem + patches := gomonkey.ApplyMethodFunc(managerSingleton, "ReportLogItem", func(item *logservice.LogItem) {}) + defer patches.Reset() + + resp, err := server.ReportLog(context.Background(), reportLogReq) + + Convey("Then the request should be handled successfully", func() { + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.Code, ShouldEqual, 0) + So(resp.Message, ShouldEqual, "Log reported successfully") + }) + }) + + Convey("When handling a ReportLog RPC request with empty items", func() { + // 模拟 managerSingleton.ReportLogItem + patches := gomonkey.ApplyMethodFunc(managerSingleton, "ReportLogItem", func(item *logservice.LogItem) {}) + defer patches.Reset() + + emptyReq := &logservice.ReportLogRequest{ + Items: []*logservice.LogItem{}, + } + resp, err := server.ReportLog(context.Background(), emptyReq) + + Convey("Then the request should be handled successfully", func() { + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.Code, ShouldEqual, 0) + So(resp.Message, ShouldEqual, "Log reported successfully") + }) + }) + }) +} diff --git a/yuanrong/pkg/dashboard/models/common_response.go b/yuanrong/pkg/dashboard/models/common_response.go new file mode 100644 index 0000000..3d2ad50 --- /dev/null +++ b/yuanrong/pkg/dashboard/models/common_response.go @@ -0,0 +1,31 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package models + +// DashboardCommonResponse is the common response, so all dashboard response will share a same structure +type DashboardCommonResponse struct { + Message string `json:"message"` + Data any `json:"data"` +} + +// SuccessDashboardCommonResponse fill data field only, message will be automatically filled by empty string +func SuccessDashboardCommonResponse(data any) DashboardCommonResponse { + return DashboardCommonResponse{ + Message: "", + Data: data, + } +} diff --git a/yuanrong/pkg/dashboard/models/serve_api_models.go b/yuanrong/pkg/dashboard/models/serve_api_models.go new file mode 100644 index 0000000..64ed365 --- /dev/null +++ b/yuanrong/pkg/dashboard/models/serve_api_models.go @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package models serve_api_models Refers to ray-operator/controllers/ray/utils/serve_api_models.go +package models + +// ServeGetResponse is the ServeDetails in kube-ray +type ServeGetResponse struct { + ServeDetails +} + +// ServeDeploymentStatus - +type ServeDeploymentStatus struct { + Name string `json:"name,omitempty"` + Status string `json:"status,omitempty"` + Message string `json:"message,omitempty"` +} + +// ServeApplicationStatus - +type ServeApplicationStatus struct { + Deployments map[string]ServeDeploymentStatus `json:"deployments"` + Name string `json:"name,omitempty"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +// ServeDeploymentDetails - +type ServeDeploymentDetails struct { + ServeDeploymentStatus + RoutePrefix string `json:"route_prefix,omitempty"` +} + +// ServeApplicationDetails - +type ServeApplicationDetails struct { + Deployments map[string]ServeDeploymentDetails `json:"deployments"` + ServeApplicationStatus + RoutePrefix string `json:"route_prefix,omitempty"` + DocsPath string `json:"docs_path,omitempty"` +} + +// ServeDetails - +type ServeDetails struct { + Applications map[string]*ServeApplicationDetails `json:"applications"` + DeployMode string `json:"deploy_mode,omitempty"` +} diff --git a/yuanrong/pkg/dashboard/routers/cors.go b/yuanrong/pkg/dashboard/routers/cors.go new file mode 100644 index 0000000..454e6ac --- /dev/null +++ b/yuanrong/pkg/dashboard/routers/cors.go @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package routers for load routs +package routers + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +// Cors function for allow cors +func Cors() gin.HandlerFunc { + return func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + c.Writer.Header().Set("Access-Control-Allow-Headers", + "Type, Content-Type, Access-Control-Allow-Headers, Authorization, X-Requested-With") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(http.StatusOK) + } + + c.Next() + } +} diff --git a/yuanrong/pkg/dashboard/routers/router.go b/yuanrong/pkg/dashboard/routers/router.go new file mode 100644 index 0000000..d21ffa0 --- /dev/null +++ b/yuanrong/pkg/dashboard/routers/router.go @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package routers for load routs +package routers + +import ( + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/job" + "yuanrong/pkg/dashboard/flags" + "yuanrong/pkg/dashboard/handlers" + "yuanrong/pkg/dashboard/logmanager" +) + +const ( + basePath = "/api/v1" + apiPath = "/api" + rsrcSummPath = "/logical-resources/summary" + rsrcPath = "/logical-resources" + rsrcUnitIDPath = "/logical-resources/:unit-id" + compPath = "/components" + compCompIDPath = "/components/:component-id" + instPath = "/instances" + instSummPath = "/instances/summary" + instInstIDPath = "/instances/:instance-id" + promPath = "/prometheus/query" + + // compatible with ray serve + servePath = "/api/serve" + serveAppPath = "/applications" + + clusterStatusPath = "/cluster_status" + + logsListPath = "/logs/list" + logsPath = "/logs" +) + +// SetRouter function for set routs +func SetRouter() *gin.Engine { + r := gin.New() + r.Use(gin.Recovery()) + r.Use(Cors()) + + r.StaticFile("/", flags.DashboardConfig.StaticPath+"/index.html") + r.StaticFile("/logo.png", flags.DashboardConfig.StaticPath+"/logo.png") + r.Static("/assets", flags.DashboardConfig.StaticPath+"/assets") + + v1Group := r.Group(basePath) + { + // resources + v1Group.GET(rsrcSummPath, handlers.ResourcesSummaryHandler) + v1Group.GET(rsrcPath, handlers.ResourcesHandler) + v1Group.GET(rsrcUnitIDPath, handlers.ResourcesByUnitIDHandler) + // components + v1Group.GET(compPath, handlers.ComponentsHandler) + v1Group.GET(compCompIDPath, handlers.ComponentsByComponentIDHandler) + // instances + v1Group.GET(instPath, handlers.InstancesHandler) + v1Group.GET(instSummPath, handlers.InstancesSummaryHandler) + v1Group.GET(instInstIDPath, handlers.InstancesByInstanceIDHandler) + // prom + v1Group.GET(promPath, handlers.PrometheusHandler) + } + + rayServeGroup := r.Group(servePath) + { + // serve + rayServeGroup.GET(serveAppPath, handlers.ServeGetHandler) + rayServeGroup.PUT(serveAppPath, handlers.ServePutHandler) + rayServeGroup.DELETE(serveAppPath, handlers.ServeDelHandler) + } + + jobGroup := r.Group(job.PathGroupJobs) + { + // jobs + jobGroup.POST("", handlers.SubmitJobHandler) + jobGroup.GET("", handlers.ListJobsHandler) + jobGroup.GET(job.PathGetJobs, handlers.GetJobInfoHandler) + jobGroup.DELETE(job.PathDeleteJobs, handlers.DeleteJobHandler) + jobGroup.POST(job.PathStopJobs, handlers.StopJobHandler) + } + + apiGroup := r.Group(apiPath) + { + // pending queue resources + apiGroup.GET(clusterStatusPath, handlers.ClusterStatusHandler) + + apiGroup.GET(logsListPath, logmanager.ListLogsHandler) + apiGroup.GET(logsPath, logmanager.ReadLogHandler) + } + + return r +} diff --git a/yuanrong/pkg/dashboard/routers/router_test.go b/yuanrong/pkg/dashboard/routers/router_test.go new file mode 100644 index 0000000..fec5276 --- /dev/null +++ b/yuanrong/pkg/dashboard/routers/router_test.go @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package routers for load routs +package routers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/smartystreets/goconvey/convey" +) + +func TestSetRouter(t *testing.T) { + convey.Convey("Test SetRouter:", t, func() { + convey.Convey("SetRouter success", func() { + r := SetRouter() + r.GET("/healthy", func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, "hello") + }) + req, err := http.NewRequest("OPTIONS", "/healthy", nil) + convey.So(err, convey.ShouldBeNil) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + convey.So(w.Code, convey.ShouldEqual, http.StatusOK) + convey.So(string(w.Body.Bytes()), convey.ShouldEqual, ``) + }) + }) +} diff --git a/yuanrong/pkg/functionmanager/config/config.go b/yuanrong/pkg/functionmanager/config/config.go new file mode 100644 index 0000000..7849cdc --- /dev/null +++ b/yuanrong/pkg/functionmanager/config/config.go @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config - +package config + +import ( + "encoding/json" + "fmt" + + "github.com/asaskevich/govalidator/v11" + + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/functionmanager/types" +) + +var cfg types.ManagerConfig + +// GetConfig return the current config +func GetConfig() types.ManagerConfig { + return cfg +} + +// InitConfig is used to initialize the config +func InitConfig(data []byte) error { + err := json.Unmarshal(data, &cfg) + if err != nil { + return fmt.Errorf("failed to parse the config data: %s", err.Error()) + } + if cfg.RouterEtcd.UseSecret { + etcd3.SetETCDTLSConfig(&cfg.RouterEtcd) + } + if cfg.MetaEtcd.UseSecret { + etcd3.SetETCDTLSConfig(&cfg.MetaEtcd) + } + _, err = govalidator.ValidateStruct(cfg) + if err != nil { + return fmt.Errorf("invalid config: %s", err.Error()) + } + if cfg.SccConfig.Enable && crypto.InitializeSCC(cfg.SccConfig) != nil { + return fmt.Errorf("failed to initialize scc") + } + return nil +} + +// InitEtcd - init router etcd and meta etcd +func InitEtcd(stopCh <-chan struct{}) error { + if &cfg == nil { + return fmt.Errorf("config is not initialized") + } + if err := etcd3.InitParam(). + WithRouteEtcdConfig(cfg.RouterEtcd). + WithStopCh(stopCh). + WithAlarmSwitch(cfg.AlarmConfig.EnableAlarm). + InitClient(); err != nil { + return fmt.Errorf("faaSManager failed to init route etcd: %s", err.Error()) + } + if err := etcd3.InitParam(). + WithMetaEtcdConfig(cfg.MetaEtcd). + WithStopCh(stopCh). + WithAlarmSwitch(cfg.AlarmConfig.EnableAlarm). + InitClient(); err != nil { + return fmt.Errorf("faaSManager failed to init metadata etcd: %s", err.Error()) + } + return nil +} diff --git a/yuanrong/pkg/functionmanager/config/config_test.go b/yuanrong/pkg/functionmanager/config/config_test.go new file mode 100644 index 0000000..424fb89 --- /dev/null +++ b/yuanrong/pkg/functionmanager/config/config_test.go @@ -0,0 +1,133 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package config + +import ( + "errors" + "fmt" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/asaskevich/govalidator/v11" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/etcd3" + mockUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionmanager/state" + "yuanrong/pkg/functionmanager/types" +) + +var ( + configString = `{ + "httpsEnable": true, + "functionCapability": 100, + "authenticationEnable": true + } + ` +) + +func TestGetConfig(t *testing.T) { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + state.InitState() + tests := []struct { + name string + want types.ManagerConfig + }{ + {"case1", types.ManagerConfig{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetConfig(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetConfig() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestInitConfig(t *testing.T) { + state.InitState() + type args struct { + data []byte + } + a := args{} + b := args{ + data: []byte(configString), + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"case1", a, true}, + {"case2", b, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := InitConfig(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("InitConfig() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } + + convey.Convey("ValidateStruct error", t, func() { + defer gomonkey.ApplyFunc(govalidator.ValidateStruct, func(s interface{}) (bool, error) { + return false, fmt.Errorf("check error") + }).Reset() + err := InitConfig([]byte(configString)) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestInitEtcd(t *testing.T) { + type args struct { + stopCh <-chan struct{} + } + tests := []struct { + name string + args args + wantErr assert.ErrorAssertionFunc + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 succeed to init etcd", args{stopCh: make(<-chan struct{})}, assert.NoError, + func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdInitParam{}), "InitClient", + func(_ *etcd3.EtcdInitParam) error { return nil }), + }) + return patches + }}, + {"case2 failed to init etcd", args{stopCh: make(<-chan struct{})}, assert.Error, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdInitParam{}), "InitClient", + func(_ *etcd3.EtcdInitParam) error { return errors.New("e") }), + }) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + tt.wantErr(t, InitEtcd(tt.args.stopCh), fmt.Sprintf("InitEtcd(%v)", tt.args.stopCh)) + patches.ResetAll() + }) + } +} diff --git a/yuanrong/pkg/functionmanager/constant/constant.go b/yuanrong/pkg/functionmanager/constant/constant.go new file mode 100644 index 0000000..47c3421 --- /dev/null +++ b/yuanrong/pkg/functionmanager/constant/constant.go @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constant - +package constant + +const ( + // PluginFile is the path of vpccontroller.so + PluginFile = "vpc-controller/vpccontroller.so" + // FunctionLibPath - + FunctionLibPath = "FUNCTION_LIB_PATH" +) diff --git a/yuanrong/pkg/functionmanager/faasmanager.go b/yuanrong/pkg/functionmanager/faasmanager.go new file mode 100644 index 0000000..305eb81 --- /dev/null +++ b/yuanrong/pkg/functionmanager/faasmanager.go @@ -0,0 +1,720 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package functionmanager - +package functionmanager + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "plugin" + "strings" + "sync" + "time" + + k8serror "k8s.io/apimachinery/pkg/api/errors" + + "yuanrong.org/kernel/runtime/libruntime/api" + + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + commonType "yuanrong/pkg/common/faas_common/types" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionmanager/config" + "yuanrong/pkg/functionmanager/constant" + "yuanrong/pkg/functionmanager/state" + "yuanrong/pkg/functionmanager/types" + "yuanrong/pkg/functionmanager/utils" + "yuanrong/pkg/functionmanager/vpcmanager" +) + +// Manager manages functions for faas pattern +type Manager struct { + patKeyList map[string]map[string]struct{} + VpcPlugin vpcmanager.PluginVPC + vpcMutex sync.Mutex + remoteClientLease map[string]*leaseTimer + remoteClientList map[string]struct{} + clientMutex sync.Mutex + stopCh chan struct{} + queue *queue + leaseRenewMinute time.Duration +} + +type leaseTimer struct { + Timer *time.Timer + Cancel context.CancelFunc + Ctx context.Context +} + +// RequestOperation defines request operations +type RequestOperation string + +var ( + // requestOpCreate stands for create pat-service operation + requestOpCreate RequestOperation = "CreatePATService" + // requestOpCreateTrigger stands for create pullTrigger operation + requestOpCreateTrigger RequestOperation = "CreatePullTrigger" + // requestOpDeleteTrigger stands for delete pullTrigger operation + requestOpDeleteTrigger RequestOperation = "DeletePullTrigger" + // requestOpReport stands for instance report operation + requestOpReport RequestOperation = "ReportInstanceID" + // requestOpDelete stands for instance delete operation + requestOpDelete RequestOperation = "DeleteInstanceID" + // requestOpUnknown stands for unknown operation + requestOpUnknown RequestOperation = "Unknown" + // requestOpNewLease stands for create lease operation + requestOpNewLease RequestOperation = commonconstant.NewLease + // requestOpKeepAlive stands for keep-alive lease operation + requestOpKeepAlive RequestOperation = commonconstant.KeepAlive + // requestOpDelLease stands for delete lease operation + requestOpDelLease RequestOperation = commonconstant.DelLease + libruntimeClient api.LibruntimeAPI +) + +const ( + validArgsNum = 3 + validArgsNumLibruntime = 4 + faasManagerOpIndex = 0 + faasManagerOpDataIndex = 1 + faasManagerOpIndexLibruntime = 1 + faasManagerOpDataIndexLibruntime = 2 + + leaseEtcdKeyLen = 3 + + defaultLeaseRenewMinute = 5 +) + +// NewFaaSManagerLibruntime will create a new faas functions manager +func NewFaaSManagerLibruntime(libruntimeAPI api.LibruntimeAPI, stopCh chan struct{}) (*Manager, error) { + libruntimeClient = libruntimeAPI + return MakeFaasManager(stopCh) +} + +// MakeFaasManager will create a new faas functions manager +func MakeFaasManager(stopCh chan struct{}) (*Manager, error) { + functionLibPath := os.Getenv(constant.FunctionLibPath) + filePath := filepath.Join(functionLibPath, constant.PluginFile) + log.GetLogger().Infof("plugin file path is %s", filePath) + pluginFile, err := plugin.Open(filePath) + if err != nil { + log.GetLogger().Errorf("failed to open vpc plugin file: %s", err.Error()) + return nil, err + } + cfg := config.GetConfig() + leaseRenewMinute := cfg.LeaseRenewMinute + if leaseRenewMinute == 0 { + leaseRenewMinute = defaultLeaseRenewMinute + } + faaSManager := &Manager{ + patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, + remoteClientLease: make(map[string]*leaseTimer), + remoteClientList: map[string]struct{}{}, + clientMutex: sync.Mutex{}, + stopCh: stopCh, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: time.Duration(cfg.LeaseRenewMinute), + } + faaSManager.VpcPlugin.Plugin = pluginFile + err = faaSManager.VpcPlugin.InitVpcPlugin() + if err != nil { + return nil, err + } + err = utils.InitKubeClient() + if err != nil { + return nil, err + } + go faaSManager.saveStateLoop() + return faaSManager, nil +} + +// RecoverData - +func (m *Manager) RecoverData() { + patKeyList := state.GetState().PatKeyList + log.GetLogger().Infof("now recover faaSManager patKeyList") + m.vpcMutex.Lock() + for patPodName, val := range patKeyList { + instanceIDMap := val + m.patKeyList[patPodName] = instanceIDMap + } + log.GetLogger().Infof("recovered patKeyList: %v", m.patKeyList) + m.vpcMutex.Unlock() + rcm := map[string]struct{}{} + err := commonUtils.DeepCopyObj(state.GetState().RemoteClientList, &rcm) + if err != nil { + log.GetLogger().Errorf("recovered remoteClientList error: %v", err) + return + } + log.GetLogger().Infof("now recover faaSManager remoteClientList") + m.clientMutex.Lock() + for client := range rcm { + lease := newLeaseTimer(m.leaseRenewMinute * time.Minute) + m.remoteClientLease[client] = lease + m.remoteClientList[client] = struct{}{} + go startLeaseTimeOutWatcher(lease.Ctx, m, lease.Timer, client, client) + } + log.GetLogger().Infof("recovered remoteClientList: %v", m.remoteClientList) + m.clientMutex.Unlock() +} +func newLeaseTimer(timeout time.Duration) *leaseTimer { + timer := time.NewTimer(timeout) + ctx, cancel := context.WithCancel(context.Background()) + return &leaseTimer{ + Timer: timer, + Cancel: cancel, + Ctx: ctx, + } +} + +func startLeaseTimeOutWatcher(ctx context.Context, fm *Manager, timer *time.Timer, clientID string, traceID string) { + select { + case <-timer.C: + log.GetLogger().Infof("lease timeout, traceID: %s, remoteClientID: %s", traceID, clientID) + if fm != nil { + fm.clearLease(clientID, traceID) + } + case <-ctx.Done(): + log.GetLogger().Infof("lease stopped before timeout, traceID: %s, remoteClientID: %s", traceID, + clientID) + timer.Stop() + } +} + +// HandlerRequest - +func (m *Manager) HandlerRequest(requestOp RequestOperation, requestData []byte, + traceID string) *commonType.CallHandlerResponse { + switch requestOp { + case requestOpCreate: + return m.handleRequestOpCreate(requestData, traceID) + case requestOpCreateTrigger: + return m.handleRequestOpCreateTrigger(requestData, traceID) + case requestOpReport: + return m.handleRequestOpReport(requestData, traceID) + case requestOpDelete: + return m.handleRequestOpDelete(requestData, traceID) + case requestOpDeleteTrigger: + return m.handleRequestOpDeleteTrigger(requestData, traceID) + case requestOpNewLease: + return m.handleNewLease(requestData, traceID, 0) + case requestOpKeepAlive: + return m.handleKeepAlive(requestData, traceID) + case requestOpDelLease: + return m.handleDelLease(requestData, traceID) + default: + log.GetLogger().Warnf("unknown request operation %s, traceID %s", requestOp, traceID) + } + return utils.GenerateErrorResponse(commonconstant.UnsupportedOperationErrorCode, + commonconstant.UnsupportedOperationErrorMessage) +} + +// ProcessSchedulerRequestLibruntime will handle create, report and delete of instance +func (m *Manager) ProcessSchedulerRequestLibruntime(args []api.Arg, traceID string) *commonType.CallHandlerResponse { + requestOp, requestData := parseRequestOperationLibruntime(args) + return m.HandlerRequest(requestOp, requestData, traceID) +} + +// ProcessSchedulerRequest will handle create, report and delete of instance +func (m *Manager) ProcessSchedulerRequest(args []*api.Arg, traceID string) *commonType.CallHandlerResponse { + requestOp, requestData := parseRequestOperation(args) + return m.HandlerRequest(requestOp, requestData, traceID) +} + +func (m *Manager) handleRequestOpCreate(requestData []byte, traceID string) *commonType.CallHandlerResponse { + newPatPod, err := m.VpcPlugin.CreateVpcResource(requestData) + if err != nil { + log.GetLogger().Errorf("failed to create vpc pat pod, traceID: %s, error: %s", traceID, err.Error()) + return nil + } + if !m.checkExists(newPatPod.PatPodName) { + m.newPatListKey(newPatPod.PatPodName) + } + log.GetLogger().Infof("succeed to create pat service pod: %s, ip: %s, traceID %s", + newPatPod.PatPodName, newPatPod.PatContainerIP, traceID) + patInfo, err := json.Marshal(newPatPod) + if err != nil { + log.GetLogger().Errorf("failed to marshal vpc pat pod info, traceID: %s, error: %s", traceID, err.Error()) + return nil + } + return utils.GenerateSuccessResponse(commonconstant.InsReqSuccessCode, string(patInfo)) +} + +func (m *Manager) handleRequestOpCreateTrigger(requestData []byte, traceID string) *commonType.CallHandlerResponse { + log.GetLogger().Infof("succeed receive pullTrigger create request, traceID: %s", traceID) + requestInfo := types.PullTriggerRequestInfo{} + err := json.Unmarshal(requestData, &requestInfo) + if err != nil { + log.GetLogger().Errorf("failed to Unmarshal pullTrigger requestInfo, traceID: %s, err: %s", + traceID, err.Error()) + return nil + } + podName := utils.HandlePullTriggerName(requestInfo.PodName) + if podName == "" { + log.GetLogger().Errorf("invalid pull trigger name, traceID: %s", traceID) + return nil + } + requestInfo.PodName = podName + newPatPod, err := m.VpcPlugin.CreateVpcResource(requestData) + if err != nil { + log.GetLogger().Errorf("failed to create pullTrigger vpc pat pod, traceID: %s, error: %s", + traceID, err.Error()) + return nil + } + if !m.checkExists(newPatPod.PatPodName) { + m.newPatListKey(newPatPod.PatPodName) + } + log.GetLogger().Infof("succeed to create pullTrigger pat pod: %s, ip: %s, traceID: %s", + newPatPod.PatPodName, newPatPod.PatContainerIP, traceID) + vpcNatConfByte, err := json.Marshal(newPatPod) + if err != nil { + log.GetLogger().Errorf("failed to Marshal PatPod info for %s, traceID: %s", newPatPod.PatPodName, traceID) + return nil + } + triggerInfo := vpcmanager.ParseFunctionMeta(requestInfo) + if triggerInfo == nil { + return nil + } + // get pull trigger deploy + triggerDeploy, errs := utils.GetDeployByK8S(utils.GetKubeClient(), triggerInfo.PodName) + if errs != nil { + log.GetLogger().Errorf("failed to get deploy %s, traceID %s", errs.Error(), traceID) + } + if k8serror.IsNotFound(errs) { + // if trigger deploy is not found, it will create pull trigger deploy + triggerDeploy = vpcmanager.MakePullTriggerDeploy(triggerInfo, vpcNatConfByte) + if err = utils.CreateDeployByK8S(utils.GetKubeClient(), triggerDeploy); err != nil { + log.GetLogger().Errorf("failed to create deploy %s by k8s, traceID: %s, error: %s", + triggerDeploy.Name, traceID, err.Error()) + return nil + } + m.addPatList(newPatPod.PatPodName, requestInfo.PodName) + log.GetLogger().Infof("succeed add Trigger to patList, patPod %s, name %s, traceID: %s", + newPatPod.PatPodName, requestInfo.PodName, traceID) + return utils.GenerateSuccessResponse(commonconstant.InsReqSuccessCode, "succeed create pull-trigger") + } + log.GetLogger().Infof("trigger %s has already exist, traceID %s", triggerDeploy.Name, traceID) + return utils.GenerateSuccessResponse(commonconstant.InsReqSuccessCode, "skip create pull-trigger") +} + +func (m *Manager) handleRequestOpReport(requestData []byte, traceID string) *commonType.CallHandlerResponse { + reportInfo := types.ReportInfo{} + err := json.Unmarshal(requestData, &reportInfo) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal report info, traceID: %s, error: %s", traceID, err.Error()) + return nil + } + if m.checkExists(reportInfo.PatPodName) { + m.addPatList(reportInfo.PatPodName, reportInfo.InstanceID) + log.GetLogger().Infof("succeed add pat list, patPodName %s, instanceID %s, traceID: %s", + reportInfo.PatPodName, reportInfo.InstanceID, traceID) + return utils.GenerateSuccessResponse(commonconstant.InsReqSuccessCode, "succeed add pat list") + } + log.GetLogger().Infof("patPodName %s is not exist, skip, traceID: %s", reportInfo.PatPodName, traceID) + return utils.GenerateSuccessResponse(commonconstant.InsReqSuccessCode, "skip add pat list") +} + +func (m *Manager) handleRequestOpDelete(requestData []byte, traceID string) *commonType.CallHandlerResponse { + deleteInfo := types.DeleteInfo{} + err := json.Unmarshal(requestData, &deleteInfo) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal delete info, traceID: %s, error: %s", traceID, err.Error()) + return nil + } + patPodName, needDeletePat := m.deletePatList(deleteInfo.InstanceID) + if needDeletePat { + err := m.VpcPlugin.DeleteVpcResource(patPodName) + if err != nil { + log.GetLogger().Errorf("failed to delete vpc pat pod, traceID: %s, error: %s", traceID, err.Error()) + return nil + } + log.GetLogger().Infof("succeed to delete pat-service %s, traceID: %s", patPodName, traceID) + } + log.GetLogger().Infof("succeed to delete instance %s, traceID: %s", deleteInfo.InstanceID, traceID) + return utils.GenerateSuccessResponse(commonconstant.InsReqSuccessCode, "succeed delete instance list") +} + +func (m *Manager) handleRequestOpDeleteTrigger(requestData []byte, traceID string) *commonType.CallHandlerResponse { + deleteInfo := types.PullTriggerDeleteInfo{} + err := json.Unmarshal(requestData, &deleteInfo) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal pullTrigger delete info, traceID: %s, error: %s", traceID, + err.Error()) + return nil + } + podName := utils.HandlePullTriggerName(deleteInfo.PodName) + if podName == "" { + log.GetLogger().Errorf("invalid pull trigger name, traceID: %s", traceID) + return nil + } + deleteInfo.PodName = podName + // get pull trigger deploy + triggerDeploy, errs := utils.GetDeployByK8S(utils.GetKubeClient(), deleteInfo.PodName) + if errs != nil { + log.GetLogger().Errorf("failed to get deploy, traceID: %s error: %s", traceID, errs.Error()) + } + if k8serror.IsNotFound(errs) { + log.GetLogger().Infof("trigger %s has already deleted, traceID: %s", triggerDeploy.Name, traceID) + return utils.GenerateSuccessResponse(commonconstant.InsReqSuccessCode, "skip delete pull-trigger") + } + if err = utils.DeleteDeployByK8S(utils.GetKubeClient(), triggerDeploy.Name); err != nil { + log.GetLogger().Errorf("failed to delete deploy %s by k8s, traceID: %s, error: %s", + triggerDeploy.Name, traceID, err.Error()) + return nil + } + patPodName, needDeletePat := m.deletePatList(deleteInfo.PodName) + if needDeletePat { + err := m.VpcPlugin.DeleteVpcResource(patPodName) + if err != nil { + log.GetLogger().Errorf("failed to delete vpc pat pod, traceID: %s, error: %s", traceID, err.Error()) + return nil + } + log.GetLogger().Infof("succeed to delete pat-service %s, traceID: %s", patPodName, traceID) + } + log.GetLogger().Infof("succeed delete Trigger, patPod %s, name %s, traceID: %s", + patPodName, deleteInfo.PodName, traceID) + return utils.GenerateSuccessResponse(commonconstant.InsReqSuccessCode, "succeed delete pull-trigger") +} + +func (m *Manager) handleNewLease(requestData []byte, traceID string, timeout int64) *commonType.CallHandlerResponse { + remoteClientID := string(requestData) + log.GetLogger().Infof("receive new lease from faas-frontend, traceID: %s", traceID) + m.clientMutex.Lock() + if _, value := m.remoteClientLease[remoteClientID]; value { + m.clientMutex.Unlock() + log.GetLogger().Infof("lease already existed, traceID: %s", traceID) + return &commonType.CallHandlerResponse{Code: commonconstant.InsReqSuccessCode, Message: "lease existed"} + } + t := time.Minute * m.leaseRenewMinute + if timeout > 0 { + t = time.Duration(timeout) * time.Second + } + lease := newLeaseTimer(t) + m.remoteClientLease[remoteClientID] = lease + m.remoteClientList[remoteClientID] = struct{}{} + m.clientMutex.Unlock() + m.saveStateData(traceID) + + go startLeaseTimeOutWatcher(lease.Ctx, m, lease.Timer, remoteClientID, traceID) + + log.GetLogger().Infof("succeed to create lease, traceID: %s", traceID) + return &commonType.CallHandlerResponse{Code: commonconstant.InsReqSuccessCode, Message: "Succeed to create lease"} +} + +func (m *Manager) clearLease(clientID string, traceID string) { + m.clientMutex.Lock() + lease, ok := m.remoteClientLease[clientID] + if !ok { + log.GetLogger().Warnf("client id not existed, %s %s ", clientID, traceID) + m.clientMutex.Unlock() + killInstanceOuter(clientID, traceID) + return + } + lease.Timer.Stop() + lease.Cancel() + delete(m.remoteClientLease, clientID) + delete(m.remoteClientList, clientID) + m.clientMutex.Unlock() + m.saveStateData(traceID) + killInstanceOuter(clientID, traceID) +} + +type stateItem struct { + ch chan error + traceID string +} + +// String - +func (s *stateItem) String() string { + return s.traceID +} + +// Done - +func (s *stateItem) Done(err error) { + s.ch <- err +} + +func (m *Manager) saveStateData(traceID string) { + i := &stateItem{ + ch: make(chan error, 1), + traceID: traceID, + } + if !m.queue.add(i) { + log.GetLogger().Warnf("add save state req to queue failed, traceID: %s", traceID) + return + } + err := <-i.ch + if err != nil { + log.GetLogger().Warnf("save state data failed, err:%v, traceID:%s", err, traceID) + } +} + +func (m *Manager) saveStateLoop() { + for { + all, shutdown := m.queue.getAll() + if shutdown { + return + } + m.clientMutex.Lock() + managerState := state.GetState() + state.ManagerStateLock.Lock() + managerState.RemoteClientList = m.remoteClientList + marshal, err := json.Marshal(managerState) + state.ManagerStateLock.Unlock() + if err != nil { + m.queue.doneAll(all) + for _, it := range all { + it.Done(err) + } + continue + } + m.clientMutex.Unlock() + client := etcd3.GetRouterEtcdClient() + if client != nil { + ctx, cancel := context.WithTimeout(context.Background(), etcd3.DurationContextTimeout) + _, err = client.Client.Put(ctx, "/faas/state/recover/faasmanager", string(marshal)) + cancel() + } else { + err = fmt.Errorf("router etcd client is nil") + } + m.queue.doneAll(all) + for _, it := range all { + it.Done(err) + } + } +} + +func killInstanceOuter(clientID string, traceID string) { + log.GetLogger().Infof("start to kill instance outer, clientID: %s, traceID:%s", clientID, traceID) + if err := libruntimeClient.Kill(clientID, types.KillSignalVal, []byte{}); err != nil { + log.GetLogger().Warnf("failed to clean instances when delete lease, traceID: %s, "+ + "remoteClientID:%s, status:%s", traceID, clientID, err.Error()) + } + libruntimeClient.SetTraceID(traceID) + if err := libruntimeClient.ReleaseGRefs(clientID); err != nil { + log.GetLogger().Warnf("failed to release refs when delete lease, traceID: %s, "+ + "remoteClientID:%s, status:%s", traceID, clientID, err.Error()) + } +} + +func (m *Manager) handleKeepAlive(requestData []byte, traceID string) *commonType.CallHandlerResponse { + remoteClientID := string(requestData) + log.GetLogger().Infof("receive keep-alive lease from faas-frontend, traceID: %s", traceID) + m.clientMutex.Lock() + if _, value := m.remoteClientLease[remoteClientID]; value { + lease := m.remoteClientLease[remoteClientID] + lease.Timer.Reset(m.leaseRenewMinute * time.Minute) + m.clientMutex.Unlock() + log.GetLogger().Infof("succeed to renew lease, traceID: %s", traceID) + return &commonType.CallHandlerResponse{ + Code: commonconstant.InsReqSuccessCode, + Message: "Succeed to renew lease", + } + } + m.clientMutex.Unlock() + log.GetLogger().Infof("failed to renew unknown lease, traceID: %s", traceID) + return &commonType.CallHandlerResponse{ + Code: commonconstant.UnsupportedOperationErrorCode, + Message: fmt.Sprintf("%s remote client id not exist", remoteClientID), + } +} + +func (m *Manager) handleDelLease(requestData []byte, traceID string) *commonType.CallHandlerResponse { + remoteClientID := string(requestData) + log.GetLogger().Infof("receive delete lease from faas-frontend, traceID: %s", traceID) + m.clearLease(remoteClientID, traceID) + + return &commonType.CallHandlerResponse{ + Code: commonconstant.InsReqSuccessCode, + Message: "Succeed to delete lease", + } +} + +func (m *Manager) checkExists(patPodName string) bool { + m.vpcMutex.Lock() + if m.patKeyList == nil { + m.vpcMutex.Unlock() + return false + } + _, ok := m.patKeyList[patPodName] + m.vpcMutex.Unlock() + return ok +} + +func (m *Manager) newPatListKey(patPodName string) { + m.vpcMutex.Lock() + m.patKeyList[patPodName] = map[string]struct{}{} + savePatKeyList(m.patKeyList) + m.vpcMutex.Unlock() + return +} + +func (m *Manager) addPatList(patPodName, instanceID string) { + m.vpcMutex.Lock() + if m.patKeyList == nil { + m.vpcMutex.Unlock() + return + } + m.patKeyList[patPodName][instanceID] = struct{}{} + savePatKeyList(m.patKeyList) + m.vpcMutex.Unlock() + return +} + +func (m *Manager) deletePatList(instanceID string) (string, bool) { + patPodName := "" + needDeletePat := false + m.vpcMutex.Lock() + if m.patKeyList == nil { + m.vpcMutex.Unlock() + return "", false + } + check := false + for patName, instanceList := range m.patKeyList { + for id := range instanceList { + if id == instanceID { + delete(m.patKeyList[patName], instanceID) + savePatKeyList(m.patKeyList) + check = true + break + } + } + if check { + if len(m.patKeyList[patName]) == 0 { + delete(m.patKeyList, patName) + savePatKeyList(m.patKeyList) + patPodName = patName + needDeletePat = true + } + break + } + } + m.vpcMutex.Unlock() + return patPodName, needDeletePat +} + +func savePatKeyList(m map[string]map[string]struct{}) { + dst := map[string]map[string]struct{}{} + err := commonUtils.DeepCopyObj(m, &dst) + if err != nil { + log.GetLogger().Errorf("deep copy patKeyList failed, err: %v", err) + return + } + state.Update(dst) +} + +func parseRequestOperationLibruntime(args []api.Arg) (RequestOperation, []byte) { + requestOp := requestOpUnknown + if len(args) != validArgsNumLibruntime { + log.GetLogger().Errorf("invalid argument number") + return requestOp, []byte{} + } + requestOp = RequestOperation(args[faasManagerOpIndexLibruntime].Data) + requestData := args[faasManagerOpDataIndexLibruntime].Data + return requestOp, requestData +} + +func parseRequestOperation(args []*api.Arg) (RequestOperation, []byte) { + requestOp := requestOpUnknown + if len(args) != validArgsNum { + log.GetLogger().Errorf("invalid argument number") + return requestOp, []byte{} + } + requestOp = RequestOperation(args[faasManagerOpIndex].Data) + requestData := args[faasManagerOpDataIndex].Data + return requestOp, requestData +} + +// WatchLeaseEvent - +func (m *Manager) WatchLeaseEvent() { + etcdClient := etcd3.GetRouterEtcdClient() + watcher := etcd3.NewEtcdWatcher(commonconstant.LeasePrefix, func(event *etcd3.Event) bool { + etcdKey := event.Key + keyParts := strings.Split(etcdKey, commonconstant.ETCDEventKeySeparator) + return len(keyParts) == leaseEtcdKeyLen + }, func(event *etcd3.Event) { + log.GetLogger().Infof("handling lease event type %d, key:%s, remoteClientLease in use len: %d", event.Type, + event.Key, len(m.remoteClientLease)) + switch event.Type { + case etcd3.PUT: + m.handleLeaseEvent(event) + go func() { + _, err := etcdClient.Client.Delete(context.Background(), event.Key) + if err != nil { + log.GetLogger().Errorf("delete lease event failed, key: %s, err: %v", event.Key, err) + } + }() + case etcd3.DELETE: + case etcd3.ERROR: + log.GetLogger().Warnf("etcd error event: %s", event.Value) + default: + log.GetLogger().Warnf("unsupported event, key: %s", event.Key) + } + }, m.stopCh, etcdClient) + watcher.StartWatch() + select { + case <-m.stopCh: + m.queue.shutDown() + return + } +} + +func (m *Manager) handleLeaseEvent(event *etcd3.Event) { + if event == nil || len(event.Value) == 0 { + log.GetLogger().Warnf("event is nil or value is empty") + return + } + log.GetLogger().Infof("handling lease put event key:%s", event.Key) + e := &commonType.LeaseEvent{} + err := json.Unmarshal(event.Value, e) + if err != nil { + log.GetLogger().Errorf("error unmarshalling lease event: %s, err: %v", string(event.Value), err) + return + } + switch e.Type { + case commonconstant.NewLease: + m.handleNewLease([]byte(e.RemoteClientID), e.TraceID, 0) + case commonconstant.DelLease: + m.handleDelLease([]byte(e.RemoteClientID), e.TraceID) + case commonconstant.KeepAlive: + // 判断是否是过期心跳 + timeout := int64((m.leaseRenewMinute * time.Minute).Seconds()) - (time.Now().Unix() - e.Timestamp) + if timeout <= 0 { + log.GetLogger().Infof("lease is expired, traceID: %s", e.TraceID) + m.handleDelLease([]byte(e.RemoteClientID), e.TraceID) + return + } + m.clientMutex.Lock() + if lease, value := m.remoteClientLease[e.RemoteClientID]; value { + lease.Timer.Reset(time.Duration(timeout) * time.Second) + m.remoteClientLease[e.RemoteClientID] = lease + m.clientMutex.Unlock() + log.GetLogger().Infof("succeed to renew lease, traceID: %s", e.TraceID) + return + } + m.clientMutex.Unlock() + m.handleNewLease([]byte(e.RemoteClientID), e.TraceID, timeout) + default: + log.GetLogger().Errorf("unexpected lease type: %s, traceId: %s", e.Type, e.TraceID) + } +} diff --git a/yuanrong/pkg/functionmanager/faasmanager_test.go b/yuanrong/pkg/functionmanager/faasmanager_test.go new file mode 100644 index 0000000..f62333a --- /dev/null +++ b/yuanrong/pkg/functionmanager/faasmanager_test.go @@ -0,0 +1,945 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package functionmanager + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "plugin" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + clientv3 "go.etcd.io/etcd/client/v3" + v1 "k8s.io/api/apps/v1" + k8serror "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes" + + "yuanrong.org/kernel/runtime/libruntime/api" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + commonType "yuanrong/pkg/common/faas_common/types" + mockUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionmanager/state" + "yuanrong/pkg/functionmanager/types" + "yuanrong/pkg/functionmanager/utils" + "yuanrong/pkg/functionmanager/vpcmanager" +) + +type KvMock struct { +} + +func (k *KvMock) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + // TODO implement me + return nil, nil +} + +func (k *KvMock) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + response := &clientv3.GetResponse{} + response.Count = 10 + return response, nil +} + +func (k *KvMock) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + return nil, fmt.Errorf("error") +} + +func (k *KvMock) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, + error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Txn(ctx context.Context) clientv3.Txn { + // TODO implement me + panic("implement me") +} + +func TestNewFaaSManager(t *testing.T) { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + type args struct { + sdkClient api.LibruntimeAPI + stopCh chan struct{} + } + tests := []struct { + name string + args args + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 failed to plugin open", args{}, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(plugin.Open, func(path string) (*plugin.Plugin, error) { + return nil, errors.New("plugin open error") + }), + }) + return patches + }}, + {"case2 InitVpcPlugin error", args{}, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(plugin.Open, func(path string) (*plugin.Plugin, error) { + return &plugin.Plugin{}, nil + }), + gomonkey.ApplyFunc((*vpcmanager.PluginVPC).InitVpcPlugin, func(_ *vpcmanager.PluginVPC) error { + return errors.New("InitVpcPlugin error") + }), + }) + return patches + }}, + {"case3 InitKubeClient error", args{}, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(plugin.Open, func(path string) (*plugin.Plugin, error) { + return &plugin.Plugin{}, nil + }), + gomonkey.ApplyFunc((*vpcmanager.PluginVPC).InitVpcPlugin, func(_ *vpcmanager.PluginVPC) error { + return nil + }), + gomonkey.ApplyFunc(utils.InitKubeClient, func() error { + return errors.New("InitKubeClient error") + }), + }) + return patches + }}, + {"case3 InitKubeClient error", args{}, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(plugin.Open, func(path string) (*plugin.Plugin, error) { + return &plugin.Plugin{}, nil + }), + gomonkey.ApplyFunc((*vpcmanager.PluginVPC).InitVpcPlugin, func(_ *vpcmanager.PluginVPC) error { + return nil + }), + gomonkey.ApplyFunc(utils.InitKubeClient, func() error { + return errors.New("InitKubeClient error") + }), + }) + return patches + }}, + {"case4 succeed to NewFaaSManager", args{}, false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(plugin.Open, func(path string) (*plugin.Plugin, error) { + return &plugin.Plugin{}, nil + }), + gomonkey.ApplyFunc((*vpcmanager.PluginVPC).InitVpcPlugin, func(_ *vpcmanager.PluginVPC) error { + return nil + }), + gomonkey.ApplyFunc(utils.InitKubeClient, func() error { + return nil + }), + }) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + _, err := NewFaaSManagerLibruntime(tt.args.sdkClient, tt.args.stopCh) + if (err != nil) != tt.wantErr { + t.Errorf("NewFaaSManager() error = %v, wantErr %v", err, tt.wantErr) + return + } + patches.ResetAll() + }) + } +} + +func TestManager_ProcessSchedulerRequest(t *testing.T) { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + pullTriggerBytes, _ := json.Marshal(types.PullTriggerRequestInfo{PodName: "sn/trigger/test"}) + reportBytes, _ := json.Marshal(types.ReportInfo{PatPodName: "pat1", InstanceID: "abc"}) + deleteBytes, _ := json.Marshal(types.DeleteInfo{InstanceID: "abc"}) + deleteTriggerBytes, _ := json.Marshal(types.PullTriggerDeleteInfo{PodName: "sn/trigger/test"}) + type fields struct { + patKeyList map[string]map[string]struct{} + VpcPlugin vpcmanager.PluginVPC + vpcMutex sync.Mutex + remoteClientLease map[string]*leaseTimer + remoteClientList map[string]struct{} + clientMutex sync.Mutex + } + type args struct { + args []*api.Arg + } + tests := []struct { + name string + fields fields + args args + want int + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 succeed to requestOpCreate", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpCreate)}, &api.Arg{Data: []byte("")}, + &api.Arg{Data: []byte("")}}}, + constant.InsReqSuccessCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc((*vpcmanager.PluginVPC).CreateVpcResource, + func(_ *vpcmanager.PluginVPC, requestData []byte) (types.NATConfigure, error) { + return types.NATConfigure{PatPodName: "pat1"}, nil + }), + }) + return patches + }}, + {"case2 succeed to requestOpCreateTrigger", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpCreateTrigger)}, &api.Arg{Data: pullTriggerBytes}, + &api.Arg{Data: []byte("")}}}, + constant.InsReqSuccessCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc((*vpcmanager.PluginVPC).CreateVpcResource, + func(_ *vpcmanager.PluginVPC, requestData []byte) (types.NATConfigure, error) { + return types.NATConfigure{PatPodName: "pat1"}, nil + }), + gomonkey.ApplyFunc(utils.GetDeployByK8S, + func(k8sClient kubernetes.Interface, deployName string) (*v1.Deployment, error) { + return nil, errors.New("errors") + }), + gomonkey.ApplyFunc(k8serror.IsNotFound, func(err error) bool { + return true + }), + gomonkey.ApplyFunc(utils.GetKubeClient, func() kubernetes.Interface { + return &kubernetes.Clientset{} + }), + gomonkey.ApplyFunc(utils.CreateDeployByK8S, + func(k8sClient kubernetes.Interface, deploy *v1.Deployment) error { + return nil + }), + }) + return patches + }}, + {"case3 succeed to requestOpReport", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpReport)}, &api.Arg{Data: reportBytes}, + &api.Arg{Data: []byte("")}}}, + constant.InsReqSuccessCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{}) + return patches + }}, + {"case4 succeed to requestOpDelete", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpDelete)}, &api.Arg{Data: deleteBytes}, + &api.Arg{Data: []byte("")}}}, + constant.InsReqSuccessCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{}) + return patches + }}, + {"case5 succeed to requestOpDeleteTrigger", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpDeleteTrigger)}, &api.Arg{Data: deleteTriggerBytes}, + &api.Arg{Data: []byte("")}}}, + constant.InsReqSuccessCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(utils.GetKubeClient, func() kubernetes.Interface { + return &kubernetes.Clientset{} + }), + gomonkey.ApplyFunc(utils.GetDeployByK8S, + func(k8sClient kubernetes.Interface, deployName string) (*v1.Deployment, error) { + return &v1.Deployment{}, nil + }), + gomonkey.ApplyFunc(k8serror.IsNotFound, func(err error) bool { + return false + }), + gomonkey.ApplyFunc(utils.DeleteDeployByK8S, + func(k8sClient kubernetes.Interface, name string) error { + return nil + }), + }) + return patches + }}, + {"case6 unknow opt", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpDeleteTrigger)}, &api.Arg{Data: []byte("")}}}, + constant.UnsupportedOperationErrorCode, + func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + return patches + }}, + {"case7 failed to requestOpNewLease", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, remoteClientLease: map[string]*leaseTimer{"test-client-ID": nil}, + remoteClientList: map[string]struct{}{}, clientMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpNewLease)}, &api.Arg{Data: []byte("")}}}, + constant.UnsupportedOperationErrorCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{}) + return patches + }}, + {"case8 success to requestOpNewLease", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, remoteClientLease: make(map[string]*leaseTimer, 1), + remoteClientList: map[string]struct{}{}, clientMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpNewLease)}, &api.Arg{Data: []byte("test-client-ID")}, + &api.Arg{Data: []byte("")}}}, + constant.InsReqSuccessCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{}) + return patches + }}, + {"case9 failed to requestOpKeepAlive", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, remoteClientLease: make(map[string]*leaseTimer, 1), + remoteClientList: map[string]struct{}{}, clientMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpKeepAlive)}, &api.Arg{Data: []byte("test-client-ID")}, + &api.Arg{Data: []byte("")}}}, + constant.UnsupportedOperationErrorCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{}) + return patches + }}, + {"case10 success to requestOpKeepAlive", fields{patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, remoteClientLease: map[string]*leaseTimer{"test-client-ID": newLeaseTimer(1000)}, + remoteClientList: map[string]struct{}{}, clientMutex: sync.Mutex{}}, + args{args: []*api.Arg{&api.Arg{Data: []byte(requestOpKeepAlive)}, &api.Arg{Data: []byte("test-client-ID")}, + &api.Arg{Data: []byte("")}}}, + constant.InsReqSuccessCode, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{}) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + fm := &Manager{ + patKeyList: tt.fields.patKeyList, + VpcPlugin: tt.fields.VpcPlugin, + vpcMutex: tt.fields.vpcMutex, + remoteClientLease: tt.fields.remoteClientLease, + remoteClientList: tt.fields.remoteClientList, + clientMutex: tt.fields.clientMutex, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go fm.saveStateLoop() + libruntimeClient = &mockUtils.FakeLibruntimeSdkClient{} + if got := fm.ProcessSchedulerRequest(tt.args.args, ""); !reflect.DeepEqual(got.Code, tt.want) { + t.Errorf("ProcessSchedulerRequest() = %v, want %v", got, tt.want) + } + patches.ResetAll() + libruntimeClient = nil + }) + } +} + +func Test_handleRequestOpCreate(t *testing.T) { + convey.Convey("ProcessSchedulerRequest-handleRequestOpCreate", t, func() { + m := &Manager{ + patKeyList: make(map[string]map[string]struct{}), + VpcPlugin: vpcmanager.PluginVPC{}, + leaseRenewMinute: 5, + } + args := []*api.Arg{ + { + Type: 0, + Data: []byte(requestOpCreate), + }, + { + Type: 0, + Data: []byte{}, + }, + { + Type: 0, + Data: []byte("trace-id"), + }, + } + convey.Convey("failed to create vpc pat pod, error", func() { + + request := m.ProcessSchedulerRequest(args, "") + convey.So(request, convey.ShouldBeNil) + }) + + convey.Convey("json Marshal error", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&m.VpcPlugin), "CreateVpcResource", + func(_ *vpcmanager.PluginVPC, requestData []byte) (types.NATConfigure, error) { + return types.NATConfigure{}, nil + }).Reset() + defer gomonkey.ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, errors.New("json marshal error") + }).Reset() + request := m.ProcessSchedulerRequest(args, "") + convey.So(request, convey.ShouldBeNil) + }) + }) +} + +func Test_handleRequestOpCreateTrigger(t *testing.T) { + convey.Convey("ProcessSchedulerRequest-handleRequestOpCreateTrigger", t, func() { + m := &Manager{ + patKeyList: make(map[string]map[string]struct{}), + VpcPlugin: vpcmanager.PluginVPC{}, + leaseRenewMinute: 5, + } + args := []*api.Arg{ + { + Type: 0, + Data: []byte(requestOpCreateTrigger), + }, + { + Type: 0, + Data: []byte{}, + }, + { + Type: 0, + Data: []byte("trace-id"), + }, + } + convey.Convey("json.Unmarshal error", func() { + defer gomonkey.ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error { + return errors.New("json unmarshal error") + }).Reset() + request := m.ProcessSchedulerRequest(args, "") + convey.So(request, convey.ShouldBeNil) + }) + + convey.Convey("invalid pull trigger name", func() { + info := &types.PullTriggerRequestInfo{PodName: "podname/invalid"} + bytes, _ := json.Marshal(info) + args[1].Data = bytes + request := m.ProcessSchedulerRequest(args, "") + convey.So(request, convey.ShouldBeNil) + }) + + convey.Convey("failed to create pullTrigger vpc pat pod, error", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&m.VpcPlugin), "CreateVpcResource", + func(_ *vpcmanager.PluginVPC, requestData []byte) (types.NATConfigure, error) { + return types.NATConfigure{}, errors.New("failed to create pullTrigger vpc pat pod") + }).Reset() + info := &types.PullTriggerRequestInfo{PodName: "podname/runtime-manager/valid"} + bytes, _ := json.Marshal(info) + args[1].Data = bytes + request := m.ProcessSchedulerRequest(args, "") + convey.So(request, convey.ShouldBeNil) + }) + + convey.Convey("failed to Marshal PatPod", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&m.VpcPlugin), "CreateVpcResource", + func(_ *vpcmanager.PluginVPC, requestData []byte) (types.NATConfigure, error) { + return types.NATConfigure{}, nil + }).Reset() + info := &types.PullTriggerRequestInfo{PodName: "podname/runtime-manager/valid"} + bytes, _ := json.Marshal(info) + args[1].Data = bytes + defer gomonkey.ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, errors.New("json marshal error") + }).Reset() + request := m.ProcessSchedulerRequest(args, "") + convey.So(request, convey.ShouldBeNil) + }) + + convey.Convey("failed to create deploy %s by k8s, error: %s", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&m.VpcPlugin), "CreateVpcResource", + func(_ *vpcmanager.PluginVPC, requestData []byte) (types.NATConfigure, error) { + return types.NATConfigure{}, nil + }).Reset() + defer gomonkey.ApplyFunc(utils.GetDeployByK8S, + func(k8sClient kubernetes.Interface, deployName string) (*v1.Deployment, error) { + return &v1.Deployment{}, k8serror.NewNotFound(schema.GroupResource{}, "runtime-manager") + }).Reset() + defer gomonkey.ApplyFunc(utils.CreateDeployByK8S, + func(k8sClient kubernetes.Interface, deploy *v1.Deployment) error { + return errors.New("create deploy error") + }).Reset() + info := &types.PullTriggerRequestInfo{PodName: "podname/runtime-manager/valid"} + bytes, _ := json.Marshal(info) + args[1].Data = bytes + request := m.ProcessSchedulerRequest(args, "") + convey.So(request, convey.ShouldBeNil) + }) + + convey.Convey("success", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&m.VpcPlugin), "CreateVpcResource", + func(_ *vpcmanager.PluginVPC, requestData []byte) (types.NATConfigure, error) { + return types.NATConfigure{}, nil + }).Reset() + defer gomonkey.ApplyFunc(utils.GetDeployByK8S, + func(k8sClient kubernetes.Interface, deployName string) (*v1.Deployment, error) { + return &v1.Deployment{}, nil + }).Reset() + info := &types.PullTriggerRequestInfo{PodName: "podname/runtime-manager/valid"} + bytes, _ := json.Marshal(info) + args[1].Data = bytes + request := m.ProcessSchedulerRequest(args, "") + convey.So(request.Code, convey.ShouldEqual, 6030) + }) + }) +} + +func Test_handleRequestOpReport(t *testing.T) { + convey.Convey("handleRequestOpReport", t, func() { + m := &Manager{ + patKeyList: make(map[string]map[string]struct{}), + VpcPlugin: vpcmanager.PluginVPC{}, + leaseRenewMinute: 5, + } + args := []*api.Arg{ + { + Type: 0, + Data: []byte(requestOpReport), + }, + { + Type: 0, + Data: []byte{}, + }, + { + Type: 0, + Data: []byte("trace-id"), + }, + } + convey.Convey("failed to unmarshal report info", func() { + defer gomonkey.ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error { + return errors.New("json unmarshal error") + }).Reset() + request := m.ProcessSchedulerRequest(args, "") + convey.So(request, convey.ShouldBeNil) + }) + + convey.Convey("", func() { + report := &types.ReportInfo{PatPodName: "test"} + bytes, _ := json.Marshal(report) + args[1].Data = bytes + m.patKeyList["test"] = make(map[string]struct{}) + request := m.ProcessSchedulerRequest(args, "") + convey.So(request.Message, convey.ShouldEqual, "succeed add pat list") + }) + }) +} + +func TestMiscellaneous(t *testing.T) { + convey.Convey(" patKeyList is nil", t, func() { + m := &Manager{ + patKeyList: nil, + leaseRenewMinute: 5, + } + convey.Convey("addPatList", func() { + m.addPatList("123", "456") + convey.So(m.patKeyList, convey.ShouldBeNil) + }) + convey.Convey("checkExists", func() { + res := m.checkExists("123") + convey.So(res, convey.ShouldBeFalse) + }) + convey.Convey("deletePatList", func() { + podName, needDeletePat := m.deletePatList("123") + convey.So(podName, convey.ShouldEqual, "") + convey.So(needDeletePat, convey.ShouldBeFalse) + }) + }) + convey.Convey("deletePatList", t, func() { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + m := &Manager{ + patKeyList: map[string]map[string]struct{}{ + "patName123": map[string]struct{}{ + "instanceID123": struct{}{}, + }, + }, + leaseRenewMinute: 5, + } + convey.Convey("deletePatList", func() { + podName, needDeletePat := m.deletePatList("instanceID123") + convey.So(podName, convey.ShouldEqual, "patName123") + convey.So(needDeletePat, convey.ShouldBeTrue) + }) + }) +} + +func Test_handleLeaseTimeout(t *testing.T) { + convey.Convey("handleLeaseTimeout", t, func() { + m := &Manager{ + patKeyList: make(map[string]map[string]struct{}), + VpcPlugin: vpcmanager.PluginVPC{}, + vpcMutex: sync.Mutex{}, + remoteClientLease: map[string]*leaseTimer{}, + remoteClientList: map[string]struct{}{}, + clientMutex: sync.Mutex{}, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go m.saveStateLoop() + args := []*api.Arg{ + { + Type: 0, + Data: []byte(requestOpNewLease), + }, + { + Type: 0, + Data: []byte("test-client-ID"), + }, + { + Type: 0, + Data: []byte("trace-id"), + }, + } + convey.Convey("handle client lease timeout", func() { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) {}).Reset() + libruntimeClient = &mockUtils.FakeLibruntimeSdkClient{} + request := m.ProcessSchedulerRequest(args, "test-trace-ID") + convey.So(request.Message, convey.ShouldEqual, "Succeed to create lease") + convey.So(len(m.remoteClientList), convey.ShouldEqual, 1) + convey.So(m.remoteClientLease["test-client-ID"], convey.ShouldNotBeNil) + m.remoteClientLease["test-client-ID"].Timer.Reset(1 * time.Second) + time.Sleep(2 * time.Second) + convey.So(len(m.remoteClientList), convey.ShouldEqual, 0) + convey.So(len(m.remoteClientLease), convey.ShouldEqual, 0) + libruntimeClient = nil + }) + }) +} + +func Test_handleDelLease(t *testing.T) { + convey.Convey("handleDelLease", t, func() { + m := &Manager{ + patKeyList: make(map[string]map[string]struct{}), + VpcPlugin: vpcmanager.PluginVPC{}, + vpcMutex: sync.Mutex{}, + remoteClientLease: map[string]*leaseTimer{"test-client-ID": newLeaseTimer(1000)}, + remoteClientList: map[string]struct{}{"test-client-ID": {}}, + clientMutex: sync.Mutex{}, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go m.saveStateLoop() + args := []*api.Arg{ + { + Type: 0, + Data: []byte(requestOpDelLease), + }, + { + Type: 0, + Data: []byte("test-client-ID"), + }, + { + Type: 0, + Data: []byte("trace-id"), + }, + } + libruntimeClient = &mockUtils.FakeLibruntimeSdkClient{} + convey.Convey("delete no exited lease", func() { + args[1].Data = []byte("test-client-ID-error") + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) {}).Reset() + request := m.ProcessSchedulerRequest(args, "test-trace-ID") + convey.So(request.Message, convey.ShouldEqual, "Succeed to delete lease") + convey.So(len(m.remoteClientList), convey.ShouldEqual, 1) + convey.So(len(m.remoteClientLease), convey.ShouldEqual, 1) + }) + convey.Convey("success to delete lease", func() { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) {}).Reset() + request := m.ProcessSchedulerRequest(args, "test-trace-ID") + convey.So(request.Message, convey.ShouldEqual, "Succeed to delete lease") + convey.So(len(m.remoteClientList), convey.ShouldEqual, 0) + convey.So(len(m.remoteClientLease), convey.ShouldEqual, 0) + }) + libruntimeClient = nil + }) +} + +func TestManager_handleLeaseEvent(t *testing.T) { + libruntimeClient = &mockUtils.FakeLibruntimeSdkClient{} + convey.Convey("test handleLeaseEvent", t, func() { + convey.Convey("baseline", func() { + p := gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{Client: &clientv3.Client{KV: &KvMock{}}} + }) + defer p.Reset() + ch := make(chan struct{}) + fm := &Manager{ + patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, + remoteClientLease: make(map[string]*leaseTimer), + remoteClientList: map[string]struct{}{}, + clientMutex: sync.Mutex{}, + stopCh: ch, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go fm.saveStateLoop() + e := &commonType.LeaseEvent{ + Type: constant.NewLease, + RemoteClientID: "123456", + Timestamp: time.Now().Unix(), + TraceID: "abc", + } + marshal, _ := json.Marshal(e) + event := &etcd3.Event{ + Type: 0, + Key: "/sn/lease/123456", + Value: marshal, + PrevValue: nil, + Rev: 0, + } + fm.handleLeaseEvent(event) + convey.So(fm.remoteClientLease["123456"], convey.ShouldNotBeNil) + + e = &commonType.LeaseEvent{ + Type: constant.DelLease, + RemoteClientID: "123456", + Timestamp: time.Now().Unix(), + TraceID: "abcd", + } + marshal, _ = json.Marshal(e) + event = &etcd3.Event{ + Type: 0, + Key: "/sn/lease/123456", + Value: marshal, + PrevValue: nil, + Rev: 0, + } + fm.handleLeaseEvent(event) + convey.So(fm.remoteClientLease["123456"], convey.ShouldBeNil) + + e = &commonType.LeaseEvent{ + Type: constant.KeepAlive, + RemoteClientID: "123456", + Timestamp: time.Now().Unix(), + TraceID: "abcde", + } + marshal, _ = json.Marshal(e) + event = &etcd3.Event{ + Type: 0, + Key: "/sn/lease/123456", + Value: marshal, + PrevValue: nil, + Rev: 0, + } + fm.handleLeaseEvent(event) + time.Sleep(100 * time.Millisecond) + convey.So(fm.remoteClientLease["123456"], convey.ShouldNotBeNil) + + e = &commonType.LeaseEvent{ + Type: constant.KeepAlive, + RemoteClientID: "123456", + Timestamp: time.Now().Unix(), + TraceID: "abcdef", + } + marshal, _ = json.Marshal(e) + event = &etcd3.Event{ + Type: 0, + Key: "/sn/lease/123456", + Value: marshal, + PrevValue: nil, + Rev: 0, + } + fm.handleLeaseEvent(event) + convey.So(fm.remoteClientLease["123456"], convey.ShouldNotBeNil) + }) + + convey.Convey("event is nil", func() { + ch := make(chan struct{}) + fm := &Manager{ + patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, + remoteClientLease: make(map[string]*leaseTimer), + remoteClientList: map[string]struct{}{}, + clientMutex: sync.Mutex{}, + stopCh: ch, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go fm.saveStateLoop() + fm.handleLeaseEvent(nil) + convey.So(len(fm.remoteClientLease), convey.ShouldEqual, 0) + }) + + convey.Convey("unmarshal failed", func() { + ch := make(chan struct{}) + fm := &Manager{ + patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, + remoteClientLease: make(map[string]*leaseTimer), + remoteClientList: map[string]struct{}{}, + clientMutex: sync.Mutex{}, + stopCh: ch, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go fm.saveStateLoop() + event := &etcd3.Event{ + Type: 0, + Key: "/sn/lease/123456", + Value: []byte("aaa"), + PrevValue: nil, + Rev: 0, + } + fm.handleLeaseEvent(event) + convey.So(len(fm.remoteClientLease), convey.ShouldEqual, 0) + }) + + convey.Convey("error type", func() { + ch := make(chan struct{}) + fm := &Manager{ + patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, + remoteClientLease: make(map[string]*leaseTimer), + remoteClientList: map[string]struct{}{}, + clientMutex: sync.Mutex{}, + stopCh: ch, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go fm.saveStateLoop() + e := &commonType.LeaseEvent{ + Type: "aaa", + RemoteClientID: "123456", + Timestamp: time.Now().Unix(), + TraceID: "abcdef", + } + marshal, _ := json.Marshal(e) + event := &etcd3.Event{ + Type: 0, + Key: "/sn/lease/123456", + Value: marshal, + PrevValue: nil, + Rev: 0, + } + fm.handleLeaseEvent(event) + convey.So(len(fm.remoteClientLease), convey.ShouldEqual, 0) + }) + + }) + libruntimeClient = nil +} + +func TestManager_WatchLeaseEvent(t *testing.T) { + libruntimeClient = &mockUtils.FakeLibruntimeSdkClient{} + convey.Convey("test watch lease event", t, func() { + convey.Convey("baseline", func() { + p := gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{Client: &clientv3.Client{KV: &KvMock{}}} + }) + defer p.Reset() + p2 := gomonkey.ApplyFunc((*etcd3.EtcdWatcher).StartWatch, func(ew *etcd3.EtcdWatcher) { + ew.ResultChan <- &etcd3.Event{ + Type: 0, + Key: "/sn/lease/123456", + Value: nil, + PrevValue: nil, + Rev: 0, + } + }) + defer p2.Reset() + flag := false + p3 := gomonkey.ApplyFunc((*Manager).handleLeaseEvent, func(_ *Manager) { + flag = true + }) + defer p3.Reset() + ch := make(chan struct{}) + fm := &Manager{ + patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, + remoteClientLease: make(map[string]*leaseTimer), + remoteClientList: map[string]struct{}{}, + clientMutex: sync.Mutex{}, + stopCh: ch, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go fm.saveStateLoop() + go fm.WatchLeaseEvent() + time.Sleep(100 * time.Millisecond) + convey.So(flag, convey.ShouldBeTrue) + ch <- struct{}{} + }) + }) + libruntimeClient = nil +} + +func TestManager_RecoverData(t *testing.T) { + libruntimeClient = &mockUtils.FakeLibruntimeSdkClient{} + convey.Convey("test recover data", t, func() { + convey.Convey("baseline", func() { + p := gomonkey.ApplyFunc(state.GetState, func() *state.ManagerState { + return &state.ManagerState{ + PatKeyList: map[string]map[string]struct{}{ + "a": {"a": struct{}{}}, + }, + RemoteClientList: map[string]struct{}{"aaa": {}, "bbb": {}}, + } + }) + defer p.Reset() + fm := &Manager{ + patKeyList: make(map[string]map[string]struct{}, 1), + vpcMutex: sync.Mutex{}, + remoteClientLease: make(map[string]*leaseTimer), + remoteClientList: map[string]struct{}{}, + clientMutex: sync.Mutex{}, + stopCh: nil, + queue: &queue{ + cond: sync.NewCond(&sync.RWMutex{}), + dirty: map[interface{}]struct{}{}, + processing: map[interface{}]struct{}{}, + }, + leaseRenewMinute: 5, + } + go fm.saveStateLoop() + fm.RecoverData() + convey.So(len(fm.remoteClientLease), convey.ShouldEqual, 2) + convey.So(len(fm.patKeyList), convey.ShouldEqual, 1) + }) + }) + libruntimeClient = nil +} diff --git a/yuanrong/pkg/functionmanager/queue.go b/yuanrong/pkg/functionmanager/queue.go new file mode 100644 index 0000000..1425323 --- /dev/null +++ b/yuanrong/pkg/functionmanager/queue.go @@ -0,0 +1,120 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faasmanager - +package functionmanager + +import ( + "sync" +) + +// item item +type item interface { + String() string + Done(err error) +} + +type set map[interface{}]struct{} + +type queue struct { + cond *sync.Cond + + // Waiting for the execution of the set + dirty set + + // Executing set + processing set + + queue []item + + shutDownFlag bool +} + +func (s set) has(item string) bool { + _, exists := s[item] + return exists +} + +func (s set) insert(item string) { + s[item] = struct{}{} +} + +func (s set) delete(item string) { + delete(s, item) +} + +func (q *queue) add(item item) bool { + q.cond.L.Lock() + defer q.cond.L.Unlock() + if q.shutDownFlag { + return false + } + if q.dirty.has(item.String()) { + return false + } + q.dirty.insert(item.String()) + if q.processing.has(item.String()) { + return false + } + q.queue = append(q.queue, item) + q.cond.Signal() + return true +} + +func (q *queue) len() int { + q.cond.L.Lock() + l := len(q.queue) + q.cond.L.Unlock() + return l +} + +func (q *queue) getAll() ([]item, bool) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + for len(q.queue) == 0 && !q.shutDownFlag { + q.cond.Wait() + } + if len(q.queue) == 0 { + // We must be shutting down. + return nil, true + } + items := q.queue + q.queue = []item{} + for _, v := range items { + q.processing.insert(v.String()) + q.dirty.delete(v.String()) + } + return items, false +} + +func (q *queue) shutDown() { + q.cond.L.Lock() + defer q.cond.L.Unlock() + q.shutDownFlag = true + q.cond.Broadcast() +} + +func (q *queue) doneAll(items []item) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + for _, v := range items { + q.processing.delete(v.String()) + if q.dirty.has(v.String()) { + q.queue = append(q.queue, v) + } + } + q.cond.Signal() +} diff --git a/yuanrong/pkg/functionmanager/state/state.go b/yuanrong/pkg/functionmanager/state/state.go new file mode 100644 index 0000000..a043fb1 --- /dev/null +++ b/yuanrong/pkg/functionmanager/state/state.go @@ -0,0 +1,133 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package state - +package state + +import ( + "encoding/json" + "fmt" + "sync" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/state" +) + +// ManagerState add the status to be saved here. +type ManagerState struct { + PatKeyList map[string]map[string]struct{} `json:"PatKeyList" valid:"optional"` + RemoteClientList map[string]struct{} `json:"RemoteClientList" valid:"optional"` +} + +const ( + defaultHandlerQueueSize = 2000 + stateKey = "/faas/state/recover/faasmanager" +) + +var ( + // managerState - + managerState = &ManagerState{ + PatKeyList: make(map[string]map[string]struct{}), + RemoteClientList: map[string]struct{}{}, + } + // ManagerStateLock - + ManagerStateLock sync.RWMutex + managerHandlerQueue *state.Queue +) + +// InitState - +func InitState() { + if managerHandlerQueue != nil { + return + } + managerHandlerQueue = state.NewStateQueue(defaultHandlerQueueSize) + if managerHandlerQueue == nil { + return + } + go managerHandlerQueue.Run(updateState) +} + +// SetState - +func SetState(byte []byte) error { + return json.Unmarshal(byte, managerState) +} + +// GetState - +func GetState() *ManagerState { + ManagerStateLock.RLock() + defer ManagerStateLock.RUnlock() + return managerState +} + +// GetStateByte is used to obtain the local state +func GetStateByte() ([]byte, error) { + if managerHandlerQueue == nil { + return nil, fmt.Errorf("managerHandlerQueue is not initialized") + } + ManagerStateLock.RLock() + defer ManagerStateLock.RUnlock() + stateBytes, err := managerHandlerQueue.GetState(stateKey) + if err != nil { + return nil, err + } + if err = json.Unmarshal(stateBytes, managerState); err != nil { + log.GetLogger().Errorf("update managerState error :%s", err.Error()) + } + log.GetLogger().Debugf("get state from etcd managerState: %v", string(stateBytes)) + return stateBytes, nil +} + +func updateState(value interface{}, tags ...string) { + if managerHandlerQueue == nil { + log.GetLogger().Errorf("manager state managerHandlerQueue is nil") + return + } + ManagerStateLock.Lock() + defer ManagerStateLock.Unlock() + switch v := value.(type) { + case map[string]map[string]struct{}: + managerState.PatKeyList = v + log.GetLogger().Infof("update manager state for PatKeyList") + case map[string]struct{}: + managerState.RemoteClientList = v + log.GetLogger().Infof("update manager state for RemoteClientList") + default: + log.GetLogger().Warnf("unknown data type for ManagerState") + return + } + + state, err := json.Marshal(managerState) + if err != nil { + log.GetLogger().Errorf("get manager state error %s", err.Error()) + return + } + if err = managerHandlerQueue.SaveState(state, stateKey); err != nil { + log.GetLogger().Errorf("save manager state error: %s", err.Error()) + return + } + log.GetLogger().Infof("update manager state: %v", string(state)) +} + +// Update is used to write manger state to the cache queue +func Update(value interface{}, tags ...string) { + if managerHandlerQueue == nil { + log.GetLogger().Errorf("manager state managerHandlerQueue is nil") + return + } + if err := managerHandlerQueue.Push(value, tags...); err != nil { + log.GetLogger().Errorf("failed to push state to state queue: %s", err.Error()) + } +} diff --git a/yuanrong/pkg/functionmanager/state/state_test.go b/yuanrong/pkg/functionmanager/state/state_test.go new file mode 100644 index 0000000..9f21712 --- /dev/null +++ b/yuanrong/pkg/functionmanager/state/state_test.go @@ -0,0 +1,146 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package state + +import ( + "encoding/json" + "errors" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/state" + commonstate "yuanrong/pkg/common/faas_common/state" + "yuanrong/pkg/functionmanager/types" +) + +func TestInitState(t *testing.T) { + convey.Convey("InitState success", t, func() { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + InitState() + }) +} + +func TestOptState(t *testing.T) { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "Put", func(_ *etcd3.EtcdClient, + ctxInfo etcd3.EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return nil + }).Reset() + InitState() + managerState = &ManagerState{ + PatKeyList: make(map[string]map[string]struct{}), + } + stateByte, _ := json.Marshal(managerState) + + convey.Convey("set state", t, func() { + err := SetState(stateByte) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("get state byte", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(managerHandlerQueue), "GetState", + func(q *state.Queue, key string) ([]byte, error) { + return stateByte, nil + }).Reset() + msByte, err := GetStateByte() + outPut := &ManagerState{} + json.Unmarshal(msByte, outPut) + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestUpdateState(t *testing.T) { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "Put", func(_ *etcd3.EtcdClient, + ctxInfo etcd3.EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return nil + }).Reset() + InitState() + managerState = &ManagerState{ + PatKeyList: make(map[string]map[string]struct{}), + } + stateByte, _ := json.Marshal(managerState) + SetState(stateByte) + + convey.Convey("update PatKeyList success", t, func() { + Update(map[string]map[string]struct{}{ + "123": map[string]struct{}{ + "356": struct{}{}, + }, + }) + time.Sleep(100 * time.Millisecond) + convey.So(GetState().PatKeyList["123"], convey.ShouldContainKey, "356") + }) + convey.Convey("type is error", t, func() { + type custom struct{} + temp1 := *GetState() + Update(&custom{}) + time.Sleep(100 * time.Millisecond) + temp2 := *GetState() + convey.So(temp1, convey.ShouldResemble, temp2) + }) + convey.Convey("managerHandlerQueue is nil", t, func() { + managerHandlerQueue = nil + temp1 := *GetState() + Update(&types.ManagerConfig{ + FunctionCapability: 3, + }) + time.Sleep(100 * time.Millisecond) + temp2 := *GetState() + convey.So(temp1, convey.ShouldResemble, temp2) + }) + +} + +func Test_updateState(t *testing.T) { + convey.Convey("updateState", t, func() { + convey.Convey("manager state managerHandlerQueue is nil", func() { + managerHandlerQueue = nil + updateState("state1", "add") + }) + + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + convey.Convey("json.Marshal error", func() { + defer gomonkey.ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, errors.New("json marshal error") + }).Reset() + managerHandlerQueue = state.NewStateQueue(defaultHandlerQueueSize) + updateState(&types.ManagerConfig{}, "add") + }) + + convey.Convey("save manager state error", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(managerHandlerQueue), "SaveState", + func(q *commonstate.Queue, state []byte) error { + return errors.New("save error") + }).Reset() + updateState(&types.ManagerConfig{}, "add") + }) + }) +} diff --git a/yuanrong/pkg/functionmanager/types/types.go b/yuanrong/pkg/functionmanager/types/types.go new file mode 100644 index 0000000..bdb3955 --- /dev/null +++ b/yuanrong/pkg/functionmanager/types/types.go @@ -0,0 +1,135 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +import ( + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/etcd3" +) + +const ( + // KillSignalVal Kill instances of job + KillSignalVal = 2 // Kill instances of job +) + +// ManagerConfig is the config used by faas frontend function +type ManagerConfig struct { + HTTPSEnable bool `json:"httpsEnable" valid:"optional"` + FunctionCapability int `json:"functionCapability" valid:"optional"` + AuthenticationEnable bool `json:"authenticationEnable" valid:"optional"` + LeaseRenewMinute int `json:"leaseRenewMinute" valid:"optional"` + RouterEtcd etcd3.EtcdConfig `json:"routerEtcd" valid:"optional"` + MetaEtcd etcd3.EtcdConfig `json:"metaEtcd" valid:"optional"` + AlarmConfig alarm.Config `json:"alarmConfig" valid:"optional"` + SccConfig crypto.SccConfig `json:"sccConfig" valid:"optional"` +} + +// NATConfigure include nat configure info for function-agent +type NATConfigure struct { + ContainerCidr string `json:"containerCidr"` + HostVMCidr string `json:"hostVmCidr"` + PatContainerIP string `json:"patContainerIP"` + PatVMIP string `json:"patVmIP"` + PatPortIP string `json:"patPortIP"` + PatMacAddr string `json:"patMacAddr"` + PatGateway string `json:"patGateway"` + PatPodName string `json:"patPodName"` + TenantCidr string `json:"tenantCidr"` + NatSubnetList map[string][]string `json:"natSubnetList"` + IsDeleted bool `json:"isDeleted"` + IsNewCreated bool `json:"isNewCreated"` +} + +// VpcControllerRequester define request message for vpc_controller +type VpcControllerRequester struct { + TraceID string `json:"trace_id"` + Delegate Delegate `json:"delegate"` + Vpc Vpc `json:"vpc"` +} + +// Delegate include Xrole and AppXrole +type Delegate struct { + Xrole string `json:"xrole,omitempty"` + AppXrole string `json:"app_xrole,omitempty"` +} + +// Vpc include info of function vpc +type Vpc struct { + ID string `json:"id,omitempty"` + DomainID string `json:"domain_id,omitempty"` + Namespace string `json:"namespace,omitempty"` + VpcName string `json:"vpc_name,omitempty"` + VpcID string `json:"vpc_id,omitempty"` + SubnetName string `json:"subnet_name,omitempty"` + SubnetID string `json:"subnet_id,omitempty"` + TenantCidr string `json:"tenant_cidr,omitempty"` + HostVMCidr string `json:"host_vm_cidr,omitempty"` + Gateway string `json:"gateway,omitempty"` + Xrole string `json:"xrole,omitempty"` +} + +// RequestInfo include info of request Option Create +type RequestInfo struct { + ID string `json:"id,omitempty"` + DomainID string `json:"domain_id,omitempty"` + Namespace string `json:"namespace,omitempty"` + VpcName string `json:"vpc_name,omitempty"` + VpcID string `json:"vpc_id,omitempty"` + SubnetName string `json:"subnet_name,omitempty"` + SubnetID string `json:"subnet_id,omitempty"` + TenantCidr string `json:"tenant_cidr,omitempty"` + HostVMCidr string `json:"host_vm_cidr,omitempty"` + Gateway string `json:"gateway,omitempty"` + Xrole string `json:"xrole,omitempty"` + AppXrole string `json:"app_xrole,omitempty"` +} + +// PullTriggerRequestInfo include info of pullTrigger Option Create +type PullTriggerRequestInfo struct { + PodName string `json:"pod_name"` + Image string `json:"image"` + DomainID string `json:"domain_id,omitempty"` + Namespace string `json:"namespace,omitempty"` + VpcName string `json:"vpc_name,omitempty"` + VpcID string `json:"vpc_id,omitempty"` + SubnetName string `json:"subnet_name,omitempty"` + SubnetID string `json:"subnet_id,omitempty"` + TenantCidr string `json:"tenant_cidr,omitempty"` + HostVMCidr string `json:"host_vm_cidr,omitempty"` + ContainerCidr string `json:"container_cidr"` + Gateway string `json:"gateway,omitempty"` + Xrole string `json:"xrole,omitempty"` + AppXrole string `json:"app_xrole,omitempty"` +} + +// ReportInfo include info of request Option report +type ReportInfo struct { + PatPodName string `json:"patPodName,omitempty"` + InstanceID string `json:"instanceID,omitempty"` +} + +// DeleteInfo include info of request Option delete +type DeleteInfo struct { + InstanceID string `json:"instanceID,omitempty"` +} + +// PullTriggerDeleteInfo include info of pullTrigger Option delete +type PullTriggerDeleteInfo struct { + PodName string `json:"pod_name,omitempty"` +} diff --git a/yuanrong/pkg/functionmanager/utils/utils.go b/yuanrong/pkg/functionmanager/utils/utils.go new file mode 100644 index 0000000..764fd60 --- /dev/null +++ b/yuanrong/pkg/functionmanager/utils/utils.go @@ -0,0 +1,123 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "context" + "errors" + "strings" + + v1 "k8s.io/api/apps/v1" + k8serror "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" +) + +const ( + triggerNameIndex = 1 + defaultPodNameLen = 3 +) + +var ( + kubeClient kubernetes.Interface + // ErrK8SClientNil - + ErrK8SClientNil = errors.New("kubernetes client is nil") +) + +// GenerateErrorResponse - +func GenerateErrorResponse(errorCode int, errorMessage string) *types.CallHandlerResponse { + return &types.CallHandlerResponse{ + Code: errorCode, + Message: errorMessage, + } +} + +// GenerateSuccessResponse - +func GenerateSuccessResponse(code int, message string) *types.CallHandlerResponse { + return &types.CallHandlerResponse{ + Code: code, + Message: message, + } +} + +// HandlePullTriggerName - +func HandlePullTriggerName(podName string) string { + splits := strings.Split(podName, "/") + if len(splits) != defaultPodNameLen { + return "" + } + return splits[triggerNameIndex] +} + +// InitKubeClient initializes kubernetes client +func InitKubeClient() error { + kubeConfig, err := rest.InClusterConfig() + if err != nil { + log.GetLogger().Errorf("failed to get token and ca for kubernetes, error: %s", err.Error()) + return err + } + + kubeClient, err = kubernetes.NewForConfig(kubeConfig) + if err != nil { + log.GetLogger().Errorf("failed to create kubernetes client, error: %s", err.Error()) + return err + } + log.GetLogger().Infof("succeed to create kubeClient") + return nil +} + +// GetKubeClient return kubernetes client +func GetKubeClient() kubernetes.Interface { + return kubeClient +} + +// GetDeployByK8S get deployment by k8s +func GetDeployByK8S(k8sClient kubernetes.Interface, deployName string) (*v1.Deployment, error) { + if k8sClient == nil { + log.GetLogger().Errorf("failed to get k8sClient") + return nil, ErrK8SClientNil + } + return k8sClient.AppsV1().Deployments("default").Get(context.TODO(), deployName, metav1.GetOptions{}) +} + +// CreateDeployByK8S create deployment by k8s +func CreateDeployByK8S(k8sClient kubernetes.Interface, deploy *v1.Deployment) error { + if k8sClient == nil { + log.GetLogger().Errorf("failed to get k8sClient") + return ErrK8SClientNil + } + _, err := k8sClient.AppsV1().Deployments("default").Create(context.TODO(), deploy, metav1.CreateOptions{}) + if err != nil && !k8serror.IsAlreadyExists(err) { + log.GetLogger().Errorf("failed to create deploy %s", err.Error()) + return err + } + return nil +} + +// DeleteDeployByK8S delete deployment by k8s +func DeleteDeployByK8S(k8sClient kubernetes.Interface, name string) error { + if k8sClient == nil { + log.GetLogger().Errorf("failed to get k8sClient") + return ErrK8SClientNil + } + return k8sClient.AppsV1().Deployments("default").Delete(context.TODO(), name, metav1.DeleteOptions{}) +} diff --git a/yuanrong/pkg/functionmanager/utils/utils_test.go b/yuanrong/pkg/functionmanager/utils/utils_test.go new file mode 100644 index 0000000..030536b --- /dev/null +++ b/yuanrong/pkg/functionmanager/utils/utils_test.go @@ -0,0 +1,212 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package utils + +import ( + "errors" + "fmt" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + appsv1 "k8s.io/api/apps/v1" + v1 "k8s.io/api/apps/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/rest" + clienttesting "k8s.io/client-go/testing" + + "yuanrong/pkg/common/faas_common/types" +) + +func TestGenerateErrorResponse(t *testing.T) { + type args struct { + errorCode int + errorMessage string + } + var a args + a.errorCode = 0 + a.errorMessage = "0" + resp := &types.CallHandlerResponse{ + Code: 0, + Message: "0", + } + tests := []struct { + name string + args args + want *types.CallHandlerResponse + }{ + {"case1", a, resp}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GenerateErrorResponse(tt.args.errorCode, tt.args.errorMessage); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GenerateErrorResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGenerateSuccessResponse(t *testing.T) { + type args struct { + code int + message string + } + var a args + a.code = 0 + a.message = "0" + resp := &types.CallHandlerResponse{ + Code: 0, + Message: "0", + } + tests := []struct { + name string + args args + want *types.CallHandlerResponse + }{ + {"case1", a, resp}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GenerateSuccessResponse(tt.args.code, tt.args.message); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GenerateSuccessResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestInitKubeClient(t *testing.T) { + kubeClient = nil + convey.Convey("failed to get config", t, func() { + defer gomonkey.ApplyFunc(rest.InClusterConfig, func() (*rest.Config, error) { + return nil, fmt.Errorf("get config error") + }).Reset() + InitKubeClient() + convey.So(kubeClient, convey.ShouldBeNil) + }) + defer gomonkey.ApplyFunc(rest.InClusterConfig, func() (*rest.Config, error) { + return &rest.Config{}, nil + }).Reset() + + convey.Convey("failed to get config", t, func() { + defer gomonkey.ApplyFunc(kubernetes.NewForConfig, func(c *rest.Config) (*kubernetes.Clientset, error) { + return nil, fmt.Errorf("get client error") + }).Reset() + InitKubeClient() + convey.So(kubeClient, convey.ShouldBeNil) + }) + defer gomonkey.ApplyFunc(kubernetes.NewForConfig, func(c *rest.Config) (*kubernetes.Clientset, error) { + return &kubernetes.Clientset{}, nil + }).Reset() + convey.Convey("init success", t, func() { + InitKubeClient() + convey.So(GetKubeClient(), convey.ShouldNotBeNil) + }) +} + +func TestGetDeployByK8S(t *testing.T) { + convey.Convey("kubeClient is nil", t, func() { + _, err := GetDeployByK8S(nil, "deployName") + convey.So(err, convey.ShouldEqual, ErrK8SClientNil) + }) + + fakeDeployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-deployment", + Namespace: "default", + }, + } + clientset := fake.NewSimpleClientset(fakeDeployment) + clientset.PrependReactor("get", "deployments", + func(action clienttesting.Action) (handled bool, ret runtime.Object, err error) { + getAction := action.(clienttesting.GetAction) + fakeDeployment.Name = getAction.GetName() + return true, fakeDeployment, nil + }) + + convey.Convey("GetDeployByK8S success", t, func() { + _, err := GetDeployByK8S(clientset, "deployName") + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestCreateDeployByK8S(t *testing.T) { + convey.Convey("kubeClient is nil", t, func() { + err := CreateDeployByK8S(nil, &v1.Deployment{}) + convey.So(err, convey.ShouldEqual, ErrK8SClientNil) + }) + + clientset := fake.NewSimpleClientset() + clientset.PrependReactor("create", "deployments", + func(action clienttesting.Action) (handled bool, ret runtime.Object, err error) { + createAction := action.(clienttesting.CreateAction) + if createAction.GetObject().(*appsv1.Deployment).Name == "my-deployment" { + return true, nil, errors.New("mock error") + } else { + fakeDeployment := createAction.GetObject().(*appsv1.Deployment) + fakeDeployment.Name = "fake-deployment" + return true, fakeDeployment, nil + } + }) + + convey.Convey("create error", t, func() { + deployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-deployment", + Namespace: "default", + }, + } + err := CreateDeployByK8S(clientset, deployment) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("create success", t, func() { + deployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "other-deployment", + Namespace: "default", + }, + } + err := CreateDeployByK8S(clientset, deployment) + convey.So(err, convey.ShouldBeNil) + }) + +} + +func TestDeleteDeployByK8S(t *testing.T) { + convey.Convey("kubeClient is nil", t, func() { + err := DeleteDeployByK8S(nil, "123") + convey.So(err, convey.ShouldEqual, ErrK8SClientNil) + }) + + clientset := fake.NewSimpleClientset() + clientset.PrependReactor("delete", "deployments", + func(action clienttesting.Action) (handled bool, ret runtime.Object, err error) { + deleteAction := action.(clienttesting.DeleteAction) + if deleteAction.GetName() == "my-deployment" { + return true, nil, errors.New("mock error") + } else { + return true, nil, nil + } + }) + convey.Convey("kubeClient is nil", t, func() { + err := DeleteDeployByK8S(clientset, "other-deployment") + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/functionmanager/vpcmanager/plugin.go b/yuanrong/pkg/functionmanager/vpcmanager/plugin.go new file mode 100644 index 0000000..9817cdb --- /dev/null +++ b/yuanrong/pkg/functionmanager/vpcmanager/plugin.go @@ -0,0 +1,157 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package vpcmanager - +package vpcmanager + +import ( + "encoding/json" + "errors" + "plugin" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionmanager/types" +) + +// PluginVPC - +type PluginVPC struct { + *plugin.Plugin + createResource plugin.Symbol + deleteResource plugin.Symbol +} + +// InitVpcPlugin - +func (p *PluginVPC) InitVpcPlugin() error { + err := p.initController() + if err != nil { + log.GetLogger().Errorf("failed to exec initController, error: %s", err.Error()) + return err + } + log.GetLogger().Infof("succeed to exec initController") + return nil +} + +func (p *PluginVPC) initController() error { + if p.Plugin == nil { + log.GetLogger().Errorf("pluginVpc is nil") + return errors.New("pluginVpc is nil") + } + targetFunc, err := p.Plugin.Lookup("InitController") + if err != nil { + log.GetLogger().Errorf("failed to look up InitController, error: %s", err.Error()) + return err + } + if initController, ok := targetFunc.(func() error); ok { + err = initController() + if err != nil { + log.GetLogger().Errorf("failed to init Controller, error: %s", err.Error()) + return err + } + } + createResource, err := p.Plugin.Lookup("CreateResource") + if err != nil { + log.GetLogger().Errorf("failed to lookup function of CreateResource: %s", err.Error()) + return err + } + p.createResource = createResource + deleteResource, err := p.Plugin.Lookup("DeleteResource") + if err != nil { + log.GetLogger().Errorf("failed to lookup function of DeleteResource: %s", err.Error()) + return err + } + p.deleteResource = deleteResource + return nil +} + +// CreateVpcResource - +func (p *PluginVPC) CreateVpcResource(requestData []byte) (types.NATConfigure, error) { + if p.Plugin == nil { + log.GetLogger().Errorf("pluginVpc is nil") + return types.NATConfigure{}, errors.New("pluginVpc is nil") + } + patPod := types.NATConfigure{} + if createResource, ok := p.createResource.(func(request []byte) ([]byte, error)); ok { + request := describeRequest(requestData) + reqInfo, err := json.Marshal(request) + if err != nil { + log.GetLogger().Errorf("HandleVpcFunctionInfo Marshal error: %s", err.Error()) + return types.NATConfigure{}, err + } + resp, err := createResource(reqInfo) + if err != nil { + log.GetLogger().Errorf("HandleVpcFunctionInfo createResource error: %s", err.Error()) + return types.NATConfigure{}, err + } + err = json.Unmarshal(resp, &patPod) + if err != nil { + log.GetLogger().Errorf("HandleVpcFunctionInfo Unmarshal error: %s", err.Error()) + return types.NATConfigure{}, err + } + } else { + log.GetLogger().Errorf("failed to assert createResource") + return types.NATConfigure{}, errors.New("failed to assert createResource") + } + return patPod, nil +} + +// DeleteVpcResource - +func (p *PluginVPC) DeleteVpcResource(patPodName string) error { + if p.Plugin == nil { + log.GetLogger().Errorf("pluginVpc is nil") + return errors.New("pluginVpc is nil") + } + if deleteResource, ok := p.deleteResource.(func(string2 string) error); ok { + err := deleteResource(patPodName) + if err != nil { + log.GetLogger().Errorf("failed to delete vpc resource: %s", err.Error()) + return err + } + } else { + log.GetLogger().Errorf("failed to assert deleteResource") + return errors.New("failed to assert deleteResource") + } + return nil +} + +// describeRequest get different requestInfo by vpcType +func describeRequest(requestData []byte) types.VpcControllerRequester { + requestInfo := types.RequestInfo{} + err := json.Unmarshal(requestData, &requestInfo) + if err != nil { + return types.VpcControllerRequester{} + } + vpcInfo := types.Vpc{ + Namespace: requestInfo.Namespace, + DomainID: requestInfo.DomainID, + } + delegate := types.Delegate{ + Xrole: requestInfo.Xrole, + AppXrole: requestInfo.AppXrole, + } + vpcInfo.SubnetName = requestInfo.SubnetName + vpcInfo.VpcID = requestInfo.VpcID + vpcInfo.SubnetID = requestInfo.SubnetID + vpcInfo.Gateway = requestInfo.Gateway + vpcInfo.ID = requestInfo.ID + vpcInfo.TenantCidr = requestInfo.TenantCidr + vpcInfo.HostVMCidr = requestInfo.HostVMCidr + vpcInfo.VpcName = requestInfo.VpcName + request := types.VpcControllerRequester{ + Delegate: delegate, + Vpc: vpcInfo, + } + return request +} diff --git a/yuanrong/pkg/functionmanager/vpcmanager/plugin_test.go b/yuanrong/pkg/functionmanager/vpcmanager/plugin_test.go new file mode 100644 index 0000000..d158933 --- /dev/null +++ b/yuanrong/pkg/functionmanager/vpcmanager/plugin_test.go @@ -0,0 +1,330 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package vpcmanager + +import ( + "encoding/json" + "errors" + "plugin" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + mockUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionmanager/types" +) + +func TestPluginVPC_CreateVpcResource(t *testing.T) { + type args struct { + requestData []byte + createResource interface{} + } + tests := []struct { + name string + args args + want types.NATConfigure + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 Marshal error", + args{ + createResource: func(request []byte) ([]byte, error) { + return []byte{}, nil + }, + }, + types.NATConfigure{}, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(describeRequest, + func(requestData []byte) types.VpcControllerRequester { + return types.VpcControllerRequester{} + }), + gomonkey.ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, errors.New("failed to marshel json") + }), + }) + return patches + }}, + + {"case2 createResource error", + args{ + createResource: func(request []byte) ([]byte, error) { + return nil, errors.New("test") + }, + }, + types.NATConfigure{}, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(describeRequest, + func(requestData []byte) types.VpcControllerRequester { + return types.VpcControllerRequester{} + }), + }) + return patches + }}, + + {"case3 UnMarshal error", + args{ + createResource: func(request []byte) ([]byte, error) { + return []byte{}, nil + }, + }, + types.NATConfigure{}, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(describeRequest, + func(requestData []byte) types.VpcControllerRequester { + return types.VpcControllerRequester{} + }), + }) + return patches + }}, + + {"case4 succeed to createVpcResource", + args{ + createResource: func(request []byte) ([]byte, error) { + tmp, _ := json.Marshal(types.NATConfigure{}) + return tmp, nil + }, + }, + types.NATConfigure{}, false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + gomonkey.ApplyFunc(describeRequest, + func(requestData []byte) types.VpcControllerRequester { + return types.VpcControllerRequester{} + }), + }) + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pluginVPC := PluginVPC{} + pluginVPC.Plugin = &plugin.Plugin{} + pluginVPC.createResource = tt.args.createResource + patches := tt.patchesFunc() + got, err := pluginVPC.CreateVpcResource(tt.args.requestData) + if (err != nil) != tt.wantErr { + t.Errorf("createVpcResource() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("createVpcResource() got = %v, want %v", got, tt.want) + } + patches.ResetAll() + }) + } +} + +func TestPluginVPC_DeleteVpcResource(t *testing.T) { + type args struct { + patPodName string + deleteResource interface{} + } + tests := []struct { + name string + args args + wantErr bool + patchesFunc mockUtils.PatchesFunc + }{ + {"case1 deleteResource error", + args{ + deleteResource: func(request string) error { + return errors.New("test") + }, + }, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + return patches + }}, + + {"case2 check deleteResource error", + args{ + deleteResource: func(request []byte) error { + return errors.New("test") + }, + }, true, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + return patches + }}, + + {"case3 succeed to deleteResource", + args{ + deleteResource: func(request string) error { + return nil + }, + }, false, func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + return patches + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pluginVPC := PluginVPC{} + pluginVPC.Plugin = &plugin.Plugin{} + pluginVPC.deleteResource = tt.args.deleteResource + patches := tt.patchesFunc() + err := pluginVPC.DeleteVpcResource(tt.args.patPodName) + if (err != nil) != tt.wantErr { + t.Errorf("createVpcResource() error = %v, wantErr %v", err, tt.wantErr) + return + } + patches.ResetAll() + }) + } +} + +var ( + requestInfo = `{ + "id": "id", + "domain_id": "domain_id", + "namespace": "namespace", + "vpc_name": "vpc_name", + "vpc_id": "vpc_id", + "subnet_name": "subnet_name", + "subnet_id": "subnet_id", + "tenant_cidr": "tenant_cidr", + "host_vm_cidr": "host_vm_cidr", + "gateway": "gateway", + "xrole": "xrole", + "app_xrole": "app_xrole" + }` +) + +func Test_describeRequest(t *testing.T) { + type args struct { + requestData []byte + } + var a args + a.requestData = []byte(requestInfo) + vpcInfo := types.Vpc{ + ID: "id", + DomainID: "domain_id", + Namespace: "namespace", + VpcName: "vpc_name", + VpcID: "vpc_id", + SubnetName: "subnet_name", + SubnetID: "subnet_id", + TenantCidr: "tenant_cidr", + HostVMCidr: "host_vm_cidr", + Gateway: "gateway", + } + delegate := types.Delegate{ + Xrole: "xrole", + AppXrole: "app_xrole", + } + request := types.VpcControllerRequester{ + TraceID: "", + Delegate: delegate, + Vpc: vpcInfo, + } + tests := []struct { + name string + args args + want types.VpcControllerRequester + }{ + {"case1", a, request}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := describeRequest(tt.args.requestData); !reflect.DeepEqual(got, tt.want) { + t.Errorf("describeRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestInitController(t *testing.T) { + p := &PluginVPC{} + convey.Convey(" patKeyList is nil", t, func() { + err := p.InitVpcPlugin() + convey.So(err, convey.ShouldResemble, errors.New("pluginVpc is nil")) + }) + defer gomonkey.ApplyMethod(reflect.TypeOf(&plugin.Plugin{}), "Lookup", + func(_ *plugin.Plugin, symName string) (plugin.Symbol, error) { + if symName == "InitController" { + return func() error { + return nil + }, nil + } + return func() {}, nil + }).Reset() + + p.Plugin = &plugin.Plugin{} + + convey.Convey(" patKeyList is nil", t, func() { + err := p.InitVpcPlugin() + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestPluginVPC_initController(t *testing.T) { + convey.Convey("initController", t, func() { + p := &PluginVPC{Plugin: &plugin.Plugin{}} + convey.Convey("failed to look up InitController", func() { + err := p.initController() + convey.So(err, convey.ShouldBeError) + }) + convey.Convey("failed to init Controller", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&plugin.Plugin{}), "Lookup", + func(_ *plugin.Plugin, symName string) (plugin.Symbol, error) { + f := func() error { + return errors.New("failed to init Controller") + } + return f, nil + }).Reset() + err := p.initController() + convey.So(err, convey.ShouldBeError) + }) + + convey.Convey("failed to lookup function of CreateResource", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&plugin.Plugin{}), "Lookup", + func(_ *plugin.Plugin, symName string) (plugin.Symbol, error) { + if symName == "InitController" { + f := func() error { + return nil + } + return f, nil + } + return nil, errors.New("not found CreateResource") + }).Reset() + err := p.initController() + convey.So(err, convey.ShouldBeError) + }) + + convey.Convey("failed to lookup function of DeleteResource", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&plugin.Plugin{}), "Lookup", + func(_ *plugin.Plugin, symName string) (plugin.Symbol, error) { + if symName == "InitController" { + f := func() error { + return nil + } + return f, nil + } + if symName == "CreateResource" { + f := func() {} + return f, nil + } + return nil, errors.New("not found DeleteResource") + }).Reset() + err := p.initController() + convey.So(err, convey.ShouldBeError) + }) + }) + +} diff --git a/yuanrong/pkg/functionmanager/vpcmanager/pulltrigger.go b/yuanrong/pkg/functionmanager/vpcmanager/pulltrigger.go new file mode 100644 index 0000000..7c65597 --- /dev/null +++ b/yuanrong/pkg/functionmanager/vpcmanager/pulltrigger.go @@ -0,0 +1,400 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package vpcmanager - +package vpcmanager + +import ( + "fmt" + "os" + "strconv" + "strings" + + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionmanager/types" +) + +// CreateVpcTriggerInstruct define create VpcTrigger instruct +type CreateVpcTriggerInstruct struct { + PulltriggerPort string `json:"pulltrigger_port"` + PodName string `json:"pod_name"` + MemoryOption string `json:"memory_option"` + CPUOption string `json:"cpu_option"` + Image string `json:"image"` + DomainID string `json:"domain_id"` + Namespace string `json:"namespace"` + VpcName string `json:"vpc_name"` + VpcID string `json:"vpc_id"` + SubnetName string `json:"subnet_name"` + SubnetID string `json:"subnet_id"` + TenantCidr string `json:"tenant_cidr"` + HostVMCidr string `json:"host_vm_cidr"` + ContainerCidr string `json:"container_cidr"` + Gateway string `json:"gateway"` + AppXrole string `json:"app_xrole,omitempty"` + Xrole string `json:"xrole,omitempty"` +} + +const ( + base10 = 10 + bitSize = 64 + defaultPullTriggerCPU = 500 + defaultPullTriggerMem = 500 + defaultSecretMode = 0400 + pullTriggerHealthCheckPort = 28917 + pullTriggerFailureThreshold = 3 + pullTriggerInitialDelaySeconds = 15 + pullTriggerPeriodSeconds = 30 + pullTriggerTimeoutSeconds = 3 + labelKey = "cff-type" + labelValue = "serverless" +) + +var ( + omsvcAddress = os.Getenv("SERVERLESS_OMSVC_ADDRESS") + enableDataPlaneRdispatcher = os.Getenv("ENABLE_DATA_PLANE_RDISPATCHER") + regionID = os.Getenv("REGION_ID") + tenantTriggerAddress = os.Getenv("TENANT_TRIGGER_ADDRESS") +) + +// MakePullTriggerDeploy return vpc_pullTrigger deployment +func MakePullTriggerDeploy(triggerInfo *CreateVpcTriggerInstruct, vpcNatConfByte []byte) *appsv1.Deployment { + labels := map[string]string{} + labels[labelKey] = labelValue + triggerReplicas := int32(1) + terminationGracePeriodSeconds := int64(corev1.DefaultTerminationGracePeriodSeconds) + return &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: triggerInfo.PodName, + Labels: labels, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: &triggerReplicas, + Selector: &metav1.LabelSelector{ + MatchLabels: labels, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Name: triggerInfo.PodName, + Labels: labels, + }, + Spec: corev1.PodSpec{ + ImagePullSecrets: genImagePullSecrets(), + RestartPolicy: "Always", + Volumes: describeVolumes(), + // Not enabled temporarily + // InitContainers: describeInitContainers(triggerInfo), + Containers: describeContainers(triggerInfo, vpcNatConfByte), + TerminationGracePeriodSeconds: &terminationGracePeriodSeconds, + }, + }, + }, + } +} + +// ParseFunctionMeta parse PullTriggerRequestInfo to CreateVpcTriggerInstruct +func ParseFunctionMeta(requestInfo types.PullTriggerRequestInfo) *CreateVpcTriggerInstruct { + return &CreateVpcTriggerInstruct{ + PulltriggerPort: "28937", + PodName: requestInfo.PodName, + Image: requestInfo.Image, + DomainID: requestInfo.DomainID, + Namespace: requestInfo.Namespace, + VpcName: requestInfo.VpcName, + VpcID: requestInfo.VpcID, + SubnetName: requestInfo.SubnetName, + SubnetID: requestInfo.SubnetID, + TenantCidr: requestInfo.TenantCidr, + HostVMCidr: requestInfo.HostVMCidr, + Gateway: requestInfo.Gateway, + AppXrole: requestInfo.AppXrole, + Xrole: requestInfo.Xrole, + } +} + +func genImagePullSecrets() []corev1.LocalObjectReference { + return []corev1.LocalObjectReference{ + { + Name: "default-secret", + }, + } +} + +func describeVolumes() []corev1.Volume { + return []corev1.Volume{ + { + Name: "localtime", + VolumeSource: corev1.VolumeSource{ + HostPath: &corev1.HostPathVolumeSource{ + Path: "/etc/localtime", + }}, + }, + { + Name: "etc", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + { + Name: "vpcpulltrigger", + VolumeSource: corev1.VolumeSource{ + HostPath: &corev1.HostPathVolumeSource{ + Path: "/var/paas/sys/log/cff/functiongraph/vpcpulltrigger", + }}, + }, + } +} + +func generateSecretVolumes() map[Type]corev1.Volume { + volumeMap := make(map[Type]corev1.Volume) + for _, volumeType := range getTypes() { + volume := corev1.Volume{ + Name: volumeType.name(), + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: volumeType.secretName(), + Items: volumeType.keyToPath(), + }, + }, + } + volumeMap[volumeType] = volume + } + return volumeMap +} + +func describeContainers(triggerInfo *CreateVpcTriggerInstruct, vpcNatConfByte []byte) []corev1.Container { + cpu, err := strconv.ParseInt(triggerInfo.CPUOption, base10, bitSize) + if err != nil { + log.GetLogger().Warnf("cpu option convert to int failed with error %s", err.Error()) + cpu = defaultPullTriggerCPU + } + mem, err := strconv.ParseInt(triggerInfo.MemoryOption, base10, bitSize) + if err != nil { + log.GetLogger().Warnf("memory option convert to int failed with error %s", err.Error()) + mem = defaultPullTriggerMem + } + ContainerSpec := corev1.Container{ + Name: "vpcpulltrigger", + Image: triggerInfo.Image, + ImagePullPolicy: "IfNotPresent", + Env: buildEnv("ccevpc", triggerInfo, vpcNatConfByte), + SecurityContext: getSecurityContext(), + Lifecycle: getLifecycle(), + // Not enabled temporarily + // LivenessProbe: getLivenessProbe(), + Resources: getResources(cpu, mem), + VolumeMounts: getVolumeMount(), + Command: []string{"tail"}, + Args: []string{"-f", "/dev/null"}, + } + return []corev1.Container{ + ContainerSpec, + } +} + +func describeInitContainers(triggerInfo *CreateVpcTriggerInstruct) []corev1.Container { + cpu, err := strconv.ParseInt(triggerInfo.CPUOption, base10, bitSize) + if err != nil { + log.GetLogger().Warnf("cpu option convert to int failed with error %s", err.Error()) + cpu = defaultPullTriggerCPU + } + mem, err := strconv.ParseInt(triggerInfo.MemoryOption, base10, bitSize) + if err != nil { + log.GetLogger().Warnf("memory option convert to int failed with error %s", err.Error()) + mem = defaultPullTriggerMem + } + initContainerSpec := corev1.Container{ + Args: []string{"cp -r /opt/CFF/etc /tmp/CFF/;find /tmp/CFF/ -type f |xargs chmod 600;"}, + Command: []string{"/bin/sh", "-c"}, + Name: "cff-vpcpulltrigger-bootstrap", + Image: triggerInfo.Image, + ImagePullPolicy: "IfNotPresent", + SecurityContext: getInitSecurityContext(), + Resources: getResources(cpu, mem), + VolumeMounts: getInitVolumeMount(), + TerminationMessagePath: "/dev/termination-log", + TerminationMessagePolicy: corev1.TerminationMessageReadFile, + } + return []corev1.Container{ + initContainerSpec, + } +} + +func buildEnv(podType string, triggerInfo *CreateVpcTriggerInstruct, vpcNatConfByte []byte) []corev1.EnvVar { + Envs := make([]corev1.EnvVar, 0) + + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_SSL_MODE", Value: "0"}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_SSL_VERIFY_CLIENT", Value: "1"}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_HOST_IP", ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "status.hostIP"}}}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_POD_NAME", ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "metadata.name"}}}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_POD_IP", ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "status.podIP"}}}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_NAMESPACE", Value: triggerInfo.Namespace}) + Envs = append(Envs, corev1.EnvVar{Name: "POD_TYPE", Value: podType}) + Envs = append(Envs, corev1.EnvVar{Name: "REGION_ID", Value: regionID}) + Envs = append(Envs, corev1.EnvVar{Name: "TENANTLB_ADDRESS", Value: tenantTriggerAddress}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_WEBSOCKET_HOST", Value: fmt.Sprintf("%s:%s", + tenantTriggerAddress, triggerInfo.PulltriggerPort)}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_SUBNETID", Value: triggerInfo.SubnetID}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_VPCNATCONF", Value: string(vpcNatConfByte)}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_IMAGE_VERSION", Value: triggerInfo.Image}) + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_CLUSTER_ID", Value: "clusterId"}) + index := strings.LastIndex(omsvcAddress, ":") + if index != -1 { + Envs = append(Envs, corev1.EnvVar{Name: "SERVERLESS_HOST", Value: omsvcAddress[:]}) + } + Envs = append(Envs, corev1.EnvVar{Name: "ENABLE_DATA_PLANE_RDISPATCHER", Value: enableDataPlaneRdispatcher}) + return Envs +} + +func getSecurityContext() *corev1.SecurityContext { + return &corev1.SecurityContext{ + Capabilities: &corev1.Capabilities{ + Add: []corev1.Capability{ + "NET_ADMIN", + "NET_RAW", + "DAC_OVERRIDE", + "SETGID", + "SETUID", + "CHOWN", + "FOWNER", + "KILL", + }, + Drop: []corev1.Capability{ + "ALL", + }, + }, + } +} + +func getInitSecurityContext() *corev1.SecurityContext { + user := int64(0) + return &corev1.SecurityContext{ + RunAsUser: &user, + } +} + +func getLifecycle() *corev1.Lifecycle { + return &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + Exec: &corev1.ExecAction{ + Command: []string{"sleep", "3"}, + }, + }, + } +} + +func getLivenessProbe() *corev1.Probe { + return &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: "/v1.0/healthcheck", + Port: intstr.IntOrString{ + IntVal: pullTriggerHealthCheckPort, + }, + Scheme: "HTTPS", + }, + }, + FailureThreshold: pullTriggerFailureThreshold, + InitialDelaySeconds: pullTriggerInitialDelaySeconds, + PeriodSeconds: pullTriggerPeriodSeconds, + TimeoutSeconds: pullTriggerTimeoutSeconds, + } +} + +func getResources(cpu, mem int64) corev1.ResourceRequirements { + if cpu == 0 || mem == 0 { + cpu = defaultPullTriggerCPU + mem = defaultPullTriggerMem + } + resourceReq := corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "cpu": resource.MustParse(fmt.Sprintf("%dm", cpu)), + "memory": resource.MustParse(fmt.Sprintf("%dMi", mem)), + }, + Requests: corev1.ResourceList{ + "cpu": resource.MustParse(fmt.Sprintf("%dm", 0)), + "memory": resource.MustParse(fmt.Sprintf("%dMi", 0)), + }, + } + return resourceReq +} + +func getVolumeMount() []corev1.VolumeMount { + return []corev1.VolumeMount{ + { + Name: "etc", + MountPath: "/opt/CFF/etc", + }, + { + Name: "localtime", + MountPath: "/etc/localtime", + ReadOnly: true, + }, + { + Name: "vpcpulltrigger", + MountPath: "/var/log/CFF/vpcpulltrigger", + }, + } +} + +func getInitVolumeMount() []corev1.VolumeMount { + return []corev1.VolumeMount{ + { + Name: "etc", + MountPath: "/tmp/CFF/etc", + }, + { + Name: "cipher", + MountPath: "/opt/CFF/etc/cipher", + }, + { + Name: "ssl", + MountPath: "/opt/CFF/etc/ssl", + }, + { + Name: "kafkacert", + MountPath: "/opt/CFF/etc/kafka", + }, + { + Name: "auth", + MountPath: "/opt/CFF/etc/auth", + }, + { + Name: "etcd", + MountPath: "/opt/CFF/etc/etcd", + }, + { + Name: "vpcpulltrigger", + MountPath: "/var/log/CFF/vpcpulltrigger", + }, + { + Name: "redisdb", + MountPath: "/opt/CFF/etc/redisdb", + }, + } +} diff --git a/yuanrong/pkg/functionmanager/vpcmanager/pulltrigger_test.go b/yuanrong/pkg/functionmanager/vpcmanager/pulltrigger_test.go new file mode 100644 index 0000000..54cf16b --- /dev/null +++ b/yuanrong/pkg/functionmanager/vpcmanager/pulltrigger_test.go @@ -0,0 +1,732 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package vpcmanager - +package vpcmanager + +import ( + "fmt" + "os" + "reflect" + "strings" + "testing" + + "github.com/smartystreets/goconvey/convey" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/intstr" +) + +func Test_buildEnv(t *testing.T) { + type args struct { + options *CreateVpcTriggerInstruct + } + var a args + podType := "podType" + a.options = &CreateVpcTriggerInstruct{ + Image: "image", + SubnetID: "subnetId", + Namespace: "default", + PulltriggerPort: "28937", + } + var vpcNatConfByte []byte = nil + Envs := make([]corev1.EnvVar, 0) + + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_SSL_MODE", Value: "0"}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_SSL_VERIFY_CLIENT", Value: "1"}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_HOST_IP", ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "status.hostIP"}}}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_POD_NAME", ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "metadata.name"}}}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_POD_IP", ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "status.podIP"}}}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_NAMESPACE", Value: "default"}) + Envs = append(Envs, corev1.EnvVar{ + Name: "POD_TYPE", Value: "podType"}) + Envs = append(Envs, corev1.EnvVar{ + Name: "REGION_ID", Value: os.Getenv("REGION_ID")}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_SK", ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: "cff-runtime-secret"}, Key: "serverless_skey"}}}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_AK", ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: "cff-runtime-secret"}, Key: "serverless_akey"}}}) + Envs = append(Envs, corev1.EnvVar{ + Name: "TENANTLB_ADDRESS", Value: os.Getenv("TENANT_TRIGGER_ADDRESS")}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_WEBSOCKET_HOST", Value: fmt.Sprintf("%s:%s", os.Getenv("TENANT_TRIGGER_ADDRESS"), "28937")}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_SUBNETID", Value: "subnetId"}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_VPCNATCONF", Value: string([]byte{})}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_IMAGE_VERSION", Value: "image"}) + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_CLUSTER_ID", Value: "clusterId"}) + omsvcAddress := os.Getenv("SERVERLESS_OMSVC_ADDRESS") + index := strings.LastIndex(omsvcAddress, ":") + if index != -1 { + Envs = append(Envs, corev1.EnvVar{ + Name: "SERVERLESS_HOST", + Value: omsvcAddress[:index], + }) + } + Envs = append(Envs, corev1.EnvVar{ + Name: "ENABLE_DATA_PLANE_RDISPATCHER", Value: os.Getenv("ENABLE_DATA_PLANE_RDISPATCHER")}) + tests := []struct { + name string + args args + want []corev1.EnvVar + }{ + {"case1", a, Envs}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := buildEnv(podType, tt.args.options, vpcNatConfByte); !reflect.DeepEqual(got[0], tt.want[0]) { + t.Errorf("buildEnv() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_describeVolumes(t *testing.T) { + secretMode := int32(defaultSecretMode) + cipherVolume := corev1.Volume{ + Name: "cipher", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-cipher-ssl-secret", + Items: []corev1.KeyToPath{ + { + Key: "root.key", + Path: "root.key", + Mode: &secretMode, + }, + { + Key: "common_shared.key", + Path: "common_shared.key", + Mode: &secretMode, + }, + }, + }, + }, + } + sslVolume := corev1.Volume{ + Name: "ssl", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-cipher-ssl-secret", + Items: []corev1.KeyToPath{ + { + Key: "ca.crt", + Path: "trust.cer", + Mode: &secretMode, + }, + { + Key: "tls.crt", + Path: "server.cer", + Mode: &secretMode, + }, + { + Key: "tls.key.pwd", + Path: "server_key.pem", + Mode: &secretMode, + }, + { + Key: "pwd", + Path: "cert_pwd", + Mode: &secretMode, + }, + }, + }, + }, + } + kafkaVolume := corev1.Volume{ + Name: "kafkacert", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-kafka-secret", + Items: []corev1.KeyToPath{ + { + Key: "tls.crt", + Path: "kafka.cer", + Mode: &secretMode, + }, + { + Key: "tls.key", + Path: "kafka.key", + Mode: &secretMode, + }, + { + Key: "tls.ca", + Path: "kafka.ca", + Mode: &secretMode, + }, + }, + }, + }, + } + authVolume := corev1.Volume{ + Name: "auth", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-common-secret", + Items: []corev1.KeyToPath{ + { + Key: "internal_sign_ak", + Path: "internal_sign_ak", + Mode: &secretMode, + }, + { + Key: "internal_sign_sk", + Path: "internal_sign_sk", + Mode: &secretMode, + }, + }, + }, + }, + } + redisVolume := corev1.Volume{ + Name: "redisdb", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-rdb-secret", + Items: []corev1.KeyToPath{ + { + Key: "appversion", + Path: "appversion.json", + Mode: &secretMode, + }, + }, + }, + }, + } + etcdVolume := corev1.Volume{ + Name: "etcd", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-etcd-secret", + Items: []corev1.KeyToPath{ + { + Key: "appversion", + Path: "appversion.json", + Mode: &secretMode, + }, + }, + }, + }, + } + a := []corev1.Volume{ + { + Name: "localtime", + VolumeSource: corev1.VolumeSource{ + HostPath: &corev1.HostPathVolumeSource{ + Path: "/etc/localtime", + }}, + }, + cipherVolume, + sslVolume, + kafkaVolume, + authVolume, + etcdVolume, + redisVolume, + { + Name: "etc", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + { + Name: "vpcpulltrigger", + VolumeSource: corev1.VolumeSource{ + HostPath: &corev1.HostPathVolumeSource{ + Path: "/var/paas/sys/log/cff/functiongraph/vpcpulltrigger", + }}, + }, + } + tests := []struct { + name string + want []corev1.Volume + }{ + {"case1", a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := describeVolumes(); !reflect.DeepEqual(got[0], tt.want[0]) { + t.Errorf("describeVolumes() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_genImagePullSecrets(t *testing.T) { + a := []corev1.LocalObjectReference{ + { + Name: "default-secret", + }, + } + tests := []struct { + name string + want []corev1.LocalObjectReference + }{ + {"case1", a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := genImagePullSecrets(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("genImagePullSecrets() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_generateSecretVolumes(t *testing.T) { + secretMode := int32(defaultSecretMode) + cipherVolume := corev1.Volume{ + Name: "cipher", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-cipher-ssl-secret", + Items: []corev1.KeyToPath{ + { + Key: "root.key", + Path: "root.key", + Mode: &secretMode, + }, + { + Key: "common_shared.key", + Path: "common_shared.key", + Mode: &secretMode, + }, + }, + }, + }, + } + sslVolume := corev1.Volume{ + Name: "ssl", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-cipher-ssl-secret", + Items: []corev1.KeyToPath{ + { + Key: "ca.crt", + Path: "trust.cer", + Mode: &secretMode, + }, + { + Key: "tls.crt", + Path: "server.cer", + Mode: &secretMode, + }, + { + Key: "tls.key.pwd", + Path: "server_key.pem", + Mode: &secretMode, + }, + { + Key: "pwd", + Path: "cert_pwd", + Mode: &secretMode, + }, + }, + }, + }, + } + kafkaVolume := corev1.Volume{ + Name: "kafkacert", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-kafka-secret", + Items: []corev1.KeyToPath{ + { + Key: "tls.crt", + Path: "kafka.cer", + Mode: &secretMode, + }, + { + Key: "tls.key", + Path: "kafka.key", + Mode: &secretMode, + }, + { + Key: "tls.ca", + Path: "kafka.ca", + Mode: &secretMode, + }, + }, + }, + }, + } + authVolume := corev1.Volume{ + Name: "auth", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-common-secret", + Items: []corev1.KeyToPath{ + { + Key: "internal_sign_ak", + Path: "internal_sign_ak", + Mode: &secretMode, + }, + { + Key: "internal_sign_sk", + Path: "internal_sign_sk", + Mode: &secretMode, + }, + }, + }, + }, + } + etcdVolume := corev1.Volume{ + Name: "etcd", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-etcd-secret", + Items: []corev1.KeyToPath{ + { + Key: "appversion", + Path: "appversion.json", + Mode: &secretMode, + }, + }, + }, + }, + } + redisVolume := corev1.Volume{ + Name: "redisdb", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "cff-rdb-secret", + Items: []corev1.KeyToPath{ + { + Key: "appversion", + Path: "appversion.json", + Mode: &secretMode, + }, + }, + }, + }, + } + tests := []struct { + name string + want corev1.Volume + want1 corev1.Volume + want2 corev1.Volume + want3 corev1.Volume + want4 corev1.Volume + want5 corev1.Volume + }{ + {"case1", cipherVolume, sslVolume, kafkaVolume, authVolume, etcdVolume, redisVolume}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + volumeMap := generateSecretVolumes() + got := volumeMap[cipherVolumeType] + got1 := volumeMap[sslVolumeType] + got2 := volumeMap[kafkaVolumeType] + got3 := volumeMap[authVolumeType] + got4 := volumeMap[etcdVolumeType] + got5 := volumeMap[redisVolumeType] + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("generateSecretVolumes() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(got1, tt.want1) { + t.Errorf("generateSecretVolumes() got1 = %v, want %v", got1, tt.want1) + } + if !reflect.DeepEqual(got2, tt.want2) { + t.Errorf("generateSecretVolumes() got2 = %v, want %v", got2, tt.want2) + } + if !reflect.DeepEqual(got3, tt.want3) { + t.Errorf("generateSecretVolumes() got3 = %v, want %v", got3, tt.want3) + } + if !reflect.DeepEqual(got4, tt.want4) { + t.Errorf("generateSecretVolumes() got4 = %v, want %v", got4, tt.want4) + } + if !reflect.DeepEqual(got5, tt.want5) { + t.Errorf("generateSecretVolumes() got4 = %v, want %v", got5, tt.want5) + } + }) + } +} + +func Test_getLifecycle(t *testing.T) { + a := &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + Exec: &corev1.ExecAction{ + Command: []string{"sleep", "3"}, + }, + }, + } + tests := []struct { + name string + want *corev1.Lifecycle + }{ + {"case1", a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getLifecycle(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getLifecycle() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getLivenessProbe(t *testing.T) { + a := &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: "/v1.0/healthcheck", + Port: intstr.IntOrString{ + IntVal: pullTriggerHealthCheckPort, + }, + Scheme: "HTTPS", + }, + }, + FailureThreshold: pullTriggerFailureThreshold, + InitialDelaySeconds: pullTriggerInitialDelaySeconds, + PeriodSeconds: pullTriggerPeriodSeconds, + TimeoutSeconds: pullTriggerTimeoutSeconds, + } + tests := []struct { + name string + want *corev1.Probe + }{ + {"case1", a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getLivenessProbe(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getLivenessProbe() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getResources(t *testing.T) { + type args struct { + cpu int64 + mem int64 + } + var a args + a.mem = 10 + a.cpu = 10 + b := corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "cpu": resource.MustParse(fmt.Sprintf("%dm", 10)), + "memory": resource.MustParse(fmt.Sprintf("%dMi", 10)), + }, + Requests: corev1.ResourceList{ + "cpu": resource.MustParse(fmt.Sprintf("%dm", 0)), + "memory": resource.MustParse(fmt.Sprintf("%dMi", 0)), + }, + } + var c args + c.mem = 0 + c.mem = 0 + d := corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "cpu": resource.MustParse(fmt.Sprintf("%dm", 500)), + "memory": resource.MustParse(fmt.Sprintf("%dMi", 500)), + }, + Requests: corev1.ResourceList{ + "cpu": resource.MustParse(fmt.Sprintf("%dm", 0)), + "memory": resource.MustParse(fmt.Sprintf("%dMi", 0)), + }, + } + tests := []struct { + name string + args args + want corev1.ResourceRequirements + }{ + {"case1", a, b}, + {"case2", c, d}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getResources(tt.args.cpu, tt.args.mem); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getResources() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getSecurityContext(t *testing.T) { + a := &corev1.SecurityContext{ + Capabilities: &corev1.Capabilities{ + Add: []corev1.Capability{ + "NET_ADMIN", + "NET_RAW", + "DAC_OVERRIDE", + "SETGID", + "SETUID", + "CHOWN", + "FOWNER", + "KILL", + }, + Drop: []corev1.Capability{ + "ALL", + }, + }, + } + tests := []struct { + name string + want *corev1.SecurityContext + }{ + {"case1", a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getSecurityContext(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getSecurityContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getVolumeMount(t *testing.T) { + a := []corev1.VolumeMount{ + { + Name: "etc", + MountPath: "/opt/CFF/etc", + }, + { + Name: "localtime", + MountPath: "/etc/localtime", + ReadOnly: true, + }, + { + Name: "vpcpulltrigger", + MountPath: "/var/log/CFF/vpcpulltrigger", + }, + } + tests := []struct { + name string + want []corev1.VolumeMount + }{ + {"case1", a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getVolumeMount(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getVolumeMount() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMakePullTriggerDeploy(t *testing.T) { + convey.Convey("MakePullTriggerDeploy", t, func() { + triggerInfo := &CreateVpcTriggerInstruct{ + PodName: "podName", + } + vpcNatConfByte := []byte("test") + deploy := MakePullTriggerDeploy(triggerInfo, vpcNatConfByte) + convey.So(deploy.ObjectMeta.Name, convey.ShouldEqual, "podName") + }) +} + +func Test_getInitSecurityContext(t *testing.T) { + user := int64(0) + a := &corev1.SecurityContext{ + RunAsUser: &user, + } + tests := []struct { + name string + want *corev1.SecurityContext + }{ + {"case1", a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getInitSecurityContext(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getInitSecurityContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getInitVolumeMount(t *testing.T) { + a := []corev1.VolumeMount{ + { + Name: "etc", + MountPath: "/tmp/CFF/etc", + }, + { + Name: "cipher", + MountPath: "/opt/CFF/etc/cipher", + }, + { + Name: "ssl", + MountPath: "/opt/CFF/etc/ssl", + }, + { + Name: "kafkacert", + MountPath: "/opt/CFF/etc/kafka", + }, + { + Name: "auth", + MountPath: "/opt/CFF/etc/auth", + }, + { + Name: "etcd", + MountPath: "/opt/CFF/etc/etcd", + }, + { + Name: "vpcpulltrigger", + MountPath: "/var/log/CFF/vpcpulltrigger", + }, + { + Name: "redisdb", + MountPath: "/opt/CFF/etc/redisdb", + }, + } + tests := []struct { + name string + want []corev1.VolumeMount + }{ + {"case1", a}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getInitVolumeMount(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getInitVolumeMount() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_describeInitContainers(t *testing.T) { + convey.Convey("describeInitContainers-default", t, func() { + triggerInfo := &CreateVpcTriggerInstruct{CPUOption: "1.5", MemoryOption: "1.5"} + containers := describeInitContainers(triggerInfo) + convey.So(containers[0].Resources.Limits.Memory().String(), convey.ShouldEqual, "500Mi") + convey.So(containers[0].Resources.Limits.Cpu().String(), convey.ShouldEqual, "500m") + }) +} diff --git a/yuanrong/pkg/functionmanager/vpcmanager/types.go b/yuanrong/pkg/functionmanager/vpcmanager/types.go new file mode 100644 index 0000000..a3283c6 --- /dev/null +++ b/yuanrong/pkg/functionmanager/vpcmanager/types.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package vpcmanager - +package vpcmanager + +import "k8s.io/api/core/v1" + +// Type volume type +type Type int + +const ( + cipherVolumeType Type = iota + sslVolumeType + kafkaVolumeType + authVolumeType + etcdVolumeType + redisVolumeType +) + +func getTypes() []Type { + return []Type{cipherVolumeType, sslVolumeType, kafkaVolumeType, authVolumeType, etcdVolumeType, redisVolumeType} +} + +func (t Type) keyToPath() []v1.KeyToPath { + return [...][]v1.KeyToPath{cipherKeyToPath, sslKeyToPath, kafkaKeyToPath, authKeyToPath, etcdKeyToPath, + redisKeyToPath}[t] +} + +func (t Type) secretName() string { + return [...]string{cipherSecretName, sslSecretName, kafkaSecretName, authSecretName, etcdSecretName, + redisSecretName}[t] +} + +func (t Type) name() string { + return [...]string{cipherName, sslName, kafkaName, authName, etcdName, redisName}[t] +} diff --git a/yuanrong/pkg/functionmanager/vpcmanager/volumeconstant.go b/yuanrong/pkg/functionmanager/vpcmanager/volumeconstant.go new file mode 100644 index 0000000..289a9b2 --- /dev/null +++ b/yuanrong/pkg/functionmanager/vpcmanager/volumeconstant.go @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package vpcmanager - +package vpcmanager + +import "k8s.io/api/core/v1" + +const ( + cipherName = "cipher" + sslName = "ssl" + kafkaName = "kafkacert" + authName = "auth" + etcdName = "etcd" + redisName = "redisdb" +) + +const ( + cipherSecretName = "cff-cipher-ssl-secret" + sslSecretName = "cff-cipher-ssl-secret" + kafkaSecretName = "cff-kafka-secret" + authSecretName = "cff-common-secret" + etcdSecretName = "cff-etcd-secret" + redisSecretName = "cff-rdb-secret" +) + +var ( + secretMode = int32(defaultSecretMode) + cipherKeyToPath = []v1.KeyToPath{ + { + Key: "root.key", + Path: "root.key", + Mode: &secretMode, + }, + { + Key: "common_shared.key", + Path: "common_shared.key", + Mode: &secretMode, + }, + } + sslKeyToPath = []v1.KeyToPath{ + { + Key: "ca.crt", + Path: "trust.cer", + Mode: &secretMode, + }, + { + Key: "tls.crt", + Path: "server.cer", + Mode: &secretMode, + }, + { + Key: "tls.key.pwd", + Path: "server_key.pem", + Mode: &secretMode, + }, + { + Key: "pwd", + Path: "cert_pwd", + Mode: &secretMode, + }, + } + kafkaKeyToPath = []v1.KeyToPath{ + { + Key: "tls.crt", + Path: "kafka.cer", + Mode: &secretMode, + }, + { + Key: "tls.key", + Path: "kafka.key", + Mode: &secretMode, + }, + { + Key: "tls.ca", + Path: "kafka.ca", + Mode: &secretMode, + }, + } + authKeyToPath = []v1.KeyToPath{ + { + Key: "internal_sign_ak", + Path: "internal_sign_ak", + Mode: &secretMode, + }, + { + Key: "internal_sign_sk", + Path: "internal_sign_sk", + Mode: &secretMode, + }, + } + etcdKeyToPath = []v1.KeyToPath{ + { + Key: "appversion", + Path: "appversion.json", + Mode: &secretMode, + }, + } + redisKeyToPath = []v1.KeyToPath{ + { + Key: "appversion", + Path: "appversion.json", + Mode: &secretMode, + }, + } +) diff --git a/yuanrong/pkg/functionscaler/config/config.go b/yuanrong/pkg/functionscaler/config/config.go new file mode 100644 index 0000000..9e506a3 --- /dev/null +++ b/yuanrong/pkg/functionscaler/config/config.go @@ -0,0 +1,218 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config - +package config + +import ( + "encoding/json" + "fmt" + "os" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/sts" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/faas_common/wisecloudtool/serviceaccount" + "yuanrong/pkg/functionscaler/types" +) + +const ( + // MetaEtcdPwdKey - + MetaEtcdPwdKey = "metaEtcdPwd" + // DockerRootPathEnv - + DockerRootPathEnv = "DOCKER_ROOT_DIR" + defaultDockerRootPath = "/var/lib/docker" + defaultFaasschedulerSTScertPath = "/opt/certs/HMSClientCloudAccelerateService/" + + "HMSCaaSYuanRongWorkerManager/HMSCaaSYuanRongWorkerManager.ini" + defaultPredictGroupWindow = 15 * 60 * 1000 +) + +var ( + // GlobalConfig is the global config + GlobalConfig types.Configuration + configEnvKey = "SCHEDULER_CONFIG" + dockerRootPrefix = []byte("Docker Root Dir: ") +) + +// InitModuleConfig initializes config for module +func InitModuleConfig() error { + config, err := loadConfigFromEnv() + if err != nil { + return err + } + GlobalConfig = *config + log.GetLogger().Infof("succeed to init module config") + return nil +} + +func loadConfigFromEnv() (*types.Configuration, error) { + configJSON := os.Getenv(configEnvKey) + config := &types.Configuration{} + err := json.Unmarshal([]byte(configJSON), config) + if err != nil { + return nil, err + } + return config, nil +} + +// InitConfig will initialize global config +func InitConfig(configData []byte) error { + GlobalConfig = types.Configuration{} + err := json.Unmarshal(configData, &GlobalConfig) + if err != nil { + return err + } + return loadFunctionConfig(&GlobalConfig) +} + +func loadFunctionConfig(GlobalConfig *types.Configuration) error { + setETCDConfig(GlobalConfig) + decryptEnvMap, err := localauth.GetDecryptFromEnv() + if err != nil { + log.GetLogger().Errorf("get decrypt from env error: %v", err) + return err + } + setDecryptPwd(decryptEnvMap, GlobalConfig) + + if _, err = govalidator.ValidateStruct(GlobalConfig); err != nil { + return err + } + err = setAlarmEnv() + if err != nil { + return err + } + + if GlobalConfig.DockerRootPath != "" { + if err = os.Setenv(DockerRootPathEnv, GlobalConfig.DockerRootPath); err != nil { + log.GetLogger().Warnf("cannot set env DOCKER_ROOT_DIR") + } + } else { + if err = os.Setenv(DockerRootPathEnv, defaultDockerRootPath); err != nil { + log.GetLogger().Warnf("cannot set default env DOCKER_ROOT_DIR") + } + } + if GlobalConfig.RawStsConfig.StsEnable { + if err := sts.InitStsSDK(GlobalConfig.RawStsConfig.ServerConfig); err != nil { + log.GetLogger().Errorf("failed to init sts sdk, err: %s", err.Error()) + return err + } + if err = os.Setenv(sts.EnvSTSEnable, "true"); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", sts.EnvSTSEnable, err.Error()) + return err + } + } + if len(GlobalConfig.Scenario) == 0 { + GlobalConfig.Scenario = types.ScenarioWiseCloud + } + if GlobalConfig.PredictGroupWindow == 0 { + GlobalConfig.PredictGroupWindow = defaultPredictGroupWindow + } + if GlobalConfig.SccConfig.Enable && crypto.InitializeSCC(GlobalConfig.SccConfig) != nil { + return fmt.Errorf("failed to initialize scc") + } + err = setServiceAccountJwt(GlobalConfig) + if err != nil { + return fmt.Errorf("failed to set serviceaccount jwt config %s", err.Error()) + } + return nil +} + +func setServiceAccountJwt(cfg *types.Configuration) error { + if cfg.RawStsConfig.StsEnable && len(cfg.ServiceAccountJwt.ServiceAccountKeyStr) > 0 { + var err error + cfg.ServiceAccountJwt.ServiceAccount, err = + serviceaccount.ParseServiceAccount(cfg.ServiceAccountJwt.ServiceAccountKeyStr) + if err != nil { + return err + } + } + if cfg.ServiceAccountJwt.TlsConfig != nil && + len(cfg.ServiceAccountJwt.TlsConfig.TlsCipherSuitesStr) > 0 { + var err error + cfg.ServiceAccountJwt.TlsConfig.TlsCipherSuites, err = + serviceaccount.ParseTlsCipherSuites(cfg.ServiceAccountJwt.TlsConfig.TlsCipherSuitesStr) + if err != nil { + return err + } + } + return nil +} + +func setETCDConfig(GlobalConfig *types.Configuration) { + if GlobalConfig == nil { + return + } + if GlobalConfig.RouterETCDConfig.UseSecret { + etcd3.SetETCDTLSConfig(&GlobalConfig.RouterETCDConfig) + } + if GlobalConfig.MetaETCDConfig.UseSecret { + etcd3.SetETCDTLSConfig(&GlobalConfig.MetaETCDConfig) + } +} + +func setDecryptPwd(decryptEnvMap map[string]string, config *types.Configuration) { + _, ok := decryptEnvMap[MetaEtcdPwdKey] + if !ok { + return + } + if decryptEnvMap[MetaEtcdPwdKey] != "" { + config.MetaETCDConfig.Password = decryptEnvMap[MetaEtcdPwdKey] + decryptEnvMap[MetaEtcdPwdKey] = "" + } +} + +// InitEtcd - init router etcd and meta etcd +func InitEtcd(stopCh <-chan struct{}) error { + if &GlobalConfig == nil { + return fmt.Errorf("config is not initialized") + } + if err := etcd3.InitRouterEtcdClient(GlobalConfig.RouterETCDConfig, GlobalConfig.AlarmConfig, stopCh); err != nil { + return fmt.Errorf("faaSScheduler failed to init route etcd: %s", err.Error()) + } + + if err := etcd3.InitMetaEtcdClient(GlobalConfig.MetaETCDConfig, GlobalConfig.AlarmConfig, stopCh); err != nil { + return fmt.Errorf("faaSScheduler failed to init metadata etcd: %s", err.Error()) + } + return nil +} + +// ClearSensitiveInfo - +func ClearSensitiveInfo() { + if &GlobalConfig == nil { + return + } + utils.ClearStringMemory(GlobalConfig.MetaETCDConfig.Password) + utils.ClearStringMemory(GlobalConfig.RouterETCDConfig.Password) +} + +func setAlarmEnv() error { + if &GlobalConfig == nil || !GlobalConfig.AlarmConfig.EnableAlarm { + log.GetLogger().Infof("enable alarm is false") + return nil + } + utils.SetClusterNameEnv(GlobalConfig.ClusterName) + alarm.SetAlarmEnv(GlobalConfig.AlarmConfig.AlarmLogConfig) + alarm.SetXiangYunFourConfigEnv(GlobalConfig.AlarmConfig.XiangYunFourConfig) + err := alarm.SetPodIP() + if err != nil { + return err + } + return nil +} diff --git a/yuanrong/pkg/functionscaler/config/config_test.go b/yuanrong/pkg/functionscaler/config/config_test.go new file mode 100644 index 0000000..2ae79a7 --- /dev/null +++ b/yuanrong/pkg/functionscaler/config/config_test.go @@ -0,0 +1,250 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config - +package config + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/sts" + "yuanrong/pkg/common/faas_common/sts/raw" + "yuanrong/pkg/functionscaler/types" +) + +func TestInitConfig(t *testing.T) { + cfg := &types.Configuration{ + CPU: 999, + Memory: 999, + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 1, + ScaleDownTime: 1, + BurstScaleNum: 1, + }, + LeaseSpan: 1, + RouterETCDConfig: etcd3.EtcdConfig{ + Password: "321", + }, + MetaETCDConfig: etcd3.EtcdConfig{ + Password: "123", + }, + SchedulerNum: 1, + DockerRootPath: "dockerRootPath", + RawStsConfig: raw.StsConfig{ + StsEnable: true, + SensitiveConfigs: raw.SensitiveConfigs{}, + ServerConfig: raw.ServerConfig{}, + MgmtServerConfig: raw.MgmtServerConfig{}, + }, + } + p1 := gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg raw.ServerConfig) error { + return nil + }) + cfgByte, _ := json.Marshal(cfg) + convey.Convey("success", t, func() { + err := InitConfig(cfgByte) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("DockerRootPath is empty", t, func() { + cfg := &types.Configuration{ + CPU: 999, + Memory: 999, + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 1, + ScaleDownTime: 1, + BurstScaleNum: 1, + }, + LeaseSpan: 1, + RouterETCDConfig: etcd3.EtcdConfig{ + Password: "321", + }, + MetaETCDConfig: etcd3.EtcdConfig{ + Password: "123", + }, + SchedulerNum: 1, + DockerRootPath: "", + RawStsConfig: raw.StsConfig{ + StsEnable: true, + SensitiveConfigs: raw.SensitiveConfigs{}, + ServerConfig: raw.ServerConfig{}, + MgmtServerConfig: raw.MgmtServerConfig{}, + }, + } + cfgByte, _ := json.Marshal(cfg) + err := InitConfig(cfgByte) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("GetDecryptFromEnv success", t, func() { + defer gomonkey.ApplyFunc(localauth.GetDecryptFromEnv, func() (map[string]string, error) { + return map[string]string{"metaEtcdPwd": "qwerdf"}, nil + }).Reset() + err := InitConfig(cfgByte) + convey.So(err, convey.ShouldBeNil) + convey.So(GlobalConfig.MetaETCDConfig.Password, convey.ShouldEqual, "qwerdf") + }) + convey.Convey("GetDecryptFromEnv error", t, func() { + defer gomonkey.ApplyFunc(localauth.GetDecryptFromEnv, func() (map[string]string, error) { + return nil, fmt.Errorf("get decrypt from env error") + }).Reset() + err := InitConfig(cfgByte) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("Unmarshal error", t, func() { + err := InitConfig(cfgByte[2:20]) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("ValidateStruct error", t, func() { + cfg.MetaETCDConfig = etcd3.EtcdConfig{} + cfgByte, _ := json.Marshal(cfg) + err := InitConfig(cfgByte) + convey.So(err, convey.ShouldNotBeNil) + }) + p1.Reset() + + convey.Convey("sts init error", t, func() { + p2 := gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg raw.ServerConfig) error { + return errors.New("init sts error") + }) + err := InitConfig(cfgByte) + convey.So(err, convey.ShouldNotBeNil) + p2.Reset() + }) + + convey.Convey("sts init error", t, func() { + defer gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg raw.ServerConfig) error { + return nil + }).Reset() + defer gomonkey.ApplyFunc(os.Setenv, func(key, value string) error { + return errors.New("set env error") + }).Reset() + err := InitConfig(cfgByte) + convey.So(err, convey.ShouldNotBeNil) + }) + +} + +func TestInitModuleConfig(t *testing.T) { + cfg := &types.Configuration{ + CPU: 999, + Memory: 999, + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 1, + ScaleDownTime: 1, + BurstScaleNum: 1, + }, + LeaseSpan: 1, + RouterETCDConfig: etcd3.EtcdConfig{ + Password: "321", + }, + MetaETCDConfig: etcd3.EtcdConfig{ + Password: "123", + }, + SchedulerNum: 1, + } + cfgByte, _ := json.Marshal(cfg) + convey.Convey("success", t, func() { + defer gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return string(cfgByte) + }).Reset() + err := InitModuleConfig() + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("Unmarshal error", t, func() { + defer gomonkey.ApplyFunc(os.Getenv, func(key string) string { + return "{" + }).Reset() + err := InitModuleConfig() + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestInitEtcd(t *testing.T) { + convey.Convey("Test InitEtcd", t, func() { + convey.Convey("global config is nil", func() { + rawGConfig := GlobalConfig + GlobalConfig = types.Configuration{} + stopCh := make(chan struct{}) + err := InitEtcd(stopCh) + GlobalConfig = rawGConfig + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("init route etcd error", func() { + rawGConfig := GlobalConfig + GlobalConfig = types.Configuration{ + RouterETCDConfig: etcd3.EtcdConfig{}, + } + err := InitEtcd(nil) + GlobalConfig = rawGConfig + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("init metadata etcd error", func() { + rawGConfig := GlobalConfig + GlobalConfig = types.Configuration{ + MetaETCDConfig: etcd3.EtcdConfig{}, + } + err := InitEtcd(nil) + GlobalConfig = rawGConfig + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestClearSensitiveInfo(t *testing.T) { + convey.Convey("Test ClearSensitiveInfo", t, func() { + convey.Convey("global config is nil", func() { + rawGConfig := GlobalConfig + GlobalConfig = types.Configuration{} + ClearSensitiveInfo() + GlobalConfig = rawGConfig + }) + convey.Convey("clear sensitive Info", func() { + rawGConfig := GlobalConfig + GlobalConfig = types.Configuration{ + RouterETCDConfig: etcd3.EtcdConfig{}, + MetaETCDConfig: etcd3.EtcdConfig{}, + } + ClearSensitiveInfo() + GlobalConfig = rawGConfig + }) + }) +} + +func TestSetAlarmEnv(t *testing.T) { + convey.Convey("Test SetAlarmEnv", t, func() { + convey.Convey("success", func() { + rawGConfig := GlobalConfig + GlobalConfig = types.Configuration{ + ClusterName: "cluster1", + AlarmConfig: alarm.Config{ + EnableAlarm: true, + }, + } + err := setAlarmEnv() + GlobalConfig = rawGConfig + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/functionscaler/config/hotload_config.go b/yuanrong/pkg/functionscaler/config/hotload_config.go new file mode 100644 index 0000000..870eb1e --- /dev/null +++ b/yuanrong/pkg/functionscaler/config/hotload_config.go @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config - +package config + +import ( + "encoding/json" + "io/ioutil" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/monitor" + "yuanrong/pkg/functionscaler/types" +) + +const ( + // ConfigFilePath defines config file path of frontend + configFilePath = "/home/sn/config/config.json" +) + +var ( + configWatcher monitor.FileWatcher + configChangedCallback ChangedCallback +) + +// ChangedCallback config change callback func +type ChangedCallback func() + +// WatchConfig - +func WatchConfig(configPath string, stopCh <-chan struct{}, callback ChangedCallback) error { + + watcher, err := monitor.CreateFileWatcher(stopCh) + if err != nil { + return err + } + configWatcher = watcher + configChangedCallback = callback + configWatcher.RegisterCallback(configPath, hotLoadConfig) + return nil +} + +func hotLoadConfig(filename string, opType monitor.OpType) { + log.GetLogger().Infof("file %s hot load start", filename) + config, err := loadConfig(filename) + if err != nil { + log.GetLogger().Errorf("hotLoadConfig failed file: %s, opType: %d, err: %s", + filename, opType, err.Error()) + return + } + hotLoadMetaEtcdConfig(config) + hotLoadRouterEtcdConfig(config) + hotLoadAutoScaleConfig(config) + if configChangedCallback != nil { + configChangedCallback() + } +} + +func loadConfig(configPath string) (*types.Configuration, error) { + data, err := ioutil.ReadFile(configPath) + if err != nil { + log.GetLogger().Errorf("read file error, file path is %s", configPath) + return nil, err + } + config := &types.Configuration{} + err = json.Unmarshal(data, config) + if err != nil { + log.GetLogger().Errorf("failed to parse the config data: %s", err) + return nil, err + } + err = loadFunctionConfig(config) + if err != nil { + return nil, err + } + return config, err +} + +func hotLoadAutoScaleConfig(newAllConfig *types.Configuration) { + if newAllConfig.AutoScaleConfig.SLAQuota > 0 && newAllConfig.AutoScaleConfig.BurstScaleNum > 0 && + newAllConfig.AutoScaleConfig.ScaleDownTime > 0 { + GlobalConfig.AutoScaleConfig = newAllConfig.AutoScaleConfig + autoScaleConfig := GlobalConfig.AutoScaleConfig + log.GetLogger().Infof("autoScaleConfig update, SLAQuota: %d,BurstScaleNum: %d,ScaleDownTime: %d ", + autoScaleConfig.SLAQuota, autoScaleConfig.BurstScaleNum, autoScaleConfig.ScaleDownTime) + } + return +} + +func hotLoadMetaEtcdConfig(newAllConfig *types.Configuration) { + if newAllConfig.MetaETCDConfig.Servers != nil && len(newAllConfig.MetaETCDConfig.Servers) > 0 { + newConfig := newAllConfig.MetaETCDConfig + oldConfig := GlobalConfig.MetaETCDConfig + oldConfig.Servers = newConfig.Servers + log.GetLogger().Infof("etcd serverList update, new: %v", newConfig.Servers) + } + return +} + +func hotLoadRouterEtcdConfig(newAllConfig *types.Configuration) { + if newAllConfig.RouterETCDConfig.Servers != nil && len(newAllConfig.RouterETCDConfig.Servers) > 0 { + newConfig := newAllConfig.RouterETCDConfig + oldConfig := GlobalConfig.RouterETCDConfig + oldConfig.Servers = newConfig.Servers + log.GetLogger().Infof("etcd serverList update, new: %v", newConfig.Servers) + } + return +} diff --git a/yuanrong/pkg/functionscaler/config/hotload_config_test.go b/yuanrong/pkg/functionscaler/config/hotload_config_test.go new file mode 100644 index 0000000..b98d628 --- /dev/null +++ b/yuanrong/pkg/functionscaler/config/hotload_config_test.go @@ -0,0 +1,225 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config - +package config + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/monitor" + "yuanrong/pkg/functionscaler/types" +) + +const ( + serverPort = "8888" + defaultMetaEtcdCafile = "/home/sn/resource/ca/ca.pem" + defaultMetaEtcdCertfile = "/home/sn/resource/ca/cert.pem" + defaultMetaEtcdKeyfile = "/home/sn/resource/ca/key.pem" + + defaultRouterEtcdCafile = "/home/sn/resource/routerEtcd/ca.pem" + defaultRouterEtcdCertfile = "/home/sn/resource/routerEtcd/cert.pem" + defaultRouterEtcdKeyfile = "/home/sn/resource/routerEtcd/key.pem" +) + +var ( + watcher *monitor.MockFileWatcher + maxTimeout = 100*24*3600 + 1 + testConfig = types.Configuration{ + CPU: 5, + Memory: 100, + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 10, + ScaleDownTime: 1, + BurstScaleNum: 1, + }, + LeaseSpan: 1, + MetaETCDConfig: etcd3.EtcdConfig{ + Servers: []string{"127.0.0.1:2379"}, + User: "root", + Password: "0000", + CaFile: defaultMetaEtcdCafile, + CertFile: defaultMetaEtcdCertfile, + KeyFile: defaultMetaEtcdKeyfile, + SslEnable: true, + }, + RouterETCDConfig: etcd3.EtcdConfig{ + Servers: []string{"127.0.0.2:2379"}, + User: "root", + Password: "1111", + CaFile: defaultRouterEtcdCafile, + CertFile: defaultRouterEtcdCertfile, + KeyFile: defaultRouterEtcdKeyfile, + SslEnable: true, + }, + } + + testConfig2 = &types.Configuration{ + CPU: 5, + Memory: 100, + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 10, + ScaleDownTime: 1, + BurstScaleNum: 1, + }, + LeaseSpan: 1, + MetaETCDConfig: etcd3.EtcdConfig{ + SslEnable: true, + }, + RouterETCDConfig: etcd3.EtcdConfig{ + SslEnable: true, + }, + } +) + +func createMockFileWatcher(stopCh <-chan struct{}) (monitor.FileWatcher, error) { + watcher = &monitor.MockFileWatcher{ + Callbacks: map[string]monitor.FileChangedCallback{}, + StopCh: stopCh, + EventChan: make(chan string, 1), + } + return watcher, nil +} + +func TestWatchConfig(t *testing.T) { + convey.Convey("TestWatchConfig error", t, func() { + defer gomonkey.ApplyFunc(monitor.CreateFileWatcher, func(stopCh <-chan struct{}) (monitor.FileWatcher, error) { + return nil, fmt.Errorf("ioutil.ReadFile error") + }).Reset() + stopCh := make(chan struct{}, 1) + err := WatchConfig(configFilePath, stopCh, nil) + if err != nil { + return + } + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestHotLoadConfig(t *testing.T) { + convey.Convey("TestHotLoadConfig OK", t, func() { + patches := gomonkey.NewPatches() + data, _ := json.Marshal(testConfig) + patches.ApplyFunc(ioutil.ReadFile, func() ([]byte, error) { + fmt.Println(string(data)) + return data, nil + }) + defer func() { + patches.Reset() + }() + + GlobalConfig = testConfig + + monitor.SetCreator(createMockFileWatcher) + + stopCh := make(chan struct{}, 1) + callbackChan := make(chan int, 1) + WatchConfig(configFilePath, stopCh, func() { + callbackChan <- 1 + fmt.Println("do call back") + }) + + watcher.EventChan <- configFilePath + + <-callbackChan + convey.So(GlobalConfig.MetaETCDConfig.SslEnable, convey.ShouldEqual, + testConfig.MetaETCDConfig.SslEnable) + convey.So(GlobalConfig.RouterETCDConfig.SslEnable, convey.ShouldEqual, + testConfig.RouterETCDConfig.SslEnable) + close(stopCh) + }) + + convey.Convey("TestHotLoadConfig OK 2", t, func() { + patches := gomonkey.NewPatches() + data, _ := json.Marshal(testConfig2) + patches.ApplyFunc(ioutil.ReadFile, func() ([]byte, error) { + fmt.Println(string(data)) + return data, nil + }) + defer func() { + patches.Reset() + }() + + GlobalConfig = testConfig + monitor.SetCreator(createMockFileWatcher) + + stopCh := make(chan struct{}, 1) + callbackChan := make(chan int, 1) + WatchConfig(configFilePath, stopCh, func() { + callbackChan <- 1 + fmt.Println("do call back") + }) + + watcher.EventChan <- configFilePath + + <-callbackChan + convey.So(GlobalConfig.MetaETCDConfig.SslEnable, convey.ShouldEqual, + testConfig.MetaETCDConfig.SslEnable) + convey.So(GlobalConfig.RouterETCDConfig.SslEnable, convey.ShouldEqual, + testConfig.RouterETCDConfig.SslEnable) + close(stopCh) + }) +} + +func TestLoadConfig(t *testing.T) { + convey.Convey("TestLoadConfig error 0", t, func() { + defer gomonkey.ApplyFunc(ioutil.ReadFile, func() ([]byte, error) { + return nil, fmt.Errorf("ioutil.ReadFile error") + }).Reset() + _, err := loadConfig(configFilePath) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("TestLoadConfig error 1", t, func() { + patches := gomonkey.NewPatches() + data, _ := json.Marshal(testConfig) + patches.ApplyFunc(ioutil.ReadFile, func() ([]byte, error) { + fmt.Println(string(data)) + return data, nil + }) + patches.ApplyFunc(json.Unmarshal, func(data []byte, v any) error { + return fmt.Errorf("json.Unmarshal error") + }) + defer func() { + patches.Reset() + }() + _, err := loadConfig(configFilePath) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("TestLoadConfig error 2", t, func() { + patches := gomonkey.NewPatches() + data, _ := json.Marshal(testConfig) + patches.ApplyFunc(ioutil.ReadFile, func() ([]byte, error) { + fmt.Println(string(data)) + return data, nil + }) + patches.ApplyFunc(loadFunctionConfig, func(config *types.Configuration) error { + return fmt.Errorf("loadFunctionConfig error") + }) + defer func() { + patches.Reset() + }() + _, err := loadConfig(configFilePath) + convey.So(err, convey.ShouldNotBeNil) + }) +} diff --git a/yuanrong/pkg/functionscaler/dynamicconfigmanager/dynamicconfigmanager.go b/yuanrong/pkg/functionscaler/dynamicconfigmanager/dynamicconfigmanager.go new file mode 100644 index 0000000..6272359 --- /dev/null +++ b/yuanrong/pkg/functionscaler/dynamicconfigmanager/dynamicconfigmanager.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package dynamicconfigmanager - +package dynamicconfigmanager + +import ( + "fmt" + "strings" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/k8sclient" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/urnutils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +const ( + // DynamicConfigSuffix dynamic config suffix + DynamicConfigSuffix = "-dynamic-config" + // DynamicConfigMapName default dynamic configmap name + DynamicConfigMapName = "dynamic-config" + // DefaultDynamicConfigPath default dynamic config path + DefaultDynamicConfigPath = "/opt/dynamic-config" + // DynamicConfigMapNameKey default dynamic configmap key + DynamicConfigMapNameKey = "dynamic-config.properties" +) + +// HandleUpdateFunctionEvent DynamicConfigManager hande function update event +func HandleUpdateFunctionEvent(funcSpec *types.FunctionSpecification) { + if config.GlobalConfig.Scenario == types.ScenarioWiseCloud { + return + } + log.GetLogger().Infof("handling dynamic config update for function %s", funcSpec.FuncKey) + // the dynamic configuration of the latest version function may change from + // enabled to disable. In this case, need to delete the old configmap. + if !funcSpec.ExtendedMetaData.DynamicConfig.Enabled { + deleteConfigMap(funcSpec) + return + } + createOrUpdateConfigMap(funcSpec) +} + +// HandleDeleteFunctionEvent DynamicConfigManager hande function delete event +func HandleDeleteFunctionEvent(funcSpec *types.FunctionSpecification) { + if config.GlobalConfig.Scenario == types.ScenarioWiseCloud { + return + } + log.GetLogger().Infof("handling dynamic config delete for function %s", funcSpec.FuncKey) + deleteConfigMap(funcSpec) +} + +func createOrUpdateConfigMap(funcSpec *types.FunctionSpecification) { + log.GetLogger().Infof("dynamic config is enabled, start to create or update configmap for "+ + "function %s", funcSpec.FuncKey) + cmName := urnutils.CrNameByURN(funcSpec.FuncMetaData.FunctionVersionURN) + "-dynamic-config" + nameSpace := config.GlobalConfig.NameSpace + if nameSpace == "" { + nameSpace = constant.DefaultNameSpace + } + expectedConfigmap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmName, + Namespace: nameSpace, + }, + Data: buildDynamicConfigContent(funcSpec), + } + err := k8sclient.GetkubeClient().CreateOrUpdateConfigMap(expectedConfigmap) + if err != nil { + log.GetLogger().Errorf("Create or update Configmap failed!Configmap.Name: %s, Error: %s", + cmName, err.Error()) + } +} + +func deleteConfigMap(funcSpec *types.FunctionSpecification) { + cmName := urnutils.CrNameByURN(funcSpec.FuncMetaData.FunctionVersionURN) + "-dynamic-config" + nameSpace := config.GlobalConfig.NameSpace + if nameSpace == "" { + nameSpace = constant.DefaultNameSpace + } + err := k8sclient.GetkubeClient().DeleteK8sConfigMap(nameSpace, cmName) + if err != nil { + log.GetLogger().Errorf("Delete Configmap failed!Configmap.Name: %s, Error: %s", + cmName, err.Error()) + } +} + +func buildDynamicConfigContent(funcSpec *types.FunctionSpecification) map[string]string { + var dynamicConfigContent = make(map[string]string) + var builder strings.Builder + for _, configKV := range funcSpec.ExtendedMetaData.DynamicConfig.ConfigContent { + builder.WriteString(fmt.Sprintf("%s=%s\n", configKV.Name, configKV.Value)) + } + dynamicConfigContent[DynamicConfigMapNameKey] = builder.String() + return dynamicConfigContent +} diff --git a/yuanrong/pkg/functionscaler/dynamicconfigmanager/dynamicconfigmanager_test.go b/yuanrong/pkg/functionscaler/dynamicconfigmanager/dynamicconfigmanager_test.go new file mode 100644 index 0000000..355a2df --- /dev/null +++ b/yuanrong/pkg/functionscaler/dynamicconfigmanager/dynamicconfigmanager_test.go @@ -0,0 +1,40 @@ +package dynamicconfigmanager + +import ( + "errors" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + v1 "k8s.io/api/core/v1" + + "yuanrong/pkg/common/faas_common/k8sclient" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/types" +) + +func TestHandleUpdateFunctionEvent(t *testing.T) { + convey.Convey("HandleUpdateFunctionEvent", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&k8sclient.KubeClient{}), "CreateOrUpdateConfigMap", + func(_ *k8sclient.KubeClient, c *v1.ConfigMap) error { + return errors.New("CreateOrUpdateConfigMap configmap error") + }).Reset() + HandleUpdateFunctionEvent(&types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{DynamicConfig: commonTypes.DynamicConfigEvent{ + Enabled: true, + ConfigContent: []commonTypes.KV{{Name: "config1", Value: "value1"}}, + }}, + }) + }) +} + +func TestHandleDeleteFunctionEvent(t *testing.T) { + convey.Convey("HandleDeleteFunctionEvent", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&k8sclient.KubeClient{}), "DeleteK8sConfigMap", + func(_ *k8sclient.KubeClient, namespace string, configMapName string) error { + return errors.New("delete configmap error") + }).Reset() + HandleDeleteFunctionEvent(&types.FunctionSpecification{}) + }) +} diff --git a/yuanrong/pkg/functionscaler/faasscheduler.go b/yuanrong/pkg/functionscaler/faasscheduler.go new file mode 100644 index 0000000..59ad4f1 --- /dev/null +++ b/yuanrong/pkg/functionscaler/faasscheduler.go @@ -0,0 +1,1151 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package functionscaler - +package functionscaler + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + _ "net/http/pprof" + "os" + "strconv" + "strings" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/healthlog" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/common/faas_common/trafficlimit" + commonTypes "yuanrong/pkg/common/faas_common/types" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/instancepool" + "yuanrong/pkg/functionscaler/lease" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + defaultChanSize = 1000 + minArgsNum = 1 + validArgsNum = 2 + libruntimeValidArgsNum = 4 + validInsOpLen = 2 + waitForETCDList = 10 * time.Millisecond + frontendNodePort = "31222" + logFileName = "faas-scheduler" + stateFuncKeyLen = 2 +) + +var ( + // insOpSeparator stands for separator of instance operation + insOpSeparator = "#" + // insOpCreate stands for instance create operation + insOpCreate InstanceOperation = "create" + // insOpDelete stands for instance delete operation + insOpDelete InstanceOperation = "delete" + // insOpAcquire stands for instance acquire operation + insOpAcquire InstanceOperation = "acquire" + // insOpRetain stands for instance retain operation + insOpRetain InstanceOperation = "retain" + // insOpBatchRetain stands for instance batch retain operation + insOpBatchRetain InstanceOperation = "batchRetain" + // insOpRelease stands for instance release operation + insOpRelease InstanceOperation = "release" + // insOpRelease stands for instance release operation + insOpRollout InstanceOperation = "rollout" + // insOpUnknown stands for unknown instance operation + insOpUnknown InstanceOperation = "unknown" + // stateSplitStr - + stateSplitStr = ";" + + // InstanceRequirementPoolLabel - key of poolLabel + instanceRequirementPoolLabel = "poolLabel" +) + +// InstanceOperation defines instance operations +type InstanceOperation string + +// StateOperation defines state instance operations +type StateOperation string + +// FaaSScheduler manages instances for faas functions +type FaaSScheduler struct { + PoolManager *instancepool.PoolManager + funcSpecCh chan registry.SubEvent + insSpecCh chan registry.SubEvent + insConfigCh chan registry.SubEvent + aliasSpecCh chan registry.SubEvent + schedulerCh chan registry.SubEvent + rolloutConfigCh chan registry.SubEvent + + leaseInterval time.Duration + + allocRecord sync.Map + sync.RWMutex +} + +var globalFaaSScheduler *FaaSScheduler + +// NewFaaSScheduler will create a FaaSScheduler +func NewFaaSScheduler(stopCh <-chan struct{}) *FaaSScheduler { + leaseInterval := time.Duration(config.GlobalConfig.LeaseSpan) * time.Millisecond + if leaseInterval < types.MinLeaseInterval { + leaseInterval = types.MinLeaseInterval + } + go func() { + if config.GlobalConfig.PprofAddr == "" { + return + } + err := http.ListenAndServe(config.GlobalConfig.PprofAddr, nil) + if err != nil { + return + } + }() + faasScheduler := &FaaSScheduler{ + PoolManager: instancepool.NewPoolManager(stopCh), + funcSpecCh: make(chan registry.SubEvent, defaultChanSize), + insSpecCh: make(chan registry.SubEvent, defaultChanSize), + insConfigCh: make(chan registry.SubEvent, defaultChanSize), + aliasSpecCh: make(chan registry.SubEvent, defaultChanSize), + schedulerCh: make(chan registry.SubEvent, defaultChanSize), + rolloutConfigCh: make(chan registry.SubEvent, defaultChanSize), + leaseInterval: leaseInterval, + } + registry.GlobalRegistry.SubscribeFuncSpec(faasScheduler.funcSpecCh) + registry.GlobalRegistry.SubscribeInsSpec(faasScheduler.insSpecCh) + registry.GlobalRegistry.SubscribeInsConfig(faasScheduler.insConfigCh) + registry.GlobalRegistry.SubscribeAliasSpec(faasScheduler.aliasSpecCh) + registry.GlobalRegistry.SubscribeSchedulerProxy(faasScheduler.schedulerCh) + registry.GlobalRegistry.SubscribeRolloutConfig(faasScheduler.rolloutConfigCh) + go faasScheduler.processFunctionSubscription() + go faasScheduler.processInstanceSubscription() + go faasScheduler.processInstanceConfigSubscription() + go faasScheduler.processAliasSpecSubscription() + go faasScheduler.processSchedulerProxySubscription() + go faasScheduler.processRolloutConfigSubscription() + go healthlog.PrintHealthLog(stopCh, printInputLog, logFileName) + if config.GlobalConfig.AlarmConfig.EnableAlarm { + faasScheduler.PoolManager.CheckMinInsAndReport(stopCh) + } + if selfregister.IsRolloutObject { + go faasScheduler.syncAllocRecordDuringRollout() + } + go metrics.InitServerMetric(stopCh) + + return faasScheduler +} + +// InitGlobalScheduler - +func InitGlobalScheduler(stopCh <-chan struct{}) { + globalFaaSScheduler = NewFaaSScheduler(stopCh) +} + +// GetGlobalScheduler - +func GetGlobalScheduler() *FaaSScheduler { + return globalFaaSScheduler +} + +// Recover before recover faaSScheduler, must wait StartList complete +func (fs *FaaSScheduler) Recover() { + // wait for StartList completion + for len(fs.funcSpecCh) != 0 { + time.Sleep(waitForETCDList) + } + time.Sleep(waitForETCDList) + fs.PoolManager.RecoverInstancePool() +} + +func (fs *FaaSScheduler) processFunctionSubscription() { + for { + select { + case event, ok := <-fs.funcSpecCh: + if !ok { + log.GetLogger().Warnf("function channel is closed") + return + } + funcSpec, ok := event.EventMsg.(*types.FunctionSpecification) + if !ok { + log.GetLogger().Warnf("event message doesn't contain function specification") + continue + } + fs.PoolManager.HandleFunctionEvent(event.EventType, funcSpec) + } + } +} + +func (fs *FaaSScheduler) processInstanceSubscription() { + for { + select { + case event, ok := <-fs.insSpecCh: + if !ok { + log.GetLogger().Warnf("instance channel is closed") + return + } + insSpec, ok := event.EventMsg.(*commonTypes.InstanceSpecification) + if !ok { + log.GetLogger().Warnf("event message doesn't contain instance specification") + continue + } + fs.PoolManager.HandleInstanceEvent(event.EventType, insSpec) + } + } +} + +func (fs *FaaSScheduler) processInstanceConfigSubscription() { + for { + select { + case event, ok := <-fs.insConfigCh: + if !ok { + log.GetLogger().Warnf("instances info channel is closed") + return + } + insConfig, ok := event.EventMsg.(*instanceconfig.Configuration) + if !ok { + log.GetLogger().Warnf("event message doesn't contain instance specification") + continue + } + fs.PoolManager.HandleInstanceConfigEvent(event.EventType, insConfig) + } + } +} + +func (fs *FaaSScheduler) processAliasSpecSubscription() { + for { + select { + case event, ok := <-fs.aliasSpecCh: + if !ok { + log.GetLogger().Warnf("instances info channel is closed") + return + } + aliasUrn, ok := event.EventMsg.(string) + if !ok { + log.GetLogger().Warnf("event message doesn't contain instance specification") + continue + } + fs.PoolManager.HandleAliasEvent(event.EventType, aliasUrn) + } + } +} + +func (fs *FaaSScheduler) processSchedulerProxySubscription() { + for { + select { + case event, ok := <-fs.schedulerCh: + if !ok { + log.GetLogger().Warnf("scheduler proxy channel is closed") + return + } + if instanceSpec, assertOK := event.EventMsg.(*commonTypes.InstanceSpecification); assertOK { + fs.PoolManager.HandleSchedulerManaged(event.EventType, instanceSpec) + } else { + log.GetLogger().Warnf("event message doesn't contain scheduler info") + continue + } + } + } +} + +func (fs *FaaSScheduler) processRolloutConfigSubscription() { + for { + select { + case event, ok := <-fs.rolloutConfigCh: + if !ok { + log.GetLogger().Warnf("scheduler proxy channel is closed") + return + } + if ratio, ok := event.EventMsg.(int); ok { + fs.PoolManager.HandleRolloutRatioChange(ratio) + } else { + log.GetLogger().Warnf("event message doesn't contain ratio info") + continue + } + } + } +} + +// ProcessInstanceRequestLibruntime will handle acquire, release and retain of instance based on multi libruntime +func (fs *FaaSScheduler) ProcessInstanceRequestLibruntime(args []api.Arg, traceID string) ([]byte, error) { + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + insOp, targetName, extraData, eventData := parseInstanceOperationLibruntime(args, traceID) + startTime := time.Now() + defer logger.Infof("process of instance operation %s target %s cost %dms", insOp, targetName, + time.Now().Sub(startTime).Milliseconds()) + result, err, shouldReply := fs.HandleRequestForward(insOp, args, traceID) + if shouldReply { + return result, err + } + var response interface{} + switch insOp { + case insOpCreate: + response = fs.handleInstanceCreate(targetName, extraData, eventData, traceID) + case insOpDelete: + response = fs.handleInstanceDelete(targetName, extraData, traceID) + case insOpAcquire: + response = fs.handleInstanceAcquire(targetName, extraData, traceID) + case insOpRelease: + response = fs.handleInstanceRelease(targetName, extraData, traceID) + case insOpRetain: + response = fs.handleInstanceRetain(targetName, extraData, traceID) + case insOpBatchRetain: + response = fs.handleInstanceBatchRetain(targetName, extraData, traceID) + case insOpRollout: + response = fs.handleRollout(targetName, traceID) + default: + logger.Warnf("unknown instance operation %s", insOp) + response = generateInstanceResponse(nil, snerror.New(constant.UnsupportedOperationErrorCode, + constant.UnsupportedOperationErrorMessage), startTime) + } + respData, err := json.Marshal(response) + if err != nil { + logger.Errorf("failed to marshal response of instance operation %s error %s", insOp, err.Error()) + return nil, err + } + return respData, nil +} + +// HandleRequestForward return forward result and shouldReply flag +func (fs *FaaSScheduler) HandleRequestForward(insOp InstanceOperation, args []api.Arg, traceID string) ([]byte, error, + bool) { + if !rollout.GetGlobalRolloutHandler().IsGaryUpdating { + return []byte{}, nil, false + } + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + switch insOp { + case insOpCreate, insOpAcquire: + if rollout.GetGlobalRolloutHandler().ShouldForwardRequest() { + logger.Infof("gray updating forward %s request to instance %s", string(insOp), + rollout.GetGlobalRolloutHandler().ForwardInstance) + result, err := rollout.InvokeByInstanceId(args, rollout.GetGlobalRolloutHandler().ForwardInstance, + traceID) + if err != nil { + // 调用另一个scheduler失败需要兜底 + return result, err, false + } + response := &commonTypes.InstanceResponse{} + err = json.Unmarshal(result, response) + if err != nil { + return []byte{}, err, false + } + if response.ErrorCode == statuscode.NoInstanceAvailableErrCode || + response.ErrorCode == statuscode.InsThdReqTimeoutCode { + logger.Infof("gray updating get no instance available error %s from instance %s", + response.ErrorCode, rollout.GetGlobalRolloutHandler().ForwardInstance) + return []byte{}, nil, false + } + return result, err, true + } + case insOpRelease, insOpRetain, insOpBatchRetain, insOpDelete: + logger.Infof("gray updating forward %s request to instance %s", string(insOp), + rollout.GetGlobalRolloutHandler().ForwardInstance) + _, _ = rollout.InvokeByInstanceId(args, rollout.GetGlobalRolloutHandler().ForwardInstance, traceID) + return []byte{}, nil, false + default: + logger.Warnf("unknown instance operation %s", insOp) + } + return []byte{}, nil, false +} + +func (fs *FaaSScheduler) handleInstanceCreate(funcKey string, extraData, eventData []byte, + traceID string) *commonTypes.InstanceResponse { + startTime := time.Now() + logger := log.GetLogger().With(zap.Any("traceID", traceID), zap.Any("funcKey", funcKey)) + funcSpec := registry.GlobalRegistry.GetFuncSpec(funcKey) + if funcSpec == nil { + logger.Errorf("failed to create instance, function %s doesn't exist", funcKey) + return generateInstanceResponse(nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, + statuscode.FuncMetaNotFoundErrMsg), startTime) + } + dataInfo, err := parseExtraData(extraData) + if err != nil { + logger.Errorf("failed to parse extraData error :%v", err) + return generateInstanceResponse(nil, err, startTime) + } + resSpec, err := getResourceSpecification(dataInfo.resourceData, dataInfo.invokeLabel, funcSpec) + if err != nil { + logger.Errorf("failed get resSpec error %v", err) + return generateInstanceResponse(nil, err, startTime) + } + instance, err := fs.PoolManager.CreateInstance(&types.InstanceCreateRequest{ + FuncSpec: funcSpec, + ResSpec: resSpec, + InstanceName: dataInfo.designateInstanceName, + CreateEvent: eventData, + }) + if err != nil { + logger.Errorf("failed to create instance for function %s, error %s", funcSpec.FuncKey, err.Error()) + return generateInstanceResponse(nil, err, startTime) + } + return generateInstanceResponse(&types.InstanceAllocation{Instance: instance}, nil, startTime) +} + +func (fs *FaaSScheduler) handleInstanceDelete(instanceID string, extraData []byte, + traceID string) *commonTypes.InstanceResponse { + startTime := time.Now() + instance := registry.GlobalRegistry.GetInstance(instanceID) + if instance == nil { + return generateInstanceResponse(nil, snerror.New(statuscode.InstanceNotFoundErrCode, + statuscode.InstanceNotFoundErrMsg), startTime) + } + logger := log.GetLogger().With(zap.Any("traceID", traceID), zap.Any("funcKey", instance.FuncKey)) + err := fs.PoolManager.DeleteInstance(instance) + if err != nil { + logger.Errorf("failed to delete instance for function %s, error %s", instance.FuncKey, err.Error()) + return generateInstanceResponse(nil, err, startTime) + } + return generateInstanceResponse(&types.InstanceAllocation{Instance: instance}, nil, startTime) +} + +func (fs *FaaSScheduler) handleInstanceAcquire(targetName string, extraData []byte, + traceID string) *commonTypes.InstanceResponse { + startTime := time.Now() + funcKey, stateID := parseStateOperation(targetName) + logger := log.GetLogger().With(zap.Any("traceID", traceID), zap.Any("funcKey", funcKey), + zap.Any("stateID", stateID)) + funcSpec := registry.GlobalRegistry.GetFuncSpec(funcKey) + if funcSpec == nil { + logger.Errorf("failed to get instance, function %s doesn't exist", funcKey) + return generateInstanceResponse(nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, + statuscode.FuncMetaNotFoundErrMsg), startTime) + } + + needForward, endpoint, forwardErr := judgeForwardToOtherCluster(funcSpec.FuncMetaData.FunctionURN, logger) + if forwardErr != nil { + return generateInstanceResponse(nil, forwardErr, startTime) + } + if needForward { + logger.Infof("request should forward to %s for %s", endpoint, funcSpec.FuncMetaData.FunctionURN) + return generateInstanceResponse(nil, snerror.New(constant.AcquireLeaseVPCConflictErrorCode, endpoint), + startTime) + } + + if !trafficlimit.FuncTrafficLimit(funcKey) { + logger.Warnf("handle instance acquire limited for function: %s", funcKey) + return generateInstanceResponse(nil, snerror.New(constant.AcquireLeaseTrafficLimitErrorCode, + constant.AcquireLeaseTrafficLimitErrorMessage), startTime) + } + var insAlloc *types.InstanceAllocation + dataInfo, err := parseExtraData(extraData) + if err != nil { + logger.Errorf("failed to parse extraData error :%v", err) + return generateInstanceResponse(nil, err, startTime) + } + resSpec, err := getResourceSpecification(dataInfo.resourceData, dataInfo.invokeLabel, funcSpec) + if err != nil { + logger.Errorf("failed get resSpec error %v", err) + return generateInstanceResponse(nil, err, startTime) + } + logger.Infof("handling instance acquire for resSpec %v instanceID %s instanceSession %v traceID %s", resSpec, + dataInfo.designateInstanceID, dataInfo.instanceSession, traceID) + poolLabel := getPoolLabel(dataInfo.poolLabel, funcSpec.InstanceMetaData.PoolLabel) + insAlloc, err = fs.PoolManager.AcquireInstanceThread(&types.InstanceAcquireRequest{ + FuncSpec: funcSpec, // etcd + ResSpec: resSpec, // args + TraceID: traceID, + StateID: stateID, + PoolLabel: poolLabel, + InstanceName: dataInfo.designateInstanceName, + DesignateInstanceID: dataInfo.designateInstanceID, + CallerPodName: dataInfo.callerPodName, + TrafficLimited: dataInfo.trafficLimited, + InstanceSession: dataInfo.instanceSession, + }) + if err != nil { + logger.Errorf("failed to acquire instance of function %s traceID %s error %s", funcSpec.FuncKey, traceID, + err.Error()) + return generateInstanceResponse(nil, err, startTime) + } + if insAlloc.Lease != nil { + fs.allocRecord.Store(insAlloc.AllocationID, insAlloc) + } + logger.Infof("succeed to acquire instance %s of function %s traceID %s", insAlloc.AllocationID, funcSpec.FuncKey, + traceID) + return generateInstanceResponse(insAlloc, nil, startTime) +} + +func unmarshalExtraData(extraData []byte) (map[string][]byte, error) { + extraDataMap := make(map[string][]byte, utils.DefaultMapSize) + if len(extraData) != 0 { + log.GetLogger().Debugf("acquire libruntime extraData: %s", string(extraData)) + defer func() { + if r := recover(); r != nil { + log.GetLogger().Errorf("acquire libruntime unmarshal extraData err: %v", r) + } + }() + jsonErr := json.Unmarshal(extraData, &extraDataMap) + if jsonErr != nil { + return nil, jsonErr + } + } + return extraDataMap, nil +} + +func parseExtraData(extraData []byte) (*extraDataInfo, snerror.SNError) { + extraDataMap, err := unmarshalExtraData(extraData) + if err != nil { + return nil, snerror.NewWithError(statuscode.StatusInternalServerError, + fmt.Errorf("unmarshal extraData err: %w", err)) + } + dataInfo := &extraDataInfo{} + if instanceName, ok := extraDataMap[constant.RuntimeInstanceName]; ok { + dataInfo.designateInstanceName = string(instanceName) + } + if instanceID, ok := extraDataMap[constant.InstanceRequirementInsIDKey]; ok { + dataInfo.designateInstanceID = string(instanceID) + } + if createEvent, ok := extraDataMap[constant.InstanceCreateEvent]; ok { + dataInfo.createEvent = createEvent + } + if resourceDataByte, ok := extraDataMap[constant.InstanceRequirementResourcesKey]; ok { + dataInfo.resourceData = resourceDataByte + } + if callerPodNameByte, ok := extraDataMap[constant.InstanceCallerPodName]; ok { + dataInfo.callerPodName = string(callerPodNameByte) + } + if poolLabelBytes, ok := extraDataMap[instanceRequirementPoolLabel]; ok { + dataInfo.poolLabel = string(poolLabelBytes) + } + if trafficLimitedByte, ok := extraDataMap[constant.InstanceTrafficLimited]; ok { + if trafficLimited, err := strconv.ParseBool(string(trafficLimitedByte)); err != nil { + dataInfo.trafficLimited = trafficLimited + } + } + if sessionConfigData, ok := extraDataMap[constant.InstanceSessionConfig]; ok { + insSessConfig := commonTypes.InstanceSessionConfig{} + err := json.Unmarshal(sessionConfigData, &insSessConfig) + if err != nil { + return nil, snerror.NewWithError(statuscode.StatusInternalServerError, err) + } + if !utils.CheckInstanceSessionValid(insSessConfig) { + return nil, snerror.New(statuscode.InstanceSessionInvalidErrCode, "session config invalid") + } + dataInfo.instanceSession = insSessConfig + } + if invokeLabel, ok := extraDataMap[constant.InstanceRequirementInvokeLabel]; ok { + dataInfo.invokeLabel = invokeLabel + } + return dataInfo, nil +} + +type extraDataInfo struct { + designateInstanceName string + designateInstanceID string + createEvent []byte + resourceData []byte + callerPodName string + poolLabel string + invokeLabel []byte + trafficLimited bool + instanceSession commonTypes.InstanceSessionConfig +} + +func judgeForwardToOtherCluster(funcURN string, logger api.FormatLogger) (bool, string, snerror.SNError) { + functionAvailableRegistry := registry.GlobalRegistry.FunctionAvailableRegistry + frontendRegistry := registry.GlobalRegistry.FaaSFrontendRegistry + clusters := functionAvailableRegistry.GeClusters(funcURN) + if len(clusters) == 0 { + return false, "", nil + } + + if commonUtils.IsStringInArray(os.Getenv(constant.ClusterIDKey), clusters) { + return false, "", nil + } + + for _, cluster := range clusters { + frontends := frontendRegistry.GetFrontends(cluster) + if len(frontends) == 0 { + continue + } + endpoint := fmt.Sprintf("%s:%s", frontends[0], frontendNodePort) + return true, endpoint, nil + } + logger.Errorf("func:%s need forward to other cluster, but no available frontend found", funcURN) + return false, "", snerror.New(statuscode.StatusInternalServerError, "no available frontend found") +} + +func (fs *FaaSScheduler) handleInstanceRelease(targetName string, metricsData []byte, + traceID string) *commonTypes.InstanceResponse { + startTime := time.Now() + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + items := strings.Split(targetName, stateSplitStr) + if len(items) == stateFuncKeyLen { // funcKey;stateID + targetName = items[0] + stateID := items[1] + if stateID != "" { + return fs.deleteState(stateID, targetName, logger) + } + } + insAlloc, err := fs.loadInsAlloc(targetName, logger) + if err != nil { + return generateInstanceResponse(nil, err, startTime) + } + logger.Infof("handling instance release %s for function %s", insAlloc.AllocationID, insAlloc.Instance.FuncKey) + if strings.Contains(insAlloc.AllocationID, "stateThread") { // %s-stateThread%d + fs.allocRecord.Delete(insAlloc.AllocationID) + err := fs.PoolManager.ReleaseStateThread(insAlloc) + if err != nil { + logger.Errorf("release thread %s fail, err: %v", targetName, err) + return generateInstanceResponse(nil, snerror.New(statuscode.StatusInternalServerError, + statuscode.InternalErrorMessage), startTime) + } + return generateInstanceResponse(insAlloc, nil, startTime) + } + data := fs.getInstanceThreadMetrics(insAlloc.AllocationID, metricsData) + insThdMetrics := fs.buildMetrics(data) + fs.reportMetrics(insAlloc.Instance.FuncKey, insAlloc.Instance.ResKey, insThdMetrics) + fs.allocRecord.Delete(insAlloc.AllocationID) + + // If the arg:isAbnormal that received from fronted is true, the instance of this lease will be unusable + // for user. Then the instance will be removed from instance queue and be clean. + if data.IsAbnormal == true { + fs.PoolManager.ReleaseAbnormalInstance(insAlloc.Instance) + } + if !commonUtils.IsNil(insAlloc.Lease) { + err := insAlloc.Lease.Release() + if err != nil { + // 正常情况下,通过insAlloc.Lease.Release()中的callback完成release + // 这里用来防止实例被删除,pool中的sessionrecord中仍然残留sessioninfo的情况 + if err == lease.ErrInstanceNotFound { + fs.PoolManager.ReleaseInstanceThread(insAlloc) + } + logger.Errorf("failed to release instance %s of function %s traceID %s error %s", insAlloc.AllocationID, + insAlloc.Instance.FuncKey, traceID, err.Error()) + } else { + logger.Infof("succeed to release instance %s of function %s traceID %s", insAlloc.AllocationID, + insAlloc.Instance.FuncKey, traceID) + } + } + return generateInstanceResponse(insAlloc, nil, startTime) +} + +func (fs *FaaSScheduler) loadInsAlloc(targetName string, logger api.FormatLogger) (*types.InstanceAllocation, + snerror.SNError) { + rawData, exist := fs.allocRecord.Load(targetName) + if !exist { + logger.Errorf("allocation of instance thread %s not found", targetName) + return nil, snerror.New(statuscode.InstanceNotFoundErrCode, statuscode.InstanceNotFoundErrMsg) + } + insAlloc, ok := rawData.(*types.InstanceAllocation) + if !ok { + logger.Errorf("instance thread allocation type error") + return nil, snerror.New(statuscode.StatusInternalServerError, statuscode.InternalErrorMessage) + } + return insAlloc, nil +} + +func (fs *FaaSScheduler) deleteState(stateID string, funcKey string, + logger api.FormatLogger) *commonTypes.InstanceResponse { + startTime := time.Now() + funcSpec := registry.GlobalRegistry.GetFuncSpec(funcKey) + if funcSpec == nil { + logger.Errorf("failed to get instance, function %s doesn't exist", funcKey) + return generateInstanceResponse(nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, + statuscode.FuncMetaNotFoundErrMsg), startTime) + } + exist := fs.PoolManager.GetAndDeleteState(stateID, funcKey, funcSpec, logger) + if !exist { + return generateInstanceResponse(nil, snerror.New(statuscode.StateNotExistedErrCode, + statuscode.StateNotExistedErrMsg), startTime) + } + return generateInstanceResponse(&types.InstanceAllocation{Instance: &types.Instance{}}, nil, startTime) +} + +func (fs *FaaSScheduler) handleInstanceBatchRetain(target string, metricsData []byte, + traceID string) *commonTypes.BatchInstanceResponse { + startTime := time.Now() + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + targetNames := strings.Split(target, ",") + insThdMetrics := map[string]*types.InstanceThreadMetrics{} + err := json.Unmarshal(metricsData, &insThdMetrics) + if err != nil { + logger.Errorf("failed to unmarshal metrics from data %s, err %s, trace %s", string(metricsData), + err.Error(), traceID) + } + batchInstanceResp := &commonTypes.BatchInstanceResponse{ + InstanceAllocSucceed: map[string]commonTypes.InstanceAllocationSucceedInfo{}, + InstanceAllocFailed: map[string]commonTypes.InstanceAllocationFailedInfo{}, + LeaseInterval: fs.leaseInterval.Milliseconds(), + } + for _, name := range targetNames { + insAlloc, err := fs.retainInstance(name, traceID, insThdMetrics[name], logger) + if err != nil { + batchInstanceResp.InstanceAllocFailed[name] = commonTypes.InstanceAllocationFailedInfo{ + ErrorCode: err.Code(), + ErrorMessage: err.Error(), + } + continue + } + batchInstanceResp.InstanceAllocSucceed[name] = commonTypes.InstanceAllocationSucceedInfo{ + FuncKey: insAlloc.Instance.FuncKey, + FuncSig: insAlloc.Instance.FuncSig, + InstanceID: insAlloc.Instance.InstanceID, + ThreadID: insAlloc.AllocationID, + } + } + batchInstanceResp.SchedulerTime = time.Now().Sub(startTime).Seconds() + return batchInstanceResp +} + +func (fs *FaaSScheduler) handleInstanceRetain(targetName string, metricsData []byte, + traceID string) *commonTypes.InstanceResponse { + startTime := time.Now() + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + insThdMetrics := &types.InstanceThreadMetrics{} + err := json.Unmarshal(metricsData, insThdMetrics) + if err != nil { + logger.Errorf("failed to unmarshal metrics from data %s for instance %s", string(metricsData), targetName) + } + insAlloc, retainErr := fs.retainInstance(targetName, traceID, insThdMetrics, logger) + return generateInstanceResponse(insAlloc, retainErr, startTime) +} + +func (fs *FaaSScheduler) retainInstance(targetName, traceID string, insThdMetrics *types.InstanceThreadMetrics, + logger api.FormatLogger) (*types.InstanceAllocation, snerror.SNError) { + rawData, exist := fs.allocRecord.Load(targetName) + if !exist && len(insThdMetrics.ReacquireData) == 0 { + logger.Errorf("allocation of instance thread %s not found", targetName) + return nil, snerror.New(statuscode.LeaseIDNotFoundCode, + statuscode.LeaseIDNotFoundMsg) + } + if !exist { + insAlloc, err := fs.reacquireLease(targetName, traceID, insThdMetrics, logger) + if err != nil { + logger.Errorf("reacquire lease failed, %s", err.Error()) + return nil, err + } + return insAlloc, err + } + insAlloc, ok := rawData.(*types.InstanceAllocation) + if !ok { + logger.Errorf("instance thread allocation type error") + return nil, snerror.New(statuscode.StatusInternalServerError, + statuscode.InternalErrorMessage) + } + if strings.Contains(insAlloc.AllocationID, "stateThread") { // %s-stateThread%d + return fs.retainStateInstance(targetName, insAlloc, logger) + } + if insThdMetrics != nil { + insThdMetrics.InsThdID = insAlloc.AllocationID + fs.reportMetrics(insAlloc.Instance.FuncKey, insAlloc.Instance.ResKey, insThdMetrics) + } + if insAlloc.Instance.InstanceStatus.Code == int32(constant.KernelInstanceStatusSubHealth) { + fs.allocRecord.Delete(insAlloc.AllocationID) + if !commonUtils.IsNil(insAlloc.Lease) { + err := insAlloc.Lease.Release() + if err != nil { + logger.Errorf("failed to delete abnormal thread %s of function %s error %s", + insAlloc.AllocationID, insAlloc.Instance.FuncKey, err.Error()) + } + } + return nil, snerror.New(statuscode.InstanceStatusAbnormalCode, constant.LeaseErrorInstanceIsAbnormalMessage) + } + if !commonUtils.IsNil(insAlloc.Lease) { + err := insAlloc.Lease.Extend() + if err != nil { + fs.allocRecord.Delete(insAlloc.AllocationID) + logger.Errorf("failed to retain instance %s of function %s error %s", insAlloc.AllocationID, + insAlloc.Instance.FuncKey, err.Error()) + return nil, snerror.New(constant.LeaseExpireOrDeletedErrorCode, constant.LeaseExpireOrDeletedErrorMessage) + } + logger.Infof("succeed to retain instance %s of function %s ", insAlloc.AllocationID, + insAlloc.Instance.FuncKey) + } + return insAlloc, nil +} + +func (fs *FaaSScheduler) reacquireLease(targetName, traceID string, insThdMetrics *types.InstanceThreadMetrics, + logger api.FormatLogger) (*types.InstanceAllocation, snerror.SNError) { + instanceId, _, parseErr := parseRetainTargetName(targetName) + if parseErr != nil { + return nil, snerror.New(statuscode.LeaseIDIllegalCode, statuscode.LeaseIDIllegalMsg) + } + dataInfo, err := parseExtraData(insThdMetrics.ReacquireData) + if err != nil { + return nil, err + } + funcSpec := registry.GlobalRegistry.GetFuncSpec(insThdMetrics.FunctionKey) + + resSpec, err := getResourceSpecification(dataInfo.resourceData, dataInfo.invokeLabel, funcSpec) + if err != nil { + return nil, err + } + logger.Infof("handling instance reacquire for resSpec %v instanceID %s instanceSession %v", resSpec, + dataInfo.designateInstanceID, dataInfo.instanceSession) + poolLabel := getPoolLabel(dataInfo.poolLabel, funcSpec.InstanceMetaData.PoolLabel) + insAlloc, err := fs.PoolManager.AcquireInstanceThread(&types.InstanceAcquireRequest{ + FuncSpec: funcSpec, // etcd + ResSpec: resSpec, // args + TraceID: traceID, + PoolLabel: poolLabel, + DesignateInstanceID: instanceId, + DesignateThreadID: targetName, + InstanceSession: dataInfo.instanceSession, + }) + if err != nil { + logger.Errorf("failed to reacquire instance of function %s traceID %s error %s", funcSpec.FuncKey, traceID, + err.Error()) + return nil, err + } + if insAlloc.Lease != nil { + fs.allocRecord.Store(insAlloc.AllocationID, insAlloc) + } + logger.Infof("succeed to reacquire instance %s of function %s traceID %s", insAlloc.AllocationID, funcSpec.FuncKey, + traceID) + return insAlloc, nil +} + +func (fs *FaaSScheduler) retainStateInstance(targetName string, insAlloc *types.InstanceAllocation, + logger api.FormatLogger) (*types.InstanceAllocation, snerror.SNError) { + if insAlloc.Instance.InstanceStatus.Code == int32(constant.KernelInstanceStatusSubHealth) { + err := fs.PoolManager.ReleaseStateThread(insAlloc) + if err != nil { + logger.Errorf("release thread %s fail", targetName) + } + return nil, snerror.New(statuscode.InstanceStatusAbnormalCode, + constant.LeaseErrorInstanceIsAbnormalMessage) + } + err := fs.PoolManager.RetainStateThread(insAlloc) + if err != nil { + logger.Errorf("handleInstanceRetain err %v", err) + return nil, snerror.New(constant.LeaseExpireOrDeletedErrorCode, constant.LeaseExpireOrDeletedErrorMessage) + } + return insAlloc, nil +} + +func (fs *FaaSScheduler) handleRollout(targetName, traceID string) *commonTypes.RolloutResponse { + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + logger.Infof("received rollout request from %s, start to gray update", targetName) + if !config.GlobalConfig.EnableRollout { + return generateRolloutErrorResponse("", nil, errors.New("rollout is not enable")) + } + discoveryConfig := config.GlobalConfig.SchedulerDiscovery + if discoveryConfig == nil || discoveryConfig.RegisterMode != types.RegisterTypeContend { + return generateRolloutErrorResponse("", nil, errors.New("incompatible register mode")) + } + if !selfregister.Registered || len(selfregister.RegisterKey) == 0 { + return generateRolloutErrorResponse("", nil, errors.New("scheduler not registered")) + } + rollout.GetGlobalRolloutHandler().IsGaryUpdating = true + selfregister.IsRollingOut = true + rollout.GetGlobalRolloutHandler().UpdateForwardInstance(targetName) + var allocRecord = make(map[string][]string) + fs.allocRecord.Range(func(key, value any) bool { + allocLease, ok := key.(string) + if !ok { + logger.Warnf("allocRecord key is invalid") + return true + } + insAlloc, ok := value.(*types.InstanceAllocation) + if !ok { + logger.Warnf("allocRecord value is invalid") + return true + } + funcKey := insAlloc.Instance.FuncKey + allocRecord[funcKey] = append(allocRecord[funcKey], allocLease) + return true + }) + return generateRolloutErrorResponse(selfregister.RegisterKey, allocRecord, nil) +} + +func (fs *FaaSScheduler) syncAllocRecordDuringRollout() { + syncCh := rollout.GetGlobalRolloutHandler().GetAllocRecordSyncChan() + for { + select { + case allocRecord, ok := <-syncCh: + if !ok { + log.GetLogger().Warnf("stop syncing allocation record") + return + } + fs.syncAllocRecord(allocRecord) + } + } +} + +func (fs *FaaSScheduler) syncAllocRecord(allocRecord map[string][]string) { + log.GetLogger().Infof("start ot sync allocRecord") + for funcKey, record := range allocRecord { + funcSpec := registry.GlobalRegistry.GetFuncSpec(funcKey) + if funcSpec == nil { + log.GetLogger().Errorf("failed to sync allocRecord for function %s, function doesn't exist", funcKey) + continue + } + resSpec := &resspeckey.ResourceSpecification{ + CPU: funcSpec.ResourceMetaData.CPU, + Memory: funcSpec.ResourceMetaData.Memory, + EphemeralStorage: funcSpec.ResourceMetaData.EphemeralStorage, + } + for _, allocation := range record { + items := strings.Split(allocation, "-") + insAcqReq := &types.InstanceAcquireRequest{ + FuncSpec: funcSpec, + ResSpec: resSpec, + InstanceName: items[0], + } + insAlloc, err := fs.PoolManager.AcquireInstanceThread(insAcqReq) + if err != nil { + log.GetLogger().Errorf("failed to sync allocation %s, acquire instance error %s", allocation, + err.Error()) + continue + } + fs.allocRecord.Store(insAlloc.AllocationID, insAlloc) + } + } +} + +func (fs *FaaSScheduler) reportMetrics(funcKey string, resKey resspeckey.ResSpecKey, + insThdMetrics *types.InstanceThreadMetrics) { + if len(funcKey) == 0 { + return + } + fs.PoolManager.ReportMetrics(funcKey, resKey, insThdMetrics) +} + +func (fs *FaaSScheduler) getInstanceThreadMetrics(threadID string, metricsData []byte) *types.InstanceThreadMetrics { + metrics := &types.InstanceThreadMetrics{} + err := json.Unmarshal(metricsData, metrics) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal metrics from data %s for instance %s", string(metricsData), + threadID) + return nil + } + metrics.InsThdID = threadID + return metrics +} + +func (fs *FaaSScheduler) buildMetrics(extraData *types.InstanceThreadMetrics) *types.InstanceThreadMetrics { + if extraData == nil { + return &types.InstanceThreadMetrics{} + } + return &types.InstanceThreadMetrics{ + ProcReqNum: extraData.ProcReqNum, + AvgProcTime: extraData.AvgProcTime, + MaxProcTime: extraData.MaxProcTime, + } +} + +func parseInstanceOperationLibruntime(args []api.Arg, traceID string) (InstanceOperation, string, []byte, []byte) { + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + insOp := insOpUnknown + if len(args) < minArgsNum { + logger.Errorf("argument number is smaller than %d check args %+v", minArgsNum, args) + return insOp, "", nil, nil + } + operationArg := args[0] + if operationArg.Type != api.Value { + logger.Errorf("invalid argument type for args[0]") + return insOp, "", nil, nil + } + items := strings.SplitN(string(operationArg.Data), insOpSeparator, validInsOpLen) + if len(items) != validInsOpLen { + logger.Errorf("failed to parse operation and target from %s", string(operationArg.Data)) + return insOp, "", nil, nil + } + insOp = InstanceOperation(items[0]) + target := items[1] + if len(args) == minArgsNum { + return insOp, target, nil, nil + } + extraDataArg := args[1] + if extraDataArg.Type != api.Value { + logger.Errorf("invalid argument type for args[1]") + return insOp, target, nil, nil + } + eventDataArg := api.Arg{} + // temporary process for forward compatible, remove this in future + if len(args) == libruntimeValidArgsNum { + eventDataArg = args[2] + } + return insOp, target, extraDataArg.Data, eventDataArg.Data +} + +func parseInstanceOperation(args []*api.Arg, traceID string) (InstanceOperation, string, []byte) { + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + if len(args) != validArgsNum { + logger.Errorf("invalid argument number") + return insOpUnknown, "", nil + } + operationArg := args[0] + extraDataArg := args[1] + if operationArg.Type != api.Value || extraDataArg.Type != api.Value { + logger.Errorf("invalid argument type") + return insOpUnknown, "", nil + } + operationCommand := string(operationArg.Data) + items := strings.SplitN(operationCommand, insOpSeparator, validInsOpLen) + if len(items) != validInsOpLen { + logger.Errorf("invalid argument %s", operationCommand) + return insOpUnknown, "", nil + } + insOp := InstanceOperation(items[0]) + targetName := items[1] + extraData := extraDataArg.Data + return insOp, targetName, extraData +} + +func getPoolLabel(poolLabelFromReq, poolLabelFromMeta string) string { + if poolLabelFromReq != "" { + return poolLabelFromReq + } + return poolLabelFromMeta +} + +func parseStateOperation(ops string) (string, string) { + targetName := ops + items := strings.Split(ops, stateSplitStr) + if len(items) != validArgsNum { + return targetName, "" + } + + targetName = items[0] + stateID := items[1] + + return targetName, stateID +} + +func parseRetainTargetName(targetName string) (string, string, error) { + // targetName: f49a9bc8-bddd-4e0c-8000-00000000b90d-thread21 + items := strings.Split(targetName, "-thread") + if len(items) != validArgsNum { + return "", "", fmt.Errorf("target name fmt error. %s", targetName) + } + instanceId := items[0] + threadId := items[1] + return instanceId, threadId, nil +} + +func getResourceSpecification(resData, labelData []byte, funcSpec *types.FunctionSpecification) ( + *resspeckey.ResourceSpecification, snerror.SNError) { + resSpec := &resspeckey.ResourceSpecification{ + CustomResources: make(map[string]int64, constant.DefaultMapSize), + } + resMap := map[string]types.IntOrString{} + if len(resData) != 0 { + err := json.Unmarshal(resData, &resMap) + if err != nil { + return nil, snerror.NewWithError(statuscode.StatusInternalServerError, err) + } + } + for k, v := range resMap { + if v.Type != types.Int { + continue + } + if k == constant.ResourceCPUName { + resSpec.CPU = v.IntVal + continue + } + if k == constant.ResourceMemoryName { + resSpec.Memory = v.IntVal + continue + } + resSpec.CustomResources[k] = v.IntVal + } + if resSpec.CPU == 0 { + resSpec.CPU = funcSpec.ResourceMetaData.CPU + } + if resSpec.Memory == 0 { + resSpec.Memory = funcSpec.ResourceMetaData.Memory + } + if resSpec.EphemeralStorage == 0 { + resSpec.EphemeralStorage = funcSpec.ResourceMetaData.EphemeralStorage + } + if len(labelData) > 0 { + labelMap := map[string]string{} + err := json.Unmarshal(labelData, &labelMap) + if err != nil { + return nil, snerror.NewWithError(statuscode.StatusInternalServerError, err) + } + resSpec.InvokeLabel = labelMap[types.HeaderInstanceLabel] + } + return resSpec, nil +} + +func generateInstanceResponse(insAlloc *types.InstanceAllocation, snErr snerror.SNError, + startTime time.Time) *commonTypes.InstanceResponse { + if snErr != nil { + return &commonTypes.InstanceResponse{ + InstanceAllocationInfo: commonTypes.InstanceAllocationInfo{ + InstanceID: "", + LeaseInterval: 0, + }, + ErrorCode: snErr.Code(), + ErrorMessage: snErr.Error(), + SchedulerTime: time.Now().Sub(startTime).Seconds(), + } + } + leaseInterval := time.Duration(0) + if insAlloc.Lease != nil { + leaseInterval = insAlloc.Lease.GetInterval() + } + forceInvoke := false + if insAlloc.Instance.InstanceStatus.Code == int32(constant.KernelInstanceStatusEvicting) { + forceInvoke = true + } + return &commonTypes.InstanceResponse{ + InstanceAllocationInfo: commonTypes.InstanceAllocationInfo{ + FuncKey: insAlloc.Instance.FuncKey, + FuncSig: insAlloc.Instance.FuncSig, + InstanceID: insAlloc.Instance.InstanceID, + InstanceIP: insAlloc.Instance.InstanceIP, + InstancePort: insAlloc.Instance.InstancePort, + NodeIP: insAlloc.Instance.NodeIP, + NodePort: insAlloc.Instance.NodePort, + ThreadID: insAlloc.AllocationID, + LeaseInterval: leaseInterval.Milliseconds(), + CPU: insAlloc.Instance.ResKey.CPU, + Memory: insAlloc.Instance.ResKey.Memory, + ForceInvoke: forceInvoke, + }, + ErrorCode: constant.InsReqSuccessCode, + ErrorMessage: constant.InsReqSuccessMessage, + SchedulerTime: time.Now().Sub(startTime).Seconds(), + } +} + +func generateRolloutErrorResponse(registerKey string, allocRecord map[string][]string, + err error) *commonTypes.RolloutResponse { + errorCode := constant.InsReqSuccessCode + errorMessage := constant.InsReqSuccessMessage + if err != nil { + errorCode = statuscode.InternalErrorCode + errorMessage = err.Error() + } + return &commonTypes.RolloutResponse{ + RegisterKey: registerKey, + AllocRecord: allocRecord, + ErrorCode: errorCode, + ErrorMessage: errorMessage, + } +} + +func printInputLog() { + log.GetLogger().Infof("%s is alive.", logFileName) +} diff --git a/yuanrong/pkg/functionscaler/faasscheduler_test.go b/yuanrong/pkg/functionscaler/faasscheduler_test.go new file mode 100644 index 0000000..3133ac2 --- /dev/null +++ b/yuanrong/pkg/functionscaler/faasscheduler_test.go @@ -0,0 +1,1794 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faasscheduler - +package functionscaler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey" + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + mockUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/instancepool" + "yuanrong/pkg/functionscaler/lease" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/types" +) + +type fakeLease struct{} + +func (l *fakeLease) Extend() error { + return nil +} +func (l *fakeLease) Release() error { + return nil +} +func (l *fakeLease) GetInterval() time.Duration { + return time.Second +} + +var testFuncSpec = &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncKey: "TestFuncKey", + FuncMetaData: commonTypes.FuncMetaData{ + Handler: "myHandler", + }, + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + MinInstance: 0, + MaxInstance: 1000, + ConcurrentNum: 100, + }, + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Handler: "myInitializer", + }, + }, +} + +func TestNewFaaSScheduler(t *testing.T) { + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "SubscribeFuncSpec", func(_ *registry.Registry, + subChan chan registry.SubEvent) { + }), + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "SubscribeInsSpec", func(_ *registry.Registry, + subChan chan registry.SubEvent) { + }), + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "SubscribeInsConfig", func(_ *registry.Registry, + subChan chan registry.SubEvent) { + }), + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "SubscribeAliasSpec", func(_ *registry.Registry, + subChan chan registry.SubEvent) { + }), + ApplyFunc((*etcd3.EtcdWatcher).StartList, func(_ *etcd3.EtcdWatcher) { + }), + ApplyFunc((*etcd3.EtcdClient).AttachAZPrefix, func(_ *etcd3.EtcdClient, key string) string { return key }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + time.Sleep(10 * time.Millisecond) + convey.Convey("test New faasscheduler", t, func() { + stopCh := make(chan struct{}) + faasScheduler := NewFaaSScheduler(stopCh) + assert.Equal(t, true, faasScheduler != nil) + close(stopCh) + }) + convey.Convey("testfaasScheduler funcSpecCh", t, func() { + stopCh := make(chan struct{}) + faasScheduler := NewFaaSScheduler(stopCh) + assert.Equal(t, true, faasScheduler != nil) + faasScheduler.funcSpecCh <- registry.SubEvent{EventMsg: struct{}{}} + time.Sleep(5 * time.Microsecond) + close(stopCh) + }) + convey.Convey("testfaasScheduler insSpecCh", t, func() { + stopCh := make(chan struct{}) + faasScheduler := NewFaaSScheduler(stopCh) + assert.Equal(t, true, faasScheduler != nil) + faasScheduler.insSpecCh <- registry.SubEvent{EventMsg: struct{}{}} + time.Sleep(5 * time.Microsecond) + close(stopCh) + }) + convey.Convey("testfaasScheduler insConfigCh", t, func() { + stopCh := make(chan struct{}) + faasScheduler := NewFaaSScheduler(stopCh) + assert.Equal(t, true, faasScheduler != nil) + faasScheduler.insConfigCh <- registry.SubEvent{EventMsg: struct{}{}} + time.Sleep(5 * time.Microsecond) + close(stopCh) + }) +} + +func initRegistry() { + patches := []*Patches{ + ApplyFunc((*etcd3.EtcdClient).AttachAZPrefix, func(_ *etcd3.EtcdClient, key string) string { return key }), + ApplyFunc((*registry.FaasSchedulerRegistry).WaitForETCDList, func() {}), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + _ = registry.InitRegistry(make(chan struct{})) + registry.GlobalRegistry.FaaSSchedulerRegistry = registry.NewFaasSchedulerRegistry(make(chan struct{})) + selfregister.SelfInstanceID = "schedulerID-1" + selfregister.GlobalSchedulerProxy.Add(&commonTypes.InstanceInfo{ + TenantID: "123456789", + FunctionName: "faasscheduler", + Version: "lastest", + InstanceName: "schedulerID-1", + }, "") +} + +func TestMain(m *testing.M) { + patches := []*Patches{ + ApplyFunc((*etcd3.EtcdWatcher).StartList, func(_ *etcd3.EtcdWatcher) {}), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc(etcd3.GetCAEMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc((*etcd3.EtcdClient).AttachAZPrefix, func(_ *etcd3.EtcdClient, key string) string { return key }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + config.GlobalConfig = types.Configuration{ + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 500, + ScaleDownTime: 1000, + BurstScaleNum: 1000, + }, + LeaseSpan: 500, + LocalAuth: localauth.AuthConfig{ + AKey: "ENC(key=servicekek, value=6B6D73763030000101D615B6381ED56AF68123844D047428BDCCBF19957866" + + "CD0D7F53C29438337667A93FB9A06C5ED4A3D925C87655E4C734)", + SKey: "ENC(key=servicekek, value=6B6D73763030000101139308ABBC0C4120F949AC833416D5E6D8CA18D8C69E" + + "4C5E03E553E18733B4119C4B716FF2C8265336BB2979545A24FDC07CDD6A6A02F412D0DE83BD43F2A07DDBC78EB2)", + Duration: 0, + }, + } + initRegistry() + instancepool.SetGlobalSdkClient(&mockUtils.FakeLibruntimeSdkClient{}) + m.Run() +} + +func TestParseInstanceOperation(t *testing.T) { + convey.Convey("success", t, func() { + args := []*api.Arg{ + { + Type: api.Value, + Data: []byte("acquire#TestFuncKey"), + }, + { + Type: api.Value, + Data: []byte("qwerdf"), + }, + } + insOp, targetName, extraData := parseInstanceOperation(args, "") + convey.So(insOp, convey.ShouldEqual, "acquire") + convey.So(targetName, convey.ShouldEqual, "TestFuncKey") + convey.So(extraData, convey.ShouldResemble, []byte("qwerdf")) + }) + convey.Convey("args length error", t, func() { + args := []*api.Arg{ + { + Type: api.Value, + Data: []byte("acquire#TestFuncKey"), + }, + } + insOp, targetName, extraData := parseInstanceOperation(args, "") + convey.So(insOp, convey.ShouldEqual, insOpUnknown) + convey.So(targetName, convey.ShouldEqual, "") + convey.So(extraData, convey.ShouldEqual, nil) + }) + convey.Convey("Type error", t, func() { + args := []*api.Arg{ + { + Type: 1, + Data: []byte("acquire#TestFuncKey"), + }, + { + Type: 1, + Data: []byte("qwerdf"), + }, + } + insOp, targetName, extraData := parseInstanceOperation(args, "") + convey.So(insOp, convey.ShouldEqual, insOpUnknown) + convey.So(targetName, convey.ShouldEqual, "") + convey.So(extraData, convey.ShouldEqual, nil) + }) + convey.Convey("extraData success with multi step", t, func() { + args := []*api.Arg{ + { + Type: api.Value, + Data: []byte("acquire#TestFuncKey#qwe#qwe"), + }, + { + Type: api.Value, + Data: []byte("qwerdf"), + }, + } + insOp, targetName, extraData := parseInstanceOperation(args, "") + convey.So(string(insOp), convey.ShouldEqual, "acquire") + convey.So(targetName, convey.ShouldEqual, "TestFuncKey#qwe#qwe") + convey.So(extraData, convey.ShouldResemble, []byte("qwerdf")) + }) +} + +func TestProcessSubscription(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + registry.GlobalRegistry = ®istry.Registry{ + FaaSSchedulerRegistry: registry.NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: registry.NewFunctionRegistry(stopCh), + InstanceRegistry: registry.NewInstanceRegistry(stopCh), + FaaSManagerRegistry: registry.NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: registry.NewInstanceConfigRegistry(stopCh), + AliasRegistry: registry.NewAliasRegistry(stopCh), + RolloutRegistry: registry.NewRolloutRegistry(stopCh), + } + faasScheduler := NewFaaSScheduler(stopCh) + convey.Convey("test processFunctionSubscription", t, func() { + faasScheduler.funcSpecCh <- registry.SubEvent{ + EventType: registry.SubEventTypeUpdate, + EventMsg: &types.FunctionSpecification{ + FuncKey: "testFunc", + }, + } + }) + convey.Convey("test processInstanceSubscription", t, func() { + faasScheduler.insSpecCh <- registry.SubEvent{ + EventType: registry.SubEventTypeUpdate, + EventMsg: &commonTypes.InstanceSpecification{ + InstanceID: "testIns", + }, + } + }) + convey.Convey("test processInstanceConfigSubscription", t, func() { + faasScheduler.insConfigCh <- registry.SubEvent{ + EventType: registry.SubEventTypeUpdate, + EventMsg: &instanceconfig.Configuration{ + FuncKey: "testFunc", + }, + } + }) +} + +func TestFaaSScheduler_processRolloutConfigSubscription(t *testing.T) { + convey.Convey("test processRolloutConfigSubscription", t, func() { + convey.Convey("baseline", func() { + count := 0 + p := ApplyFunc((*instancepool.PoolManager).HandleRolloutRatioChange, + func(_ *instancepool.PoolManager, ratio int) { + count++ + }) + defer p.Reset() + rolloutConfigCh := make(chan registry.SubEvent) + faasScheduler := &FaaSScheduler{ + rolloutConfigCh: rolloutConfigCh, + PoolManager: &instancepool.PoolManager{}, + } + go faasScheduler.processRolloutConfigSubscription() + rolloutConfigCh <- registry.SubEvent{ + EventType: "aaa", + EventMsg: 50, + } + time.Sleep(100 * time.Millisecond) + convey.So(count, convey.ShouldEqual, 1) + rolloutConfigCh <- registry.SubEvent{ + EventType: "aaa", + EventMsg: "123", + } + time.Sleep(100 * time.Millisecond) + convey.So(count, convey.ShouldEqual, 1) + close(rolloutConfigCh) + time.Sleep(100 * time.Millisecond) + convey.So(count, convey.ShouldEqual, 1) + }) + }) +} + +func Test_processAliasSpecSubscription(t *testing.T) { + convey.Convey("processAliasSpecSubscription", t, func() { + stopCh := make(chan struct{}) + defer close(stopCh) + registry.GlobalRegistry = ®istry.Registry{ + FaaSSchedulerRegistry: registry.NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: registry.NewFunctionRegistry(stopCh), + InstanceRegistry: registry.NewInstanceRegistry(stopCh), + FaaSManagerRegistry: registry.NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: registry.NewInstanceConfigRegistry(stopCh), + AliasRegistry: registry.NewAliasRegistry(stopCh), + RolloutRegistry: registry.NewRolloutRegistry(stopCh), + } + faasScheduler := NewFaaSScheduler(stopCh) + var checkAliasUrn string + defer ApplyMethod(reflect.TypeOf(faasScheduler.PoolManager), "HandleAliasEvent", + func(pm *instancepool.PoolManager, eventType registry.EventType, aliasUrn string) { + checkAliasUrn = aliasUrn + }).Reset() + faasScheduler.aliasSpecCh <- registry.SubEvent{ + EventType: registry.SubEventTypeUpdate, + EventMsg: "aliasUrn", + } + time.Sleep(100 * time.Millisecond) + convey.So(checkAliasUrn, convey.ShouldEqual, "aliasUrn") + checkAliasUrn = "" + faasScheduler.aliasSpecCh <- registry.SubEvent{ + EventType: registry.SubEventTypeUpdate, + EventMsg: 123, + } + time.Sleep(100 * time.Millisecond) + convey.So(checkAliasUrn, convey.ShouldEqual, "") + close(faasScheduler.aliasSpecCh) + }) +} + +func TestParseStateOperation(t *testing.T) { + convey.Convey("success", t, func() { + ops := "funcKey;stateId" + targetName, stateID := parseStateOperation(ops) + convey.So(stateID, convey.ShouldEqual, "stateId") + convey.So(targetName, convey.ShouldEqual, "funcKey") + }) +} + +func TestProcessInstanceRequest(t *testing.T) { + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "GetFuncSpec", func(_ *registry.Registry, + funcKey string) *types.FunctionSpecification { + return testFuncSpec + }), + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "SubscribeFuncSpec", func(_ *registry.Registry, + subChan chan registry.SubEvent) { + }), + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "SubscribeInsSpec", func(_ *registry.Registry, + subChan chan registry.SubEvent) { + }), + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "SubscribeInsConfig", func(_ *registry.Registry, + subChan chan registry.SubEvent) { + }), + ApplyMethod(reflect.TypeOf(®istry.Registry{}), "SubscribeAliasSpec", func(_ *registry.Registry, + subChan chan registry.SubEvent) { + }), + ApplyFunc((*registry.FunctionAvailableRegistry).GeClusters, func(_ *registry.FunctionAvailableRegistry, _ string) []string { + return []string{} + }), + ApplyFunc((*registry.FaaSFrontendRegistry).GetFrontends, func(_ *registry.FaaSFrontendRegistry, _ string) []string { + return []string{} + }), + ApplyFunc((*instancepool.PoolManager).ReleaseAbnormalInstance, func(_ *instancepool.PoolManager, + instance *types.Instance) { + }), + ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }), + ApplyFunc((*etcd3.EtcdWatcher).StartList, func(_ *etcd3.EtcdWatcher) { + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + stopCh := make(chan struct{}) + defer close(stopCh) + faasScheduler := NewFaaSScheduler(stopCh) + time.Sleep(1 * time.Second) + faasScheduler.PoolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, testFuncSpec) + faasScheduler.PoolManager.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, &instanceconfig.Configuration{ + FuncKey: "TestFuncKey", + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 100, + ConcurrentNum: 1, + }, + }) + metrics := &types.InstanceThreadMetrics{} + metricsData, _ := json.Marshal(metrics) + releaseExtraData := &types.InstanceThreadMetrics{ + ProcReqNum: 11, + AvgProcTime: 11, + MaxProcTime: 11, + IsAbnormal: true, + } + releaseExtraRawData, _ := json.Marshal(releaseExtraData) + acquireRsp := &commonTypes.InstanceResponse{} + + convey.Convey("acquire", t, func() { + m := map[string][]byte{"resourcesData": []byte("")} + bytes, _ := json.Marshal(m) + acquireArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte("acquire#TestFuncKey"), + }, + { + Type: api.Value, + Data: bytes, + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + convey.Convey("acquire success", func() { + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(acquireArgs, "") + _ = json.Unmarshal(resData, acquireRsp) + assert.Equal(t, constant.InsReqSuccessCode, acquireRsp.ErrorCode) + }) + resourceRes := &resspeckey.ResourceSpecification{ + CPU: 300, + Memory: 128, + } + resource, _ := json.Marshal(resourceRes) + m[constant.InstanceRequirementResourcesKey] = resource + convey.Convey("acquire set resource success", func() { + acquireArgs[1].Data, _ = json.Marshal(m) + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(acquireArgs, "") + _ = json.Unmarshal(resData, acquireRsp) + assert.Equal(t, constant.InsReqSuccessCode, acquireRsp.ErrorCode) + }) + convey.Convey("acquire set resource error", func() { + m[constant.InstanceRequirementResourcesKey] = resource[1:1] + acquireArgs[1].Data, _ = json.Marshal(m) + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(acquireArgs, "") + _ = json.Unmarshal(resData, acquireRsp) + assert.Equal(t, constant.InsReqSuccessCode, acquireRsp.ErrorCode) + }) + convey.Convey("acquire metrics error", func() { + defer ApplyMethod(reflect.TypeOf(registry.GlobalRegistry), "GetFuncSpec", + func(_ *registry.Registry, funcKey string) *types.FunctionSpecification { + return nil + }).Reset() + defer ApplyMethod(reflect.TypeOf(registry.GlobalRegistry), "FetchSilentFuncSpec", + func(_ *registry.Registry, funcKey string) *types.FunctionSpecification { + return nil + }).Reset() + releaseRsp := &commonTypes.InstanceResponse{} + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(acquireArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, statuscode.FuncMetaNotFoundErrCode, releaseRsp.ErrorCode) + }) + }) + + convey.Convey("retain", t, func() { + retainArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("retain#%s", acquireRsp.ThreadID)), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + retainRsp := &commonTypes.InstanceResponse{} + convey.Convey("retain success", func() { + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(retainArgs, "") + _ = json.Unmarshal(resData, retainRsp) + assert.Equal(t, constant.InsReqSuccessCode, retainRsp.ErrorCode) + }) + convey.Convey("retain metrics error", func() { + retainArgs[0].Data = []byte("retain#000thread111") + retainArgs[1].Data = metricsData + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(retainArgs, "") + _ = json.Unmarshal(resData, retainRsp) + assert.Equal(t, statuscode.LeaseIDNotFoundCode, retainRsp.ErrorCode) + }) + convey.Convey("retain stateThread error instance subHealth", func() { + retainErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("retain#%s", "TestFuncKey1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patch := ApplyMethod(reflect.TypeOf(p), + "ReleaseStateThread", func(db *instancepool.PoolManager, + insAlloc *types.InstanceAllocation) error { + return errors.New("release state thread error") + }) + defer patch.Reset() + insAlloc := &types.InstanceAllocation{ + AllocationID: "testFunc-stateThread1", + Instance: &types.Instance{ + FuncKey: "TestFuncKey1", + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusSubHealth), + }, + ResKey: resspeckey.ResSpecKey{}, + }, + } + faasScheduler.allocRecord.Store("TestFuncKey1", insAlloc) + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(retainErrArgs, "") + _ = json.Unmarshal(resData, retainRsp) + assert.Equal(t, statuscode.InstanceStatusAbnormalCode, retainRsp.ErrorCode) + }) + convey.Convey("retain stateThread error", func() { + retainErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("retain#%s", "TestFuncKey1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patch1 := ApplyMethod(reflect.TypeOf(p), "ReleaseStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return nil + }) + defer patch1.Reset() + patch2 := ApplyMethod(reflect.TypeOf(p), "RetainStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return errors.New("retain state thread error") + }) + defer patch2.Reset() + insAlloc := &types.InstanceAllocation{ + AllocationID: "testFunc-stateThread1", + Instance: &types.Instance{ + FuncKey: "TestFuncKey1", + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + ResKey: resspeckey.ResSpecKey{}, + }, + Lease: &lease.GenericInstanceLease{}, + } + faasScheduler.allocRecord.Store("TestFuncKey1", insAlloc) + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(retainErrArgs, "") + _ = json.Unmarshal(resData, retainRsp) + assert.Equal(t, constant.LeaseExpireOrDeletedErrorCode, retainRsp.ErrorCode) + }) + convey.Convey("retain stateThread success", func() { + retainErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("retain#%s", "TestFuncKey1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patch1 := ApplyMethod(reflect.TypeOf(p), "ReleaseStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return nil + }) + defer patch1.Reset() + patch2 := ApplyMethod(reflect.TypeOf(p), "RetainStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return nil + }) + defer patch2.Reset() + insAlloc := &types.InstanceAllocation{ + AllocationID: "testFunc-stateThread1", + Instance: &types.Instance{ + FuncKey: "TestFuncKey1", + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + ResKey: resspeckey.ResSpecKey{}, + }, + Lease: &lease.GenericInstanceLease{}, + } + faasScheduler.allocRecord.Store("TestFuncKey1", insAlloc) + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(retainErrArgs, "") + _ = json.Unmarshal(resData, retainRsp) + assert.Equal(t, constant.InsReqSuccessCode, retainRsp.ErrorCode) + }) + convey.Convey("retain InsAlloc error release error", func() { + retainErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("retain#%s", "TestFuncKey1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patch1 := ApplyMethod(reflect.TypeOf(p), "ReleaseStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return nil + }) + defer patch1.Reset() + patch2 := ApplyMethod(reflect.TypeOf(p), "RetainStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return nil + }) + defer patch2.Reset() + l := &fakeLease{} + patch3 := ApplyMethod(reflect.TypeOf(l), "Release", func(l *fakeLease) error { + return errors.New("release error") + }) + defer patch3.Reset() + insAlloc := &types.InstanceAllocation{ + AllocationID: "testFunc-Thread1", + Instance: &types.Instance{ + FuncKey: "TestFuncKey1", + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusSubHealth), + }, + ResKey: resspeckey.ResSpecKey{}, + }, + Lease: &lease.GenericInstanceLease{}, + } + faasScheduler.allocRecord.Store("TestFuncKey1", insAlloc) + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(retainErrArgs, "") + _ = json.Unmarshal(resData, retainRsp) + assert.Equal(t, statuscode.InstanceStatusAbnormalCode, retainRsp.ErrorCode) + }) + convey.Convey("retain InsAlloc error extend error", func() { + retainErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("retain#%s", "TestFuncKey1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patch1 := ApplyMethod(reflect.TypeOf(p), + "ReleaseStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return nil + }) + defer patch1.Reset() + patch2 := ApplyMethod(reflect.TypeOf(p), + "RetainStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return nil + }) + defer patch2.Reset() + l := &fakeLease{} + patch3 := ApplyMethod(reflect.TypeOf(l), + "Extend", func(l *fakeLease) error { + return errors.New("extend error") + }) + defer patch3.Reset() + insAlloc := &types.InstanceAllocation{ + AllocationID: "testFunc-Thread1", + Instance: &types.Instance{ + FuncKey: "TestFuncKey1", + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + ResKey: resspeckey.ResSpecKey{}, + }, + Lease: &lease.GenericInstanceLease{}, + } + faasScheduler.allocRecord.Store("TestFuncKey1", insAlloc) + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(retainErrArgs, "") + _ = json.Unmarshal(resData, retainRsp) + assert.Equal(t, constant.LeaseExpireOrDeletedErrorCode, retainRsp.ErrorCode) + }) + convey.Convey("retain insThdAlloc data error", func() { + retainErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("retain#%s", "TestFuncKey1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + faasScheduler.allocRecord.Store("TestFuncKey1", "") + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(retainErrArgs, "") + _ = json.Unmarshal(resData, retainRsp) + assert.Equal(t, statuscode.StatusInternalServerError, retainRsp.ErrorCode) + }) + }) + + convey.Convey("release", t, func() { + releaseArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("release#%s", acquireRsp.ThreadID)), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + releaseRsp := &commonTypes.InstanceResponse{} + convey.Convey("release success", func() { + releaseArgs[1].Data = releaseExtraRawData + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(releaseArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, constant.InsReqSuccessCode, releaseRsp.ErrorCode) + }) + convey.Convey("release metrics error", func() { + releaseArgs[1].Data = releaseExtraRawData + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(releaseArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, statuscode.InstanceNotFoundErrCode, releaseRsp.ErrorCode) + }) + convey.Convey("release state error func not exist", func() { + releaseErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("release#%s", "TestFuncKeyE;stateID")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + patch := ApplyMethod(reflect.TypeOf(®istry.Registry{}), "GetFuncSpec", func(_ *registry.Registry, + funcKey string) *types.FunctionSpecification { + return nil + }) + defer patch.Reset() + releaseErrArgs[1].Data = releaseExtraRawData + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(releaseErrArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, statuscode.FuncMetaNotFoundErrCode, releaseRsp.ErrorCode) + }) + convey.Convey("release state error delete error", func() { + releaseErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("release#%s", "TestFuncKey;stateID")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patchGet := ApplyMethod(reflect.TypeOf(p), "GetAndDeleteState", + func(db *instancepool.PoolManager, stateID string, funcKey string, + funcSpec *types.FunctionSpecification, logger api.FormatLogger) bool { + return false + }) + defer patchGet.Reset() + releaseErrArgs[1].Data = releaseExtraRawData + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(releaseErrArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, statuscode.StateNotExistedErrCode, releaseRsp.ErrorCode) + }) + convey.Convey("release state success", func() { + releaseErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("release#%s", "TestFuncKey;stateID")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patchGet := ApplyMethod(reflect.TypeOf(p), "GetAndDeleteState", func(db *instancepool.PoolManager, + stateID string, funcKey string, funcSpec *types.FunctionSpecification, logger api.FormatLogger) bool { + return true + }) + defer patchGet.Reset() + releaseErrArgs[1].Data = releaseExtraRawData + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(releaseErrArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, constant.InsReqSuccessCode, releaseRsp.ErrorCode) + }) + convey.Convey("insThdAlloc data error", func() { + releaseErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("release#%s", "TestFuncKeyE1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + releaseErrArgs[1].Data = releaseExtraRawData + faasScheduler.allocRecord.Store("TestFuncKeyE1", "") + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(releaseErrArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, statuscode.StatusInternalServerError, releaseRsp.ErrorCode) + }) + convey.Convey("release state thread error", func() { + releaseErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("release#%s", "TestFuncKey1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patchGet := ApplyMethod(reflect.TypeOf(p), "ReleaseStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return errors.New("release state thread error") + }) + defer patchGet.Reset() + insAlloc := &types.InstanceAllocation{ + AllocationID: "testFunck-stateThread1", + Instance: &types.Instance{ + FuncKey: "TestFuncKey1", + ResKey: resspeckey.ResSpecKey{}, + }, + Lease: &lease.GenericInstanceLease{}, + } + faasScheduler.allocRecord.Store("TestFuncKey1", insAlloc) + releaseErrArgs[1].Data = releaseExtraRawData + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(releaseErrArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, statuscode.StatusInternalServerError, releaseRsp.ErrorCode) + }) + convey.Convey("release state thread success", func() { + releaseErrArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("release#%s", "TestFuncKey1")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + p := &instancepool.PoolManager{} + patchGet := ApplyMethod(reflect.TypeOf(p), + "ReleaseStateThread", func(db *instancepool.PoolManager, + thread *types.InstanceAllocation) error { + return nil + }) + defer patchGet.Reset() + insAlloc := &types.InstanceAllocation{ + AllocationID: "testFunc-stateThread1", + Instance: &types.Instance{ + FuncKey: "TestFuncKey1", + ResKey: resspeckey.ResSpecKey{}, + }, + } + faasScheduler.allocRecord.Store("TestFuncKey1", insAlloc) + releaseErrArgs[1].Data = releaseExtraRawData + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(releaseErrArgs, "") + _ = json.Unmarshal(resData, releaseRsp) + assert.Equal(t, constant.InsReqSuccessCode, releaseRsp.ErrorCode) + }) + }) + + convey.Convey("error opt", t, func() { + errorArgs := []api.Arg{ + { + Type: api.Value, + Data: []byte(fmt.Sprintf("xxxxx#")), + }, + { + Type: api.Value, + Data: []byte(""), + }, + { + Type: api.Value, + Data: []byte(""), + }, + } + errorRsp := &commonTypes.InstanceResponse{} + resData, _ := faasScheduler.ProcessInstanceRequestLibruntime(errorArgs, "") + _ = json.Unmarshal(resData, errorRsp) + assert.Equal(t, constant.UnsupportedOperationErrorCode, errorRsp.ErrorCode) + }) +} + +func Test_parseInstanceOperationLibruntime(t *testing.T) { + convey.Convey("test parseInstanceOperationLibruntime", t, func() { + convey.Convey("baseline", func() { + op, name, data, _ := parseInstanceOperationLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("acquire#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + convey.So(op, convey.ShouldNotBeNil) + convey.So(name, convey.ShouldEqual, "aaa") + convey.So(data, convey.ShouldBeNil) + }) + convey.Convey("error args", func() { + op, name, data, _ := parseInstanceOperationLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("acquire#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + convey.So(op, convey.ShouldEqual, insOpAcquire) + convey.So(name, convey.ShouldEqual, "aaa") + convey.So(data, convey.ShouldBeNil) + }) + convey.Convey("error types", func() { + op, name, data, _ := parseInstanceOperationLibruntime([]api.Arg{ + { + Type: 1, + Data: []byte("acquire#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 1, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + convey.So(op, convey.ShouldEqual, insOpUnknown) + convey.So(name, convey.ShouldEqual, "") + convey.So(data, convey.ShouldBeNil) + }) + convey.Convey("error operationArg", func() { + op, name, data, _ := parseInstanceOperationLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("acquire"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + convey.So(op, convey.ShouldEqual, insOpUnknown) + convey.So(name, convey.ShouldEqual, "") + convey.So(data, convey.ShouldBeNil) + }) + }) +} + +func TestFaaSScheduler_processSchedulerProxySubscription(t *testing.T) { + convey.Convey("test processSchedulerProxySubscription", t, func() { + convey.Convey("baseline", func() { + count := 0 + p := ApplyFunc((*instancepool.PoolManager).HandleSchedulerManaged, + func(_ *instancepool.PoolManager, eventType registry.EventType, + insSpec *commonTypes.InstanceSpecification) { + count++ + }) + defer p.Reset() + schedulerCh := make(chan registry.SubEvent) + faasScheduler := &FaaSScheduler{ + schedulerCh: schedulerCh, + PoolManager: &instancepool.PoolManager{}, + } + go faasScheduler.processSchedulerProxySubscription() + schedulerCh <- registry.SubEvent{ + EventType: "aaa", + EventMsg: &commonTypes.InstanceSpecification{}, + } + time.Sleep(100 * time.Millisecond) + convey.So(count, convey.ShouldEqual, 1) + schedulerCh <- registry.SubEvent{ + EventType: "aaa", + EventMsg: "123", + } + time.Sleep(100 * time.Millisecond) + convey.So(count, convey.ShouldEqual, 1) + close(schedulerCh) + time.Sleep(100 * time.Millisecond) + convey.So(count, convey.ShouldEqual, 1) + }) + }) +} + +func TestFaaSScheduler_ProcessInstanceRequestLibruntime(t *testing.T) { + faasScheduler := &FaaSScheduler{} + convey.Convey("test ProcessInstanceRequestLibruntime", t, func() { + convey.Convey("baseline", func() { + defer ApplyFunc((*FaaSScheduler).handleInstanceAcquire, func(_ *FaaSScheduler, + targetName string, extraData []byte, + traceID string) *commonTypes.InstanceResponse { + return &commonTypes.InstanceResponse{ + ErrorCode: 111, + } + }).Reset() + resData, err := faasScheduler.ProcessInstanceRequestLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("acquire#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + response := &commonTypes.InstanceResponse{} + _ = json.Unmarshal(resData, response) + convey.So(err, convey.ShouldBeNil) + convey.So(response.ErrorCode, convey.ShouldEqual, 111) + }) + }) + convey.Convey("test create instance", t, func() { + var createErr snerror.SNError + defer ApplyFunc((*instancepool.PoolManager).CreateInstance, func(_ *instancepool.PoolManager, + insCrtReq *types.InstanceCreateRequest) (*types.Instance, snerror.SNError) { + if createErr != nil { + return nil, createErr + } + return &types.Instance{}, nil + }).Reset() + defer ApplyFunc((*registry.Registry).GetFuncSpec, func(_ *registry.Registry, funcKey string) *types.FunctionSpecification { + if funcKey == "testFunc" { + return &types.FunctionSpecification{} + } + return nil + }).Reset() + resData, err := faasScheduler.ProcessInstanceRequestLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("create#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: []byte(""), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + response := &commonTypes.InstanceResponse{} + _ = json.Unmarshal(resData, response) + convey.So(err, convey.ShouldBeNil) + convey.So(response, convey.ShouldNotBeNil) + convey.So(response.ErrorCode, convey.ShouldEqual, statuscode.FuncMetaNotFoundErrCode) + resData, err = faasScheduler.ProcessInstanceRequestLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("create#testFunc"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: []byte("wrong data"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + _ = json.Unmarshal(resData, response) + convey.So(err, convey.ShouldBeNil) + convey.So(response, convey.ShouldNotBeNil) + convey.So(response.ErrorCode, convey.ShouldEqual, statuscode.StatusInternalServerError) + resData, err = faasScheduler.ProcessInstanceRequestLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("create#testFunc"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + _ = json.Unmarshal(resData, response) + convey.So(err, convey.ShouldBeNil) + convey.So(response, convey.ShouldNotBeNil) + convey.So(response.ErrorCode, convey.ShouldEqual, constant.InsReqSuccessCode) + }) + convey.Convey("test delete instance", t, func() { + var deleteErr snerror.SNError + defer ApplyFunc((*instancepool.PoolManager).DeleteInstance, func(_ *instancepool.PoolManager, + instance *types.Instance) snerror.SNError { + if deleteErr != nil { + return deleteErr + } + return nil + }).Reset() + defer ApplyFunc((*registry.Registry).GetInstance, func(_ *registry.Registry, instanceID string) *types.Instance { + if instanceID == "testIns" { + return &types.Instance{} + } + return nil + }).Reset() + resData, err := faasScheduler.ProcessInstanceRequestLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("delete#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: []byte(""), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + response := &commonTypes.InstanceResponse{} + _ = json.Unmarshal(resData, response) + convey.So(response, convey.ShouldNotBeNil) + convey.So(response.ErrorCode, convey.ShouldEqual, statuscode.InstanceNotFoundErrCode) + resData, err = faasScheduler.ProcessInstanceRequestLibruntime([]api.Arg{ + { + Type: 0, + Data: []byte("delete#testIns"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: []byte(""), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + }, "") + _ = json.Unmarshal(resData, response) + convey.So(err, convey.ShouldBeNil) + convey.So(response, convey.ShouldNotBeNil) + convey.So(response.ErrorCode, convey.ShouldEqual, constant.InsReqSuccessCode) + }) +} + +func TestHandleInstanceCreate(t *testing.T) { + tests := []struct { + name string + funcKey string + extraData []byte + traceID string + mockFuncSpec *types.FunctionSpecification + mockDataInfo *extraDataInfo + mockResSpec *resspeckey.ResourceSpecification + mockInstance *types.Instance + mockError snerror.SNError + }{ + { + name: "Function does not exist", + funcKey: "nonexistent-func", + extraData: []byte{}, + traceID: "trace1", + mockFuncSpec: nil, + }, + { + name: "Failed to parse extra data", + funcKey: "test-func", + extraData: []byte("invalid-data"), + traceID: "trace2", + mockFuncSpec: &types.FunctionSpecification{}, + mockError: snerror.New(1, "parse error"), + }, + { + name: "Failed to get resource specification", + funcKey: "test-func", + extraData: []byte("valid-data"), + traceID: "trace3", + mockFuncSpec: &types.FunctionSpecification{}, + mockDataInfo: &extraDataInfo{resourceData: []byte("resourceData"), invokeLabel: []byte("invokeLabel")}, + mockError: snerror.New(1, "resSpec error"), + }, + { + name: "Failed to create instance", + funcKey: "test-func", + extraData: []byte("valid-data"), + traceID: "trace4", + mockFuncSpec: &types.FunctionSpecification{}, + mockDataInfo: &extraDataInfo{resourceData: []byte("resourceData"), invokeLabel: []byte("invokeLabel")}, + mockError: snerror.New(1, "create instance error"), + }, + { + name: "Successfully created instance", + funcKey: "test-func", + extraData: []byte("valid-data"), + traceID: "trace5", + mockFuncSpec: &types.FunctionSpecification{}, + mockDataInfo: &extraDataInfo{designateInstanceName: "instanceName", resourceData: []byte("resourceData"), + invokeLabel: []byte("invokeLabel")}, + mockInstance: &types.Instance{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fs := &FaaSScheduler{ + PoolManager: &instancepool.PoolManager{}, + } + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(registry.GlobalRegistry.GetFuncSpec, + func(funcKey string) *types.FunctionSpecification { + return tt.mockFuncSpec + }) + + patches.ApplyMethod(reflect.TypeOf(fs.PoolManager), "CreateInstance", + func(_ *instancepool.PoolManager, req *types.InstanceCreateRequest) (*types.Instance, snerror.SNError) { + if tt.mockError != nil { + return nil, tt.mockError + } + return tt.mockInstance, nil + }) + + result := fs.handleInstanceCreate(tt.funcKey, tt.extraData, nil, tt.traceID) + + assert.NotNil(t, result) + }) + } +} + +func TestFaaSScheduler_parseExtraData(t *testing.T) { + convey.Convey("test parseExtraData", t, func() { + convey.Convey("baseline", func() { + data := map[string][]byte{ + "instanceName": []byte("testInstanceName"), + "designateInstanceID": []byte("testInstanceID"), + "instanceCreateEvent": []byte("testCreateEvent"), + "resourcesData": []byte("testResourceData"), + "instanceCallerPodName": []byte("testPodName"), + "poolLabel": []byte("testPoolLabel"), + } + dataBytes, _ := json.Marshal(data) + dataInfo, err := parseExtraData(dataBytes) + convey.So(err, convey.ShouldBeNil) + convey.So(dataInfo, convey.ShouldNotBeEmpty) + convey.So(dataInfo.designateInstanceName, convey.ShouldEqual, "testInstanceName") + convey.So(dataInfo.designateInstanceID, convey.ShouldEqual, "testInstanceID") + }) + convey.Convey("invalid session config", func() { + data := map[string][]byte{ + "instanceSessionConfig": []byte(`{"sessionID":"","sessionTTL":10}`), + } + dataBytes, _ := json.Marshal(data) + _, err := parseExtraData(dataBytes) + convey.So(err.Code(), convey.ShouldEqual, statuscode.InstanceSessionInvalidErrCode) + }) + }) +} + +func TestGetResourceSpecification(t *testing.T) { + defaultFuncSpec := &types.FunctionSpecification{ + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 300, + Memory: 128, + EphemeralStorage: 500, + }, + } + + tests := []struct { + name string + resData []byte + labelData []byte + targetCPU int64 + targetMem int64 + targetEphemeralStorage int + targetLabel string + targetErr error + }{ + { + name: "normal", + resData: []byte("{\"CPU\": 500, \"Memory\": 512}"), + labelData: []byte("{\"X-Instance-Label\": \"aaaaa\"}"), + targetCPU: 500, + targetMem: 512, + targetEphemeralStorage: 500, + targetLabel: "aaaaa", + targetErr: nil, + }, + { + name: "no resData", + resData: []byte("{}"), + labelData: []byte("{\"X-Instance-Label\": \"aaaaa\"}"), + targetCPU: 300, + targetMem: 128, + targetEphemeralStorage: 500, + targetLabel: "aaaaa", + targetErr: nil, + }, + { + name: "unmarshal error", + resData: []byte("{\"CPU\": 500, \"Memory\": 512, \"test\": []}"), + labelData: []byte("{\"X-Instance-Label\": \"aaaaa\"}"), + targetCPU: 500, + targetMem: 512, + targetEphemeralStorage: 500, + targetLabel: "aaaaa", + targetErr: fmt.Errorf("expected int or string, but got []"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resSpec, err := getResourceSpecification(tt.resData, tt.labelData, defaultFuncSpec) + if err != nil { + assert.Equal(t, tt.targetErr.Error(), err.Error()) + return + } + assert.Equal(t, tt.targetCPU, resSpec.CPU) + assert.Equal(t, tt.targetMem, resSpec.Memory) + assert.Equal(t, tt.targetEphemeralStorage, resSpec.EphemeralStorage) + assert.Equal(t, tt.targetLabel, resSpec.InvokeLabel) + }) + } +} + +func TestFaaSScheduler_HandleRequestForward(t *testing.T) { + faasScheduler := &FaaSScheduler{ + allocRecord: sync.Map{}, + } + convey.Convey("Test HandleRequestForward", t, func() { + convey.Convey("IsGaryUpdating is false", func() { + rollout.GetGlobalRolloutHandler().IsGaryUpdating = false + _, _, shouldReply := faasScheduler.HandleRequestForward(InstanceOperation("acquire"), []api.Arg{}, "") + convey.So(shouldReply, convey.ShouldBeFalse) + }) + convey.Convey("acquire should forward,invoke failed and reply", func() { + args := []api.Arg{ + { + Type: 0, + Data: []byte("acquire#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + } + rollout.GetGlobalRolloutHandler().IsGaryUpdating = true + rollout.GetGlobalRolloutHandler().ForwardInstance = "instance" + rolloutRatio := &rollout.RolloutRatio{ + RolloutRatio: "100%", + } + ratio, _ := json.Marshal(rolloutRatio) + _ = rollout.GetGlobalRolloutHandler().ProcessRatioUpdate(ratio) + + defer ApplyFunc(rollout.InvokeByInstanceId, func(args []api.Arg, instanceID string, + traceID string) ([]byte, error) { + return nil, errors.New("invoke instance failed") + }).Reset() + _, err, shouldReply := faasScheduler.HandleRequestForward(InstanceOperation("acquire"), args, "") + convey.So(err.Error(), convey.ShouldContainSubstring, "invoke instance failed") + convey.So(shouldReply, convey.ShouldBeFalse) + }) + convey.Convey("acquire should forward,invoke success and reply", func() { + args := []api.Arg{ + { + Type: 0, + Data: []byte("acquire#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + } + rollout.GetGlobalRolloutHandler().IsGaryUpdating = true + rolloutRatio := &rollout.RolloutRatio{ + RolloutRatio: "100%", + } + ratio, _ := json.Marshal(rolloutRatio) + _ = rollout.GetGlobalRolloutHandler().ProcessRatioUpdate(ratio) + + defer ApplyFunc(rollout.InvokeByInstanceId, func(args []api.Arg, instanceID string, + traceID string) ([]byte, error) { + response := &commonTypes.InstanceResponse{} + data, _ := json.Marshal(response) + return data, nil + }).Reset() + _, err, shouldReply := faasScheduler.HandleRequestForward(InstanceOperation("acquire"), args, "") + convey.So(err, convey.ShouldBeNil) + convey.So(shouldReply, convey.ShouldBeTrue) + }) + convey.Convey("acquire should forward,invoke success but no instance available not reply", func() { + args := []api.Arg{ + { + Type: 0, + Data: []byte("acquire#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + } + rollout.GetGlobalRolloutHandler().IsGaryUpdating = true + rolloutRatio := &rollout.RolloutRatio{ + RolloutRatio: "100%", + } + ratio, _ := json.Marshal(rolloutRatio) + _ = rollout.GetGlobalRolloutHandler().ProcessRatioUpdate(ratio) + + defer ApplyFunc(rollout.InvokeByInstanceId, func(args []api.Arg, instanceID string, + traceID string) ([]byte, error) { + response := &commonTypes.InstanceResponse{ + ErrorMessage: "no instance available", + ErrorCode: statuscode.NoInstanceAvailableErrCode, + } + data, _ := json.Marshal(response) + return data, nil + }).Reset() + _, err, shouldReply := faasScheduler.HandleRequestForward(InstanceOperation("acquire"), args, "") + convey.So(err, convey.ShouldBeNil) + convey.So(shouldReply, convey.ShouldBeFalse) + }) + convey.Convey("retain should forward and not reply", func() { + args := []api.Arg{ + { + Type: 0, + Data: []byte("retain#aaa"), + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + { + Type: 0, + Data: nil, + ObjectID: "", + NestedObjectIDs: nil, + }, + } + defer ApplyFunc(rollout.InvokeByInstanceId, func(args []api.Arg, instanceID string, + traceID string) ([]byte, error) { + response := &commonTypes.InstanceResponse{} + data, _ := json.Marshal(response) + return data, nil + }).Reset() + rollout.GetGlobalRolloutHandler().IsGaryUpdating = true + _, _, shouldReply := faasScheduler.HandleRequestForward(InstanceOperation("retain"), args, "") + convey.So(shouldReply, convey.ShouldBeFalse) + }) + }) +} + +func TestFaaSScheduler_handleRollout(t *testing.T) { + faasScheduler := &FaaSScheduler{ + allocRecord: sync.Map{}, + } + defer func() { + selfregister.RegisterKey = "" + }() + convey.Convey("Test handleRollout", t, func() { + alloc1 := &types.InstanceAllocation{ + Instance: &types.Instance{ + FuncKey: "funcA", + }, + } + alloc2 := &types.InstanceAllocation{ + Instance: &types.Instance{ + FuncKey: "funcA", + }, + } + alloc3 := &types.InstanceAllocation{ + Instance: &types.Instance{ + FuncKey: "funcB", + }, + } + faasScheduler.allocRecord.Store("lease1", alloc1) + faasScheduler.allocRecord.Store("lease2", alloc2) + faasScheduler.allocRecord.Store("lease3", alloc3) + rsp := faasScheduler.handleRollout("instance1", "123") + convey.So(rsp.ErrorCode, convey.ShouldEqual, statuscode.InternalErrorCode) + config.GlobalConfig.EnableRollout = true + rsp = faasScheduler.handleRollout("instance1", "123") + convey.So(rsp.ErrorCode, convey.ShouldEqual, statuscode.InternalErrorCode) + config.GlobalConfig.SchedulerDiscovery = &types.SchedulerDiscovery{RegisterMode: types.RegisterTypeContend} + rsp = faasScheduler.handleRollout("instance1", "123") + convey.So(rsp.ErrorCode, convey.ShouldEqual, statuscode.InternalErrorCode) + selfregister.Registered = true + selfregister.RegisterKey = "testKey" + rsp = faasScheduler.handleRollout("instance1", "123") + convey.So(rsp.ErrorCode, convey.ShouldEqual, constant.InsReqSuccessCode) + convey.So(rsp.RegisterKey, convey.ShouldEqual, "testKey") + convey.So(len(rsp.AllocRecord["funcA"]), convey.ShouldEqual, 2) + convey.So(len(rsp.AllocRecord["funcB"]), convey.ShouldEqual, 1) + }) +} +func TestSyncAllocRecord(t *testing.T) { + tests := []struct { + name string + allocRecord map[string][]string + mockFuncSpec *types.FunctionSpecification + mockAcquireErr snerror.SNError + mockAllocResult *types.InstanceAllocation + expectedCount int + }{ + { + name: "Function does not exist", + allocRecord: map[string][]string{ + "nonexistent-func": {"instance1-allocID"}, + }, + mockFuncSpec: nil, + expectedCount: 0, + }, + { + name: "AcquireInstance returns error", + allocRecord: map[string][]string{ + "test-func": {"instance2-allocID"}, + }, + mockFuncSpec: &types.FunctionSpecification{}, + mockAcquireErr: snerror.New(1, "acquire error"), + expectedCount: 0, + }, + { + name: "Successfully acquire instance", + allocRecord: map[string][]string{ + "test-func": {"instance3-allocID"}, + }, + mockFuncSpec: &types.FunctionSpecification{}, + mockAllocResult: &types.InstanceAllocation{AllocationID: "alloc123"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fs := &FaaSScheduler{ + allocRecord: sync.Map{}, + PoolManager: &instancepool.PoolManager{}, + } + + patches := gomonkey.NewPatches() + defer patches.Reset() + defer ApplyMethod(reflect.TypeOf(registry.GlobalRegistry), "GetFuncSpec", + func(_ *registry.Registry, funcKey string) *types.FunctionSpecification { + return tt.mockFuncSpec + }).Reset() + + // Mock PoolManager.AcquireInstanceThread + patches.ApplyMethod( + reflect.TypeOf(fs.PoolManager), + "AcquireInstanceThread", + func(_ *instancepool.PoolManager, req *types.InstanceAcquireRequest) (*types.InstanceAllocation, snerror.SNError) { + return tt.mockAllocResult, tt.mockAcquireErr + }, + ) + + fs.syncAllocRecord(tt.allocRecord) + + actualCount := 0 + fs.allocRecord.Range(func(key, value interface{}) bool { + actualCount++ + return true + }) + assert.Equal(t, tt.expectedCount, actualCount, "allocRecord count mismatch") + }) + } +} + +func TestReacquireLease(t *testing.T) { + patches := gomonkey.NewPatches() + defer patches.Reset() + patches.ApplyFunc((*instancepool.PoolManager).AcquireInstanceThread, func(_ *instancepool.PoolManager, req *types.InstanceAcquireRequest) (*types.InstanceAllocation, snerror.SNError) { + return &types.InstanceAllocation{Instance: &types.Instance{ + FuncKey: "", + }}, nil + }) + patches.ApplyMethod(reflect.TypeOf(registry.GlobalRegistry), "GetFuncSpec", func(_ *registry.Registry, funcKey string) *types.FunctionSpecification { + return &types.FunctionSpecification{ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 0, + Memory: 0, + GpuMemory: 0, + EnableDynamicMemory: false, + CustomResources: "", + EnableTmpExpansion: false, + EphemeralStorage: 0, + CustomResourcesSpec: "", + }} + }) + + fs := &FaaSScheduler{ + allocRecord: sync.Map{}, + PoolManager: &instancepool.PoolManager{}, + } + resp := fs.handleInstanceBatchRetain("e58bd817-1132-4b5b-8000-00000000009c-thread5f0d3377-59", []byte("{\"e58bd817-1132-4b5b-8000-00000000009c-thread5f0d3377-59\":{\"avgProcTime\":307,\"functionKey\":\"12345678901234561234567890123456/0@functest@functest/latest\",\"isAbnormal\":false,\"maxProcTime\":323,\"procReqNum\":217,\"reacquireData\":[123,34,114,101,115,111,117,114,99,101,115,68,97,116,97,34,58,91,49,50,51,44,51,52,44,54,55,44,56,48,44,56,53,44,51,52,44,53,56,44,53,52,44,52,56,44,52,56,44,52,52,44,51,52,44,55,55,44,49,48,49,44,49,48,57,44,49,49,49,44,49,49,52,44,49,50,49,44,51,52,44,53,56,44,53,51,44,52,57,44,53,48,44,49,50,53,93,125]}}"), "aaaaa") + assert.Equal(t, len(resp.InstanceAllocSucceed), 1) +} diff --git a/yuanrong/pkg/functionscaler/healthcheck/healthcheck.go b/yuanrong/pkg/functionscaler/healthcheck/healthcheck.go new file mode 100644 index 0000000..bbfb6c1 --- /dev/null +++ b/yuanrong/pkg/functionscaler/healthcheck/healthcheck.go @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package healthcheck implement health check +package healthcheck + +import ( + "errors" + "net" + "net/http" + "os" + "time" + + "github.com/gin-gonic/gin" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" +) + +const ( + readTimeout = 180 * time.Second + writeTimeout = 180 * time.Second + healthCheckPort = "9994" +) + +// StartHealthCheck - +func StartHealthCheck(errChan chan<- error) error { + if !config.GlobalConfig.EnableHealthCheck { + return nil + } + if errChan == nil { + return errors.New("errChan is nil") + } + + router := createRouter() + + podIP := os.Getenv("POD_IP") + if net.ParseIP(podIP) == nil { + log.GetLogger().Errorf("failed to get pod ip, pod ip is %s", podIP) + return errors.New("failed to get pod ip") + } + addr := podIP + ":" + healthCheckPort + + httpServer := &http.Server{ + Handler: router, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + Addr: addr, + } + + go func() { + err := startServer(httpServer) + if err != nil { + log.GetLogger().Errorf("failed to startServer, err %s", err.Error()) + } + errChan <- err + }() + + return nil +} + +func startServer(httpServer *http.Server) error { + if config.GlobalConfig.HTTPSConfig == nil || !config.GlobalConfig.HTTPSConfig.HTTPSEnable { + err := httpServer.ListenAndServe() + if err != nil { + log.GetLogger().Errorf("failed to http ListenAndServe: %s", err.Error()) + } + return err + } + err := tls.InitTLSConfig(*config.GlobalConfig.HTTPSConfig) + if err != nil { + log.GetLogger().Errorf("failed to init the HTTPS config: %s", err.Error()) + return err + } + httpServer.TLSConfig = tls.GetClientTLSConfig() + err = httpServer.ListenAndServeTLS("", "") + if err != nil { + log.GetLogger().Errorf("failed to HTTPListenAndServeTLS: %s", err.Error()) + return err + } + return nil +} + +func createRouter() *gin.Engine { + gin.SetMode(gin.ReleaseMode) + router := gin.New() + router.Use(gin.Recovery()) + router.GET("/healthcheck", func(c *gin.Context) { + check(c.Writer, c.Request) + }) + return router +} + +func check(w http.ResponseWriter, r *http.Request) { + discoveryConfig := config.GlobalConfig.SchedulerDiscovery + if discoveryConfig != nil && discoveryConfig.RegisterMode == types.RegisterTypeContend { + if !selfregister.Registered { + log.GetLogger().Warnf("health check now, scheduler is not registered") + if config.GlobalConfig.EnableRollout && selfregister.IsRolloutObject { + log.GetLogger().Infof("health check now, scheduler is the rollout object") + w.WriteHeader(http.StatusOK) + } else { + log.GetLogger().Errorf("health check now, scheduler is not rollout object") + w.WriteHeader(http.StatusInternalServerError) + } + return + } + } + w.WriteHeader(http.StatusOK) + return +} diff --git a/yuanrong/pkg/functionscaler/healthcheck/healthcheck_test.go b/yuanrong/pkg/functionscaler/healthcheck/healthcheck_test.go new file mode 100644 index 0000000..04b6a25 --- /dev/null +++ b/yuanrong/pkg/functionscaler/healthcheck/healthcheck_test.go @@ -0,0 +1,158 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package healthcheck + +import ( + "errors" + "net/http" + "net/http/httptest" + "os" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" +) + +func TestStartHealthCheck(t *testing.T) { + config.GlobalConfig.EnableHealthCheck = true + convey.Convey("TestStartHealthCheck", t, func() { + os.Setenv("POD_IP", "127.0.0.1") + config.InitModuleConfig() + config.GlobalConfig.HTTPSConfig = &tls.InternalHTTPSConfig{ + HTTPSEnable: false, + } + errChan := make(chan error, 2) + err := StartHealthCheck(errChan) + convey.So(err, convey.ShouldBeNil) + time.Sleep(1 * time.Second) + }) +} + +func TestStartServer_WithHTTPS(t *testing.T) { + config.GlobalConfig.EnableHealthCheck = true + os.Setenv("POD_IP", "127.0.0.1") + config.InitModuleConfig() + config.GlobalConfig.HTTPSConfig = &tls.InternalHTTPSConfig{ + HTTPSEnable: true, + } + httpServer := &http.Server{ + Addr: ":443", + } + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(tls.InitTLSConfig, func(httpsConfig tls.InternalHTTPSConfig) error { + return nil + }) + + mockedTLSConfig := &tls.InternalHTTPSConfig{} + patches.ApplyFunc(tls.GetClientTLSConfig, func() *tls.InternalHTTPSConfig { + return mockedTLSConfig + }) + + mockedError := errors.New("mocked ListenAndServeTLS error") + patches.ApplyMethod(reflect.TypeOf(httpServer), "ListenAndServeTLS", + func(_ *http.Server, certFile, keyFile string) error { + return mockedError + }) + + err := startServer(httpServer) + + assert.Equal(t, mockedError, err) +} + +func TestCheck(t *testing.T) { + req, err := http.NewRequest("GET", "/check", nil) + if err != nil { + t.Fatalf("err %v", err) + } + convey.Convey("Test check", t, func() { + convey.Convey("normal case", func() { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(check) + handler.ServeHTTP(rr, req) + expectedBody := "" + convey.So(rr.Code, convey.ShouldEqual, http.StatusOK) + convey.So(rr.Body.String(), convey.ShouldEqual, expectedBody) + }) + convey.Convey("rollout case", func() { + discoveryConfig := config.GlobalConfig.SchedulerDiscovery + rolloutConfig := config.GlobalConfig.EnableRollout + isRolloutObject := selfregister.IsRolloutObject + defer func() { + config.GlobalConfig.SchedulerDiscovery = discoveryConfig + config.GlobalConfig.EnableRollout = rolloutConfig + selfregister.IsRolloutObject = isRolloutObject + }() + config.GlobalConfig.SchedulerDiscovery = &types.SchedulerDiscovery{ + RegisterMode: types.RegisterTypeContend, + } + handler := http.HandlerFunc(check) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + convey.So(rr.Code, convey.ShouldEqual, http.StatusInternalServerError) + config.GlobalConfig.EnableRollout = true + selfregister.IsRolloutObject = true + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + convey.So(rr.Code, convey.ShouldEqual, http.StatusOK) + }) + }) + +} + +func TestStartHealthCheckFail(t *testing.T) { + convey.Convey("TestStartHealthCheckFail", t, func() { + config.InitModuleConfig() + config.GlobalConfig.HTTPSConfig = &tls.InternalHTTPSConfig{ + HTTPSEnable: false, + } + + err := StartHealthCheck(nil) + convey.So(err, convey.ShouldNotBeNil) + + os.Clearenv() + errChan := make(chan error, 1) + err = StartHealthCheck(errChan) + convey.So(err, convey.ShouldNotBeNil) + + os.Setenv("POD_IP", "123") + err = StartHealthCheck(errChan) + convey.So(err, convey.ShouldNotBeNil) + + os.Setenv("POD_IP", "1.1.1.1") + err = StartHealthCheck(errChan) + convey.So(err, convey.ShouldBeNil) + err = <-errChan + convey.So(err, convey.ShouldNotBeNil) + + config.GlobalConfig.HTTPSConfig.HTTPSEnable = true + err = StartHealthCheck(errChan) + convey.So(err, convey.ShouldBeNil) + err = <-errChan + convey.So(err, convey.ShouldNotBeNil) + }) +} diff --git a/yuanrong/pkg/functionscaler/httpserver/httpserver.go b/yuanrong/pkg/functionscaler/httpserver/httpserver.go new file mode 100644 index 0000000..dfbd4fd --- /dev/null +++ b/yuanrong/pkg/functionscaler/httpserver/httpserver.go @@ -0,0 +1,177 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package httpserver - +package httpserver + +import ( + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "os" + "time" + + "github.com/valyala/fasthttp" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/logger/log" + commonTls "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/functionscaler" + "yuanrong/pkg/functionscaler/config" +) + +const ( + defaultReadBufferSize = 1 * 1024 + defaultMaxRequestBodySize = 1 * 1024 * 1024 + defaultServerTimeout = 900 * time.Second + invokePath = "/invoke" +) + +// StartHTTPServer - +func StartHTTPServer(errChan chan<- error) (*fasthttp.Server, error) { + fastServer := &fasthttp.Server{ + Handler: route, + TLSConfig: getTLSConfig(), + ReadBufferSize: defaultReadBufferSize, + ReadTimeout: defaultServerTimeout, + WriteTimeout: defaultServerTimeout, + MaxRequestBodySize: defaultMaxRequestBodySize, + } + if config.GlobalConfig.HTTPSConfig.HTTPSEnable { + if err := commonTls.InitTLSConfig(*config.GlobalConfig.HTTPSConfig); err != nil { + return nil, fmt.Errorf("init HTTPS config error: %s", err.Error()) + } + } + go func() { + err := startServer(fastServer) + if err != nil { + log.GetLogger().Errorf("failed to start http server, err %s", err.Error()) + } + errChan <- err + }() + return fastServer, nil +} +func getTLSConfig() *tls.Config { + if !config.GlobalConfig.HTTPSConfig.HTTPSEnable { + return nil + } + tlsConfig := commonTls.GetClientTLSConfig() + if tlsConfig != nil { + tlsConfig.NextProtos = []string{"http/1.1"} + } + return tlsConfig +} + +func startServer(httpServer *fasthttp.Server) error { + podIP := os.Getenv("POD_IP") + if net.ParseIP(podIP) == nil { + log.GetLogger().Errorf("failed to get pod ip, pod ip is %s", podIP) + return errors.New("failed to get pod ip") + } + serverAddr := fmt.Sprintf("%s:%s", podIP, config.GlobalConfig.ModuleConfig.ServicePort) + if config.GlobalConfig.HTTPSConfig.HTTPSEnable { + log.GetLogger().Infof("start to listen the https request on addr: %s", serverAddr) + if err := fastHTTPListenAndServeTLS(serverAddr, httpServer); err != nil { + log.GetLogger().Errorf("failed to start the HTTPS server: %s", err.Error()) + return err + } + return nil + } + log.GetLogger().Infof("start to listen the http request on addr: %s", serverAddr) + err := httpServer.ListenAndServe(serverAddr) + if err != nil { + log.GetLogger().Errorf("failed to start the HTTP server: %s", err.Error()) + return err + } + return nil +} + +func fastHTTPListenAndServeTLS(addr string, server *fasthttp.Server) error { + listener, err := net.Listen("tcp4", addr) + if err != nil { + return err + } + if server == nil || server.TLSConfig == nil { + return errors.New("server or tls config is nil") + } + tlsListener := tls.NewListener(listener, server.TLSConfig) + if err = server.Serve(tlsListener); err != nil { + return err + } + return nil +} + +func route(ctx *fasthttp.RequestCtx) { + err := auth(ctx) + if err != nil { + ctx.SetStatusCode(http.StatusUnauthorized) + log.GetLogger().Errorf("failed to check auth, error: %s", err.Error()) + return + } + path := string(ctx.Path()) + switch path { + case invokePath: + invokeHandler(ctx) + default: + ctx.SetStatusCode(http.StatusInternalServerError) + log.GetLogger().Errorf("unsupported http request path %s", path) + } + return +} + +func auth(ctx *fasthttp.RequestCtx) error { + if !config.GlobalConfig.AuthenticationEnable { + return nil + } + sign := string(ctx.Request.Header.Peek(constant.HeaderAuthorization)) + timestamp := string(ctx.Request.Header.Peek(constant.HeaderAuthTimestamp)) + err := localauth.AuthCheckLocally(config.GlobalConfig.LocalAuth.AKey, config.GlobalConfig.LocalAuth.SKey, sign, + timestamp, config.GlobalConfig.LocalAuth.Duration) + if err != nil { + return err + } + return nil +} + +func invokeHandler(ctx *fasthttp.RequestCtx) { + traceID := string(ctx.Request.Header.Peek(constant.HeaderTraceID)) + reqBody := ctx.Request.Body() + var args []api.Arg + err := json.Unmarshal(reqBody, &args) + if err != nil { + ctx.SetStatusCode(http.StatusInternalServerError) + log.GetLogger().Errorf("unmarshl request body error, err %s", err.Error()) + return + } + if functionscaler.GetGlobalScheduler() == nil { + ctx.SetStatusCode(http.StatusInternalServerError) + log.GetLogger().Errorf("scheduler is nil") + return + } + respBody, err := functionscaler.GetGlobalScheduler().ProcessInstanceRequestLibruntime(args, traceID) + if err != nil { + ctx.SetStatusCode(http.StatusInternalServerError) + log.GetLogger().Errorf("marshl response body, err %s", err.Error()) + return + } + ctx.SetStatusCode(http.StatusOK) + ctx.Response.SetBody(respBody) +} diff --git a/yuanrong/pkg/functionscaler/httpserver/httpserver_test.go b/yuanrong/pkg/functionscaler/httpserver/httpserver_test.go new file mode 100644 index 0000000..d897a7f --- /dev/null +++ b/yuanrong/pkg/functionscaler/httpserver/httpserver_test.go @@ -0,0 +1,272 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package httpserver + +import ( + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "os" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/localauth" + commtls "yuanrong/pkg/common/faas_common/tls" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/types" +) + +type mockListener struct{} + +func (m *mockListener) Accept() (net.Conn, error) { + return nil, fmt.Errorf("failed to accept") +} + +func (m *mockListener) Close() error { + return nil +} +func (m *mockListener) Addr() net.Addr { + return nil +} + +func TestStartHTTPServer(t *testing.T) { + rawConfig := config.GlobalConfig + defer func() { + config.GlobalConfig = rawConfig + }() + convey.Convey("TestStartHTTPServer", t, func() { + os.Setenv("POD_IP", "127.0.0.1") + config.GlobalConfig = types.Configuration{ + HTTPSConfig: &commtls.InternalHTTPSConfig{ + HTTPSEnable: false}, + ModuleConfig: &types.ModuleConfig{ServicePort: "8889"}, + } + errChan := make(chan error, 1) + _, err := StartHTTPServer(errChan) + convey.So(err, convey.ShouldBeNil) + time.Sleep(1 * time.Second) + }) + + convey.Convey("TestStartHTTPSServer", t, func() { + os.Setenv("POD_IP", "127.0.0.1") + config.GlobalConfig = types.Configuration{ + HTTPSConfig: &commtls.InternalHTTPSConfig{ + HTTPSEnable: true}, + ModuleConfig: &types.ModuleConfig{ServicePort: "8899"}, + } + defer gomonkey.ApplyFunc(commtls.InitTLSConfig, func(config commtls.InternalHTTPSConfig) error { + return nil + }).Reset() + defer gomonkey.ApplyFunc(commtls.GetClientTLSConfig, func() *tls.Config { + return &tls.Config{} + }).Reset() + errChan := make(chan error, 1) + _, err := StartHTTPServer(errChan) + convey.So(err, convey.ShouldBeNil) + time.Sleep(1 * time.Second) + }) + +} + +func TestStartServer_InvalidPodIP(t *testing.T) { + originalPodIP := os.Getenv("POD_IP") + defer os.Setenv("POD_IP", originalPodIP) + + os.Setenv("POD_IP", "invalid_ip") + + httpServer := &fasthttp.Server{} + + patches := gomonkey.NewPatches() + defer patches.Reset() + + err := startServer(httpServer) + + assert.NotNil(t, err) + assert.Equal(t, "failed to get pod ip", err.Error()) +} + +func TestStartServer_FastHTTPListenAndServeTLS_Error(t *testing.T) { + os.Setenv("POD_IP", "127.0.0.1") + defer os.Unsetenv("POD_IP") + + httpServer := &fasthttp.Server{} + config.GlobalConfig = types.Configuration{ + HTTPSConfig: &commtls.InternalHTTPSConfig{ + HTTPSEnable: true}, + ModuleConfig: &types.ModuleConfig{ServicePort: "8080"}, + } + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(net.Listen, func(network, address string) (net.Listener, error) { + return nil, fmt.Errorf("err") + }) + + err := startServer(httpServer) + + assert.NotNil(t, err) +} + +func TestStartServer_ListenAndServe_Error(t *testing.T) { + os.Setenv("POD_IP", "127.0.0.1") + defer os.Unsetenv("POD_IP") + + httpServer := &fasthttp.Server{} + + config.GlobalConfig = types.Configuration{ + HTTPSConfig: &commtls.InternalHTTPSConfig{ + HTTPSEnable: false}, + ModuleConfig: &types.ModuleConfig{ServicePort: "8080"}, + } + + mockError := errors.New("mocked ListenAndServe error") + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(httpServer), "ListenAndServe", func(_ *fasthttp.Server, addr string) error { + return mockError + }) + + err := startServer(httpServer) + + assert.NotNil(t, err) + assert.Equal(t, mockError, err) +} + +func TestRout(t *testing.T) { + rawConfig := config.GlobalConfig + defer func() { + config.GlobalConfig = rawConfig + }() + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + gomonkey.ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + gomonkey.ApplyFunc(etcd3.GetCAEMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + convey.Convey("TestRout", t, func() { + convey.Convey("auth failed", func() { + config.GlobalConfig = types.Configuration{ + AuthenticationEnable: true, + } + defer gomonkey.ApplyFunc(localauth.AuthCheckLocally, func(ak string, sk string, + requestSign string, timestamp string, duration int) error { + return errors.New("auth failed") + }).Reset() + ctx := &fasthttp.RequestCtx{} + route(ctx) + convey.So(ctx.Response.StatusCode(), convey.ShouldEqual, http.StatusUnauthorized) + }) + convey.Convey("path error", func() { + config.GlobalConfig = types.Configuration{} + ctx := &fasthttp.RequestCtx{} + ctx.Request.URI().SetPath("/acquire") + route(ctx) + convey.So(ctx.Response.StatusCode(), convey.ShouldEqual, http.StatusInternalServerError) + }) + convey.Convey("invoke unmarshal body error ", func() { + config.GlobalConfig = types.Configuration{} + ctx := &fasthttp.RequestCtx{} + ctx.Request.URI().SetPath(invokePath) + ctx.Request.SetBody([]byte("aaa")) + route(ctx) + convey.So(ctx.Response.StatusCode(), convey.ShouldEqual, http.StatusInternalServerError) + }) + convey.Convey("invoke scheduler is nil ", func() { + config.GlobalConfig = types.Configuration{} + ctx := &fasthttp.RequestCtx{} + ctx.Request.URI().SetPath(invokePath) + args := []api.Arg{{Type: 1, Data: []byte("aaa")}} + body, _ := json.Marshal(args) + ctx.Request.SetBody(body) + route(ctx) + convey.So(ctx.Response.StatusCode(), convey.ShouldEqual, http.StatusInternalServerError) + }) + convey.Convey("invoke ProcessInstanceRequestLibruntime success ", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&functionscaler.FaaSScheduler{}), + "ProcessInstanceRequestLibruntime", func(_ *functionscaler.FaaSScheduler, + args []api.Arg, traceID string) ([]byte, error) { + return json.Marshal(&commonTypes.InstanceResponse{}) + }).Reset() + defer gomonkey.ApplyFunc((*etcd3.EtcdWatcher).StartList, func(ew *etcd3.EtcdWatcher) { + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.SYNCED, + Key: "", + Value: nil, + PrevValue: nil, + Rev: 0, + ETCDType: "", + } + }).Reset() + defer gomonkey.ApplyFunc((*registry.FaasSchedulerRegistry).WaitForETCDList, func() {}).Reset() + config.GlobalConfig = types.Configuration{} + ctx := &fasthttp.RequestCtx{} + ctx.Request.URI().SetPath(invokePath) + args := []api.Arg{{Type: 1, Data: []byte("aaa")}} + stopCh := make(chan struct{}) + registry.InitRegistry(stopCh) + functionscaler.InitGlobalScheduler(stopCh) + body, _ := json.Marshal(args) + ctx.Request.SetBody(body) + route(ctx) + convey.So(ctx.Response.StatusCode(), convey.ShouldEqual, http.StatusOK) + }) + }) +} + +func TestFastHTTPListenAndServeTLS(t *testing.T) { + server := fasthttp.Server{ + TLSConfig: &tls.Config{}, + } + convey.Convey("FastHTTPListenAndServeTLS failed", t, func() { + patch := gomonkey.ApplyFunc(net.Listen, func(network, address string) (net.Listener, error) { + return &mockListener{}, errors.New("listen fail") + }) + defer patch.Reset() + err := fastHTTPListenAndServeTLS("123", &server) + convey.So(err.Error(), convey.ShouldEqual, "listen fail") + }) + convey.Convey("FastHTTPListenAndServeTLS server is nil", t, func() { + patch := gomonkey.ApplyFunc(net.Listen, func(network, address string) (net.Listener, error) { + return &mockListener{}, nil + }) + defer patch.Reset() + err := fastHTTPListenAndServeTLS("123", nil) + convey.So(err.Error(), convey.ShouldEqual, "server or tls config is nil") + }) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/componentset.go b/yuanrong/pkg/functionscaler/instancepool/componentset.go new file mode 100644 index 0000000..b19ad83 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/componentset.go @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import "k8s.io/api/core/v1" + +type container string + +const ( + containerDelegate container = "delegate-container" + containerRuntimeManager container = "runtime-manager" + containerFunctionAgent container = "function-agent" + containerFunctionAgentInit container = "function-agent-init" +) + +type volumeBuilder struct { + volumes []v1.Volume + mounts map[container][]v1.VolumeMount +} + +func (vc *volumeBuilder) addVolume(volume v1.Volume) { + vc.volumes = append(vc.volumes, volume) +} + +func (vc *volumeBuilder) addVolumeMount(name container, mount v1.VolumeMount) { + vc.mounts[name] = append(vc.mounts[name], mount) +} + +func newVolumeBuilder() *volumeBuilder { + return &volumeBuilder{ + mounts: make(map[container][]v1.VolumeMount), + } +} + +type envBuilder struct { + envs map[container][]v1.EnvVar +} + +func (b *envBuilder) addEnvVar(name container, env v1.EnvVar) { + b.envs[name] = append(b.envs[name], env) +} + +func newEnvBuilder() *envBuilder { + return &envBuilder{ + envs: make(map[container][]v1.EnvVar), + } +} diff --git a/yuanrong/pkg/functionscaler/instancepool/container_tool.go b/yuanrong/pkg/functionscaler/instancepool/container_tool.go new file mode 100644 index 0000000..f52b230 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/container_tool.go @@ -0,0 +1,77 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "encoding/json" + + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + // DefaultInstanceLabel is empty + DefaultInstanceLabel = "" + // DefaultDataVolumeName default dataVolume name + DefaultDataVolumeName = "data-volume" + + // RaspDefaultInitialDelaySeconds - + RaspDefaultInitialDelaySeconds = 10 + // RaspDefaultTimeoutSeconds - + RaspDefaultTimeoutSeconds = 5 + // RaspDefaultPeriodSeconds - + RaspDefaultPeriodSeconds = 20 + // RaspDefaultSuccessThreshold - + RaspDefaultSuccessThreshold = 1 + // RaspDefaultFailureThreshold - + RaspDefaultFailureThreshold = 3 + // RaspDefaultCPU - + RaspDefaultCPU = 300 + // RaspDefaultMemory - + RaspDefaultMemory = 500 + // RaspInitDefaultCPU - + RaspInitDefaultCPU = 100 + // RaspInitDefaultMemory - + RaspInitDefaultMemory = 100 +) + +func sideCarAdd(funcSpec *types.FunctionSpecification) ([]byte, error) { + var sideCars []types.DelegateContainerSideCarConfig + + if utils.IsNeedRaspSideCar(funcSpec) { + sideCars = append(sideCars, makeRaspContainer(funcSpec)) + } + + configData, err := json.Marshal(sideCars) + if err != nil { + return nil, err + } + return configData, nil +} + +func initContainerAdd(funcSpec *types.FunctionSpecification) ([]byte, error) { + var initContainers []types.DelegateInitContainerConfig + if utils.IsNeedRaspSideCar(funcSpec) { + initContainers = []types.DelegateInitContainerConfig{makeRaspInitContainer(funcSpec)} + } + configData, err := json.Marshal(initContainers) + if err != nil { + return nil, err + } + return configData, nil +} diff --git a/yuanrong/pkg/functionscaler/instancepool/instance_operation_adapter.go b/yuanrong/pkg/functionscaler/instancepool/instance_operation_adapter.go new file mode 100644 index 0000000..b2ac493 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/instance_operation_adapter.go @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package instancepool + +import ( + "strings" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/signalmanager" + "yuanrong/pkg/functionscaler/types" +) + +// CreateInstance - +func CreateInstance(request createInstanceRequest) (*types.Instance, error) { + if config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeFG { + return createInstanceForFG(request) + } + return createInstanceForKernel(request) +} + +// DeleteInstance - +func DeleteInstance(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo, + instance *types.Instance) error { + if config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeFG { + return deleteInstanceForFG(funcSpec, faasManagerInfo, instance) + } + return deleteInstanceForKernel(funcSpec, faasManagerInfo, instance) +} + +// DeleteInstanceByID - +func DeleteInstanceByID(instanceID, funcKey string) error { + if config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeFG { + return deleteInstanceByIDForFG(instanceID, funcKey) + } + return deleteInstanceByIDForKernel(instanceID, funcKey) +} + +// DeleteUnexpectInstance - +func DeleteUnexpectInstance(parentID, instanceID, funcKey string, logger api.FormatLogger) { + if parentID != selfregister.SelfInstanceID && + selfregister.GlobalSchedulerProxy.Contains(parentID) || + parentID == constant.WorkerManagerApplier || + strings.HasPrefix(parentID, constant.FunctionTaskApplier) || + parentID == constant.ASBResApplier || + parentID == constant.StaticInstanceApplier { + return + } + logger.Warnf("instance is belong to this scheduler, but not found function meta, start to delete instance.") + err := DeleteInstanceByID(instanceID, funcKey) + if err != nil { + logger.Errorf("failed to delete instance, err: %v", err) + } +} + +// SignalInstance - +func SignalInstance(instance *types.Instance, signal int) { + if config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeFG { + signalInstanceForFG(instance, signal) + return + } + signalmanager.GetSignalManager().SignalInstance(instance, signal) +} + +func buildInstance(instanceID string, request createInstanceRequest) *types.Instance { + return &types.Instance{ + InstanceType: request.instanceType, + ResKey: request.resKey, + InstanceID: instanceID, + InstanceName: request.instanceName, + ParentID: selfregister.SelfInstanceID, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + Msg: "", + }, + FuncKey: request.funcSpec.FuncKey, + FuncSig: request.funcSpec.FuncMetaSignature, + ConcurrentNum: request.funcSpec.InstanceMetaData.ConcurrentNum, + } +} diff --git a/yuanrong/pkg/functionscaler/instancepool/instance_operation_adapter_test.go b/yuanrong/pkg/functionscaler/instancepool/instance_operation_adapter_test.go new file mode 100644 index 0000000..cd0fd3e --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/instance_operation_adapter_test.go @@ -0,0 +1,203 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package instancepool + +import ( + "fmt" + + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/wait" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + mockUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/workermanager" +) + +func TestCreateInstance(t *testing.T) { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(createInstanceForFG, func(request createInstanceRequest) (*types.Instance, error) { + return &types.Instance{InstanceID: "instance-fg"}, nil + }), + gomonkey.ApplyFunc(createInstanceForKernel, func(request createInstanceRequest) (*types.Instance, error) { + return &types.Instance{InstanceID: "instance-kernel"}, nil + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + config.GlobalConfig.InstanceOperationBackend = constant.BackendTypeFG + ins1, err := CreateInstance(createInstanceRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance-fg", ins1.InstanceID) + config.GlobalConfig.InstanceOperationBackend = constant.BackendTypeKernel + ins2, err := CreateInstance(createInstanceRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance-kernel", ins2.InstanceID) +} + +func TestDeleteInstance(t *testing.T) { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(deleteInstanceForFG, func(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo, + instance *types.Instance) error { + return nil + }), + gomonkey.ApplyFunc(deleteInstanceForKernel, func(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo, + instance *types.Instance) error { + return nil + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + config.GlobalConfig.InstanceOperationBackend = constant.BackendTypeFG + err := DeleteInstance(&types.FunctionSpecification{}, faasManagerInfo{}, &types.Instance{}) + assert.Nil(t, err) + config.GlobalConfig.InstanceOperationBackend = constant.BackendTypeKernel + err = DeleteInstance(&types.FunctionSpecification{}, faasManagerInfo{}, &types.Instance{}) + assert.Nil(t, err) +} + +func TestDeleteInstanceRetry(t *testing.T) { + cnt := 0 + + SetGlobalSdkClient(&mockUtils.FakeLibruntimeSdkClient{}) + defer gomonkey.ApplyMethod(globalSdkClient, "Kill", func(_ api.LibruntimeAPI, instanceID string, signal int, payload []byte) error { + cnt++ + return fmt.Errorf("error kill") + }).Reset() + config.GlobalConfig.InstanceOperationBackend = constant.BackendTypeKernel + coldStartBackoff = wait.Backoff{ + Duration: 10 * time.Millisecond, + Factor: 1, + Jitter: 0.1, + Steps: 2, + Cap: 15 * time.Millisecond, + } + err := DeleteInstance(&types.FunctionSpecification{}, faasManagerInfo{}, &types.Instance{}) + time.Sleep(50 * time.Millisecond) + assert.Equal(t, "error kill", err.Error()) + assert.Equal(t, 3, cnt) +} + +func TestDeleteUnexpectInstance(t *testing.T) { + deleteInstanceCalled := false + var deleteInstanceInstanceID string + var deleteInstanceFuncKey string + + defer gomonkey.ApplyFunc(DeleteInstanceByID, func(instanceID, funcKey string) error { + deleteInstanceCalled = true + deleteInstanceInstanceID = instanceID + deleteInstanceFuncKey = funcKey + return nil + }).Reset() + + originalInstanceIDSelf := selfregister.SelfInstanceID + selfregister.SelfInstanceID = "self-id" + defer func() { + selfregister.SelfInstanceID = originalInstanceIDSelf + }() + + parentID := "parent-id" + instanceID := "test-instance-id" + funcKey := "test-func-key" + + parentIDContains := false + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(&selfregister.SchedulerProxy{}), "Contains", + func(_ interface{}, id string) bool { + return parentIDContains + }) + logger := log.GetLogger() + parentIDContains = true + deleteInstanceCalled = false + DeleteUnexpectInstance(parentID, instanceID, funcKey, logger) + + assert.False(t, deleteInstanceCalled) + + parentIDContains = false + deleteInstanceCalled = false + DeleteUnexpectInstance(parentID, instanceID, funcKey, logger) + + assert.True(t, deleteInstanceCalled) + assert.Equal(t, instanceID, deleteInstanceInstanceID) + assert.Equal(t, funcKey, deleteInstanceFuncKey) + + parentID = selfregister.SelfInstanceID + deleteInstanceCalled = false + DeleteUnexpectInstance(parentID, instanceID, funcKey, logger) + + assert.True(t, deleteInstanceCalled) + assert.Equal(t, instanceID, deleteInstanceInstanceID) + assert.Equal(t, funcKey, deleteInstanceFuncKey) + + parentID = constant.WorkerManagerApplier + deleteInstanceCalled = false + DeleteUnexpectInstance(parentID, instanceID, funcKey, logger) + + assert.False(t, deleteInstanceCalled) + + parentID = constant.FunctionTaskApplier + "123" + deleteInstanceCalled = false + DeleteUnexpectInstance(parentID, instanceID, funcKey, logger) + + assert.False(t, deleteInstanceCalled) + + parentID = constant.ASBResApplier + deleteInstanceCalled = false + DeleteUnexpectInstance(parentID, instanceID, funcKey, logger) + + assert.False(t, deleteInstanceCalled) +} + +func TestDeleteInstanceByID(t *testing.T) { + scaleDownNum := 0 + defer gomonkey.ApplyFunc(workermanager.ScaleDownInstance, func(instanceID, functionKey, traceID string) error { + scaleDownNum++ + return nil + }).Reset() + SetGlobalSdkClient(&mockUtils.FakeLibruntimeSdkClient{}) + config.GlobalConfig.InstanceOperationBackend = constant.BackendTypeFG + err := DeleteInstanceByID("testInsID", "testFuncKey") + assert.Nil(t, err) + + config.GlobalConfig.InstanceOperationBackend = constant.BackendTypeKernel + err = DeleteInstanceByID("testInsID", "testFuncKey") + assert.Nil(t, err) + SetGlobalSdkClient(nil) +} + +func TestSignalInstance(t *testing.T) { + config.GlobalConfig.InstanceOperationBackend = constant.BackendTypeFG + SignalInstance(&types.Instance{}, 0) + assert.Equal(t, 0, 0) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/instance_operation_fg.go b/yuanrong/pkg/functionscaler/instancepool/instance_operation_fg.go new file mode 100644 index 0000000..49e643d --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/instance_operation_fg.go @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package instancepool + +import ( + "time" + + "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/workermanager" +) + +func createInstanceForFG(request createInstanceRequest) (*types.Instance, error) { + createInstanceTraceID := uuid.New().String() + logger := log.GetLogger().With(zap.Any("funcKey", request.funcSpec.FuncKey), + zap.Any("traceID", createInstanceTraceID)) + createBeginTime := time.Now() + ScaleOutParam := &workermanager.ScaleUpParam{ + TraceID: createInstanceTraceID, + FunctionKey: request.funcSpec.FuncKey, + Timeout: request.createTimeout, + CPU: int(request.resKey.CPU), + Memory: int(request.resKey.Memory), + } + wmInstance, err := workermanager.ScaleUpInstance(ScaleOutParam) + if err != nil { + createErr := err + logger.Errorf("createErr is %v , instance %s, cost: %s", createErr, wmInstance, time.Since(createBeginTime)) + return nil, createErr + } + instance := buildInstance(wmInstance.InstanceID, request) + instance.InstanceIP = wmInstance.IP + instance.InstancePort = wmInstance.Port + instance.NodeIP = wmInstance.OwnerIP + instance.ResKey.CPU = wmInstance.Resource.Runtime.CPULimit + instance.ResKey.Memory = wmInstance.Resource.Runtime.MemoryLimit + instance.NodePort = constant.BusProxyHTTPPort + logger.Infof("succeed to create instance %s, cost: %s, instance: %+v", wmInstance.InstanceID, + time.Since(createBeginTime), instance) + return instance, nil +} + +func deleteInstanceForFG(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo, + instance *types.Instance) error { + log.GetLogger().Debugf("start to delete instance %s for function %s", + instance.InstanceID, funcSpec.FuncKey) + err := workermanager.ScaleDownInstance(instance.InstanceID, funcSpec.FuncKey, "") + if err != nil { + log.GetLogger().Errorf("failed to delete instance %s for function %s,err: %s ", + instance.InstanceID, funcSpec.FuncKey, err.Error()) + return err + } + log.GetLogger().Infof("succeed to delete instance %s for function %s", + instance.InstanceID, funcSpec.FuncKey) + return nil +} + +func deleteInstanceByIDForFG(instanceID, funcKey string) error { + log.GetLogger().Debugf("start to delete instance %s for function %s", + instanceID, funcKey) + err := workermanager.ScaleDownInstance(instanceID, funcKey, "") + if err != nil { + log.GetLogger().Errorf("failed to delete instance %s for function %s,err: %s ", + instanceID, funcKey, err.Error()) + return err + } + log.GetLogger().Infof("succeed to delete instance %s for function %s", + instanceID, funcKey) + return nil +} + +func signalInstanceForFG(instance *types.Instance, signal int) {} diff --git a/yuanrong/pkg/functionscaler/instancepool/instance_operation_fg_test.go b/yuanrong/pkg/functionscaler/instancepool/instance_operation_fg_test.go new file mode 100644 index 0000000..50bbe3c --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/instance_operation_fg_test.go @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package instancepool + +import ( + "context" + "errors" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/resspeckey" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/workermanager" +) + +func TestCreateInstanceForFG(t *testing.T) { + funcSpec := &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + FuncMetaData: commonTypes.FuncMetaData{ + Handler: "myHandler", + }, + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 100, + ConcurrentNum: 100, + }, + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Handler: "myInitializer", + }, + VpcConfig: &commonTypes.VpcConfig{}, + }, + } + resKey := resspeckey.ConvertToResSpecKey(resspeckey.ConvertResourceMetaDataToResSpec(funcSpec.ResourceMetaData)) + instanceBuilder := func(instanceID string) *types.Instance { + return &types.Instance{ + InstanceID: instanceID, + ResKey: resKey, + } + } + request := createInstanceRequest{ + funcSpec: funcSpec, + resKey: resKey, + instanceBuilder: instanceBuilder, + } + convey.Convey("Test createInstanceForFG", t, func() { + convey.Convey("create success", func() { + defer gomonkey.ApplyFunc(workermanager.ScaleUpInstance, + func(scaleUpParam *workermanager.ScaleUpParam) (*types.WmInstance, error) { + return &types.WmInstance{}, nil + }).Reset() + instance, err := createInstanceForFG(request) + convey.So(instance, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("create failed", func() { + defer gomonkey.ApplyFunc(workermanager.ScaleUpInstance, + func(scaleUpParam *workermanager.ScaleUpParam) (*types.WmInstance, error) { + return nil, errors.New("create failed") + }).Reset() + instance, err := createInstanceForFG(request) + convey.So(instance, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestDeleteInstanceForFG(t *testing.T) { + convey.Convey("Test deleteInstanceForFG", t, func() { + convey.Convey("delete success", func() { + defer gomonkey.ApplyFunc(workermanager.ScaleDownInstance, + func(instanceID, functionKey, traceID string) error { + return nil + }).Reset() + err := deleteInstanceForFG(&types.FunctionSpecification{}, faasManagerInfo{}, &types.Instance{}) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("delete failed", func() { + defer gomonkey.ApplyFunc(workermanager.ScaleDownInstance, + func(instanceID, functionKey, traceID string) error { + return errors.New("delete failed") + }).Reset() + err := deleteInstanceForFG(&types.FunctionSpecification{}, faasManagerInfo{}, &types.Instance{}) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestDeleteInstanceByIDForFG(t *testing.T) { + convey.Convey("Test deleteInstanceByIDForFG", t, func() { + convey.Convey("delete by id success", func() { + defer gomonkey.ApplyFunc(workermanager.ScaleDownInstance, + func(instanceID, functionKey, traceID string) error { + return nil + }).Reset() + err := deleteInstanceByIDForFG("instance", "testFunc") + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("delete by id failed", func() { + defer gomonkey.ApplyFunc(workermanager.ScaleDownInstance, + func(instanceID, functionKey, traceID string) error { + return errors.New("delete failed") + }).Reset() + err := deleteInstanceByIDForFG("instance", "testFunc") + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/instance_operation_kernel.go b/yuanrong/pkg/functionscaler/instancepool/instance_operation_kernel.go new file mode 100644 index 0000000..fe91c83 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/instance_operation_kernel.go @@ -0,0 +1,1691 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package instancepool + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "reflect" + "strconv" + "strings" + "time" + + "go.uber.org/zap" + "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/wait" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/common/faas_common/sts/raw" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/urnutils" + commonUtils "yuanrong/pkg/common/faas_common/utils" + wisecloudTypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/signalmanager" + "yuanrong/pkg/functionscaler/tenantquota" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +// 1 + 2 + 4 + 8 + 16 + 32 + 64 + 128 + 256 + 300 + 300... = 3811秒 +const ( + retryDuration = 1 * time.Second // 初始等待时间 + retryFactor = 2 // 倍数因子(每次翻4倍) + retryJitter = 0.1 // 随机抖动系数 + retryTime = 20 + retryCap = 300 * time.Second // 最大等待时间上限 +) + +var ( + globalSdkClient api.LibruntimeAPI + coldStartBackoff = wait.Backoff{ + Duration: retryDuration, + Factor: retryFactor, + Jitter: retryJitter, + Steps: retryTime, + Cap: retryCap, + } +) + +// SetGlobalSdkClient - +func SetGlobalSdkClient(sdkClient api.LibruntimeAPI) { + globalSdkClient = sdkClient +} + +func dealWithVPCError(originErr error) error { + if errCode := statuscode.VpcCode(originErr.Error()); errCode != 0 { + return snerror.New(errCode, statuscode.VpcErMsg(errCode)) + } + return originErr +} + +func createInstanceForKernel(request createInstanceRequest) (instance *types.Instance, createErr error) { + logger := log.GetLogger().With(zap.Any("funcKey", request.funcSpec.FuncKey)) + logger.Infof("start to create instance") + createBeginTime := time.Now() + resSpec := request.resKey.ToResSpec() + createOpt, args, err := generateOptionAndArgsForCreate(request, resSpec) + if err != nil { + createErr = err + return + } + if request.funcSpec.ExtendedMetaData.VpcConfig != nil { + vpcNatConfig, vpcErr := createPATService(request.funcSpec, request.faasManagerInfo, + request.funcSpec.ExtendedMetaData, request.funcSpec.ExtendedMetaData.VpcConfig) + if vpcErr != nil { + createErr = dealWithVPCError(vpcErr) + return + } + setCreateOptionForVPC(createOpt, vpcNatConfig) + defer func() { + if createErr == nil { + go reportInstanceWithVPC(request.funcSpec, request.faasManagerInfo, instance, vpcNatConfig) + } + }() + } + defer func() { + if createErr != nil && config.GlobalConfig.TenantInsNumLimitEnable { + tenantID := urnutils.GetTenantFromFuncKey(instance.FuncKey) + tenantquota.DecreaseTenantInstance(tenantID, instance.InstanceType == types.InstanceTypeReserved) + } + }() + + var instanceID string + schedulingOptions := prepareSchedulingOptions(request.funcSpec, resSpec) + funcMeta := api.FunctionMeta{FuncID: getExecutorFuncKey(request.funcSpec), + Api: commonUtils.GetAPIType(request.funcSpec.FuncMetaData.BusinessType)} + invokeOpts := createInvokeOptions(request.funcSpec, schedulingOptions, createOpt, request.poolLabel) + logger.Debugf("invoke opts cpu is %v, mem is %v\n", invokeOpts.Cpu, invokeOpts.Memory) + delete(invokeOpts.CustomResources, resourcesCPU) + delete(invokeOpts.CustomResources, resourcesMemory) + instanceID, createErr = globalSdkClient.CreateInstance(funcMeta, args, invokeOpts) + if createErr != nil { + logger.Errorf("failed to create instance, error info is %#v", createErr) + createErr = generateSnErrorFromKernelError(createErr) + return + } + instance = buildInstance(instanceID, request) + logger.Infof("succeed to create instance %s, cost: %s, instance: %+v", instanceID, time.Since(createBeginTime), + instance) + return +} + +func deleteInstanceForKernel(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo, + instance *types.Instance) error { + log.GetLogger().Debugf("start to delete instance %s for function %s", instance.InstanceID, funcSpec.FuncKey) + // maybe we should wait for delete response + var err error + err = globalSdkClient.Kill(instance.InstanceID, killSignalVal, []byte{}) + if funcSpec.ExtendedMetaData.VpcConfig != nil { + go deleteInstanceWithVPC(funcSpec, faasManagerInfo, instance) + } + if err != nil { + log.GetLogger().Errorf("failed to delete instance %s for function %s first, start retry", + instance.InstanceID, funcSpec.FuncKey) + go deleteInstanceForKernelRetry(instance.InstanceID) + return err + } + if config.GlobalConfig.TenantInsNumLimitEnable { + tenantID := urnutils.GetTenantFromFuncKey(instance.FuncKey) + tenantquota.DecreaseTenantInstance(tenantID, instance.InstanceType == types.InstanceTypeReserved) + } + log.GetLogger().Infof("succeed to delete instance %s for function %s", instance.InstanceID, funcSpec.FuncKey) + return nil +} + +func deleteInstanceForKernelRetry(instanceID string) { + var err error + backoffErr := wait.ExponentialBackoff( + coldStartBackoff, func() (bool, error) { + err = globalSdkClient.Kill(instanceID, killSignalVal, []byte{}) + if err != nil { + log.GetLogger().Errorf("failed to delete instance %s, retry", instanceID) + return false, nil + } + return true, nil + }) + if backoffErr != nil { + log.GetLogger().Errorf("[MAJOR] failed to delete instance %s, err is %s", backoffErr.Error()) + } + if err != nil { + log.GetLogger().Errorf("[MAJOR] failed to delete instance %s, err is %s", err.Error()) + } +} + +func deleteInstanceByIDForKernel(instanceID, funcKey string) error { + log.GetLogger().Debugf("start to delete instance %s for function %s", + instanceID, funcKey) + if err := globalSdkClient.Kill(instanceID, killSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to delete instance %s, err: %v, retry", instanceID, err) + go deleteInstanceForKernelRetry(instanceID) + } + return nil +} + +func getExecutorFuncKey(funcSpec *types.FunctionSpecification) string { + if funcSpec.FuncMetaData.BusinessType == constant.BusinessTypeServe { + return serveExecutor + } + switch { + case strings.Contains(funcSpec.FuncMetaData.Runtime, "python3.6"): + return fmt.Sprintf(executorFormat, "Python3.6") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "python3.7"): + return fmt.Sprintf(executorFormat, "Python3.7") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "python3.8"): + return fmt.Sprintf(executorFormat, "Python3.8") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "python3.9"): + return fmt.Sprintf(executorFormat, "Python3.9") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "python3.10"): + return fmt.Sprintf(executorFormat, "Python3.10") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "python3.11"): + return fmt.Sprintf(executorFormat, "Python3.11") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "go"), + strings.Contains(funcSpec.FuncMetaData.Runtime, "http"), + strings.Contains(funcSpec.FuncMetaData.Runtime, "custom image"): + return fmt.Sprintf(executorFormat, "Go1.x") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "java8"): + return fmt.Sprintf(executorFormat, "Java8") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "java11"): + return fmt.Sprintf(executorFormat, "Java11") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "java17"): + return fmt.Sprintf(executorFormat, "Java17") + case strings.Contains(funcSpec.FuncMetaData.Runtime, "java21"): + return fmt.Sprintf(executorFormat, "Java21") + case strings.Contains(funcSpec.FuncMetaData.Runtime, constant.PosixCustomRuntimeType): + return fmt.Sprintf(executorFormat, "PosixCustom") + default: + return funcSpec.FuncKey + } +} + +func generateOptionAndArgsForCreate(request createInstanceRequest, resSpec *resspeckey.ResourceSpecification) ( + map[string]string, []api.Arg, error) { + createOpt, err := prepareCreateOptions(request, resSpec) + if err != nil || createOpt == nil { + return nil, nil, err + } + err = errors.New("failed to create instance, failed to generate option and argument") + args := prepareCreateArguments(request) + if args == nil { + return nil, nil, err + } + return createOpt, args, nil +} + +func prepareSchedulingOptions(funcSpec *types.FunctionSpecification, + resSpec *resspeckey.ResourceSpecification) *types.SchedulingOptions { + schedulingOptions := &types.SchedulingOptions{} + schedulingOptions.Resources = generateResources(resSpec) + if config.GlobalConfig.DeployMode == constant.DeployModeProcesses { + schedulingOptions.Extension = commonUtils.CreateCustomExtensions(schedulingOptions.Extension, + commonUtils.SharedPolicyValue) + } else { + schedulingOptions.Extension = commonUtils.CreateCustomExtensions(schedulingOptions.Extension, + commonUtils.MonopolyPolicyValue) + } + if len(funcSpec.InstanceMetaData.PoolID) != 0 { + schedulingOptions.Extension[constant.AffinityPoolIDKey] = funcSpec.InstanceMetaData.PoolID + } + utils.AddNodeSelector(config.GlobalConfig.NodeSelector, schedulingOptions, resSpec) + if strings.Contains(funcSpec.FuncMetaData.Runtime, types.CustomContainerRuntimeType) { + setEphemeralStorage(defaultEphemeralStorage, config.GlobalConfig.EphemeralStorage, schedulingOptions.Resources) + if npu, _ := utils.GetNpuTypeAndInstanceTypeFromStr(funcSpec.ResourceMetaData.CustomResources, + funcSpec.ResourceMetaData.CustomResourcesSpec); npu != "" && + config.GlobalConfig.NpuEphemeralStorage != 0 { + setEphemeralStorage(defaultEphemeralStorage, config.GlobalConfig.NpuEphemeralStorage, + schedulingOptions.Resources) + } + utils.AddAffinityCPU(strings.TrimSuffix(funcSpec.FuncSecretName, "-sts"), + schedulingOptions, resSpec, api.PreferredAntiAffinity) + } + if config.GlobalConfig.Scenario == types.ScenarioWiseCloud { + labelValue := funcSpec.FuncKey + if resSpec.InvokeLabel != DefaultInstanceLabel { + labelValue = fmt.Sprintf("%s/%s", funcSpec.FuncKey, resSpec.InvokeLabel) + } + agentInfoAffinity := api.Affinity{ + Kind: api.AffinityKindResource, + Affinity: api.PreferredAffinity, + PreferredPriority: true, + PreferredAntiOtherLabels: true, + LabelOps: []api.LabelOperator{ + { + Type: api.LabelOpIn, + LabelKey: "poolname", + LabelValues: []string{labelValue}, + }, { + Type: api.LabelOpIn, + LabelKey: "revisionId", + LabelValues: []string{funcSpec.FuncMetaData.RevisionID}, + }, + }, + } + schedulingOptions.Affinity = append(schedulingOptions.Affinity, agentInfoAffinity) + } + log.GetLogger().Infof("generate scheduling options %v "+ + "for function %s", schedulingOptions, funcSpec.FuncKey) + return schedulingOptions +} + +func createInvokeOptions(funcSpec *types.FunctionSpecification, schedulingOptions *types.SchedulingOptions, + createOpt map[string]string, poolLabel string) api.InvokeOptions { + codeEntrys := []string{funcSpec.ExtendedMetaData.Initializer.Handler, funcSpec.FuncMetaData.Handler} + if funcSpec.ExtendedMetaData.PreStop.Handler != "" { + codeEntrys = append(codeEntrys, funcSpec.ExtendedMetaData.PreStop.Handler) + } + invokeOpts := api.InvokeOptions{ + Cpu: int(schedulingOptions.Resources[resourcesCPU]), + Memory: int(schedulingOptions.Resources[resourcesMemory]), + CustomResources: schedulingOptions.Resources, + CustomExtensions: schedulingOptions.Extension, + CreateOpt: createOpt, + Priority: int(schedulingOptions.Priority), + ScheduleAffinities: generateScheduleAffinity(schedulingOptions.Affinity, poolLabel), + Labels: []string{strings.TrimSuffix(funcSpec.FuncSecretName, "-sts"), "faas"}, + CodePaths: codeEntrys, + Timeout: int(utils.GetCreateTimeout(funcSpec).Seconds()), + ScheduleTimeoutMs: constant.KernelScheduleTimeout * time.Second.Milliseconds(), + } + return invokeOpts +} + +func generateScheduleAffinity(scheduleAffinity []api.Affinity, label string) []api.Affinity { + if label == "" { + return scheduleAffinity + } + labels := strings.Split(label, ",") + for _, poolLabel := range labels { + if strings.TrimSpace(poolLabel) == constant.UnUseAntiOtherLabelsKey { + continue + } + affinity := api.Affinity{ + Kind: api.AffinityKindResource, + Affinity: api.PreferredAffinity, + PreferredPriority: true, + PreferredAntiOtherLabels: !strings.Contains(label, constant.UnUseAntiOtherLabelsKey), + LabelOps: []api.LabelOperator{ + { + Type: api.LabelOpExists, + LabelKey: strings.TrimSpace(poolLabel), + LabelValues: nil, + }, + }, + } + scheduleAffinity = append(scheduleAffinity, affinity) + } + return scheduleAffinity +} + +func createPATService(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo, + extMetaData commonTypes.ExtendedMetaData, vpcConfig *commonTypes.VpcConfig) (*types.NATConfigure, error) { + createErr := errors.New("failed to create pat service") + faasManagerFuncKey := faasManagerInfo.funcKey + faasManagerInstanceID := faasManagerInfo.instanceID + if len(faasManagerFuncKey) == 0 || len(faasManagerInstanceID) == 0 { + log.GetLogger().Errorf("no faas manager instance info") + return nil, createErr + } + createCh := make(chan error, 1) + args := prepareCreatePATServiceArguments(extMetaData, vpcConfig) + var responseData []byte + traceID := utils.GenerateTraceID() + var invokeErr error + funcMeta := api.FunctionMeta{FuncID: faasManagerFuncKey, Api: api.FaaSApi} + invokeOpts := api.InvokeOptions{TraceID: traceID} + objID, errorInfo := globalSdkClient.InvokeByInstanceId(funcMeta, faasManagerInstanceID, args, invokeOpts) + invokeErr = errorInfo + if invokeErr == nil { + globalSdkClient.GetAsync(objID, func(result []byte, err error) { + createCh <- err + }) + } + if invokeErr != nil { + log.GetLogger().Errorf("failed to send create request of PATService %s, traceID: %s, function: %s, error: %s", + vpcConfig.VpcID, traceID, funcSpec.FuncKey, invokeErr.Error()) + return nil, createErr + } + timer := time.NewTimer(faasManagerRequestTimeout) + select { + case resultErr, ok := <-createCh: + if !ok { + log.GetLogger().Errorf("result channel of PATService request is closed, traceID: %s", traceID) + return nil, createErr + } + timer.Stop() + if resultErr != nil { + log.GetLogger().Errorf("failed to create PATService %s for function %s, traceID: %s, error: %s", + vpcConfig.VpcID, funcSpec.FuncKey, traceID, resultErr.Error()) + return nil, resultErr + } + case <-timer.C: + log.GetLogger().Errorf("time out waiting for PATService creation of function %s, traceID: %s", + funcSpec.FuncKey, traceID) + return nil, createErr + } + response := &patSvcCreateResponse{} + err := json.Unmarshal(responseData, response) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal PATService create response, traceID: %s, error: %s", + traceID, err.Error()) + return nil, createErr + } + natConfig := &types.NATConfigure{} + err = json.Unmarshal([]byte(response.Message), natConfig) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal network config from response, traceID: %s, error: %s", + traceID, err.Error()) + return nil, createErr + } + return natConfig, nil +} + +func reportInstanceWithVPC(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo, + instance *types.Instance, natConfig *types.NATConfigure) { + faasManagerFuncKey := faasManagerInfo.funcKey + faasManagerInstanceID := faasManagerInfo.instanceID + if len(faasManagerFuncKey) == 0 || len(faasManagerInstanceID) == 0 { + log.GetLogger().Errorf("no faas manager instance info") + return + } + args := prepareReportInstanceArguments(instance, natConfig) + reportCh := make(chan error, 1) + traceID := utils.GenerateTraceID() + var reportErr error + funcMeta := api.FunctionMeta{FuncID: faasManagerFuncKey, Api: api.PosixApi} + invokeOpts := api.InvokeOptions{TraceID: traceID} + objID, errorInfo := globalSdkClient.InvokeByInstanceId(funcMeta, faasManagerInstanceID, args, invokeOpts) + reportErr = errorInfo + if reportErr == nil { + globalSdkClient.GetAsync(objID, func(result []byte, err error) { + reportCh <- err + }) + } + if reportErr != nil { + log.GetLogger().Errorf("failed to send create request of PATService %s, traceID: %s, function: %s, error: %s", + natConfig.PatPodName, traceID, funcSpec.FuncKey, reportErr.Error()) + return + } + timer := time.NewTimer(faasManagerRequestTimeout) + select { + case resultErr, ok := <-reportCh: + if !ok { + log.GetLogger().Errorf("result channel of PATService request is closed, traceID: %s", traceID) + return + } + timer.Stop() + if resultErr != nil { + log.GetLogger().Errorf("failed to create PATService %s for function %s, traceID: %s, error: %s", + natConfig.PatPodName, funcSpec.FuncKey, traceID, resultErr.Error()) + return + } + case <-timer.C: + log.GetLogger().Errorf("time out waiting for PATService creation of function %s, traceID: %s", + funcSpec.FuncKey, traceID) + return + } +} + +func deleteInstanceWithVPC(funcSpec *types.FunctionSpecification, + faasManagerInfo faasManagerInfo, instance *types.Instance) { + faasManagerFuncKey := faasManagerInfo.funcKey + faasManagerInstanceID := faasManagerInfo.instanceID + if len(faasManagerFuncKey) == 0 || len(faasManagerInstanceID) == 0 { + log.GetLogger().Errorf("no faas manager instance info") + return + } + args := prepareDeleteInstanceArguments(instance) + deleteCh := make(chan error, 1) + traceID := utils.GenerateTraceID() + var deleteErr error + funcMeta := api.FunctionMeta{FuncID: faasManagerFuncKey, Api: api.PosixApi} + invokeOpts := api.InvokeOptions{TraceID: traceID} + objID, errorInfo := globalSdkClient.InvokeByInstanceId(funcMeta, faasManagerInstanceID, args, invokeOpts) + deleteErr = errorInfo + if deleteErr == nil { + globalSdkClient.GetAsync(objID, func(result []byte, err error) { + deleteCh <- err + }) + } + if deleteErr != nil { + log.GetLogger().Errorf("failed to send create request of PATService, traceID: %s, function: %s, error: %s", + traceID, funcSpec.FuncKey, deleteErr.Error()) + return + } + timer := time.NewTimer(faasManagerRequestTimeout) + select { + case resultErr, ok := <-deleteCh: + if !ok { + log.GetLogger().Errorf("result channel of PATService request is closed, traceID: %s", traceID) + return + } + timer.Stop() + if resultErr != nil { + log.GetLogger().Errorf("failed to create PATService for function %s, traceID: %s, error %s", + funcSpec.FuncKey, traceID, resultErr.Error()) + return + } + case <-timer.C: + log.GetLogger().Errorf("time out waiting for PATService creation of function %s, traceID: %s", + funcSpec.FuncKey, traceID) + return + } +} + +func generateResources(resSpec *resspeckey.ResourceSpecification) map[string]float64 { + resourcesMap := make(map[string]float64) + if resSpec == nil { + return resourcesMap + } + resourcesMap[resourcesCPU] = float64(resSpec.CPU) + resourcesMap[resourcesMemory] = float64(resSpec.Memory) + if resSpec.CustomResources != nil { + for key, value := range resSpec.CustomResources { + resourcesMap[key] = float64(value) + } + } + return resourcesMap +} + +func setCreateOptionForDelegateBootstrap(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if createOpt == nil { + return errors.New("createOpt is nil") + } + if funcSpec.FuncMetaData.Runtime == constant.PosixCustomRuntimeType { + createOpt[constant.DelegateBootstrapKey] = funcSpec.FuncMetaData.Handler + } + return nil +} + +func setCreateOptionForInvokeTimeout(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if createOpt == nil { + return errors.New("createOpt is nil") + } + if funcSpec.FuncMetaData.Timeout > 0 { + createOpt[constant.FaasInvokeTimeout] = strconv.FormatInt(funcSpec.FuncMetaData.Timeout, base) + } + return nil +} + +// CreateOption contains params for runtime not for user code +func prepareCreateOptions(request createInstanceRequest, resSpec *resspeckey.ResourceSpecification) (map[string]string, + error) { + createOpt := make(map[string]string, constant.DefaultMapSize) + setCreateOptionForFuncSpec(request.funcSpec, createOpt) + setFunctions := []func(*types.FunctionSpecification, map[string]string) error{setCreateOptionForDownloadData, + setCreateOptionForDelegateMount, setCreateOptionForUserAgencyAndEnv, setCreateOptionForDelegateContainer, + setCreateOptionForFileBeat, setCreateOptionForHostAliases, setCreateOptionForRASP, setCustomPodSeccompProfile, + setFunctionAgentInitContainer, setCreateOptionForInitContainerEnv, setCreateOptionForLifeCycleDetached, + setCreateOptionForDelegateBootstrap, setCreateOptionForPostStartExec, setCreateOptionForInvokeTimeout} + for _, f := range setFunctions { + if err := f(request.funcSpec, createOpt); err != nil { + return nil, err + } + } + if err := setCreateOptionForNuwaRuntimeInfo(request.nuwaRuntimeInfo, createOpt); err != nil { + return nil, err + } + if err := setCreateOptionForName(request.instanceName, request.callerPodName, createOpt); err != nil { + return nil, err + } + if err := setCreateOptionForLabel(request.instanceType, request.funcSpec, resSpec, createOpt); err != nil { + return nil, err + } + if err := setCreateOptionForNote(request.instanceType, request.funcSpec, resSpec, createOpt); err != nil { + return nil, err + } + if err := setCreateOptionForAscendNPU(request.funcSpec, resSpec, createOpt); err != nil { + return nil, err + } + return createOpt, nil +} + +func setCreateOptionForPostStartExec(funcSpec *types.FunctionSpecification, + createOpt map[string]string) error { + if createOpt == nil { + return errors.New("createOpt is nil") + } + if funcSpec.FuncMetaData.BusinessType == constant.BusinessTypeServe && + len(funcSpec.ExtendedMetaData.ServeDeploySchema.Applications) != 0 && + len(funcSpec.ExtendedMetaData.ServeDeploySchema.Applications[constant.ApplicationIndex]. + RuntimeEnv.Pip) != 0 { + installCommand := fmt.Sprintf("%s %s && %s", constant.PipInstallPrefix, + strings.Join(funcSpec.ExtendedMetaData.ServeDeploySchema. + Applications[constant.ApplicationIndex].RuntimeEnv.Pip, " "), constant.PipCheckSuffix) + createOpt[constant.PostStartExec] = installCommand + } + return nil +} + +func setCreateOptionForDownloadData(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if createOpt == nil { + return errors.New("createOpt is nil") + } + if funcSpec.FuncMetaData.BusinessType == constant.BusinessTypeServe && + len(funcSpec.ExtendedMetaData.ServeDeploySchema.Applications) != 0 { + codeMetaData := commonTypes.CodeMetaData{ + LocalMetaData: commonTypes.LocalMetaData{ + StorageType: constant.WorkingDirType, + CodePath: funcSpec.ExtendedMetaData.ServeDeploySchema. + Applications[constant.ApplicationIndex].RuntimeEnv.WorkingDir, + }, + } + delegateDownloadData, err := json.Marshal(codeMetaData) + if err != nil { + log.GetLogger().Errorf("failed to marshal delegate download data error %s", err.Error()) + return err + } + createOpt[constant.DelegateDownloadKey] = string(delegateDownloadData) + return nil + } + if !reflect.DeepEqual(funcSpec.S3MetaData, commonTypes.S3MetaData{}) { + funcSpec.CodeMetaData = s3MetaDataConvert2CodeMetaData(funcSpec.S3MetaData) + } + if funcSpec.CodeMetaData.Sha512 == "" && funcSpec.FuncMetaData.CodeSha512 != "" { + funcSpec.CodeMetaData.Sha512 = funcSpec.FuncMetaData.CodeSha512 + } + if funcSpec.FuncMetaData.Runtime != constant.CustomContainerRuntimeType { + delegateDownloadData, err := json.Marshal(funcSpec.CodeMetaData) + if err != nil { + log.GetLogger().Errorf("failed to marshal delegate download data error %s", err.Error()) + return err + } + createOpt[constant.DelegateDownloadKey] = string(delegateDownloadData) + } + if len(funcSpec.FuncMetaData.Layers) != 0 { + delegateLayerDownloadData, err := json.Marshal(funcSpec.FuncMetaData.Layers) + if err != nil { + log.GetLogger().Errorf("failed to marshal delegate layer download data error %s", err.Error()) + return err + } + createOpt[constant.DelegateLayerDownloadKey] = string(delegateLayerDownloadData) + log.GetLogger().Infof("generate delegate download config %s for "+ + "function %s", string(delegateLayerDownloadData), funcSpec.FuncKey) + } + return nil +} + +func s3MetaDataConvert2CodeMetaData(s3MetaData commonTypes.S3MetaData) commonTypes.CodeMetaData { + return commonTypes.CodeMetaData{ + Sha512: "", + LocalMetaData: commonTypes.LocalMetaData{StorageType: "s3"}, + S3MetaData: s3MetaData, + } +} + +func setCreateOptionForLifeCycleDetached(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if createOpt == nil { + return nil + } + createOpt[constant.InstanceLifeCycle] = constant.InstanceLifeCycleDetached + return nil +} + +func setCreateOptionForDelegateMount(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if createOpt == nil { + return errors.New("createOpt is nil") + } + if funcSpec.ExtendedMetaData.FuncMountConfig != nil && + len(funcSpec.ExtendedMetaData.FuncMountConfig.FuncMounts) != 0 { + bytesData, err := json.Marshal(funcSpec.ExtendedMetaData.FuncMountConfig) + if err != nil { + log.GetLogger().Errorf("failed to marshal func mount config error: %s", err.Error()) + return err + } + createOpt[constant.DelegateMountKey] = string(bytesData) + log.GetLogger().Infof("generate delegate mount config %s for function %s", string(bytesData), funcSpec.FuncKey) + } + return nil +} + +func setCreateOptionForUserAgencyAndEnv(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if createOpt == nil { + return errors.New("createOpt is nil") + } + if funcSpec.FuncMetaData.BusinessType == constant.BusinessTypeServe && + len(funcSpec.ExtendedMetaData.ServeDeploySchema.Applications) != 0 && + len(funcSpec.ExtendedMetaData.ServeDeploySchema.Applications[constant.ApplicationIndex]. + RuntimeEnv.EnvVars) != 0 { + envVar, err := json.Marshal(funcSpec.ExtendedMetaData.ServeDeploySchema. + Applications[constant.ApplicationIndex].RuntimeEnv.EnvVars) + if err != nil { + log.GetLogger().Errorf("failed to marshal env var for %s", funcSpec.FuncKey) + return err + } + createOpt[constant.DelegateEnvVar] = string(envVar) + return nil + } + userAgency := funcSpec.ExtendedMetaData.UserAgency + encryptMap := map[string]string{"secretKey": userAgency.SecretKey, + "accessKey": userAgency.AccessKey, "authToken": userAgency.Token, + "securityAk": userAgency.SecurityAk, "securitySk": userAgency.SecuritySk, + "securityToken": userAgency.SecurityToken, + "environment": funcSpec.EnvMetaData.Environment, + "encrypted_user_data": funcSpec.EnvMetaData.EncryptedUserData, + "envKey": funcSpec.EnvMetaData.EnvKey, + "cryptoAlgorithm": funcSpec.EnvMetaData.CryptoAlgorithm, + } + encryptData, err := json.Marshal(encryptMap) + if err != nil { + log.GetLogger().Errorf("encryptData json marshal failed, err:%s", err.Error()) + return err + } + createOpt[constant.DelegateEncryptKey] = string(encryptData) + log.GetLogger().Infof("generate delegate encrypt config %s for function %s", string(encryptData), funcSpec.FuncKey) + return nil +} + +func setCreateOptionForDelegateContainer(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if createOpt == nil { + return errors.New("createOpt is nil") + } + if reflect.DeepEqual(funcSpec.ExtendedMetaData.CustomContainerConfig, commonTypes.CustomContainerConfig{}) { + return nil + } + delegateContainerConfig := types.DelegateContainerConfig{ + Image: funcSpec.ExtendedMetaData.CustomContainerConfig.Image, + Command: funcSpec.ExtendedMetaData.CustomContainerConfig.Command, + Args: funcSpec.ExtendedMetaData.CustomContainerConfig.Args, + UID: funcSpec.ExtendedMetaData.CustomContainerConfig.UID, + GID: funcSpec.ExtendedMetaData.CustomContainerConfig.GID, + CustomGracefulShutdown: funcSpec.ExtendedMetaData.CustomGracefulShutdown, + Env: initCustomContainerEnv(funcSpec), + Lifecycle: v1.Lifecycle{PreStop: &v1.LifecycleHandler{ + Exec: &v1.ExecAction{Command: []string{"/bin/sh", "-c", + "while [ $(netstat -plnut | grep tcp | grep 21005 | wc -l | xargs) -ne 0 ];" + + "do echo 'worker still alive, sleep 1';sleep 1; done; echo 'worker exit' && exit 0"}}, + }}, + } + vb := newVolumeBuilder() + ve := newEnvBuilder() + setVolumeForDelegateContainer(funcSpec, vb) + setEnvForDelegateContainer(funcSpec, ve) + delegateContainerConfig.VolumeMounts = vb.mounts[containerDelegate] + delegateContainerConfig.Env = append(delegateContainerConfig.Env, ve.envs[containerDelegate]...) + delegateRuntimeManager, err := json.Marshal(map[string]interface{}{ + "env": ve.envs[containerDelegate], + }) + if err != nil { + log.GetLogger().Errorf("failed to marshal runtime manager config, error %s", err.Error()) + return err + } + createOpt[constant.DelegateRuntimeManagerTag] = string(delegateRuntimeManager) + configData, err := json.Marshal(delegateContainerConfig) + if err != nil { + log.GetLogger().Errorf("failed to marshal delegate container config, error %s", err.Error()) + return err + } + createOpt[constant.DelegateContainerKey] = string(configData) + if funcSpec.ExtendedMetaData.CustomGracefulShutdown.MaxShutdownTimeout > 0 { + createOpt[types.GracefulShutdownTime] = + strconv.Itoa(funcSpec.ExtendedMetaData.CustomGracefulShutdown.MaxShutdownTimeout) + } else { + createOpt[types.GracefulShutdownTime] = strconv.Itoa(types.MaxShutdownTimeout) + } + + log.GetLogger().Infof("generate delegate container config %s for function %s", string(configData), funcSpec.FuncKey) + setVolumeForAgentAndRuntime(funcSpec, vb) + volumeMountData, err := json.Marshal(vb.mounts[containerRuntimeManager]) + if err != nil { + log.GetLogger().Errorf("failed to marshal runtime manager volume mount, error:%s", err.Error()) + return err + } + createOpt[constant.DelegateVolumeMountKey] = string(volumeMountData) + agentVolumeMountData, err := json.Marshal(vb.mounts[containerFunctionAgent]) + if err != nil { + log.GetLogger().Errorf("failed to marshal function agent volume mount, error:%s", err.Error()) + return err + } + createOpt[constant.DelegateAgentVolumeMountKey] = string(agentVolumeMountData) + volumesData, err := json.Marshal(vb.volumes) + if err != nil { + log.GetLogger().Errorf("failed to marshal runtime manager volume, error:%s", err.Error()) + return err + } + createOpt[constant.DelegateVolumesKey] = string(volumesData) + log.GetLogger().Infof("generate delegate volume: %s volume mount: %s", string(volumesData), + string(volumeMountData)) + return nil +} + +func setVolumeForAgentAndRuntime(funcSpec *types.FunctionSpecification, vb *volumeBuilder) { + (&cgroupMemory{}).configVolume(vb) + (&dockerSocket{}).configVolume(vb) + (&dockerRootDir{}).configVolume(vb) + if config.GlobalConfig.Scenario == types.ScenarioWiseCloud { + if funcSpec.StsMetaData.EnableSts { + log.GetLogger().Infof("start to build sts delegate for function %s", funcSpec.FuncKey) + podRequest := types.PodRequest{ + FunSvcID: funcSpec.FuncSecretName, + NameSpace: constant.DefaultNameSpace, + } + (&stsSecret{ + param: funcSpec, + req: podRequest, + }).configVolume(vb) + } + (&faasAgentSts{ + crName: urnutils.CrNameByURN(funcSpec.FuncMetaData.FunctionVersionURN), + }).configVolume(vb) + } +} + +func setCreateOptionForHostAliases(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if createOpt == nil || len(config.GlobalConfig.HostAliases) == 0 { + return nil + } + hostAliases := generateCustomHostAliases(config.GlobalConfig.HostAliases) + hostAliasesData, err := json.Marshal(hostAliases) + if err != nil { + log.GetLogger().Errorf("failed to marshal host aliases data, error:%s", err.Error()) + return err + } + createOpt[constant.DelegateHostAliases] = string(hostAliasesData) + log.GetLogger().Infof("generate delegate host aliases: %+v "+ + "for function: %s", string(hostAliasesData), funcSpec.FuncKey) + return nil +} + +func setCreateOptionForRASP(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + return nil + } + if createOpt == nil { + return errors.New("createOpt is nil") + } + add, createErr := initContainerAdd(funcSpec) + if createErr != nil { + log.GetLogger().Errorf(fmt.Sprintf("create init container error, %s", createErr.Error())) + } + createOpt[constant.DelegateInitContainers] = string(add) + return nil +} + +func setCreateOptionForFileBeat(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + return nil + } + if createOpt == nil { + return errors.New("createOpt is nil") + } + add, createErr := sideCarAdd(funcSpec) + if createErr != nil { + log.GetLogger().Errorf(fmt.Sprintf("create sideCar error, %s", createErr.Error())) + return createErr + } + createOpt[constant.DelegateContainerSideCars] = string(add) + return nil +} + +func setCustomPodSeccompProfile(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + return nil + } + if createOpt == nil { + return errors.New("createOpt is nil") + } + if funcSpec.FuncMetaData.Runtime != types.CustomContainerRuntimeType { + return nil + } + seccompProfile := &v1.SeccompProfile{Type: v1.SeccompProfileTypeRuntimeDefault} + seccompProfileData, err := json.Marshal(seccompProfile) + if err != nil { + log.GetLogger().Errorf("failed to marshal seccompProfile data, error:%s", err.Error()) + return err + } + createOpt[constant.DelegatePodSeccompProfile] = string(seccompProfileData) + log.GetLogger().Infof("generate delegate seccompProfile: %+v "+ + "for function: %s", string(seccompProfileData), funcSpec.FuncKey) + return nil +} + +func setFunctionAgentInitContainer(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + return nil + } + if createOpt == nil { + return errors.New("createOpt is nil") + } + if funcSpec.FuncMetaData.Runtime != types.CustomContainerRuntimeType { + return nil + } + vb := newVolumeBuilder() + vb.addVolumeMount(containerFunctionAgentInit, v1.VolumeMount{ + Name: biLogsVolume, + MountPath: biLogVolumeMountPath, + SubPathExpr: biLogVolumeMountSubPathExpr, + }) + volumeMountData, err := json.Marshal(vb.mounts[containerFunctionAgentInit]) + if err != nil { + log.GetLogger().Errorf("failed to marshal function agent init volumeMount data, error:%s", err.Error()) + return err + } + createOpt[constant.DelegateInitVolumeMounts] = string(volumeMountData) + log.GetLogger().Infof("generate function agent init volumeMount: %+v "+ + "for function: %s", string(volumeMountData), funcSpec.FuncKey) + return nil +} + +func setCreateOptionForInitContainerEnv(funcSpec *types.FunctionSpecification, createOpt map[string]string) error { + logger := log.GetLogger().With(zap.Any("function", funcSpec.FuncKey)) + npuType, _ := utils.GetNpuTypeAndInstanceTypeFromStr(funcSpec.ResourceMetaData.CustomResources, + funcSpec.ResourceMetaData.CustomResourcesSpec) + if npuType != types.AscendResourceD910B { + return nil + } + initContainerEnvs := []v1.EnvVar{ + { + Name: "GENERATE_RANKTABLE_FILE", + Value: "true", + }, + } + initContainerEnvsRaw, err := json.Marshal(initContainerEnvs) + if err != nil { + logger.Errorf("failed to marshal function agent init envs data, error:%s", err.Error()) + return err + } + createOpt[constant.DelegateInitEnv] = string(initContainerEnvsRaw) + logger.Infof("generate function agent init envs: %+v ", string(initContainerEnvsRaw)) + return nil +} + +func setCreateOptionForLabel(instanceType types.InstanceType, funcSpec *types.FunctionSpecification, + resSpec *resspeckey.ResourceSpecification, createOpt map[string]string) error { + if createOpt == nil { + return errors.New("createOpt is nil") + } + labels, err := getPodLabel(funcSpec, resSpec, instanceType) + if err != nil { + log.GetLogger().Errorf("get pod labels failed, err:%s", err.Error()) + return err + } + log.GetLogger().Infof("pod labels is: %s", string(labels)) + if resSpec != nil && resSpec.InvokeLabel != "" { + createOpt[types.InstanceLabelNode] = resSpec.InvokeLabel + } + createOpt[constant.DelegatePodLabels] = string(labels) + podInitLabels := map[string]string{ + podLabelSecurityGroup: strings.Split(funcSpec.FuncKey, "/")[0], + } + initLabels, err := json.Marshal(podInitLabels) + if err != nil { + log.GetLogger().Errorf("pod init labels json marshal failed, err:%s", err.Error()) + return err + } + createOpt[constant.DelegatePodInitLabels] = string(initLabels) + return nil +} + +func getPodLabel(funcSpec *types.FunctionSpecification, resSpec *resspeckey.ResourceSpecification, + instanceType types.InstanceType) ([]byte, error) { + version := funcSpec.FuncMetaData.Version + // $ is an illegal character in k8s label + if strings.HasPrefix(version, "$") { + version = strings.TrimPrefix(version, "$") + } + podLabels := map[string]string{ + podLabelInstanceType: string(instanceType), + podLabelFuncName: funcSpec.FuncMetaData.FuncName, + podLabelIsPoolPod: "false", + podLabelServiceID: funcSpec.FuncMetaData.Service, + podLabelTenantID: strings.Split(funcSpec.FuncKey, "/")[0], + podLabelVersion: version, + } + if resSpec != nil { + resSpecString := fmt.Sprintf("%d-%d-fusion", resSpec.CPU, resSpec.Memory) + podLabels[podLabelStandard] = resSpecString + } + if resSpec != nil && resSpec.InvokeLabel != "" { + podLabels[types.HeaderInstanceLabel] = resSpec.InvokeLabel + } + labels, err := json.Marshal(podLabels) + if err != nil { + return nil, err + } + return labels, nil +} + +func setCreateOptionForNote(instanceType types.InstanceType, funcSpec *types.FunctionSpecification, + resSpec *resspeckey.ResourceSpecification, createOpt map[string]string) error { + if createOpt == nil { + return nil + } + resSpecData, err := json.Marshal(resSpec) + if err != nil { + log.GetLogger().Errorf("failed to marshal resourceSpecification error %s", err.Error()) + return err + } + createOpt[types.ResourceSpecNote] = string(resSpecData) + createOpt[types.FunctionKeyNote] = funcSpec.FuncKey + if funcSpec.FuncMetaData.BusinessType == constant.BusinessTypeServe { + createOpt[constant.BusinessTypeTypeNote] = constant.BusinessTypeServe + } + if selfregister.GlobalSchedulerProxy.CheckFuncOwner(funcSpec.FuncKey) { + createOpt[types.SchedulerIDNote] = selfregister.SelfInstanceID + types.TemporaryInstance + } else { + createOpt[types.SchedulerIDNote] = selfregister.SelfInstanceID + types.PermanentInstance + } + createOpt[types.InstanceTypeNote] = string(instanceType) + createOpt[types.TenantID] = strings.Split(funcSpec.FuncKey, "/")[0] + return nil +} + +func buildDelegateNodeAffinity(xpuNodeLabel types.XpuNodeLabel) *v1.NodeAffinity { + return &v1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{NodeSelectorTerms: []v1.NodeSelectorTerm{ + v1.NodeSelectorTerm{ + MatchExpressions: []v1.NodeSelectorRequirement{ + v1.NodeSelectorRequirement{ + Key: xpuNodeLabel.NodeLabelKey, + Operator: "In", + Values: xpuNodeLabel.NodeLabelValues, + }}}}}, + } +} + +func setCreateOptionForAscendNPU(funcSpec *types.FunctionSpecification, resSpec *resspeckey.ResourceSpecification, + createOpt map[string]string) error { + if createOpt == nil || resSpec == nil { + return nil + } + npuCores := 0 + logger := log.GetLogger().With(zap.Any("funcKey", funcSpec.FuncKey)) + + customResourcesSpecStr := funcSpec.ResourceMetaData.CustomResourcesSpec + customResourceStr := funcSpec.ResourceMetaData.CustomResources + npuType, npuInstanceType := utils.GetNpuTypeAndInstanceTypeFromStr(customResourceStr, customResourcesSpecStr) + if npuType == "" { + return nil + } + for _, v := range config.GlobalConfig.XpuNodeLabels { + if v.XpuType == npuType && v.InstanceType == npuInstanceType { + nodeAffinity := buildDelegateNodeAffinity(v) + jsonBytes, err := json.Marshal(nodeAffinity) + if err != nil { + return err + } + createOpt[constant.DelegateNodeAffinity] = string(jsonBytes) + createOpt[constant.DelegateNodeAffinityPolicy] = constant.DelegateNodeAffinityPolicyAggregation + break + } + } + + for k, v := range resSpec.CustomResources { + if strings.Contains(k, types.AscendResourcePrefix) { + npuType = k + npuCores = int(v) + } + } + if npuCores == 0 { + return nil + } + npuMask := uint(0) + for i := 0; i < npuCores; i++ { + npuMask += 1 << i + } + data, err := json.Marshal(map[string]string{cceAscendAnnotation: fmt.Sprintf("%s=0x%x", npuType, npuMask)}) + if err != nil { + logger.Errorf("failed to marshal ascend pod annotation, error %s", err.Error()) + return err + } + createOpt[constant.DelegatePodAnnotations] = string(data) + return nil +} + +func setCreateOptionForName(instanceName, callerPodName string, createOpt map[string]string) error { + if createOpt == nil { + return nil + } + if len(instanceName) != 0 { + createOpt[types.InstanceNameNote] = instanceName + } + if len(callerPodName) != 0 { + labelsData, exist := createOpt[constant.DelegatePodLabels] + data := &map[string]string{} + if exist { + err := json.Unmarshal([]byte(labelsData), data) + if err != nil { + log.GetLogger().Errorf("Unmarshal instance labels error") + return err + } + } + (*data)["callerPodName"] = callerPodName + labels, err := json.Marshal(data) + if err != nil { + log.GetLogger().Errorf("pod labels json marshal failed, err:%s", err.Error()) + return err + } + if createOpt != nil { + createOpt[constant.DelegatePodLabels] = string(labels) + } + } + return nil +} + +func prepareCreateArguments(request createInstanceRequest) []api.Arg { + funcSpecCopy := &types.FunctionSpecification{} + commonUtils.DeepCopyObj(request.funcSpec, funcSpecCopy) + funcSpecCopy.ResourceMetaData.CPU = request.resKey.CPU + funcSpecCopy.ResourceMetaData.Memory = request.resKey.Memory + funcSpecData, err := json.Marshal(funcSpecCopy) + if err != nil { + log.GetLogger().Errorf("failed to marshal create params error %s", err.Error()) + return nil + } + createParamsData, err := prepareCreateParamsData(funcSpecCopy, request.resKey) + if err != nil { + log.GetLogger().Errorf("failed to prepare creatParamsData error %s", err.Error()) + return nil + } + schedulerData, err := signalmanager.PrepareSchedulerArg() + if err != nil { + log.GetLogger().Errorf("failed to prepare scheduler params error %s", err.Error()) + return nil + } + args := []api.Arg{ + { + Type: api.Value, + Data: funcSpecData, + }, + { + Type: api.Value, + Data: createParamsData, + }, + { + Type: api.Value, + Data: schedulerData, + }, + { + Type: api.Value, + Data: request.createEvent, + }, + } + if request.funcSpec.FuncMetaData.Runtime == types.CustomContainerRuntimeType { + customUserData, err := prepareCustomUserArg(funcSpecCopy) + if err != nil { + log.GetLogger().Errorf("failed to prepare custom user data error %s", err.Error()) + return nil + } + args = append(args, api.Arg{ + Type: api.Value, + Data: customUserData, + }) + } + return args +} + +func prepareCreateParamsData(funcSpec *types.FunctionSpecification, resKey resspeckey.ResSpecKey) ([]byte, error) { + var createParamsData []byte + var err error + if strings.Contains(funcSpec.FuncMetaData.Runtime, types.HTTPRuntimeType) || + strings.Contains(funcSpec.FuncMetaData.Runtime, types.CustomContainerRuntimeType) { + createParams := CreateParams{ + InstanceLabel: resKey.InvokeLabel, + HTTPCreateParams: HTTPCreateParams{ + Port: types.HTTPFuncPort, + CallRoute: types.HTTPCallRoute, + }, + } + createParamsData, err = json.Marshal(createParams) + } else { + userInitEntry := funcSpec.ExtendedMetaData.Initializer.Handler + userCallEntry := funcSpec.FuncMetaData.Handler + if len(userCallEntry) == 0 { + log.GetLogger().Warnf("user call entry for function %s is empty", funcSpec.FuncKey) + } + createParams := CreateParams{ + InstanceLabel: resKey.InvokeLabel, + EventCreateParams: EventCreateParams{ + UserInitEntry: userInitEntry, + UserCallEntry: userCallEntry, + }, + } + createParamsData, err = json.Marshal(createParams) + } + if err != nil { + log.GetLogger().Errorf("failed to marshal create params error %s", err.Error()) + return nil, err + } + return createParamsData, nil +} + +func prepareCustomUserArg(funcSpec *types.FunctionSpecification) ([]byte, error) { + faasExecutorStsServerConfig := getStsServerConfig(funcSpec) + localAuth := localauth.AuthConfig{ + AKey: config.GlobalConfig.LocalAuth.AKey, + SKey: config.GlobalConfig.LocalAuth.SKey, + Duration: config.GlobalConfig.LocalAuth.Duration, + } + customUserArgInfo := &types.CustomUserArgs{ + AlarmConfig: config.GlobalConfig.AlarmConfig, + ClusterName: config.GlobalConfig.ClusterName, + StsServerConfig: faasExecutorStsServerConfig, + DiskMonitorEnable: config.GlobalConfig.DiskMonitorEnable, + LocalAuth: localAuth, + } + customUserArgData, err := json.Marshal(customUserArgInfo) + if err != nil { + return nil, err + } + return customUserArgData, nil +} + +func getStsServerConfig(funcSpec *types.FunctionSpecification) raw.ServerConfig { + if !funcSpec.StsMetaData.EnableSts { + return raw.ServerConfig{} + } + domain := config.GlobalConfig.RawStsConfig.ServerConfig.Domain + if config.GlobalConfig.RawStsConfig.StsDomainForRuntime != "" { + domain = config.GlobalConfig.RawStsConfig.StsDomainForRuntime + } + faasExecutorStsServerConfig := raw.ServerConfig{ + Domain: domain, + Path: fmt.Sprintf(faasExecutorStsCertPath, funcSpec.StsMetaData.ServiceName, + funcSpec.StsMetaData.MicroService, funcSpec.StsMetaData.MicroService), + } + return faasExecutorStsServerConfig +} + +func hasD910b(resSpec *resspeckey.ResourceSpecification) bool { + if resSpec == nil { + return false + } + d910bstr, _ := utils.GetNpuTypeAndInstanceType(resSpec.CustomResources, resSpec.CustomResourcesSpec) + if d910bstr == types.AscendResourceD910B { + return true + } + return false +} + +func setEphemeralStorage(defaultES, configES float64, resourcesMap map[string]float64) { + if resourcesMap == nil { + resourcesMap = make(map[string]float64) + } + if configES == 0 { + resourcesMap[resourcesEphemeralStorage] = defaultES + return + } + resourcesMap[resourcesEphemeralStorage] = configES +} + +func setCreateOptionForFuncSpec(funcSpec *types.FunctionSpecification, createOpt map[string]string) { + if createOpt == nil { + return + } + createOpt[types.ConcurrentNumKey] = strconv.Itoa(funcSpec.InstanceMetaData.ConcurrentNum) + createOpt[constant.DelegateDirectoryInfo] = defaultDelegateDirectoryInfo + if funcSpec.InstanceMetaData.DiskLimit == 0 { + createOpt[constant.DelegateDirectoryQuota] = strconv.Itoa(defaultDirectoryLimit) + } else { + createOpt[constant.DelegateDirectoryQuota] = strconv.FormatInt(funcSpec.InstanceMetaData.DiskLimit, base) + } + + createOpt[types.InitCallTimeoutKey] = strconv.Itoa(int(funcSpec.ExtendedMetaData.Initializer.Timeout)) + createOpt[types.CallTimeoutKey] = strconv.Itoa(int(funcSpec.FuncMetaData.Timeout)) + createOpt[types.FunctionSign] = funcSpec.FuncMetaSignature +} + +func initCustomContainerEnv(funcSpec *types.FunctionSpecification) []v1.EnvVar { + parsedURN, err := urnutils.GetFunctionInfo(funcSpec.FuncMetaData.FunctionURN) + if err != nil { + log.GetLogger().Errorf("getFunctionInfo error: %s", err.Error()) + } + + envs := []v1.EnvVar{ + {Name: invokeTypeEnvName, Value: invokeTypeEnvValue}, + // Bi environment variable + {Name: biEvnTenantID, Value: parsedURN.TenantID}, + {Name: biEvnFunctionName, Value: getNoPrefixFuncName(funcSpec.FuncMetaData.Name)}, + {Name: biEvnFunctionVersion, Value: funcSpec.FuncMetaData.Version}, + {Name: biEvnRegion, Value: config.GlobalConfig.RegionName}, + {Name: biEvnClusterID, Value: os.Getenv("CLUSTER_ID")}, + {Name: biEvnNodeIP, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{APIVersion: "v1", FieldPath: "status.hostIP"}, + }}, + {Name: biEvnPodName, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{APIVersion: "v1", FieldPath: "metadata.name"}, + }}, + {Name: podIPEnv, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{APIVersion: "v1", FieldPath: "status.podIP"}, + }}, + {Name: hostIPEnv, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{APIVersion: "v1", FieldPath: "status.hostIP"}, + }}, + {Name: podNameEnv, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{APIVersion: "v1", FieldPath: "metadata.name"}, + }}, + {Name: podIDEnv, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{APIVersion: "v1", FieldPath: "metadata.uid"}, + }}, + {Name: podNameEnvNew, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{APIVersion: "v1", FieldPath: "metadata.name"}, + }}, + } + npuType, npuInstanceType := utils.GetNpuTypeAndInstanceTypeFromStr( + funcSpec.ResourceMetaData.CustomResources, funcSpec.ResourceMetaData.CustomResourcesSpec) + if npuType != "" { + envs = append(envs, v1.EnvVar{ + Name: types.SystemNodeInstanceType, Value: npuInstanceType}) + } + return envs +} + +func getNoPrefixFuncName(name string) string { + lastIndex := strings.LastIndex(name, "@") + if lastIndex > 0 { + return name[lastIndex+1:] + } + return name +} + +func setVolumeForDelegateContainer(funcSpec *types.FunctionSpecification, vb *volumeBuilder) { + if strings.Contains(funcSpec.ResourceMetaData.CustomResources, types.AscendResourcePrefix) { + (&ascendConfig{}).configVolume(vb) + if config.GlobalConfig.EnableNPUDriverMount { + (&npuDriver{}).configVolume(vb) + } + } + if len(config.GlobalConfig.FunctionConfig) != 0 { + for _, fc := range config.GlobalConfig.FunctionConfig { + (&functionDefaultConfig{ + configName: fc.ConfigName, + mount: fc.Mount, + }).configVolume(vb) + } + } +} + +func setEnvForDelegateContainer(funcSpec *types.FunctionSpecification, eb *envBuilder) { + if funcSpec.StsMetaData.EnableSts { + configEnv(eb, funcSpec.StsMetaData.SensitiveConfigs) + } + if strings.Contains(funcSpec.ResourceMetaData.CustomResources, types.AscendResourcePrefix) { + eb.addEnvVar(containerDelegate, v1.EnvVar{Name: types.AscendRankTableFileEnvKey, + Value: types.AscendRankTableFileEnvValue}) + } +} + +func generateCustomHostAliases(hostAliases []v1.HostAlias) map[string][]string { + // key: ip, value: hosts + hostAliasMap := make(map[string][]string, constant.DefaultMapSize) + for _, hostAlias := range hostAliases { + if _, exist := hostAliasMap[hostAlias.IP]; !exist { + hostAliasMap[hostAlias.IP] = make([]string, 0, constant.DefaultHostAliasesSliceSize) + } + hostAliasMap[hostAlias.IP] = append(hostAliasMap[hostAlias.IP], hostAlias.Hostnames...) + } + return hostAliasMap +} + +func setCreateOptionForVPC(createOption map[string]string, natConfig *types.NATConfigure) { + if createOption == nil { + createOption = make(map[string]string, utils.DefaultMapSize) + } + networkConfigs := []types.NetworkConfig{ + { + RouteConfig: types.RouteConfig{ + Gateway: natConfig.PatContainerIP, + Cidr: "0.0.0.0/0", + }, + TunnelConfig: types.TunnelConfig{ + TunnelName: "tunl_fgs_vpc", + RemoteIP: natConfig.PatContainerIP, + Mode: "ipip", + }, + FirewallConfig: types.FirewallConfig{ + Chain: "OUTPUT", + Table: "filter", + Operation: "add", + Target: natConfig.PatContainerIP, + Args: "-j ACCEPT", + }, + }, + } + networkConfigData, err := json.Marshal(networkConfigs) + if err != nil { + log.GetLogger().Errorf("failed to marshal network config for pat service %s", natConfig.PatPodName) + return + } + createOption[types.NetworkConfigKey] = string(networkConfigData) + proberConfigs := []types.ProberConfig{ + { + Protocol: "ICMP", + Address: natConfig.PatContainerIP, + Interval: patProberInterval, + Timeout: patProberTimeout, + FailureThreshold: patProberFailureThreshold, + }, + } + proberConfigData, err := json.Marshal(proberConfigs) + if err != nil { + log.GetLogger().Errorf("failed to marshal prober config for pat service %s", natConfig.PatPodName) + return + } + createOption[types.ProberConfigKey] = string(proberConfigData) +} + +func setCreateOptionForNuwaRuntimeInfo(nuwaRuntimeInfo *wisecloudTypes.NuwaRuntimeInfo, + createOpt map[string]string) error { + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + return nil + } + if createOpt == nil { + return errors.New("createOpt is nil") + } + if nuwaRuntimeInfo == nil { + return errors.New("nuwa runtimeinfo is nil") + } + nuwaRuntimeInfoData, err := json.Marshal(nuwaRuntimeInfo) + if err != nil { + log.GetLogger().Errorf("failed to marshal nuwa runtime info for %v", nuwaRuntimeInfo) + return err + } + createOpt[constant.DelegateNuwaRuntimeInfo] = string(nuwaRuntimeInfoData) + return nil +} + +func prepareCreatePATServiceArguments(extMetaData commonTypes.ExtendedMetaData, + vpcConfig *commonTypes.VpcConfig) []api.Arg { + patSvcReq := types.PATServiceRequest{ + ID: vpcConfig.ID, + DomainID: vpcConfig.DomainID, + Namespace: vpcConfig.Namespace, + VpcName: vpcConfig.VpcName, + VpcID: vpcConfig.VpcID, + SubnetName: vpcConfig.SubnetName, + SubnetID: vpcConfig.SubnetID, + TenantCidr: vpcConfig.TenantCidr, + HostVMCidr: vpcConfig.HostVMCidr, + Gateway: vpcConfig.Gateway, + Xrole: extMetaData.Role.XRole, + AppXrole: extMetaData.Role.AppXRole, + } + patSvcReqData, err := json.Marshal(patSvcReq) + if err != nil { + log.GetLogger().Errorf("failed to marshal vpc config, error %s", err.Error()) + } + return []api.Arg{ + { + Type: api.Value, + Data: []byte(vpcOpCreatePATService), + }, + { + Type: api.Value, + Data: patSvcReqData, + }, + } +} + +func prepareReportInstanceArguments(instance *types.Instance, natConfig *types.NATConfigure) []api.Arg { + report := vpcInsCreateReport{ + PatPodName: natConfig.PatPodName, + InstanceID: instance.InstanceID, + } + reportData, err := json.Marshal(report) + if err != nil { + log.GetLogger().Errorf("failed to marshal create report for instance %s vpc %s, error %s", instance.InstanceID, + natConfig.PatPodName, err.Error()) + } + return []api.Arg{ + { + Type: api.Value, + Data: []byte(vpcOpReportInstanceID), + }, + { + Type: api.Value, + Data: reportData, + }, + } +} + +func prepareDeleteInstanceArguments(instance *types.Instance) []api.Arg { + report := vpcInsDeleteReport{ + InstanceID: instance.InstanceID, + } + reportData, err := json.Marshal(report) + if err != nil { + log.GetLogger().Errorf("failed to marshal delete report for instance %s vpc %s, error %s", instance.InstanceID, + err.Error()) + } + return []api.Arg{ + { + Type: api.Value, + Data: []byte(vpcOpDeleteInstanceID), + }, + { + Type: api.Value, + Data: reportData, + }, + } +} + +func generateSnErrorFromKernelError(kernelErr error) snerror.SNError { + var kernelErrCode int + var kernelErrMsg string + if snErr, ok := kernelErr.(api.ErrorInfo); ok { + kernelErrCode = snErr.Code + kernelErrMsg = snErr.Error() + } else { + kernelErrCode = statuscode.GetKernelErrorCode(kernelErr.Error()) + kernelErrMsg = statuscode.GetKernelErrorMessage(kernelErr.Error()) + } + // if user error, return original errorCode + if kernelErrCode < snerror.UserErrorMax && kernelErrCode >= snerror.UserErrorMin { + return snerror.New(kernelErrCode, kernelErrMsg) + } + // posix errorCode from functionsystem + if kernelErrCode == constant.KernelUserCodeLoadErrCode { + return snerror.New(statuscode.UserFuncEntryNotFoundErrCode, kernelErrMsg) + } + if kernelErrCode == constant.KernelCreateLimitErrCode { + return snerror.New(statuscode.CreateLimitErrorCode, kernelErrMsg) + } + if kernelErrCode == constant.KernelWriteEtcdCircuitErrCode { + return snerror.New(statuscode.KernelEtcdWriteFailedCode, kernelErrMsg) + } + if kernelErrCode == constant.KernelResourceNotEnoughErrCode { + return snerror.New(statuscode.KernelResourceNotEnoughErrCode, kernelErrMsg) + } + + if kernelErrCode == constant.KernelRequestErrBetweenRuntimeAndBus && + strings.Contains(kernelErrMsg, "reason(timeout)") { + return snerror.New(statuscode.UserFuncInitTimeoutCode, "runtime initialization timed out") + } + + // faas-executor returns two types of error: inner-system-error and user-function-error, + // which contains message in JSON format + snErr := snerror.New(statuscode.InternalErrorCode, kernelErrMsg) + if kernelErrCode == constant.KernelInnerSystemErrCode || + kernelErrCode == constant.KernelUserFunctionExceptionErrCode { + initRsp := &types.ExecutorInitResponse{} + err := json.Unmarshal([]byte(kernelErrMsg), initRsp) + if err != nil { + log.GetLogger().Errorf("json unMarshal initRsp error: %s", err.Error()) + return snErr + } + errCode, err := strconv.Atoi(initRsp.ErrorCode) + if err != nil { + log.GetLogger().Errorf("initRsp errorCode invalid: %s", err.Error()) + return snErr + } + errMsg, err := initRsp.Message.MarshalJSON() + if err != nil { + log.GetLogger().Errorf("initRsp message marshal error: %s", err.Error()) + return snErr + } + return snerror.New(errCode, string(errMsg)) + } + return snErr +} + +func handlePullTriggerCreate(faasMgrInfo faasManagerInfo, funcSpec *types.FunctionSpecification) { + if len(faasMgrInfo.funcKey) == 0 || len(faasMgrInfo.instanceID) == 0 { + log.GetLogger().Errorf("no faas-manager instance info") + return + } + createCh := make(chan error, 1) + log.GetLogger().Infof("start to prepare arguments %s", funcSpec.FuncKey) + args := prepareCreatePullTriggerArguments(funcSpec) + if args == nil { + return + } + traceID := utils.GenerateTraceID() + log.GetLogger().Infof("start to invoke manager %s, instance %s, traceID: %s", faasMgrInfo.funcKey, + faasMgrInfo.instanceID, traceID) + var invokeErr error + funcMeta := api.FunctionMeta{FuncID: faasMgrInfo.funcKey, Api: api.PosixApi} + invokeOpts := api.InvokeOptions{TraceID: traceID} + objID, errorInfo := globalSdkClient.InvokeByInstanceId(funcMeta, faasMgrInfo.instanceID, args, invokeOpts) + invokeErr = errorInfo + if invokeErr == nil { + globalSdkClient.GetAsync(objID, func(result []byte, err error) { + createCh <- err + }) + } + if invokeErr != nil { + log.GetLogger().Errorf("failed to send create request of vpcPullTrigger %s, traceID: %s, error: %s", + funcSpec.FuncKey, traceID, invokeErr.Error()) + return + } + timer := time.NewTimer(vpcPullTriggerRequestTimeout) + select { + case resultErr, ok := <-createCh: + if !ok { + log.GetLogger().Errorf("result channel of create pullTrigger request is closed, traceID: %s", traceID) + } + timer.Stop() + if resultErr != nil { + log.GetLogger().Errorf("failed to create pullTrigger %s, traceID: %s, error %s", + funcSpec.FuncKey, traceID, resultErr.Error()) + return + } + case <-timer.C: + log.GetLogger().Errorf("time out waiting for pullTrigger create of function %s, traceID %s", + funcSpec.FuncKey, traceID) + return + } +} + +func handlePullTriggerDelete(faasMgrInfo faasManagerInfo, funcSpec *types.FunctionSpecification) { + traceID := utils.GenerateTraceID() + log.GetLogger().Infof("handling vpc pull trigger %s delete, traceID: %s", funcSpec.FuncKey, traceID) + if len(faasMgrInfo.funcKey) == 0 || len(faasMgrInfo.instanceID) == 0 { + log.GetLogger().Errorf("no faas-manager instance info, traceID: %s", traceID) + return + } + args := prepareDeletePullTriggerArguments(funcSpec.FuncKey) + deleteCh := make(chan error, 1) + var deleteErr error + funcMeta := api.FunctionMeta{FuncID: faasMgrInfo.funcKey, Api: api.PosixApi} + invokeOpts := api.InvokeOptions{TraceID: traceID} + objID, errorInfo := globalSdkClient.InvokeByInstanceId(funcMeta, faasMgrInfo.instanceID, args, invokeOpts) + deleteErr = errorInfo + if deleteErr == nil { + globalSdkClient.GetAsync(objID, func(result []byte, err error) { + deleteCh <- err + }) + } + if deleteErr != nil { + log.GetLogger().Errorf("failed to send delete request of vpcPullTrigger %s, traceID: %s, error %s", + funcSpec.FuncKey, traceID, deleteErr.Error()) + return + } + timer := time.NewTimer(vpcPullTriggerRequestTimeout) + select { + case resultErr, ok := <-deleteCh: + if !ok { + log.GetLogger().Errorf("result channel of delete pullTrigger request is closed, traceID: %s", traceID) + return + } + timer.Stop() + if resultErr != nil { + log.GetLogger().Errorf("failed to delete pullTrigger %s, traceID: %s, error %s", + funcSpec.FuncKey, traceID, resultErr.Error()) + return + } + case <-timer.C: + log.GetLogger().Errorf("time out waiting for pullTrigger delete of function %s, traceID: %s", + funcSpec.FuncKey, traceID) + return + } +} + +func prepareCreatePullTriggerArguments(funcSpec *types.FunctionSpecification) []api.Arg { + if funcSpec.ExtendedMetaData.VpcConfig == nil { + log.GetLogger().Errorf("failed to get vpc config") + return nil + } + pullTriggerReq := types.PullTriggerRequestInfo{ + PodName: funcSpec.FuncKey, + Image: funcSpec.FuncMetaData.VPCTriggerImage, + DomainID: funcSpec.ExtendedMetaData.VpcConfig.DomainID, + Namespace: funcSpec.ExtendedMetaData.VpcConfig.Namespace, + VpcName: funcSpec.ExtendedMetaData.VpcConfig.VpcName, + VpcID: funcSpec.ExtendedMetaData.VpcConfig.VpcID, + SubnetName: funcSpec.ExtendedMetaData.VpcConfig.SubnetName, + SubnetID: funcSpec.ExtendedMetaData.VpcConfig.SubnetID, + TenantCidr: funcSpec.ExtendedMetaData.VpcConfig.TenantCidr, + HostVMCidr: funcSpec.ExtendedMetaData.VpcConfig.HostVMCidr, + Gateway: funcSpec.ExtendedMetaData.VpcConfig.Gateway, + Xrole: funcSpec.ExtendedMetaData.Role.XRole, + AppXrole: funcSpec.ExtendedMetaData.Role.AppXRole, + } + pullTriggerReqData, err := json.Marshal(pullTriggerReq) + if err != nil { + log.GetLogger().Errorf("failed to marshal pullTrigger config, error %s", err.Error()) + return nil + } + return []api.Arg{ + { + Type: api.Value, + Data: []byte(vpcOpCreatePullTrigger), + }, + { + Type: api.Value, + Data: pullTriggerReqData, + }, + } +} + +func prepareDeletePullTriggerArguments(funcKey string) []api.Arg { + deleteInfo := types.PullTriggerDeleteInfo{ + PodName: funcKey, + } + deleteData, err := json.Marshal(deleteInfo) + if err != nil { + log.GetLogger().Errorf("failed to marshal delete info for funcKey %s, error %s", funcKey, err.Error()) + } + return []api.Arg{ + { + Type: api.Value, + Data: []byte(vpcOpDeletePullTrigger), + }, + { + Type: api.Value, + Data: deleteData, + }, + } +} diff --git a/yuanrong/pkg/functionscaler/instancepool/instance_operation_kernel_test.go b/yuanrong/pkg/functionscaler/instancepool/instance_operation_kernel_test.go new file mode 100644 index 0000000..c36137d --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/instance_operation_kernel_test.go @@ -0,0 +1,1828 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package instancepool + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/agiledragon/gomonkey" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "yuanrong.org/kernel/runtime/libruntime/api" + + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/common/faas_common/sts/raw" + commonTypes "yuanrong/pkg/common/faas_common/types" + mockUtils "yuanrong/pkg/common/faas_common/utils" + wisecloudTypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/signalmanager" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +func TestCreateInstanceForKernel(t *testing.T) { + SetGlobalSdkClient(&mockUtils.FakeLibruntimeSdkClient{}) + funcSpec := &types.FunctionSpecification{ + FuncKey: "12345678901234561234567890123456/0-system-faasExecutor/$latest", + FuncMetaData: commonTypes.FuncMetaData{}, + InstanceMetaData: commonTypes.InstanceMetaData{}, + } + convey.Convey("Test CreateInstanceForKernel", t, func() { + convey.Convey("create success", func() { + request := createInstanceRequest{ + funcSpec: funcSpec, + } + instance, err := createInstanceForKernel(request) + convey.So(err, convey.ShouldBeNil) + convey.So(instance, convey.ShouldNotBeNil) + }) + }) +} + +func TestDealWithVPCError(t *testing.T) { + convey.Convey("TestDealWithVPCError", t, func() { + convey.Convey("VPCXRoleNotFound", func() { + err := errors.New(statuscode.ErrVPCXRoleNotFound.Error()) + err = dealWithVPCError(err) + var snErr snerror.SNError + ok := errors.As(err, &snErr) + convey.So(ok, convey.ShouldBeTrue) + convey.So(snErr.Error(), convey.ShouldEqual, fmt.Sprintf("VPC can't find xrole")) + }) + convey.Convey("ErrNoOperationalPermissionsVpc", func() { + err := errors.New(statuscode.ErrNoOperationalPermissionsVpc.Error()) + err = dealWithVPCError(err) + var snErr snerror.SNError + ok := errors.As(err, &snErr) + convey.So(ok, convey.ShouldBeTrue) + convey.So(snErr.Error(), convey.ShouldEqual, statuscode.ErrNoOperationalPermissionsVpc.Error()) + }) + convey.Convey("NotEnoughNIC", func() { + err := errors.New(statuscode.ErrNoAvailableVpcPatInstance.Error()) + err = dealWithVPCError(err) + var snErr snerror.SNError + ok := errors.As(err, &snErr) + convey.So(ok, convey.ShouldBeTrue) + convey.So(snErr.Error(), convey.ShouldEqual, fmt.Sprintf("not enough network cards")) + }) + convey.Convey("ErrVPCNotFound", func() { + err := errors.New(statuscode.ErrVPCNotFound.Error()) + err = dealWithVPCError(err) + var snErr snerror.SNError + ok := errors.As(err, &snErr) + convey.So(ok, convey.ShouldBeTrue) + convey.So(snErr.Error(), convey.ShouldEqual, fmt.Sprintf("VPC item not found")) + }) + convey.Convey("NotVPCError", func() { + err := errors.New("not a vpc error") + createErr := dealWithVPCError(err) + convey.So(createErr, convey.ShouldEqual, err) + }) + }) +} + +func TestSetCreateOptionForDownloadData(t *testing.T) { + convey.Convey("Test SetCreateOptionForDownloadData", t, func() { + convey.Convey("create Opt is nil", func() { + err := setCreateOptionForDownloadData(nil, nil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("SetCreateOptionForDownloadData success", func() { + funcSpec := &types.FunctionSpecification{ + S3MetaData: commonTypes.S3MetaData{}, + FuncMetaData: commonTypes.FuncMetaData{ + Layers: []*commonTypes.Layer{{}}, + }, + } + createOpt := map[string]string{} + err := setCreateOptionForDownloadData(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey(" marshal delegate download data error", func() { + patch := ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return []byte{}, errors.New("marshal error") + }) + defer patch.Reset() + funcSpec := &types.FunctionSpecification{ + S3MetaData: commonTypes.S3MetaData{}, + } + createOpt := map[string]string{} + err := setCreateOptionForDownloadData(funcSpec, createOpt) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey(" marshal delegate layer download data error", func() { + patch := ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + _, ok := v.(commonTypes.S3MetaData) + if ok { + return []byte{}, nil + } + return []byte{}, errors.New("marshal error") + }) + defer patch.Reset() + funcSpec := &types.FunctionSpecification{ + S3MetaData: commonTypes.S3MetaData{}, + FuncMetaData: commonTypes.FuncMetaData{ + Layers: []*commonTypes.Layer{{}}, + }, + } + createOpt := map[string]string{} + err := setCreateOptionForDownloadData(funcSpec, createOpt) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func Test_setCreateOptionForLifeCycleDetached(t *testing.T) { + createOpt := make(map[string]string) + err := setCreateOptionForLifeCycleDetached(nil, createOpt) + if err != nil { + t.Errorf("do setCreateOptionForLifeCycleDetached failed, err: %s", err.Error()) + return + } + lifeCycle := createOpt[commonconstant.InstanceLifeCycle] + if lifeCycle != "detached" { + t.Errorf("do setCreateOptionForLifeCycleDetached failed, lifeCyle = %s", lifeCycle) + return + } +} + +func Test_setCreateOptionForNuwaRuntimeInfo(t *testing.T) { + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + ClusterID: "cluster001", + Scenario: types.ScenarioWiseCloud, + } + defer func() { + config.GlobalConfig = rawGConfig + }() + type args struct { + funcSpec *types.FunctionSpecification + nuwaRuntimeInfo *wisecloudTypes.NuwaRuntimeInfo + createOpt map[string]string + } + tests := []struct { + name string + args args + want map[string]string + }{ + {"case1 map is nil", + args{ + funcSpec: &types.FunctionSpecification{}, + nuwaRuntimeInfo: &wisecloudTypes.NuwaRuntimeInfo{}, + createOpt: nil, + }, + nil, + }, + { + "case2 succeeded to marshal ers workload config", + args{ + funcSpec: &types.FunctionSpecification{}, + nuwaRuntimeInfo: &wisecloudTypes.NuwaRuntimeInfo{ + WisecloudRuntimeId: "runtimeId", + WisecloudSite: "site", + WisecloudTenantId: "tenant", + WisecloudApplicationId: "application", + WisecloudServiceId: "serviceid", + WisecloudEnvironmentId: "environment", + EnvLabel: "label", + }, + createOpt: map[string]string{}, + }, + map[string]string{"DELEGATE_NUWA_RUNTIME_INFO": `{"wisecloudRuntimeId":"runtimeId","wisecloudSite":"site","wisecloudTenantId":"tenant","wisecloudApplicationId":"application","wisecloudServiceId":"serviceid","wisecloudEnvironmentId":"environment","envLabel":"label"}`}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if setCreateOptionForNuwaRuntimeInfo(tt.args.nuwaRuntimeInfo, + tt.args.createOpt); !reflect.DeepEqual(tt.args.createOpt, tt.want) { + t.Errorf("setCreateOptionForDelegateMount() = %v, want %v", tt.args.createOpt, tt.want) + } + }) + } +} + +func Test_setCreateOptionForMountVolume(t *testing.T) { + type args struct { + funcSpec *types.FunctionSpecification + createOpt map[string]string + } + tests := []struct { + name string + args args + want map[string]string + }{ + {"case1 map is nil", args{ + funcSpec: &types.FunctionSpecification{}, + createOpt: nil, + }, nil}, + {"case2 succeeded to marshal func mount config", args{ + funcSpec: &types.FunctionSpecification{ExtendedMetaData: commonTypes.ExtendedMetaData{ + FuncMountConfig: &commonTypes.FuncMountConfig{ + FuncMountUser: commonTypes.FuncMountUser{ + UserID: 1004, + GroupID: 1004, + }, + FuncMounts: []commonTypes.FuncMount{commonTypes.FuncMount{ + MountType: "ecs", + MountResource: "eb4ebf7a-db82-4602-82ce-7e1e57a8ef46", + MountSharePath: "1.1.1.1:/sharerdata", + LocalMountPath: "/home/", + Status: "active", + }}, + }}}, + createOpt: map[string]string{"test": "test"}, + }, map[string]string{"test": "test", + "DELEGATE_MOUNT": `{"mount_user":{"user_id":1004,"user_group_id":1004},"func_mounts":[{"mount_type":"ecs","mount_resource":"eb4ebf7a-db82-4602-82ce-7e1e57a8ef46","mount_share_path":"1.1.1.1:/sharerdata","local_mount_path":"/home/","status":"active"}]}`}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if setCreateOptionForDelegateMount(tt.args.funcSpec, + tt.args.createOpt); !reflect.DeepEqual(tt.args.createOpt, tt.want) { + t.Errorf("setCreateOptionForDelegateMount() = %v, want %v", tt.args.createOpt, tt.want) + } + }) + } +} + +func Test_setCreateOptionForRASP(t *testing.T) { + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + Scenario: types.ScenarioWiseCloud, + } + defer func() { + config.GlobalConfig = rawGConfig + }() + convey.Convey("test setCreateOptionForRASP", t, func() { + convey.Convey("createOpt is nil", func() { + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspImage: "test-image", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + }, + } + err := setCreateOptionForRASP(funcSpec, nil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("initContainerAdd error", func() { + defer ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + }).Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspImage: "test-image", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + }, + } + createOpt := make(map[string]string) + err := setCreateOptionForRASP(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestSetCreateOptionForContainerSideCar(t *testing.T) { + config.GlobalConfig.Scenario = types.ScenarioWiseCloud + convey.Convey("Test SetCreateOptionForContainerSideCar", t, func() { + convey.Convey("createOpt is nil", func() { + err := setCreateOptionForFileBeat(nil, nil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("json Marshal CustomFilebeatConfig error", func() { + patch := ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return []byte{}, errors.New("marshal error") + }) + defer patch.Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{}, + } + createOpt := map[string]string{} + err := setCreateOptionForFileBeat(funcSpec, createOpt) + convey.So(err, convey.ShouldNotBeNil) + + }) + convey.Convey("json Marshal config error", func() { + patch := ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return []byte{}, errors.New("marshal error") + }) + defer patch.Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + CustomFilebeatConfig: commonTypes.CustomFilebeatConfig{ + ImageAddress: "image", + }, + }, + } + createOpt := map[string]string{} + err := setCreateOptionForFileBeat(funcSpec, createOpt) + convey.So(err, convey.ShouldNotBeNil) + + }) + convey.Convey("success", func() { + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + CustomFilebeatConfig: commonTypes.CustomFilebeatConfig{ + ImageAddress: "image", + }, + }, + } + createOpt := map[string]string{} + err := setCreateOptionForFileBeat(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func Test_setCustomPodSeccompProfile(t *testing.T) { + convey.Convey("test setCustomPodSeccompProfile", t, func() { + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + Scenario: types.ScenarioWiseCloud, + } + defer func() { + config.GlobalConfig = rawGConfig + }() + convey.Convey("creatOpt is nil", func() { + err := setCustomPodSeccompProfile(nil, nil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("runtime is not CustomContainer", func() { + createOpt := map[string]string{} + funcSpec := &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Runtime: "java", + }, + } + err := setCustomPodSeccompProfile(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("json Marsha data error", func() { + defer ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + }).Reset() + createOpt := map[string]string{} + funcSpec := &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Runtime: types.CustomContainerRuntimeType, + }, + } + err := setCustomPodSeccompProfile(funcSpec, createOpt) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("set success", func() { + createOpt := map[string]string{} + funcSpec := &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Runtime: types.CustomContainerRuntimeType, + }, + } + err := setCustomPodSeccompProfile(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func Test_setFunctionAgentInitContainer(t *testing.T) { + convey.Convey("test setFunctionAgentInitContainer", t, func() { + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + Scenario: types.ScenarioWiseCloud, + } + defer func() { + config.GlobalConfig = rawGConfig + }() + convey.Convey("creatOpt is nil", func() { + err := setFunctionAgentInitContainer(nil, nil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("runtime is not CustomContainer", func() { + createOpt := map[string]string{} + funcSpec := &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Runtime: "java", + }, + } + err := setFunctionAgentInitContainer(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("json Marsha data error", func() { + defer ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + }).Reset() + createOpt := map[string]string{} + funcSpec := &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Runtime: types.CustomContainerRuntimeType, + }, + } + err := setFunctionAgentInitContainer(funcSpec, createOpt) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("set success", func() { + createOpt := map[string]string{} + funcSpec := &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Runtime: types.CustomContainerRuntimeType, + }, + } + err := setFunctionAgentInitContainer(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func Test_setCreateOptionForInitContainerEnv(t *testing.T) { + convey.Convey("test setCreateOptionForInitContainerEnv", t, func() { + funcSpec := &types.FunctionSpecification{ + ResourceMetaData: commonTypes.ResourceMetaData{ + CustomResources: "{\"huawei.com/ascend-1980\":8}", + CustomResourcesSpec: "{\"instanceType\":\"376T\"}", + }, + } + convey.Convey("d910b generate ranktable file", func() { + createOpt := map[string]string{} + + err := setCreateOptionForInitContainerEnv(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(createOpt[commonconstant.DelegateInitEnv], "GENERATE_RANKTABLE_FILE"), + convey.ShouldEqual, true) + convey.So(strings.Contains(createOpt[commonconstant.DelegateInitEnv], "true"), convey.ShouldEqual, true) + }) + convey.Convey("no generate ranktable file", func() { + createOpt := map[string]string{} + funcSpec.ResourceMetaData = commonTypes.ResourceMetaData{} + err := setCreateOptionForInitContainerEnv(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(createOpt[commonconstant.DelegateInitEnv], "GENERATE_RANKTABLE_FILE"), + convey.ShouldEqual, false) + convey.So(strings.Contains(createOpt[commonconstant.DelegateInitEnv], "true"), convey.ShouldEqual, false) + }) + convey.Convey("no generate ranktable file 1", func() { + createOpt := map[string]string{} + funcSpec.ResourceMetaData = commonTypes.ResourceMetaData{ + CustomResources: "{\"huawei.com/ascend-1980123\":8}", + CustomResourcesSpec: "{\"instanceType\":\"376T\"}", + } + err := setFunctionAgentInitContainer(funcSpec, createOpt) + convey.So(err, convey.ShouldBeNil) + convey.So(strings.Contains(createOpt[commonconstant.DelegateInitEnv], "GENERATE_RANKTABLE_FILE"), + convey.ShouldEqual, false) + convey.So(strings.Contains(createOpt[commonconstant.DelegateInitEnv], "true"), convey.ShouldEqual, false) + }) + }) +} + +func Test_setCreateOptionForLabel(t *testing.T) { + type args struct { + funcSpec *types.FunctionSpecification + createOpt map[string]string + resSpec *resspeckey.ResourceSpecification + instanceType types.InstanceType + } + tests := []struct { + name string + args args + wantNil bool + }{ + {"case1 map is nil", args{ + funcSpec: &types.FunctionSpecification{}, + createOpt: nil, + }, true}, + {"case2 succeeded to set createOption for label", args{ + funcSpec: &types.FunctionSpecification{FuncMetaData: commonTypes.FuncMetaData{FuncName: "test", + TenantID: "tenantID", Service: "serviceID", Version: "$latest"}}, + createOpt: map[string]string{}, + resSpec: &resspeckey.ResourceSpecification{CPU: 500, Memory: 500}, + instanceType: types.InstanceTypeReserved, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setCreateOptionForLabel(tt.args.instanceType, tt.args.funcSpec, tt.args.resSpec, tt.args.createOpt) + if !reflect.DeepEqual(tt.args.createOpt == nil, tt.wantNil) { + t.Errorf("setCreateOptionForLabel() = %v, want %v", tt.args.createOpt, tt.wantNil) + } + }) + } +} + +func TestSetCreateOptionForNodeAffinity(t *testing.T) { + getMockFuncSpec := func(customResource string, customResourcesSpec string) *types.FunctionSpecification { + return &types.FunctionSpecification{ + ResourceMetaData: commonTypes.ResourceMetaData{ + CustomResources: customResource, + CustomResourcesSpec: customResourcesSpec, + }, + } + } + + config.GlobalConfig.XpuNodeLabels = []types.XpuNodeLabel{ + types.XpuNodeLabel{ + XpuType: "huawei.com/ascend-1980", + InstanceType: "376T", + NodeLabelKey: "node.kubernetes.io/instance-type", + NodeLabelValues: []string{ + "physical.kat2ne.48xlarge.8.376t.ei.c002.ondemand", + "physical.kat2ne.48xlarge.8.ei.pod101.ondemand", + }, + }, + types.XpuNodeLabel{ + XpuType: "huawei.com/ascend-1980", + InstanceType: "", + NodeLabelKey: "node.kubernetes.io/instance-type", + NodeLabelValues: []string{ + "physical.kat2ne.48xlarge.8.376t.ei.c002.ondemand", + "physical.kat2ne.48xlarge.8.ei.pod101.ondemand", + }, + }, + types.XpuNodeLabel{ + XpuType: "huawei.com/ascend-1980", + InstanceType: "280T", + NodeLabelKey: "node.kubernetes.io/instance-type", + NodeLabelValues: []string{ + "physical.kat2e.48xlarge.8.280t.ei.c005.ondemand", + }, + }, + } + + tests := []struct { + name string + customResource string + customResourcesSpec string + delegateNodeAffinity string + delegateNodeAffinityPolicy string + }{ + { + name: "376T", + customResource: "{\"huawei.com/ascend-1980\":8}", + customResourcesSpec: "{\"instanceType\":\"376T\"}", + delegateNodeAffinity: "{" + + "\"requiredDuringSchedulingIgnoredDuringExecution\": {" + + " \"nodeSelectorTerms\": [{" + + " \"matchExpressions\": [{" + + " \"key\": \"node.kubernetes.io/instance-type\"," + + " \"operator\": \"In\"," + + " \"values\": [\"physical.kat2ne.48xlarge.8.376t.ei.c002.ondemand\"," + + " \"physical.kat2ne.48xlarge.8.ei.pod101.ondemand\"]" + + " }]" + + " }]" + + "}" + + "}", + delegateNodeAffinityPolicy: commonconstant.DelegateNodeAffinityPolicyAggregation, + }, + { + name: "280T", + customResource: "{\"huawei.com/ascend-1980\":8}", + customResourcesSpec: "{\"instanceType\":\"280T\"}", + delegateNodeAffinity: "{" + + "\"requiredDuringSchedulingIgnoredDuringExecution\": {" + + " \"nodeSelectorTerms\": [{" + + " \"matchExpressions\": [{" + + " \"key\": \"node.kubernetes.io/instance-type\"," + + " \"operator\": \"In\"," + + " \"values\": [\"physical.kat2e.48xlarge.8.280t.ei.c005.ondemand\"]" + + " }]" + + " }]" + + "}" + + "}", + delegateNodeAffinityPolicy: commonconstant.DelegateNodeAffinityPolicyAggregation, + }, + { + name: "376T_1", + customResource: "{\"huawei.com/ascend-1980\":8}", + delegateNodeAffinity: "{" + + "\"requiredDuringSchedulingIgnoredDuringExecution\": {" + + " \"nodeSelectorTerms\": [{" + + " \"matchExpressions\": [{" + + " \"key\": \"node.kubernetes.io/instance-type\"," + + " \"operator\": \"In\"," + + " \"values\": [\"physical.kat2ne.48xlarge.8.376t.ei.c002.ondemand\"," + + " \"physical.kat2ne.48xlarge.8.ei.pod101.ondemand\"]" + + " }]" + + " }]" + + "}" + + "}", + delegateNodeAffinityPolicy: commonconstant.DelegateNodeAffinityPolicyAggregation, + }, + { + name: "nil", + customResource: "", + delegateNodeAffinity: "", + }, + { + name: "error", + customResource: "{\"instanceType\":\"28\"}", + delegateNodeAffinity: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := make(map[string]string) + e := setCreateOptionForAscendNPU(getMockFuncSpec(tt.customResource, tt.customResourcesSpec), + &resspeckey.ResourceSpecification{CPU: 500, Memory: 500, + CustomResources: map[string]int64{"huawei.com/ascend-1980": 8}}, m) + if e != nil { + t.Errorf("setCreateOptionForAscendNPU failed, err: %v", e) + } + actualNodeAffinity, ok := m["DELEGATE_NODE_AFFINITY"] + if tt.delegateNodeAffinity == "" && !ok { + return + } + if tt.delegateNodeAffinity == "" || !ok { + t.Errorf("actual nodeAffinity is %s, expect Affinity is %s", actualNodeAffinity, + tt.delegateNodeAffinity) + } + actualJson := make(map[string]interface{}) + expectJson := make(map[string]interface{}) + + e1 := json.Unmarshal([]byte(actualNodeAffinity), &actualJson) + e2 := json.Unmarshal([]byte(tt.delegateNodeAffinity), &expectJson) + if e1 != nil || e2 != nil { + t.Errorf("json.Unmarshal failed, err1: %v, actualNodeAffinity: %s, err2: %s, expectNodeAffinity: %s", + e1, actualNodeAffinity, e2, tt.delegateNodeAffinity) + } + + if !reflect.DeepEqual(actualJson, expectJson) { + t.Errorf("deepEqual failed, actualNodeAffinity: %s, expectNodeAffinity: %s", actualNodeAffinity, + tt.delegateNodeAffinity) + } + + if m[commonconstant.DelegateNodeAffinityPolicy] != tt.delegateNodeAffinityPolicy { + t.Errorf("deepEqual failed, actualNodeAffinityPolicy: %s, expectNodeAffinityPolicy: %s", + m[commonconstant.DelegateNodeAffinityPolicy], tt.delegateNodeAffinityPolicy) + } + fmt.Printf("actual NodeAffinity: %s, expect NodeAffinity: %s", actualNodeAffinity, tt.delegateNodeAffinity) + }) + } + +} + +func Test_PrepareCreateArguments(t *testing.T) { + convey.Convey("test Test_PrepareCreateArguments", t, func() { + convey.Convey("json marshal error", func() { + defer ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + }).Reset() + arg := prepareCreateArguments(createInstanceRequest{ + funcSpec: &types.FunctionSpecification{}, + resKey: resspeckey.ResSpecKey{}, + }) + convey.So(arg, convey.ShouldBeNil) + }) + convey.Convey("prepareCreateParamsData error", func() { + defer ApplyFunc(prepareCreateParamsData, func(funcSpec *types.FunctionSpecification) ([]byte, error) { + return nil, fmt.Errorf("prepareCreateParamsData error") + }).Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + }, + } + arg := prepareCreateArguments(createInstanceRequest{ + funcSpec: funcSpec, + resKey: resspeckey.ResSpecKey{}, + }) + convey.So(arg, convey.ShouldBeNil) + }) + convey.Convey("prepareSchedulerArg error", func() { + defer ApplyFunc(prepareCreateParamsData, func(funcSpec *types.FunctionSpecification) ([]byte, error) { + return []byte("createParamsData"), nil + }).Reset() + defer ApplyFunc(signalmanager.PrepareSchedulerArg, func() ([]byte, error) { + return nil, fmt.Errorf("prepareSchedulerArg error") + }).Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + }, + } + arg := prepareCreateArguments(createInstanceRequest{ + funcSpec: funcSpec, + resKey: resspeckey.ResSpecKey{}, + }) + convey.So(arg, convey.ShouldBeNil) + }) + convey.Convey("prepareCustomUserArg error", func() { + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + Scenario: types.ScenarioWiseCloud, + } + defer func() { + config.GlobalConfig = rawGConfig + }() + defer ApplyFunc(prepareCreateParamsData, func(funcSpec *types.FunctionSpecification) ([]byte, error) { + return []byte("createParamsData"), nil + }).Reset() + defer ApplyFunc(signalmanager.PrepareSchedulerArg, func() ([]byte, error) { + return []byte("prepareSchedulerArg"), nil + }).Reset() + defer ApplyFunc(prepareCustomUserArg, func(funcSpec *types.FunctionSpecification) ([]byte, error) { + return nil, fmt.Errorf("prepareSchedulerArg error") + }).Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + }, + } + arg := prepareCreateArguments(createInstanceRequest{ + funcSpec: funcSpec, + resKey: resspeckey.ResSpecKey{}, + }) + convey.So(arg, convey.ShouldNotBeNil) + }) + convey.Convey("success", func() { + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + Scenario: types.ScenarioWiseCloud, + } + defer func() { + config.GlobalConfig = rawGConfig + }() + defer ApplyFunc(prepareCreateParamsData, func(funcSpec *types.FunctionSpecification) ([]byte, error) { + return []byte("createParamsData"), nil + }).Reset() + defer ApplyFunc(signalmanager.PrepareSchedulerArg, func() ([]byte, error) { + return []byte("prepareSchedulerArg"), nil + }).Reset() + defer ApplyFunc(prepareCustomUserArg, func(funcSpec *types.FunctionSpecification) ([]byte, error) { + return []byte("prepareCustomUserArg"), nil + }).Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + }, + } + arg := prepareCreateArguments(createInstanceRequest{ + funcSpec: funcSpec, + resKey: resspeckey.ResSpecKey{}, + }) + convey.So(len(arg), convey.ShouldEqual, 4) + }) + }) +} + +func Test_PrepareCustomUserArg(t *testing.T) { + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + RawStsConfig: raw.StsConfig{ + ServerConfig: raw.ServerConfig{}, + }, + LocalAuth: localauth.AuthConfig{}, + } + defer func() { + config.GlobalConfig = rawGConfig + }() + convey.Convey("test PrepareCustomUserArg", t, func() { + convey.Convey("json Marshal error", func() { + defer ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, errors.New("json marshal error") + }).Reset() + _, err := prepareCustomUserArg(&types.FunctionSpecification{}) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("success", func() { + _, err := prepareCustomUserArg(&types.FunctionSpecification{}) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func Test_addInstanceCallerPodName(t *testing.T) { + convey.Convey("test setCreateOptionForName", t, func() { + convey.Convey("Unmarshal labels data error", func() { + createOpt := map[string]string{ + commonconstant.DelegatePodLabels: "", + } + err := setCreateOptionForName("", "callerPodName", createOpt) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("json Marsha data error", func() { + defer ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + }).Reset() + createOpt := map[string]string{ + commonconstant.DelegatePodLabels: "{\"funcName\":\"testcustom1024001\"}", + } + err := setCreateOptionForName("", "callerPodName", createOpt) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("success", func() { + createOpt := map[string]string{ + commonconstant.DelegatePodLabels: "{\"funcName\":\"testcustom1024001\"}", + } + err := setCreateOptionForName("", "callerPodName", createOpt) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func Test_SetEphemeralStorage(t *testing.T) { + convey.Convey("test setEphemeralStorage", t, func() { + convey.Convey("resourcesMap is nil", func() { + setEphemeralStorage(1, 0, nil) + resourcesMap := make(map[string]float64) + setEphemeralStorage(1, 1, resourcesMap) + convey.So(resourcesMap[resourcesEphemeralStorage], convey.ShouldEqual, 1) + }) + }) +} + +func TestGetNpuInstanceType(t *testing.T) { + tests := []struct { + name string + customResource string + customResourcesPec string + npuType string + npuInstanceType string + }{ + { + name: "376T", + customResource: "{\"huawei.com/ascend-1980\":8}", + customResourcesPec: "{\"instanceType\":\"376T\"}", + npuType: "huawei.com/ascend-1980", + npuInstanceType: "376T", + }, + { + name: "280T", + customResource: "{\"huawei.com/ascend-1980\":8}", + customResourcesPec: "{\"instanceType\":\"280T\"}", + npuType: "huawei.com/ascend-1980", + npuInstanceType: "280T", + }, + { + name: "376T_1", + customResource: "{\"huawei.com/ascend-1980\":8}", + customResourcesPec: "", + npuType: "huawei.com/ascend-1980", + npuInstanceType: "376T", + }, + { + name: "nil", + customResource: "", + customResourcesPec: "{\"instanceType\":\"280\"}", + npuType: "", + npuInstanceType: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if npuType, npuInstanceType := utils.GetNpuTypeAndInstanceTypeFromStr(tt.customResource, + tt.customResourcesPec); npuInstanceType != tt.npuInstanceType || npuType != tt.npuType { + t.Errorf("failed, actual npuType: %s npuInstanceType: %s, expect npuType: %s npuInstanceType: %s", + npuType, npuInstanceType, tt.npuType, tt.npuInstanceType) + } + }) + } +} + +func TestGetNpuTypeAndInstanceType(t *testing.T) { + tests := []struct { + name string + customResource map[string]int64 + customResourcesSpec map[string]interface{} + npuType string + npuInstanceType string + }{ + { + name: "376T", + customResource: map[string]int64{"huawei.com/ascend-1980": 8}, + customResourcesSpec: map[string]interface{}{"instanceType": "376T"}, + npuType: "huawei.com/ascend-1980", + npuInstanceType: "376T", + }, + { + name: "280T", + customResource: map[string]int64{"huawei.com/ascend-1980": 8}, + customResourcesSpec: map[string]interface{}{"instanceType": "280T"}, + npuType: "huawei.com/ascend-1980", + npuInstanceType: "280T", + }, + { + name: "376T_1", + customResource: map[string]int64{"huawei.com/ascend-1980": 8}, + npuType: "huawei.com/ascend-1980", + npuInstanceType: "376T", + }, + { + name: "nil", + npuType: "", + npuInstanceType: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if npuType, npuInstanceType := utils.GetNpuTypeAndInstanceType(tt.customResource, + tt.customResourcesSpec); npuType != tt.npuType || npuInstanceType != tt.npuInstanceType { + t.Errorf("failed, actual npuType: %s npuInstanceType: %s, expect npuType: %s npuInstanceType: %s", + npuType, npuInstanceType, tt.npuType, tt.npuInstanceType) + } + }) + } +} + +func TestInitCustomContainerEnvForNpu(t *testing.T) { + getMockFuncSpec := func(customResource string, customResourcesSpec string) *types.FunctionSpecification { + return &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Name: "0@default@zyrfunction3", + FunctionVersionURN: "sn:cn:yrk:172120022624845016:function:0@default@zyrfunction3:1", + FunctionURN: "sn:cn:yrk:172120022624845016:function:0@default@zyrfunction3", + }, + ResourceMetaData: commonTypes.ResourceMetaData{ + CustomResources: customResource, + CustomResourcesSpec: customResourcesSpec, + }, + } + } + + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + RegionName: "12324234", + } + defer func() { + config.GlobalConfig = rawGConfig + }() + + tests := []struct { + name string + customResource string + customResourcesSpec string + nodeNpuInstanceType string + }{ + { + name: "376T", + customResource: "{\"huawei.com/ascend-1980\":8}", + customResourcesSpec: "{\"instanceType\":\"376T\"}", + nodeNpuInstanceType: "376T", + }, + { + name: "280T", + customResource: "{\"huawei.com/ascend-1980\":8}", + customResourcesSpec: "{\"instanceType\":\"280T\"}", + nodeNpuInstanceType: "280T", + }, + { + name: "376T_1", + customResource: "{\"huawei.com/ascend-1980\":8}", + customResourcesSpec: "{\"instanceType\":\"376T\"}", + nodeNpuInstanceType: "376T", + }, + { + name: "nil", + customResource: "", + nodeNpuInstanceType: "", + }, + { + name: "error", + customResource: "{\"instanceType\":\"280\"}", + nodeNpuInstanceType: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + envs := initCustomContainerEnv(getMockFuncSpec(tt.customResource, tt.customResourcesSpec)) + actualNodeInstanceType := "" + // actualNodeInstanceTypeFromMeta := "" + for _, env := range envs { + if env.Name == "X_SYSTEM_NODE_INSTANCE_TYPE" { + actualNodeInstanceType = env.Value + } + if env.Name == "X_SYSTEM_NODE_INSTANCE_TYPE_IN_META_CUSTOM_RESOURCE" { + // actualNodeInstanceTypeFromMeta = env.Value + } + } + if actualNodeInstanceType != tt.nodeNpuInstanceType { + t.Errorf("failed, actual nodeInstanceType is %s, expect nodeInstanceType is %s", actualNodeInstanceType, + tt.nodeNpuInstanceType) + } + }) + + } +} + +func TestSetCreateOptionForVPC(t *testing.T) { + convey.Convey("Test TestSetCreateOptionForVPC", t, func() { + convey.Convey(" marshal network config error", func() { + patch := ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return []byte{}, errors.New("marshal error") + }) + defer patch.Reset() + natConfig := &types.NATConfigure{} + setCreateOptionForVPC(nil, natConfig) + }) + convey.Convey(" marshal prober config error", func() { + patch := ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + _, ok := v.([]types.NetworkConfig) + if ok { + return []byte{}, nil + } + return []byte{}, errors.New("marshal error") + }) + defer patch.Reset() + natConfig := &types.NATConfigure{} + setCreateOptionForVPC(nil, natConfig) + }) + }) +} + +func TestGenerateSnErrorFromKernelError(t *testing.T) { + convey.Convey("Test GenerateSnErrorFromKernelError", t, func() { + convey.Convey("json Unmarshal error", func() { + kernelErr := errors.New("code:3003,message: ") + snError := generateSnErrorFromKernelError(kernelErr) + convey.So(snError.Code(), convey.ShouldEqual, statuscode.InternalErrorCode) + + }) + convey.Convey("errorCode error", func() { + initRsp := &types.ExecutorInitResponse{} + data, _ := json.Marshal(initRsp) + kernelErr := errors.New(fmt.Sprintf("code:3003,message: %s", string(data))) + snError := generateSnErrorFromKernelError(kernelErr) + convey.So(snError.Code(), convey.ShouldEqual, statuscode.InternalErrorCode) + + }) + convey.Convey("message MarshalJSON error", func() { + initRsp := &types.ExecutorInitResponse{ + ErrorCode: "1", + Message: json.RawMessage{}, + } + data, _ := json.Marshal(initRsp) + kernelErr := errors.New(fmt.Sprintf("code:3003,message: %s", string(data))) + snError := generateSnErrorFromKernelError(kernelErr) + convey.So(snError.Code(), convey.ShouldEqual, statuscode.InternalErrorCode) + }) + convey.Convey("KernelUserCodeLoadErrCode", func() { + initRsp := &types.ExecutorInitResponse{ + ErrorCode: "2001", + } + data, _ := json.Marshal(initRsp) + kernelErr := errors.New(fmt.Sprintf("code:2001,message: %s", string(data))) + snError := generateSnErrorFromKernelError(kernelErr) + convey.So(snError.Code(), convey.ShouldEqual, 4001) + }) + convey.Convey("success", func() { + initRsp := &types.ExecutorInitResponse{ + ErrorCode: "1", + } + data, _ := json.Marshal(initRsp) + kernelErr := errors.New(fmt.Sprintf("code:3003,message: %s", string(data))) + snError := generateSnErrorFromKernelError(kernelErr) + convey.So(snError.Code(), convey.ShouldEqual, 1) + }) + }) +} + +func Test_setCreateOptionForPodInitLabel(t *testing.T) { + type args struct { + funcSpec *types.FunctionSpecification + resSpec *resspeckey.ResourceSpecification + instanceType types.InstanceType + } + tests := []struct { + name string + args args + podLabels map[string]string + podInitLabels map[string]string + tenantId string + }{ + {"case1 map is nil", + args{ + funcSpec: &types.FunctionSpecification{}, + }, + map[string]string{ + podLabelInstanceType: "", + podLabelFuncName: "", + podLabelIsPoolPod: "false", + podLabelServiceID: "", + podLabelTenantID: "", + podLabelVersion: "", + }, + map[string]string{ + podLabelSecurityGroup: "", + }, + "", + }, + {"case2 succeeded to set createOption for label", + args{ + funcSpec: &types.FunctionSpecification{FuncMetaData: commonTypes.FuncMetaData{FuncName: "test", + TenantID: "tenantID", Service: "serviceID", Version: "$latest"}, + FuncKey: "12345678901234561234567890123456/test/$latest"}, + resSpec: &resspeckey.ResourceSpecification{CPU: 500, Memory: 500}, + instanceType: types.InstanceTypeReserved, + }, + map[string]string{ + podLabelInstanceType: string(types.InstanceTypeReserved), + podLabelFuncName: "test", + podLabelIsPoolPod: "false", + podLabelServiceID: "serviceID", + podLabelTenantID: "12345678901234561234567890123456", + podLabelVersion: "latest", + podLabelStandard: "500-500-fusion", + }, + map[string]string{ + podLabelSecurityGroup: "12345678901234561234567890123456", + }, + "12345678901234561234567890123456", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + createOpt := make(map[string]string) + err := setCreateOptionForLabel(tt.args.instanceType, tt.args.funcSpec, tt.args.resSpec, createOpt) + if err != nil { + t.Errorf("setCreateOptionForLabel failed, err: %s", err.Error()) + return + } + err = setCreateOptionForNote(tt.args.instanceType, tt.args.funcSpec, tt.args.resSpec, createOpt) + if err != nil { + t.Errorf("setCreateOptionForNote failed, err: %s", err.Error()) + return + } + if _, ok := createOpt[commonconstant.DelegatePodLabels]; !ok { + t.Errorf("no deletegate pod labels") + return + } + if _, ok := createOpt[commonconstant.DelegatePodInitLabels]; !ok { + t.Errorf("no deletegate pod init labels") + return + } + podLabels := make(map[string]string) + podInitLabels := make(map[string]string) + err1 := json.Unmarshal([]byte(createOpt[commonconstant.DelegatePodLabels]), &podLabels) + err2 := json.Unmarshal([]byte(createOpt[commonconstant.DelegatePodInitLabels]), &podInitLabels) + if err1 != nil || err2 != nil { + t.Errorf("unmarshal failed") + return + } + if !reflect.DeepEqual(podInitLabels, tt.podInitLabels) { + t.Errorf("setCreateOptionForLabel() = %v, want %v", podInitLabels, tt.podInitLabels) + return + } + if !reflect.DeepEqual(podLabels, tt.podLabels) { + t.Errorf("setCreateOptionForLabel() = %v, want %v", podLabels, tt.podLabels) + return + } + if createOpt[types.TenantID] != tt.tenantId { + t.Errorf("setCreateOptionForLabel failed, tenantId set failed, actual: %s, expect: %s", + createOpt[types.TenantID], tt.tenantId) + } + }) + } +} + +func TestGetExecutorFuncKey(t *testing.T) { + gi := &GenericInstancePool{ + FuncSpec: &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{}, + }, + } + convey.Convey("python3.6", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "python3.6" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, + "12345678901234561234567890123456/0-system-faasExecutorPython3.6/$latest") + }) + convey.Convey("python3.8", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "python3.8" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, + "12345678901234561234567890123456/0-system-faasExecutorPython3.8/$latest") + }) + convey.Convey("python3.9", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "python3.9" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, + "12345678901234561234567890123456/0-system-faasExecutorPython3.9/$latest") + }) + convey.Convey("python3.10", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "python3.10" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, + "12345678901234561234567890123456/0-system-faasExecutorPython3.10/$latest") + }) + convey.Convey("python3.11", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "python3.11" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, + "12345678901234561234567890123456/0-system-faasExecutorPython3.11/$latest") + }) + convey.Convey("go, http, custom", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "go" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, "12345678901234561234567890123456/0-system-faasExecutorGo1.x/$latest") + gi.FuncSpec.FuncMetaData.Runtime = "http" + funcKey = getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, "12345678901234561234567890123456/0-system-faasExecutorGo1.x/$latest") + gi.FuncSpec.FuncMetaData.Runtime = "custom image" + funcKey = getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, "12345678901234561234567890123456/0-system-faasExecutorGo1.x/$latest") + }) + convey.Convey("java8", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "java8" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, "12345678901234561234567890123456/0-system-faasExecutorJava8/$latest") + }) + convey.Convey("java11", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "java11" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, "12345678901234561234567890123456/0-system-faasExecutorJava11/$latest") + }) + convey.Convey("java17", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "java17" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, "12345678901234561234567890123456/0-system-faasExecutorJava17/$latest") + }) + convey.Convey("java21", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "java21" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, "12345678901234561234567890123456/0-system-faasExecutorJava21/$latest") + }) + convey.Convey("posix-custom-runtime", t, func() { + gi.FuncSpec.FuncMetaData.Runtime = "posix-custom-runtime" + funcKey := getExecutorFuncKey(gi.FuncSpec) + convey.So(funcKey, convey.ShouldEqual, + "12345678901234561234567890123456/0-system-faasExecutorPosixCustom/$latest") + }) +} + +func TestDeleteInstanceForKernel(t *testing.T) { + convey.Convey("Test DeleteInstance", t, func() { + config.GlobalConfig.InstanceOperationBackend = commonconstant.BackendTypeKernel + funcSpec := &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + FuncMetaData: commonTypes.FuncMetaData{ + Handler: "myHandler", + }, + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 100, + ConcurrentNum: 100, + }, + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Handler: "myInitializer", + }, + }, + } + instance := &types.Instance{ + InstanceID: "testInstance", + } + convey.Convey("delete instance with global client", func() { + SetGlobalSdkClient(&mockUtils.FakeLibruntimeSdkClient{}) + err := deleteInstanceForKernel(funcSpec, faasManagerInfo{}, instance) + convey.So(err, convey.ShouldBeNil) + SetGlobalSdkClient(nil) + }) + }) +} + +func TestCreateInvokeOptions(t *testing.T) { + funcSpec := &types.FunctionSpecification{ + FuncKey: "testVpcFuncKey", + FuncMetaData: commonTypes.FuncMetaData{ + VPCTriggerImage: "vpc trigger image url", + }, + ExtendedMetaData: commonTypes.ExtendedMetaData{}, + } + opt := createInvokeOptions(funcSpec, &types.SchedulingOptions{}, nil, "") + assert.NotNil(t, opt) +} + +func Test_getStsServerConfig(t *testing.T) { + convey.Convey("test getStsServerConfig", t, func() { + convey.Convey("baseline", func() { + funcSpec := &types.FunctionSpecification{ + StsMetaData: commonTypes.StsMetaData{ + EnableSts: true, + ServiceName: "a", + MicroService: "b", + }, + } + config.GlobalConfig.RawStsConfig.ServerConfig.Domain = "12345" + config.GlobalConfig.RawStsConfig.StsDomainForRuntime = "" + serverConfig := getStsServerConfig(funcSpec) + convey.So(serverConfig.Domain, convey.ShouldEqual, "12345") + convey.So(serverConfig.Path, convey.ShouldEqual, "/opt/huawei/certs/a/b/b.ini") + config.GlobalConfig.RawStsConfig.StsDomainForRuntime = "67890" + serverConfig = getStsServerConfig(funcSpec) + convey.So(serverConfig.Domain, convey.ShouldEqual, "67890") + convey.So(serverConfig.Path, convey.ShouldEqual, "/opt/huawei/certs/a/b/b.ini") + }) + }) +} + +func TestCreatePATService_InvokeError(t *testing.T) { + funcSpec := &types.FunctionSpecification{} + + faasManagerInfo := faasManagerInfo{ + funcKey: "faas-manager-func-key", + instanceID: "faas-manager-instance-id", + } + extMetaData := commonTypes.ExtendedMetaData{} + vpcConfig := &commonTypes.VpcConfig{} + + patches := gomonkey.NewPatches() + defer patches.Reset() + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + patches.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + return "", errors.New("invoke error") + }) + natConfig, err := createPATService(funcSpec, faasManagerInfo, extMetaData, vpcConfig) + + assert.NotNil(t, err) + assert.Nil(t, natConfig) + assert.Equal(t, "failed to create pat service", err.Error()) + faasManagerInfo.instanceID = "" + + natConfig, err = createPATService(funcSpec, faasManagerInfo, extMetaData, vpcConfig) + assert.Nil(t, natConfig) + assert.Equal(t, "failed to create pat service", err.Error()) +} + +func TestCreatePATService_Success(t *testing.T) { + funcSpec := &types.FunctionSpecification{ + FuncKey: "test-func-key", + } + faasManagerInfo := faasManagerInfo{ + funcKey: "faas-manager-func-key", + instanceID: "faas-manager-instance-id", + } + extMetaData := commonTypes.ExtendedMetaData{} + vpcConfig := &commonTypes.VpcConfig{ + VpcID: "test-vpc-id", + } + + patches := gomonkey.NewPatches() + defer patches.Reset() + + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + + patches.ApplyFunc(prepareCreatePATServiceArguments, + func(extMetaData commonTypes.ExtendedMetaData, vpcConfig *commonTypes.VpcConfig) []api.Arg { + return []api.Arg{} + }) + + patches.ApplyFunc(utils.GenerateTraceID, func() string { + return "test-trace-id" + }) + + patches.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpts api.InvokeOptions) (string, error) { + return "test-obj-id", nil + }) + + patches.ApplyMethod(reflect.TypeOf(sdk), "GetAsync", + func(_ *mockUtils.FakeLibruntimeSdkClient, objectID string, cb api.GetAsyncCallback) { + response := &patSvcCreateResponse{ + Code: 0, + Message: `{"containerCidr":"10.0.0.0/24"}`, + } + responseData, _ := json.Marshal(response) + cb(responseData, nil) + }) + + _, err := createPATService(funcSpec, faasManagerInfo, extMetaData, vpcConfig) + + assert.Equal(t, err.Error(), "failed to create pat service") +} + +func TestDeleteInstanceWithVPCSuccess(t *testing.T) { + funcSpec := &types.FunctionSpecification{ + FuncKey: "test-func", + } + managerInfo := faasManagerInfo{ + funcKey: "manager-key", + instanceID: "manager-instance", + } + instance := &types.Instance{ + InstanceID: "test-instance", + } + + t.Run("success_case", func(t *testing.T) { + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + + var getAsyncCalled bool + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpts api.InvokeOptions) (string, error) { + return "test-obj-id", nil + }) + + patches.ApplyMethod(reflect.TypeOf(sdk), "GetAsync", + func(_ *mockUtils.FakeLibruntimeSdkClient, objectID string, cb api.GetAsyncCallback) { + getAsyncCalled = true + cb([]byte("success"), nil) + }) + + deleteInstanceWithVPC(funcSpec, managerInfo, instance) + + assert.NotNil(t, getAsyncCalled, "GetAsync should be called") + }) +} + +func TestDeleteInstanceWithVPC(t *testing.T) { + funcSpec := &types.FunctionSpecification{ + FuncKey: "test-func", + } + managerInfo := faasManagerInfo{ + funcKey: "manager-key", + instanceID: "manager-instance", + } + instance := &types.Instance{ + InstanceID: "test-instance", + } + + t.Run("get_async_error", func(t *testing.T) { + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + + getAsyncCalled := false + + defer gomonkey.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + return "mock-obj-id", nil + }).Reset() + + defer gomonkey.ApplyMethod(reflect.TypeOf(sdk), "GetAsync", + func(_ *mockUtils.FakeLibruntimeSdkClient, objectID string, cb api.GetAsyncCallback) { + getAsyncCalled = true + cb(nil, errors.New("get async error")) + }).Reset() + + deleteInstanceWithVPC(funcSpec, managerInfo, instance) + assert.NotNil(t, getAsyncCalled, "GetAsync should be called with error") + }) +} + +func TestHandlePullTriggerCreate(t *testing.T) { + funcSpec := &types.FunctionSpecification{ + FuncKey: "test-func", + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Handler: "myInitializer", + }, + VpcConfig: &commonTypes.VpcConfig{}, + }, + } + managerInfo := faasManagerInfo{ + funcKey: "manager-key", + instanceID: "manager-instance", + } + + t.Run("success_case", func(t *testing.T) { + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + + var getAsyncCalled bool + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + return "mock-obj-id", nil + }) + + patches.ApplyMethod(reflect.TypeOf(sdk), "GetAsync", + func(_ *mockUtils.FakeLibruntimeSdkClient, objectID string, cb api.GetAsyncCallback) { + getAsyncCalled = true + cb([]byte("success"), nil) + }) + + handlePullTriggerCreate(managerInfo, funcSpec) + + assert.NotNil(t, getAsyncCalled, "GetAsync should be called") + }) + + t.Run("invoke_error", func(t *testing.T) { + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + var getAsyncCalled bool + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + getAsyncCalled = true + return "", errors.New("invoke error") + }) + + handlePullTriggerCreate(managerInfo, funcSpec) + assert.NotNil(t, getAsyncCalled, "GetAsync should be called") + }) + + t.Run("get_async_error", func(t *testing.T) { + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + + var getAsyncCalled bool + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + return "mock-obj-id", nil + }) + + patches.ApplyMethod(reflect.TypeOf(sdk), "GetAsync", + func(_ *mockUtils.FakeLibruntimeSdkClient, objectID string, cb api.GetAsyncCallback) { + getAsyncCalled = true + cb(nil, errors.New("get async error")) + }) + + handlePullTriggerCreate(managerInfo, funcSpec) + + assert.NotNil(t, getAsyncCalled, "GetAsync should be called") + }) +} + +func TestReportInstanceWithVPC(t *testing.T) { + tests := []struct { + name string + funcSpec *types.FunctionSpecification + faasManagerInfo faasManagerInfo + instance *types.Instance + natConfig *types.NATConfigure + invokeErr error + getAsyncErr error + expectGetAsync bool + }{ + { + name: "successful_report", + funcSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Handler: "myInitializer", + }, + VpcConfig: &commonTypes.VpcConfig{}, + }, + }, + faasManagerInfo: faasManagerInfo{ + funcKey: "faas-key", + instanceID: "instance-1", + }, + instance: &types.Instance{ + InstanceID: "test-instance", + }, + natConfig: &types.NATConfigure{ + PatPodName: "test-pat-pod", + }, + invokeErr: nil, + getAsyncErr: nil, + expectGetAsync: true, + }, + { + name: "empty_faas_manager_info", + funcSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + }, + faasManagerInfo: faasManagerInfo{}, + instance: &types.Instance{}, + natConfig: &types.NATConfigure{}, + expectGetAsync: false, + }, + { + name: "invoke_error", + funcSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + }, + faasManagerInfo: faasManagerInfo{ + funcKey: "faas-key", + instanceID: "instance-1", + }, + instance: &types.Instance{}, + natConfig: &types.NATConfigure{ + PatPodName: "test-pat-pod", + }, + invokeErr: assert.AnError, + expectGetAsync: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + + var getAsyncCalled bool + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + if tt.invokeErr != nil { + return "", tt.invokeErr + } + assert.Equal(t, tt.faasManagerInfo.funcKey, funcMeta.FuncID) + assert.Equal(t, tt.faasManagerInfo.instanceID, instanceID) + return "mock-obj-id", nil + }) + + patches.ApplyMethod(reflect.TypeOf(sdk), "GetAsync", + func(_ *mockUtils.FakeLibruntimeSdkClient, objectID string, cb api.GetAsyncCallback) { + getAsyncCalled = true + cb([]byte("success"), tt.getAsyncErr) + }) + + reportInstanceWithVPC(tt.funcSpec, tt.faasManagerInfo, tt.instance, tt.natConfig) + + assert.NotNil(t, tt.expectGetAsync, getAsyncCalled, "GetAsync called status mismatch") + }) + } +} + +func TestHandlePullTriggerDelete(t *testing.T) { + tests := []struct { + name string + funcSpec *types.FunctionSpecification + faasManagerInfo faasManagerInfo + invokeErr error + getAsyncErr error + expectGetAsync bool + }{ + { + name: "successful_delete", + funcSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + }, + faasManagerInfo: faasManagerInfo{ + funcKey: "faas-key", + instanceID: "instance-1", + }, + invokeErr: nil, + getAsyncErr: nil, + expectGetAsync: true, + }, + { + name: "empty_faas_manager_info", + funcSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + }, + faasManagerInfo: faasManagerInfo{}, + expectGetAsync: false, + }, + { + name: "invoke_error", + funcSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + }, + faasManagerInfo: faasManagerInfo{ + funcKey: "faas-key", + instanceID: "instance-1", + }, + invokeErr: assert.AnError, + expectGetAsync: false, + }, + { + name: "get_async_error", + funcSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + }, + faasManagerInfo: faasManagerInfo{ + funcKey: "faas-key", + instanceID: "instance-1", + }, + invokeErr: nil, + getAsyncErr: assert.AnError, + expectGetAsync: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sdk := &mockUtils.FakeLibruntimeSdkClient{} + SetGlobalSdkClient(sdk) + + var getAsyncCalled bool + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(utils.GenerateTraceID, + func() string { + return "mock-trace-id" + }) + + patches.ApplyMethod(reflect.TypeOf(sdk), "InvokeByInstanceId", + func(_ *mockUtils.FakeLibruntimeSdkClient, funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + if tt.invokeErr != nil { + return "", tt.invokeErr + } + assert.Equal(t, tt.faasManagerInfo.funcKey, funcMeta.FuncID) + assert.Equal(t, tt.faasManagerInfo.instanceID, instanceID) + assert.Equal(t, api.PosixApi, funcMeta.Api) + return "mock-obj-id", nil + }) + + patches.ApplyMethod(reflect.TypeOf(sdk), "GetAsync", + func(_ *mockUtils.FakeLibruntimeSdkClient, objectID string, cb api.GetAsyncCallback) { + getAsyncCalled = true + cb([]byte("success"), tt.getAsyncErr) + }) + + handlePullTriggerDelete(tt.faasManagerInfo, tt.funcSpec) + + assert.NotNil(t, tt.expectGetAsync, getAsyncCalled, "GetAsync called status mismatch") + }) + } +} + +func TestHasD910b(t *testing.T) { + var nilResSpec *resspeckey.ResourceSpecification + assert.False(t, hasD910b(nilResSpec), "Expected false when resKey is nil") + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(utils.GetNpuTypeAndInstanceType, + func(customRes map[string]int64, customResSpec map[string]interface{}) (string, string) { + return types.AscendResourceD910B, "" + }) + + resSpecWithD910B := &resspeckey.ResourceSpecification{ + CustomResources: map[string]int64{"mock-key": 1}, + CustomResourcesSpec: map[string]interface{}{"mock-key": "mock-value"}, + } + assert.True(t, hasD910b(resSpecWithD910B), "Expected true when resKey has D910B resource") + + patches.Reset() + patches.ApplyFunc(utils.GetNpuTypeAndInstanceType, + func(customRes map[string]int64, customResSpec map[string]interface{}) (string, string) { + return "non-D910B-resource", "" + }) + + resSpecWithoutD910B := &resspeckey.ResourceSpecification{ + CustomResources: map[string]int64{"mock-key": 1}, + CustomResourcesSpec: map[string]interface{}{"mock-key": "mock-value"}, + } + assert.False(t, hasD910b(resSpecWithoutD910B), "Expected false when resKey does not have D910B resource") +} diff --git a/yuanrong/pkg/functionscaler/instancepool/instancepool.go b/yuanrong/pkg/functionscaler/instancepool/instancepool.go new file mode 100644 index 0000000..db724db --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/instancepool.go @@ -0,0 +1,1509 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "context" + "fmt" + "reflect" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/urnutils" + wisecloudTypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/instancequeue" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/signalmanager" + "yuanrong/pkg/functionscaler/stateinstance" + "yuanrong/pkg/functionscaler/tenantquota" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + defaultSessionReaperInterval = 60 * time.Second + faasManagerRequestTimeout = 30 * time.Second + vpcPullTriggerRequestTimeout = 30 * time.Second + maxInstanceLimit = 1000 + killSignalVal = 1 + vpcOpCreatePATService = "CreatePATService" + vpcOpReportInstanceID = "ReportInstanceID" + vpcOpDeleteInstanceID = "DeleteInstanceID" + vpcOpCreatePullTrigger = "CreatePullTrigger" + vpcOpDeletePullTrigger = "DeletePullTrigger" + patProberInterval = 5 + patProberTimeout = 5 + patProberFailureThreshold = 100 + resourcesCPU = "CPU" + resourcesMemory = "Memory" + defaultEphemeralStorage = 512 + resourcesEphemeralStorage = "ephemeral-storage" + podLabelInstanceType = "instanceType" + podLabelFuncName = "funcName" + podLabelIsPoolPod = "isPoolPod" + podLabelServiceID = "serviceID" + podLabelStandard = "standard" + podLabelTenantID = "tenantID" + podLabelVersion = "version" + podLabelSecurityGroup = "securityGroup" + executorFormat = "12345678901234561234567890123456/0-system-faasExecutor%s/$latest" + serveExecutor = "12345678901234561234567890123456/0-system-serveExecutor/$latest" + resSpecLen = 4 + faasExecutorStsCertPath = "/opt/certs/%s/%s/%s.ini" + defaultDelegateDirectoryInfo = "/tmp" + base = 10 + defaultDirectoryLimit = 512 + annotationFullFuncName = "funcName" +) + +const ( + invokeTypeEnvName = "INVOKE_TYPE" + invokeTypeEnvValue = "faas" + cceAscendAnnotation = "scheduling.cce.io/gpu-topology-placement" +) + +const ( + biEvnTenantID = "x-system-tenantId" + biEvnFunctionName = "x-system-functionName" + biEvnFunctionVersion = "x-system-functionVersion" + biEvnRegion = "x-system-region" + biEvnClusterID = "x-system-clusterID" + biEvnNodeIP = "x-system-NODE_IP" + biEvnPodName = "x-system-podName" + podIPEnv = "POD_IP" + podNameEnvNew = "POD_NAME" + hostIPEnv = "HOST_IP" + podNameEnv = "PodName" + podIDEnv = "POD_ID" +) + +const ( + stateInstanceDelete = "instanceDelete" + stateDelete = "delete" + stateUpdate = "update" +) + +type createOption struct { + callerPodName string +} + +// CreateParams is used to send config to runtime during instance initialization +type CreateParams struct { + InstanceLabel string `json:"instanceLabel,omitempty"` + EventCreateParams `json:",inline"` + HTTPCreateParams `json:",inline"` +} + +// EventCreateParams is used to send config to runtime during instance initialization +type EventCreateParams struct { + UserInitEntry string `json:"userInitEntry,omitempty"` + UserCallEntry string `json:"userCallEntry,omitempty"` + UserStateEntry string `json:"userStateEntry,omitempty"` +} + +// HTTPCreateParams is used to send config of http function during instance initialization +type HTTPCreateParams struct { + Port int `json:"port,omitempty"` + InitRoute string `json:"initRoute,omitempty"` + CallRoute string `json:"callRoute,omitempty"` +} + +type patSvcCreateResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type vpcInsCreateReport struct { + PatPodName string `json:"patPodName"` + InstanceID string `json:"instanceID"` +} + +type vpcInsDeleteReport struct { + InstanceID string `json:"instanceID"` +} + +type createInstanceRequest struct { + createEvent []byte + instanceName string + callerPodName string + poolLabel string + poolID string + funcSpec *types.FunctionSpecification + nuwaRuntimeInfo *wisecloudTypes.NuwaRuntimeInfo + instanceType types.InstanceType + resKey resspeckey.ResSpecKey + instanceBuilder types.InstanceBuilder + faasManagerInfo faasManagerInfo + createTimeout time.Duration +} + +type sessionRecord struct { + instance *types.Instance + sessionCtx context.Context +} + +// InstancePool defines operations of instance pool +type InstancePool interface { + CreateInstance(insCrtReq *types.InstanceCreateRequest) (*types.Instance, snerror.SNError) + DeleteInstance(instance *types.Instance) snerror.SNError + AcquireInstance(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, snerror.SNError) + ReleaseInstance(instance *types.InstanceAllocation) + HandleFunctionEvent(eventType registry.EventType, funcSpec *types.FunctionSpecification) + HandleAliasEvent(eventType registry.EventType, aliasUrn string) + HandleFaaSSchedulerEvent() + HandleInstanceEvent(eventType registry.EventType, instance *types.Instance) + HandleInstanceConfigEvent(eventType registry.EventType, insConfig *instanceconfig.Configuration) + UpdateInvokeMetrics(resKey resspeckey.ResSpecKey, insMetrics *types.InstanceThreadMetrics) + HandleFaaSManagerUpdate(faasManagerInfo faasManagerInfo) + GetFuncSpec() *types.FunctionSpecification + RecoverInstance(*types.FunctionSpecification, *types.InstancePoolState, bool, *sync.WaitGroup) + GetAndDeleteState(stateID string) bool + DeleteStateInstance(stateID string, instaceID string) + handleManagedChange() + handleRatioChange(ratio int) + CleanOrphansInstanceQueue() +} + +// GenericInstancePool is a generic instance pool to manage instances of a specific function +type GenericInstancePool struct { + FuncSpec *types.FunctionSpecification + defaultResSpec *resspeckey.ResourceSpecification + insConfig map[resspeckey.ResSpecKey]*instanceconfig.Configuration + metricsCollector map[resspeckey.ResSpecKey]metrics.Collector + insAcqReqQueue map[resspeckey.ResSpecKey]*requestqueue.InsAcqReqQueue + onDemandInstanceQueue map[resspeckey.ResSpecKey]*instancequeue.OnDemandInstanceQueue + reservedInstanceQueue map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue + scaledInstanceQueue map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue + sessionRecordMap map[string]sessionRecord + instanceSessionMap map[string]map[string]struct{} + stateInstanceID sync.Map + defaultResKey resspeckey.ResSpecKey + stateRoute StateRoute + faasManagerInfo faasManagerInfo + createTimeout time.Duration + sessionReaperInterval time.Duration + stopCh chan struct{} + waitInsConfigChan chan struct{} + functionSignature string + currentPoolLabel string + defaultPoolLabel string + minScaleUpdatedTime time.Time + pendingInstanceNum map[string]int + minScaleAlarmSign map[string]bool + + synced bool + sync.RWMutex + closeChanOnce sync.Once +} + +// GetAndDeleteState delete state and instance, return whether the state exists +func (gi *GenericInstancePool) GetAndDeleteState(stateID string) bool { + return gi.stateRoute.GetAndDeleteState(stateID) +} + +// DeleteStateInstance called by ReleaseLease +func (gi *GenericInstancePool) DeleteStateInstance(stateID string, instanceID string) { + gi.stateRoute.DeleteStateInstance(stateID, instanceID) +} + +func (gi *GenericInstancePool) recoverStateRouteMap(stateInstanceMap map[string]*types.Instance) { + gi.stateRoute.recover(stateInstanceMap) +} + +// NewGenericInstancePool creates a GenericInstancePool +func NewGenericInstancePool(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo) (InstancePool, + error) { + log.GetLogger().Infof("create instance pool for function %s", funcSpec.FuncKey) + defaultResSpec := resspeckey.ConvertResourceMetaDataToResSpec(funcSpec.ResourceMetaData) + defaultResKey := resspeckey.ConvertToResSpecKey(defaultResSpec) + pool := &GenericInstancePool{ + FuncSpec: funcSpec, + defaultResSpec: defaultResSpec, + defaultPoolLabel: funcSpec.InstanceMetaData.PoolLabel, + currentPoolLabel: funcSpec.InstanceMetaData.PoolLabel, + insConfig: make(map[resspeckey.ResSpecKey]*instanceconfig.Configuration, utils.DefaultMapSize), + metricsCollector: make(map[resspeckey.ResSpecKey]metrics.Collector, utils.DefaultMapSize), + insAcqReqQueue: make(map[resspeckey.ResSpecKey]*requestqueue.InsAcqReqQueue, utils.DefaultMapSize), + onDemandInstanceQueue: make(map[resspeckey.ResSpecKey]*instancequeue.OnDemandInstanceQueue, utils.DefaultMapSize), + reservedInstanceQueue: make(map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue, utils.DefaultMapSize), + scaledInstanceQueue: make(map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue, utils.DefaultMapSize), + sessionRecordMap: make(map[string]sessionRecord, utils.DefaultMapSize), + instanceSessionMap: make(map[string]map[string]struct{}, utils.DefaultMapSize), + defaultResKey: defaultResKey, + createTimeout: utils.GetCreateTimeout(funcSpec), + sessionReaperInterval: defaultSessionReaperInterval, + faasManagerInfo: faasManagerInfo, + functionSignature: funcSpec.FuncMetaSignature, + minScaleUpdatedTime: time.Now(), + waitInsConfigChan: make(chan struct{}), + pendingInstanceNum: make(map[string]int), + minScaleAlarmSign: make(map[string]bool), + } + pool.stateRoute = StateRoute{ + funcSpec: funcSpec, + stateRoute: make(map[string]*StateInstance, utils.DefaultMapSize), + stateLeaseManager: make(map[string]*stateinstance.Leaser, utils.DefaultMapSize), + stateConfig: funcSpec.FuncMetaData.StateConfig, + resSpec: defaultResSpec, + deleteInstanceFunc: pool.deleteInstance, + createInstanceFunc: pool.createInstanceAndAddCallerPodName, + leaseInterval: time.Duration(config.GlobalConfig.AutoScaleConfig.ScaleDownTime) * time.Millisecond, + RWMutex: sync.RWMutex{}, + logger: log.GetLogger().With(zap.Any("funcKey", funcSpec.FuncKey)), + } + if reservedInstanceQueue, err := pool.createInstanceQueue(types.InstanceTypeReserved, defaultResKey); err == nil { + pool.reservedInstanceQueue[defaultResKey] = reservedInstanceQueue.(*instancequeue.ScaledInstanceQueue) + } else { + return nil, err + } + if scaledInstanceQueue, err := pool.createInstanceQueue(types.InstanceTypeScaled, defaultResKey); err == nil { + pool.scaledInstanceQueue[defaultResKey] = scaledInstanceQueue.(*instancequeue.ScaledInstanceQueue) + } else { + return nil, err + } + if onDemandInstanceQueue, err := pool.createInstanceQueue(types.InstanceTypeOnDemand, defaultResKey); err == nil { + pool.onDemandInstanceQueue[defaultResKey] = onDemandInstanceQueue.(*instancequeue.OnDemandInstanceQueue) + } else { + return nil, err + } + go pool.instanceSessionReaper() + return pool, nil +} + +func (gi *GenericInstancePool) createInstanceQueue(instanceType types.InstanceType, resKey resspeckey.ResSpecKey) ( + instancequeue.InstanceQueue, snerror.SNError) { + metricsCollector, exist := gi.metricsCollector[resKey] + if !exist { + metricsCollector = metrics.NewBucketMetricsCollector(gi.FuncSpec.FuncKey, resKey.String()) + gi.metricsCollector[resKey] = metricsCollector + } + insAcqReqQueue, exist := gi.insAcqReqQueue[resKey] + if !exist { + funcKeyWithRes := utils.GenFuncKeyWithRes(gi.FuncSpec.FuncKey, resKey.String()) + requestTimeout := utils.GetRequestTimeout(gi.FuncSpec) + insAcqReqQueue = requestqueue.NewInsAcqReqQueue(funcKeyWithRes, requestTimeout) + gi.insAcqReqQueue[resKey] = insAcqReqQueue + } + insQueConfig := &instancequeue.InsQueConfig{ + FuncSpec: gi.FuncSpec, + InsThdReqQueue: insAcqReqQueue, + InstanceType: instanceType, + ResKey: resKey, + MetricsCollector: metricsCollector, + CreateInstanceFunc: gi.createInstance, + DeleteInstanceFunc: gi.deleteInstance, + SignalInstanceFunc: gi.handleSignal, + } + instanceQueue, err := instancequeue.BuildInstanceQueue(insQueConfig, insAcqReqQueue, metricsCollector) + if err != nil { + log.GetLogger().Errorf("failed to create %s instance queue for function %s of resource %+v error %s", + instanceType, gi.FuncSpec.FuncKey, resKey, err.Error()) + return nil, snerror.New(statuscode.StatusInternalServerError, err.Error()) + } + log.GetLogger().Infof("create %s instance queue for function %s of resource %+v", instanceType, + gi.FuncSpec.FuncKey, resKey) + return instanceQueue, nil +} + +// RecoverInstance recover instance pool, reserved pool and scaled pool +func (gi *GenericInstancePool) RecoverInstance(funcSpec *types.FunctionSpecification, + instancePoolState *types.InstancePoolState, deleteFunc bool, wg *sync.WaitGroup) { + defer wg.Done() + // existInstanceIdMap got from etcd + instanceIDMapsFromEtcd := registry.GlobalRegistry.InstanceRegistry.GetFunctionInstanceIDMap() + StateInstanceMap := instancePoolState.StateInstance + if len(StateInstanceMap) != 0 { + filterStateInstanceMap(StateInstanceMap, instanceIDMapsFromEtcd, gi.FuncSpec.FuncKey) + gi.recoverStateRouteMap(StateInstanceMap) + } +} + +// filterInstanceIDMap delete keys in filterMap but not in existsMap +func (gi *GenericInstancePool) filterInstanceIDMap(instanceMapFromState map[string]*types.Instance, + instanceMapFromEtcd map[string]map[string]*commonTypes.InstanceSpecification, instanceType string) { + if instanceMapFromEtcd == nil { + for id := range instanceMapFromState { + delete(instanceMapFromState, id) + } + return + } + // delete extra instance in state + for id := range instanceMapFromState { + var etcdIDMap map[string]*commonTypes.InstanceSpecification + var ok bool + if etcdIDMap, ok = instanceMapFromEtcd[gi.FuncSpec.FuncKey]; !ok { + delete(instanceMapFromState, id) + continue + } + if _, exist := etcdIDMap[id]; !exist { + delete(instanceMapFromState, id) + } else { + oldStatus := instanceMapFromState[id].InstanceStatus + instanceMapFromState[id].InstanceStatus = etcdIDMap[id].InstanceStatus + log.GetLogger().Infof("instanceFromState instanceStatus update: oldStatus: %v -> newStatus: %v, "+ + "instanceId: %v", oldStatus, etcdIDMap[id].InstanceStatus, id) + } + } + if etcdIDMap, ok := instanceMapFromEtcd[gi.FuncSpec.FuncKey]; ok { + // kill new instance + for id, insSpec := range etcdIDMap { + if _, exist := instanceMapFromState[id]; !exist && + insSpec.CreateOptions[types.FunctionKeyNote] == gi.FuncSpec.FuncKey && + insSpec.CreateOptions[types.InstanceTypeNote] == instanceType { + if err := gi.deleteInstance(&types.Instance{InstanceID: id}); err != nil { + continue + } + } + } + } +} + +// filterStateInstanceMap delete keys in filterMap but not in existsMap +func filterStateInstanceMap(filterMap map[string]*types.Instance, + existsMap map[string]map[string]*commonTypes.InstanceSpecification, funcKey string) { + if etcdIDMap, ok := existsMap[funcKey]; ok { + for _, instance := range filterMap { + if existInstance, exist := etcdIDMap[instance.InstanceID]; !exist { + instance.InstanceStatus.Code = int32(constant.KernelInstanceStatusExited) + } else { + instance.InstanceStatus = existInstance.InstanceStatus + } + } + } +} + +// CreateInstance will create an instance +func (gi *GenericInstancePool) CreateInstance(insCrtReq *types.InstanceCreateRequest) (*types.Instance, + snerror.SNError) { + log.GetLogger().Infof("start to create instance for function %s with instanceName %s traceID %s", + gi.FuncSpec.FuncKey, insCrtReq.InstanceName, insCrtReq.TraceID) + select { + case <-gi.FuncSpec.FuncCtx.Done(): + log.GetLogger().Errorf("function %s is deleted, can not create instance", gi.FuncSpec.FuncKey) + return nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, statuscode.FuncMetaNotFoundErrMsg) + default: + var ( + resSpec *resspeckey.ResourceSpecification + err snerror.SNError + ) + if utils.IsResSpecEmpty(insCrtReq.ResSpec) { + resSpec = gi.defaultResSpec + } else { + resSpec = insCrtReq.ResSpec + } + resKey := resspeckey.ConvertToResSpecKey(resSpec) + onDemandInstanceQueue, err := gi.acquireOnDemandInstanceQueue(resKey) + if err != nil { + log.GetLogger().Errorf("failed to acquire on-demand instance queue of function %s error %s", + gi.FuncSpec.FuncKey, err.Error()) + return nil, err + } + return onDemandInstanceQueue.CreateInstance(insCrtReq) + } +} + +// DeleteInstance will delete an instance +func (gi *GenericInstancePool) DeleteInstance(instance *types.Instance) snerror.SNError { + log.GetLogger().Infof("start to delete instance for function %s with instanceName %s", gi.FuncSpec.FuncKey, + instance.InstanceName) + select { + case <-gi.FuncSpec.FuncCtx.Done(): + log.GetLogger().Errorf("function %s is deleted, can not delete instance", gi.FuncSpec.FuncKey) + return snerror.New(statuscode.FuncMetaNotFoundErrCode, statuscode.FuncMetaNotFoundErrMsg) + default: + gi.RLock() + onDemandInstanceQueue, exist := gi.onDemandInstanceQueue[instance.ResKey] + if !exist { + gi.RUnlock() + return snerror.New(statuscode.StatusInternalServerError, "on-demand instance queue not exist") + } + gi.RUnlock() + return onDemandInstanceQueue.DeleteInstance(instance) + } +} + +func (gi *GenericInstancePool) acquireStateInstanceThread(insAcqReq *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, snerror.SNError) { + /* stateID + satateful func -> ok + nil stateID + satateful func -> err + stateID + non-satateful func -> err + nil stateID + non-satateful func -> ok + */ + if !gi.FuncSpec.FuncMetaData.IsStatefulFunction || len(insAcqReq.StateID) == 0 { + return nil, snerror.New(statuscode.StateMismatch, statuscode.StateMismatchErrMsg) + } + return gi.stateRoute.acquireStateInstanceThread(insAcqReq) +} + +// AcquireInstance will acquire an instance +func (gi *GenericInstancePool) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + logger := log.GetLogger().With(zap.Any("traceId", insAcqReq.TraceID), zap.Any("funcKey", gi.FuncSpec.FuncKey), + zap.Any("designatedInstance", insAcqReq.DesignateInstanceID)) + logger.Debugf("acquire instance") + select { + case <-gi.FuncSpec.FuncCtx.Done(): + logger.Errorf("function is deleted, can not acquire instance") + return nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, statuscode.FuncMetaNotFoundErrMsg) + default: + if insAcqReq.StateID != "" { + return gi.acquireStateInstanceThread(insAcqReq) + } + if utils.IsResSpecEmpty(insAcqReq.ResSpec) { + return nil, snerror.New(statuscode.InternalErrorCode, statuscode.InternalErrorMessage) + } + var ( + insAlloc *types.InstanceAllocation + ) + defer func() { + if insAlloc != nil && len(insAlloc.SessionInfo.SessionID) != 0 { + gi.recordInstanceSession(insAlloc) + } + }() + if len(insAcqReq.InstanceSession.SessionID) != 0 { + gi.processInstanceSession(insAcqReq) + } + gi.currentPoolLabel = insAcqReq.PoolLabel + resKey := resspeckey.ConvertToResSpecKey(insAcqReq.ResSpec) + // 当label有值但是没有对于label的实例配置时,直接返回报错 + gi.RLock() + if insAcqReq.ResSpec.InvokeLabel != DefaultInstanceLabel && gi.insConfig[resKey] == nil { + gi.RUnlock() + return nil, snerror.New(statuscode.InstanceLabelNotFoundErrCode, statuscode.InstanceLabelNotFoundErrMsg) + } + gi.RUnlock() + logger = logger.With(zap.Any("resource", resKey)) + if len(insAcqReq.InstanceName) != 0 { + onDemandInstanceQueue, err := gi.acquireOnDemandInstanceQueue(resKey) + if err != nil { + logger.Errorf("failed to acquire on-demand instance queue of function, error %s", err.Error()) + return nil, err + } + return onDemandInstanceQueue.AcquireInstance(insAcqReq) + } + if !insAcqReq.TrafficLimited { + reservedInstanceQueue, err := gi.acquireReservedInstanceQueue(resKey) + if err != nil { + logger.Errorf("failed to acquire reserved instance queue of function, error %s", err.Error()) + return nil, err + } + insAlloc, err = reservedInstanceQueue.AcquireInstance(insAcqReq) + if insAlloc != nil { + logger.Infof("acquired reserved instance thread %s of function", insAlloc.AllocationID) + return insAlloc, nil + } + if err.Code() != statuscode.NoInstanceAvailableErrCode && err.Code() != statuscode.InstanceNotFoundErrCode { + logger.Errorf("failed to acquire reserved instance of function, error %s", err.Error()) + return nil, err + } + } + return gi.acquireInstanceFromScaleQueueWithBackup(resKey, insAcqReq, logger) + } +} + +func (gi *GenericInstancePool) acquireInstanceFromScaleQueueWithBackup(resKey resspeckey.ResSpecKey, + insAcqReq *types.InstanceAcquireRequest, logger api.FormatLogger) (*types.InstanceAllocation, snerror.SNError) { + var backupQueue []instancequeue.InstanceQueue + gi.RLock() + for key, queue := range gi.reservedInstanceQueue { + if resKey != key && resKey.InvokeLabel == key.InvokeLabel && queue.GetInstanceNumber(false) > 0 { + backupQueue = append(backupQueue, queue) + } + } + for key, queue := range gi.scaledInstanceQueue { + if resKey != key && resKey.InvokeLabel == key.InvokeLabel && queue.GetInstanceNumber(false) > 0 { + backupQueue = append(backupQueue, queue) + } + } + gi.RUnlock() + if len(backupQueue) > 0 { + logger.Infof("has backup queue, will skip cold start") + insAcqReq.SkipWaitPending = true + } + scaledInstanceQueue, err := gi.acquireScaleInstanceQueue(resKey) + if err != nil { + logger.Errorf("failed to acquire scaled instance queue of function, error %s", err.Error()) + return nil, err + } + insAlloc, err := scaledInstanceQueue.AcquireInstance(insAcqReq) + if insAlloc == nil { + for _, queue := range backupQueue { + insAlloc, err = queue.AcquireInstance(insAcqReq) + if insAlloc != nil { + logger.Infof("acquired backup instance thread %s of function", insAlloc.AllocationID) + return insAlloc, nil + } + } + logger.Errorf("failed to acquire scaled instance of function, error %s", err.Error()) + return nil, err + } + logger.Infof("acquired scaled instance insAlloc %s of function", insAlloc.AllocationID) + return insAlloc, nil +} + +// ReleaseInstance will release an instance +func (gi *GenericInstancePool) ReleaseInstance(insAlloc *types.InstanceAllocation) { + instance := insAlloc.Instance + var err snerror.SNError + switch instance.InstanceType { + case types.InstanceTypeReserved: + if gi.reservedInstanceQueue != nil && gi.reservedInstanceQueue[instance.ResKey] != nil { + err = gi.reservedInstanceQueue[instance.ResKey].ReleaseInstance(insAlloc) + } + case types.InstanceTypeScaled: + gi.RLock() + scaledInstanceQueue, exist := gi.scaledInstanceQueue[instance.ResKey] + gi.RUnlock() + if !exist { + err = snerror.New(statuscode.StatusInternalServerError, + "instance queue with this resource doesn't exist") + break + } + err = scaledInstanceQueue.ReleaseInstance(insAlloc) + default: + log.GetLogger().Errorf("unsupported instance type") + } + if err != nil && err.Code() == statuscode.InstanceNotFoundErrCode { + gi.cleanInstanceSession(instance.InstanceID) + } + if err != nil { + log.GetLogger().Errorf("failed to release instance insAlloc %s error %s", insAlloc.AllocationID, err.Error()) + return + } + log.GetLogger().Infof("released instance insAlloc %s of function %s resource %+v", insAlloc.AllocationID, + gi.FuncSpec.FuncKey, instance.ResKey) +} + +// HandleFaaSManagerUpdate handles faas manager update +func (gi *GenericInstancePool) HandleFaaSManagerUpdate(faasManagerInfo faasManagerInfo) { + gi.Lock() + gi.faasManagerInfo = faasManagerInfo + gi.Unlock() +} + +// HandleFunctionEvent handles function event +func (gi *GenericInstancePool) HandleFunctionEvent(eventType registry.EventType, + funcSpec *types.FunctionSpecification) { + log.GetLogger().Infof("handling event type %s for function %s", eventType, gi.FuncSpec.FuncKey) + if eventType == registry.SubEventTypeUpdate { + gi.HandleFunctionUpdateEvent(funcSpec) + } + if eventType == registry.SubEventTypeDelete { + gi.HandleFunctionDeleteEvent() + } +} + +// HandleFunctionDeleteEvent - +func (gi *GenericInstancePool) HandleFunctionDeleteEvent() { + for _, mc := range gi.metricsCollector { + mc.Stop() + } + gi.RLock() + if gi.reservedInstanceQueue != nil { + for _, queue := range gi.reservedInstanceQueue { + queue.Destroy() + } + } + for _, queue := range gi.scaledInstanceQueue { + queue.Destroy() + } + for _, insThdReqQueue := range gi.insAcqReqQueue { + insThdReqQueue.Stop() + } + gi.RUnlock() + gi.stateRoute.Destroy() // 租约怎么办 +} + +// HandleFunctionUpdateEvent - +func (gi *GenericInstancePool) HandleFunctionUpdateEvent(funcSpec *types.FunctionSpecification) { + gi.Lock() + // watch of instance need this funcSpec to set currentNum in instance struct, so this funcSpec needs to be + // updated + preResSpec := gi.defaultResSpec + gi.FuncSpec = funcSpec + gi.defaultResSpec = resspeckey.ConvertResourceMetaDataToResSpec(funcSpec.ResourceMetaData) + gi.defaultResKey = resspeckey.ConvertToResSpecKey(gi.defaultResSpec) + if !reflect.DeepEqual(preResSpec, gi.defaultResSpec) && config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + reservedInstanceQueueMap := make(map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue, utils.DefaultMapSize) + for resKey, configuration := range gi.insConfig { + resSpec := gi.defaultResSpec.DeepCopy() + resSpec.InvokeLabel = resKey.InvokeLabel + newResKey := resspeckey.ConvertToResSpecKey(resSpec) + reservedInstanceQueue, err := gi.createInstanceQueue(types.InstanceTypeReserved, newResKey) + if err != nil { + log.GetLogger().Errorf("failed to create reserved instance queue for function %s during function "+ + "update error %s", gi.FuncSpec.FuncKey, err.Error()) + } else { + if queue, exist := gi.reservedInstanceQueue[resKey]; exist { + log.GetLogger().Debugf("reservedInstanceQueue for function %s destroy start", + gi.FuncSpec.FuncKey) + queue.Destroy() + } + reservedInstanceQueueMap[newResKey] = reservedInstanceQueue.(*instancequeue.ScaledInstanceQueue) + reservedInstanceQueueMap[newResKey].HandleInsConfigUpdate(configuration) + reservedInstanceQueueMap[newResKey].EnableInstanceScale() + } + } + gi.reservedInstanceQueue = reservedInstanceQueueMap + scaledInstanceQueue, err := gi.createInstanceQueue(types.InstanceTypeScaled, gi.defaultResKey) + if err != nil { + log.GetLogger().Errorf("failed to create scaled instance queue for function %s during function "+ + "update error %s", gi.FuncSpec.FuncKey, err.Error()) + } else { + gi.scaledInstanceQueue[gi.defaultResKey] = scaledInstanceQueue.(*instancequeue.ScaledInstanceQueue) + } + } + gi.createTimeout = utils.GetCreateTimeout(funcSpec) + gi.defaultPoolLabel = funcSpec.InstanceMetaData.PoolLabel + gi.currentPoolLabel = funcSpec.InstanceMetaData.PoolLabel + gi.minScaleUpdatedTime = time.Now() + gi.Unlock() + gi.RLock() + if gi.reservedInstanceQueue != nil { + for resKey, queue := range gi.reservedInstanceQueue { + if queue.GetSchedulerPolicy() != gi.FuncSpec.InstanceMetaData.SchedulePolicy { + err := gi.resetInstanceScheduler(queue, resKey) + if err != nil { + log.GetLogger().Errorf("%s failed to reset instance scheduler, from %s to %s, err: %s", + queue.GetSchedulerPolicy(), gi.FuncSpec.FuncKey, gi.FuncSpec.InstanceMetaData.SchedulePolicy, + err.Error()) + } + } + queue.HandleFuncSpecUpdate(funcSpec) + } + } + for resKey, queue := range gi.scaledInstanceQueue { + if queue.GetSchedulerPolicy() != gi.FuncSpec.InstanceMetaData.SchedulePolicy { + err := gi.resetInstanceScheduler(queue, resKey) + if err != nil { + log.GetLogger().Errorf("%s failed to reset instance scheduler, from %s to %s, err: %s", + queue.GetSchedulerPolicy(), gi.FuncSpec.FuncKey, gi.FuncSpec.InstanceMetaData.SchedulePolicy, + err.Error()) + } + } + queue.HandleFuncSpecUpdate(funcSpec) + } + for _, insThdReqQueue := range gi.insAcqReqQueue { + insThdReqQueue.UpdateRequestTimeout(utils.GetRequestTimeout(funcSpec)) + } + gi.RUnlock() + // todo 之后再考虑状态函数 函数元信息变更事件, 考虑两种情况,有状态无状态之间切换、有状态函数其他原信息变更 +} + +func (gi *GenericInstancePool) resetInstanceScheduler(instanceQueue *instancequeue.ScaledInstanceQueue, + resKey resspeckey.ResSpecKey) error { + log.GetLogger().Debugf("%s reset instance scheduler, from %s to %s", gi.FuncSpec.FuncKey, + instanceQueue.GetSchedulerPolicy(), gi.FuncSpec.InstanceMetaData.SchedulePolicy) + insAcqReqQueue, exist := gi.insAcqReqQueue[resKey] + if !exist { + funcKeyWithRes := utils.GenFuncKeyWithRes(gi.FuncSpec.FuncKey, resKey.String()) + requestTimeout := utils.GetRequestTimeout(gi.FuncSpec) + insAcqReqQueue = requestqueue.NewInsAcqReqQueue(funcKeyWithRes, requestTimeout) + gi.insAcqReqQueue[resKey] = insAcqReqQueue + } + instanceQueue.L.Lock() + defer instanceQueue.L.Unlock() + var currentInstance []*types.Instance + oldInstanceScheduler := instanceQueue.GetInstanceScheduler() + for { + ins := oldInstanceScheduler.PopInstance(true) + if ins == nil { + break + } + currentInstance = append(currentInstance, ins) + } + oldInstanceScheduler.Destroy() + err := instancequeue.AssembleScheduler(gi.FuncSpec.InstanceMetaData.SchedulePolicy, instanceQueue, insAcqReqQueue) + if err != nil { + return err + } + instanceQueue.ReconnectWithScaler() + for _, instance := range currentInstance { + instanceQueue.HandleInstanceUpdate(instance) + } + return nil +} + +// HandleAliasEvent handles alias event +func (gi *GenericInstancePool) HandleAliasEvent(eventType registry.EventType, aliasUrn string) { + log.GetLogger().Infof("pool %s handling event type %s for alias,urn:%s", gi.FuncSpec.FuncKey, eventType, aliasUrn) + if eventType == registry.SubEventTypeUpdate || eventType == registry.SubEventTypeDelete { + gi.updateInstanceAliasData() + } +} + +func (gi *GenericInstancePool) updateInstanceAliasData() { + if gi.reservedInstanceQueue != nil { + for _, queue := range gi.reservedInstanceQueue { + queue.HandleAliasUpdate() + } + } + for _, instanceQueue := range gi.scaledInstanceQueue { + instanceQueue.HandleAliasUpdate() + } +} + +// HandleFaaSSchedulerEvent - +func (gi *GenericInstancePool) HandleFaaSSchedulerEvent() { + if gi.reservedInstanceQueue != nil { + for _, queue := range gi.reservedInstanceQueue { + queue.HandleFaaSSchedulerUpdate() + } + } + for _, instanceQueue := range gi.scaledInstanceQueue { + instanceQueue.HandleFaaSSchedulerUpdate() + } +} + +// HandleInstanceEvent handles instance event +func (gi *GenericInstancePool) HandleInstanceEvent(eventType registry.EventType, instance *types.Instance) { + instance.FuncKey = gi.FuncSpec.FuncKey + logger := log.GetLogger().With(zap.Any("", "HandleInstanceEvent")). + With(zap.Any("FuncKey", gi.FuncSpec.FuncKey)). + With(zap.Any("ResKey", instance.ResKey)). + With(zap.Any("InstanceID", instance.InstanceID)).With(zap.Any("eventType", eventType)) + // Reserved instance must have insConfig. + insConfResKey := gi.defaultResKey + insConfResKey.InvokeLabel = instance.ResKey.InvokeLabel + gi.RLock() + if _, ok := gi.insConfig[insConfResKey]; gi.synced && instance.InstanceType == types.InstanceTypeReserved && !ok { + gi.RUnlock() + go DeleteUnexpectInstance(instance.ParentID, instance.InstanceID, instance.FuncKey, logger) + return + } + gi.RUnlock() + logger.Infof("handling instance event") + switch eventType { + case registry.SubEventTypeUpdate: + gi.handleInstanceUpdate(instance, logger) + case registry.SubEventTypeDelete: + gi.handleInstanceDelete(instance, logger) + case registry.SubEventTypeRemove: + gi.removeInstance(instance, logger) + default: + logger.Warnf("unsupported instance event type: %s", eventType) + } +} + +var defaultInstanceNeedProcessInstanceCodeMap = map[int32]struct{}{ + int32(constant.KernelInstanceStatusRunning): {}, + int32(constant.KernelInstanceStatusSubHealth): {}, +} + +var reservedAndScaledInstanceNeedProcessInstanceCodeMap = map[int32]struct{}{ + int32(constant.KernelInstanceStatusRunning): {}, + int32(constant.KernelInstanceStatusSubHealth): {}, + int32(constant.KernelInstanceStatusEvicting): {}, // 为了支持绑定会话的实例,在优雅退出时,依旧能支持会话请求, +} + +func getNeedProcessInstanceCodeMap(instanceType types.InstanceType) map[int32]struct{} { + switch instanceType { + case types.InstanceTypeReserved, types.InstanceTypeScaled: + return reservedAndScaledInstanceNeedProcessInstanceCodeMap + default: + return defaultInstanceNeedProcessInstanceCodeMap + } +} + +var defaultFaultyInstanceStatusMap = map[int32]struct{}{ + int32(constant.KernelInstanceStatusFatal): {}, + int32(constant.KernelInstanceStatusScheduleFailed): {}, + int32(constant.KernelInstanceStatusEvicted): {}, + int32(constant.KernelInstanceStatusEvicting): {}, + int32(constant.KernelInstanceStatusExiting): {}, +} + +var reservedAndScaledFaultyInstanceStatusMap = map[int32]struct{}{ + int32(constant.KernelInstanceStatusFatal): {}, + int32(constant.KernelInstanceStatusScheduleFailed): {}, + int32(constant.KernelInstanceStatusEvicted): {}, + int32(constant.KernelInstanceStatusExiting): {}, +} + +func getFaultyInstanceStatusMap(instanceType types.InstanceType) map[int32]struct{} { + switch instanceType { + case types.InstanceTypeReserved, types.InstanceTypeScaled: + return reservedAndScaledFaultyInstanceStatusMap + default: + return defaultFaultyInstanceStatusMap + } +} + +func (gi *GenericInstancePool) handleInstanceUpdate(instance *types.Instance, logger api.FormatLogger) { + instanceStatusCodeMap := getNeedProcessInstanceCodeMap(instance.InstanceType) + if _, ok := instanceStatusCodeMap[instance.InstanceStatus.Code]; ok { + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud && + instance.FuncSig != gi.FuncSpec.FuncMetaSignature { + logger.Errorf("handle event failed, function signature changes, killing instance now") + gi.removeInstance(instance, logger) + return + } + switch instance.InstanceType { + case types.InstanceTypeOnDemand: + onDemandInstanceQueue, err := gi.acquireOnDemandInstanceQueue(instance.ResKey) + if err != nil { + logger.Errorf("failed to acquire on-demand instance queue error %s", err.Error()) + break + } + onDemandInstanceQueue.HandleInstanceUpdate(instance) + case types.InstanceTypeReserved: + reservedInstanceQueue, err := gi.acquireReservedInstanceQueue(instance.ResKey) + if err != nil { + logger.Errorf("failed to acquire reserved instance queue error %s", err.Error()) + break + } + reservedInstanceQueue.HandleInstanceUpdate(instance) + case types.InstanceTypeScaled: + scaledInstanceQueue, err := gi.acquireScaleInstanceQueue(instance.ResKey) + if err != nil { + logger.Errorf("failed to acquire scaled instance queue error %s", err.Error()) + break + } + scaledInstanceQueue.HandleInstanceUpdate(instance) + case types.InstanceTypeState: + gi.stateRoute.HandleInstanceUpdate(instance) + default: + logger.Warnf("instance type %s update not implemented", instance.InstanceType) + } + gi.judgeExceedInstance(instance.ResKey, logger) + } + + faultyInstanceStatusMap := getFaultyInstanceStatusMap(instance.InstanceType) + if _, ok := faultyInstanceStatusMap[instance.InstanceStatus.Code]; ok { + logger.Warnf("instance status is updated to %d, remove this instance now", instance.InstanceStatus.Code) + gi.removeInstance(instance, logger) + } +} + +func (gi *GenericInstancePool) handleInstanceDelete(instance *types.Instance, logger api.FormatLogger) { + signalmanager.GetSignalManager().RemoveInstance(instance.InstanceID) + gi.cleanInstanceSession(instance.InstanceID) + switch instance.InstanceType { + case types.InstanceTypeOnDemand: + gi.RLock() + onDemandInstanceQueue, exist := gi.onDemandInstanceQueue[instance.ResKey] + gi.RUnlock() + if exist { + onDemandInstanceQueue.HandleInstanceDelete(instance) + } + case types.InstanceTypeReserved: + gi.RLock() + reservedInstanceQueue, exist := gi.reservedInstanceQueue[instance.ResKey] + gi.RUnlock() + if exist { + reservedInstanceQueue.HandleInstanceDelete(instance) + } + case types.InstanceTypeScaled: + gi.RLock() + scaledInstanceQueue, exist := gi.scaledInstanceQueue[instance.ResKey] + gi.RUnlock() + if exist { + scaledInstanceQueue.HandleInstanceDelete(instance) + } + default: + logger.Warnf("instance type %s update not implemented", instance.InstanceType) + } +} + +func (gi *GenericInstancePool) acquireOnDemandInstanceQueue(resKey resspeckey.ResSpecKey) ( + *instancequeue.OnDemandInstanceQueue, snerror.SNError) { + var err snerror.SNError + gi.RLock() + instanceQueue, exist := gi.onDemandInstanceQueue[resKey] + gi.RUnlock() + if exist { + return instanceQueue, err + } + gi.Lock() + // 需要二次判断,防止重复创建 + instanceQueue, exist = gi.onDemandInstanceQueue[resKey] + if !exist { + log.GetLogger().Debugf("createInstanceQueue type OnDemand for function %s destroy start", + gi.FuncSpec.FuncKey) + var queue instancequeue.InstanceQueue + queue, err = gi.createInstanceQueue(types.InstanceTypeOnDemand, resKey) + if err == nil { + instanceQueue = queue.(*instancequeue.OnDemandInstanceQueue) + gi.onDemandInstanceQueue[resKey] = instanceQueue + } + } + gi.Unlock() + return instanceQueue, err +} + +func (gi *GenericInstancePool) acquireReservedInstanceQueue(resKey resspeckey.ResSpecKey) ( + *instancequeue.ScaledInstanceQueue, snerror.SNError) { + var snErr snerror.SNError + gi.RLock() + instanceQueue, exist := gi.reservedInstanceQueue[resKey] + gi.RUnlock() + if exist { + return instanceQueue, snErr + } + // 需要createInstanceQueue时,从读锁升级为写锁 + gi.Lock() + // 需要二次判断,防止重复创建 + instanceQueue, exist = gi.reservedInstanceQueue[resKey] + if !exist { + log.GetLogger().Debugf("createInstanceQueue type reserved for function %s destroy start", + gi.FuncSpec.FuncKey) + insQ, err := gi.createInstanceQueue(types.InstanceTypeReserved, resKey) + if err == nil { + gi.reservedInstanceQueue[resKey] = insQ.(*instancequeue.ScaledInstanceQueue) + instanceQueue = insQ.(*instancequeue.ScaledInstanceQueue) + if gi.insConfig[resKey] != nil { + instanceQueue.HandleInsConfigUpdate(gi.insConfig[resKey]) + } + } + snErr = err + } + gi.Unlock() + return instanceQueue, snErr +} + +func (gi *GenericInstancePool) acquireScaleInstanceQueue(resKey resspeckey.ResSpecKey) ( + *instancequeue.ScaledInstanceQueue, snerror.SNError) { + var snErr snerror.SNError + gi.RLock() + instanceQueue, exist := gi.scaledInstanceQueue[resKey] + gi.RUnlock() + if exist { + return instanceQueue, snErr + } + // 需要createInstanceQueue时,从读锁升级为写锁 + gi.Lock() + // 需要二次判断,防止重复创建 + instanceQueue, exist = gi.scaledInstanceQueue[resKey] + if !exist { + log.GetLogger().Debugf("createInstanceQueue type scaled for function %s destroy start", + gi.FuncSpec.FuncKey) + insQ, err := gi.createInstanceQueue(types.InstanceTypeScaled, resKey) + if err == nil { + gi.scaledInstanceQueue[resKey] = insQ.(*instancequeue.ScaledInstanceQueue) + instanceQueue = insQ.(*instancequeue.ScaledInstanceQueue) + if gi.insConfig[resKey] != nil { + instanceQueue.HandleInsConfigUpdate(gi.insConfig[resKey]) + } + } + snErr = err + } + gi.Unlock() + return instanceQueue, snErr +} + +func (gi *GenericInstancePool) removeInstance(instance *types.Instance, logger api.FormatLogger) { + logger.Infof("start to removed instance") + signalmanager.GetSignalManager().RemoveInstance(instance.InstanceID) + switch instance.InstanceType { + case types.InstanceTypeReserved: + gi.RLock() + reservedInstanceQueue, exist := gi.reservedInstanceQueue[instance.ResKey] + gi.RUnlock() + // label没有了也应该要清理instance + if !exist { + logger.Errorf("reserved queue of function %s resource %+v doesn't exist", gi.FuncSpec.FuncKey, + instance.ResKey) + go DeleteUnexpectInstance(instance.ParentID, instance.InstanceID, instance.FuncKey, logger) + break + } + reservedInstanceQueue.HandleFaultyInstance(instance) + case types.InstanceTypeScaled: + gi.RLock() + scaledInstanceQueue, exist := gi.scaledInstanceQueue[instance.ResKey] + gi.RUnlock() + // label没有了也应该要清理instance + if !exist { + logger.Errorf("scaled queue of function %s resource %+v doesn't exist", gi.FuncSpec.FuncKey, + instance.ResKey) + go DeleteUnexpectInstance(instance.ParentID, instance.InstanceID, instance.FuncKey, logger) + break + } + scaledInstanceQueue.HandleFaultyInstance(instance) + // todo 以后这里要考虑删除对应的租约 + gi.stateRoute.DeleteStateInstanceByInstanceID(instance.InstanceID) + default: + logger.Errorf("unsupported instance type") + } + logger.Infof("succeed to remove instance") +} + +// HandleInstanceConfigEvent updates instance configuration +func (gi *GenericInstancePool) HandleInstanceConfigEvent(eventType registry.EventType, + insConfig *instanceconfig.Configuration) { + logger := log.GetLogger().With(zap.Any("", "HandleInstanceConfigEvent")). + With(zap.Any("FuncKey", gi.FuncSpec.FuncKey)). + With(zap.Any("Label", insConfig.InstanceLabel)) + logger.Infof("handle start") + // currently insConfig isn't stored with resSpec, resKey of all insConfig will be set with default resource + resKey := resspeckey.ConvertToResSpecKey(gi.defaultResSpec) + resKey.InvokeLabel = insConfig.InstanceLabel + switch eventType { + case registry.SubEventTypeUpdate: + gi.Lock() + if _, ok := gi.insConfig[resKey]; !ok || + gi.insConfig[resKey].InstanceMetaData.MinInstance != insConfig.InstanceMetaData.MinInstance { + gi.minScaleUpdatedTime = time.Now() + } + gi.insConfig[resKey] = generateInstanceConfig(insConfig) + gi.Unlock() + // there is a checkpoint of insConfig in create function for no label instance, insConfig of labeled instances + // will be checked in the process of label + if len(insConfig.InstanceLabel) == 0 { + gi.closeChanOnce.Do(func() { close(gi.waitInsConfigChan) }) + } + gi.handleInstanceConfigUpdate(insConfig, resKey, logger) + gi.judgeExceedInstance(resKey, logger) + case registry.SubEventTypeDelete: + gi.Lock() + delete(gi.insConfig, resKey) + gi.Unlock() + gi.handleInstanceConfigDelete(insConfig, resKey, logger) + metrics.ClearMetricsForFunctionInsConfig(gi.FuncSpec, resKey.InvokeLabel) + default: + logger.Warnf("unsupported instance config event type: %s", eventType) + } +} + +func (gi *GenericInstancePool) handleInstanceConfigUpdate(insConfig *instanceconfig.Configuration, + resKey resspeckey.ResSpecKey, logger api.FormatLogger) { + reservedInstanceQueue, err := gi.acquireReservedInstanceQueue(resKey) + if err == nil { + reservedInstanceQueue.HandleInsConfigUpdate(insConfig) + reservedInstanceQueue.EnableInstanceScale() + } else { + logger.Errorf("acquire reserved instance queue failed, err %s", err.Error()) + } + gi.RLock() + for scaleResKey, queue := range gi.scaledInstanceQueue { + if scaleResKey.InvokeLabel != resKey.InvokeLabel { + continue + } + queue.HandleInsConfigUpdate(insConfig) + queue.EnableInstanceScale() + } + gi.RUnlock() +} + +func (gi *GenericInstancePool) handleInstanceConfigDelete(insConfig *instanceconfig.Configuration, + resKey resspeckey.ResSpecKey, logger api.FormatLogger) { + logger.Debugf("handleInstanceConfigDelete") + insConfig.InstanceMetaData.MinInstance = 0 + insConfig.InstanceMetaData.MaxInstance = 0 + gi.RLock() + reservedInstanceQueue, exist := gi.reservedInstanceQueue[resKey] + gi.RUnlock() + if exist { + // labeled instance queue need to be destroyed, instance queue with no label updates with min/max=0/0 + if resKey.InvokeLabel != DefaultInstanceLabel { + gi.Lock() + delete(gi.reservedInstanceQueue, resKey) + gi.Unlock() + reservedInstanceQueue.Destroy() + } else { + reservedInstanceQueue.HandleInsConfigUpdate(insConfig) + reservedInstanceQueue.EnableInstanceScale() + } + } + // labeled instance queue need to be destroyed, instance queue with no label takes no action, minInstance won't + // influence scaled instance, maxInstance will be handled in createInstanceFunc + for res, queue := range gi.scaledInstanceQueue { + if res.InvokeLabel != DefaultInstanceLabel && res.InvokeLabel == resKey.InvokeLabel { + gi.Lock() + delete(gi.scaledInstanceQueue, resKey) + gi.Unlock() + queue.Destroy() + } + } +} + +// CleanOrphansInstanceQueue destroy instance queue without instance config +func (gi *GenericInstancePool) CleanOrphansInstanceQueue() { + gi.Lock() + defer gi.Unlock() + for resKey, queue := range gi.reservedInstanceQueue { + if _, ok := gi.insConfig[resKey]; !ok && resKey.InvokeLabel != DefaultInstanceLabel { + delete(gi.reservedInstanceQueue, resKey) + queue.Destroy() + } + } + for key, queue := range gi.scaledInstanceQueue { + insConfResKey := gi.defaultResKey + insConfResKey.InvokeLabel = key.InvokeLabel + if _, ok := gi.insConfig[insConfResKey]; !ok { + delete(gi.scaledInstanceQueue, key) + queue.Destroy() + } + } + for key, queue := range gi.onDemandInstanceQueue { + insConfResKey := gi.defaultResKey + insConfResKey.InvokeLabel = key.InvokeLabel + if _, ok := gi.insConfig[insConfResKey]; !ok { + delete(gi.onDemandInstanceQueue, key) + queue.Destroy() + } + } + gi.synced = true +} + +func generateInstanceConfig(insConf *instanceconfig.Configuration) *instanceconfig.Configuration { + if insConf.InstanceMetaData.MinInstance < 0 { + insConf.InstanceMetaData.MinInstance = 0 + } + if insConf.InstanceMetaData.MaxInstance < 0 { + insConf.InstanceMetaData.MaxInstance = maxInstanceLimit + } + return insConf +} + +// UpdateInvokeMetrics sends invoke metrics of instance thread to autoScaler +func (gi *GenericInstancePool) UpdateInvokeMetrics(resKey resspeckey.ResSpecKey, + InsThdMetrics *types.InstanceThreadMetrics) { + gi.RLock() + metricsCollector, exist := gi.metricsCollector[resKey] + gi.RUnlock() + if !exist { + log.GetLogger().Errorf("update invoke metrics failed for function %s, resource %s doesn't exist", + gi.FuncSpec.FuncKey, resKey) + return + } + metricsCollector.UpdateInvokeMetrics(InsThdMetrics) +} + +// GetFuncSpec will return the funcSpec of this pool +func (gi *GenericInstancePool) GetFuncSpec() *types.FunctionSpecification { + gi.RLock() + funcSpec := gi.FuncSpec + gi.RUnlock() + return funcSpec +} + +func (gi *GenericInstancePool) getCurrentInstanceNum(resKey resspeckey.ResSpecKey) (int, int, snerror.SNError) { + // insConfig is stored with default resource, labeled instances have their individual limit, all types of instances + // are subject to the limit with default resource and no label + if _, exist := gi.insConfig[resKey]; !exist { + log.GetLogger().Warnf("insConfig of function %s for resource %+v doesn't exist", gi.FuncSpec.FuncKey, resKey) + return 0, 0, snerror.New(statuscode.FunctionIsDisabled, fmt.Sprintf("function is disabled")) + } + var scaledNum, reservedNum int + var scaledNumGlobal, reservedNumGlobal, pendingNumGlobal int + + for res, queue := range gi.reservedInstanceQueue { + reservedNumGlobal += queue.GetInstanceNumber(true) + // reserved instance with label can't be set with dynamic resource + if res == resKey { + reservedNum += queue.GetInstanceNumber(true) + } + } + for res, queue := range gi.scaledInstanceQueue { + scaledNumGlobal += queue.GetInstanceNumber(true) + // scaled instance with label can be set with dynamic resource + if res.InvokeLabel == resKey.InvokeLabel { + scaledNum += queue.GetInstanceNumber(true) + } + } + for _, v := range gi.pendingInstanceNum { + pendingNumGlobal += v + } + sumForCurrentLabel := scaledNum + reservedNum + gi.pendingInstanceNum[resKey.InvokeLabel] + sumForGlobal := scaledNumGlobal + reservedNumGlobal + pendingNumGlobal + return sumForCurrentLabel, sumForGlobal, nil +} + +// 判断实例是否可以继续扩容,需要判断同一个label的实例数之和是否超出label级的 +// 最大实例限制,再判断同一个函数版本的实例数之和是否超出函数级的最大实例限制 +func (gi *GenericInstancePool) checkScaleLimit(resKey resspeckey.ResSpecKey) snerror.SNError { + insConfResKey := gi.defaultResKey + insConfResKey.InvokeLabel = resKey.InvokeLabel + sumForCurrentLabel, sumForGlobal, err := gi.getCurrentInstanceNum(insConfResKey) + if err != nil { + return err + } + limitForLabel := int(gi.insConfig[insConfResKey].InstanceMetaData.MaxInstance) + limitForGlobal := int(gi.insConfig[gi.defaultResKey].InstanceMetaData.MaxInstance) + + reachLimitForLabel := sumForCurrentLabel >= limitForLabel + reachLimitForGlobal := sumForGlobal >= limitForGlobal + if reachLimitForLabel { + log.GetLogger().Errorf("function %s reaches scale limit %d of label %s, pending num %d, current num %d, "+ + "stop creating", gi.FuncSpec.FuncKey, limitForLabel, resKey.InvokeLabel, + gi.pendingInstanceNum[resKey.InvokeLabel], sumForCurrentLabel) + return snerror.New(statuscode.ReachMaxInstancesCode, fmt.Sprintf("%s %d", statuscode.ReachMaxInstancesErrMsg, + limitForLabel)) + } + if reachLimitForGlobal { + log.GetLogger().Errorf("function %s reaches general scale limit %d, current num %d, stop creating", + gi.FuncSpec.FuncKey, sumForGlobal, limitForGlobal) + return snerror.New(statuscode.ReachMaxInstancesCode, fmt.Sprintf("%s %d", statuscode.ReachMaxInstancesErrMsg, + limitForGlobal)) + } + return nil +} + +func (gi *GenericInstancePool) checkTenantLimit(instanceType types.InstanceType) (bool, bool) { + if !config.GlobalConfig.TenantInsNumLimitEnable { + return false, false + } + tenantID := urnutils.GetTenantFromFuncKey(gi.FuncSpec.FuncKey) + return tenantquota.IncreaseTenantInstanceNum(tenantID, instanceType == types.InstanceTypeReserved) +} + +func (gi *GenericInstancePool) createInstanceAndAddCallerPodName(resSpec *resspeckey.ResourceSpecification, + instanceType types.InstanceType, callerPodName string) (*types.Instance, error) { + return gi.createInstanceFunc("", instanceType, gi.defaultResKey, nil, createOption{callerPodName: callerPodName}) +} + +func (gi *GenericInstancePool) createInstance(insName string, instanceType types.InstanceType, + resKey resspeckey.ResSpecKey, createEvent []byte) (*types.Instance, error) { + return gi.createInstanceFunc(insName, instanceType, resKey, createEvent, createOption{}) +} + +func (gi *GenericInstancePool) createInstanceFunc(insName string, instanceType types.InstanceType, + resKey resspeckey.ResSpecKey, createEvent []byte, createOption createOption) (instance *types.Instance, + createErr error) { + logger := log.GetLogger().With(zap.Any("funcKey", gi.FuncSpec.FuncKey)) + // createInstance use insConfig, so gi.insConfig must exist + <-gi.waitInsConfigChan + gi.Lock() + if config.GlobalConfig.InstanceOperationBackend != constant.BackendTypeFG { + reachMaxOnDemandInsNum, reachMaxReversedInsNum := gi.checkTenantLimit(instanceType) + logger.Debugf("checkTenantLimit completed, reachMaxOnDemandInsNum is %v, reachMaxReversedInsNum is %v", + reachMaxOnDemandInsNum, reachMaxReversedInsNum) + maxOnDemandInsNum, maxReversedInsNum := tenantquota.GetTenantCache().GetTenantQuotaNum( + urnutils.GetTenantFromFuncKey(gi.FuncSpec.FuncKey)) + if reachMaxOnDemandInsNum { + gi.Unlock() + return nil, snerror.New(statuscode.ReachMaxOnDemandInstancesPerTenant, fmt.Sprintf("%s, limit %d", + statuscode.ReachMaxInstancesPerTenantErrMsg, maxOnDemandInsNum)) + } + if reachMaxReversedInsNum { + logger.Warnf("reach max reversed instance num per tenant: %s, limit %d", urnutils.Anonymize( + urnutils.GetTenantFromFuncKey(gi.FuncSpec.FuncKey)), maxReversedInsNum) + } + if err := gi.checkScaleLimit(resKey); err != nil { + gi.Unlock() + return nil, err + } + } + gi.pendingInstanceNum[resKey.InvokeLabel]++ + gi.Unlock() + defer func() { + gi.Lock() + if _, ok := gi.pendingInstanceNum[resKey.InvokeLabel]; ok { + gi.pendingInstanceNum[resKey.InvokeLabel]-- + if gi.pendingInstanceNum[resKey.InvokeLabel] == 0 { + delete(gi.pendingInstanceNum, resKey.InvokeLabel) + } + } + gi.Unlock() + }() + + gi.RLock() + createRequest := createInstanceRequest{ + funcSpec: gi.FuncSpec, + poolLabel: gi.currentPoolLabel, + createTimeout: gi.createTimeout, + faasManagerInfo: gi.faasManagerInfo, + resKey: resKey, + instanceName: insName, + callerPodName: createOption.callerPodName, + createEvent: createEvent, + instanceType: instanceType, + } + insConfResKey := gi.defaultResKey + insConfResKey.InvokeLabel = resKey.InvokeLabel + if insConf, exist := gi.insConfig[insConfResKey]; exist { + createRequest.nuwaRuntimeInfo = &insConf.NuwaRuntimeInfo + } + gi.RUnlock() + return CreateInstance(createRequest) +} + +func (gi *GenericInstancePool) deleteInstance(instance *types.Instance) error { + gi.RLock() + currentFaasManagerInfo := gi.faasManagerInfo + gi.RUnlock() + return DeleteInstance(gi.FuncSpec, currentFaasManagerInfo, instance) +} + +func (gi *GenericInstancePool) handleSignal(instance *types.Instance, signal int) { + SignalInstance(instance, signal) +} + +func (gi *GenericInstancePool) handleManagedChange() { + gi.Lock() + log.GetLogger().Debugf("HandleFuncOwnerChange for function %s start", + gi.FuncSpec.FuncKey) + for _, q := range gi.scaledInstanceQueue { + q.HandleFuncOwnerChange() + } + for k, q := range gi.reservedInstanceQueue { + q.HandleFuncOwnerChange() + if _, ok := gi.insConfig[k]; ok { + q.HandleInsConfigUpdate(gi.insConfig[k]) + } + } + gi.Unlock() +} + +func (gi *GenericInstancePool) handleRatioChange(ratio int) { + gi.Lock() + for _, q := range gi.scaledInstanceQueue { + q.HandleRatioUpdate(ratio) + } + for _, q := range gi.reservedInstanceQueue { + q.HandleRatioUpdate(ratio) + } + gi.Unlock() +} + +func (gi *GenericInstancePool) recordInstanceSession(insAlloc *types.InstanceAllocation) { + gi.Lock() + sessions, exist := gi.instanceSessionMap[insAlloc.Instance.InstanceID] + if !exist { + sessions = make(map[string]struct{}, utils.DefaultMapSize) + gi.instanceSessionMap[insAlloc.Instance.InstanceID] = sessions + } + sessions[insAlloc.SessionInfo.SessionID] = struct{}{} + gi.sessionRecordMap[insAlloc.SessionInfo.SessionID] = sessionRecord{ + instance: insAlloc.Instance, + sessionCtx: insAlloc.SessionInfo.SessionCtx, + } + gi.Unlock() +} + +func (gi *GenericInstancePool) cleanInstanceSession(instanceID string) { + gi.RLock() + _, exist := gi.instanceSessionMap[instanceID] + if !exist { + gi.RUnlock() + return + } + gi.RUnlock() + gi.Lock() + sessions, exist := gi.instanceSessionMap[instanceID] + if exist { + delete(gi.instanceSessionMap, instanceID) + for sessionID, _ := range sessions { + delete(gi.sessionRecordMap, sessionID) + } + } + gi.Unlock() +} + +func (gi *GenericInstancePool) processInstanceSession(insAcqReq *types.InstanceAcquireRequest) { + gi.RLock() + record, exist := gi.sessionRecordMap[insAcqReq.InstanceSession.SessionID] + gi.RUnlock() + if exist { + select { + case <-record.sessionCtx.Done(): + gi.Lock() + delete(gi.sessionRecordMap, insAcqReq.InstanceSession.SessionID) + sessions, exist := gi.instanceSessionMap[record.instance.InstanceID] + if exist { + delete(sessions, insAcqReq.InstanceSession.SessionID) + } + gi.Unlock() + default: + insAcqReq.DesignateInstanceID = record.instance.InstanceID + } + } +} + +func (gi *GenericInstancePool) instanceSessionReaper() { + // for LLT convenience + if gi.FuncSpec == nil || gi.FuncSpec.FuncCtx == nil { + return + } + ticker := time.NewTicker(gi.sessionReaperInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + case _, ok := <-gi.FuncSpec.FuncCtx.Done(): + if !ok { + log.GetLogger().Warnf("instance pool stopped, stop instance session reaper now") + return + } + } + gi.Lock() + for sessionID, record := range gi.sessionRecordMap { + select { + case <-record.sessionCtx.Done(): + delete(gi.sessionRecordMap, sessionID) + sessions, exist := gi.instanceSessionMap[record.instance.InstanceID] + if exist { + delete(sessions, sessionID) + } + default: + } + } + gi.Unlock() + } +} + +func (gi *GenericInstancePool) judgeExceedInstance(resKey resspeckey.ResSpecKey, logger api.FormatLogger) { + if config.GlobalConfig.Scenario == types.ScenarioWiseCloud { + return + } + insConfResKey := gi.defaultResKey + insConfResKey.InvokeLabel = resKey.InvokeLabel + gi.RLock() + defer gi.RUnlock() + _, ok1 := gi.insConfig[insConfResKey] + _, ok2 := gi.insConfig[gi.defaultResKey] + if !ok1 || !ok2 { + logger.Errorf("instance config not exist, skip") + return + } + sumForCurrentLabel, sumForGlobal, err := gi.getCurrentInstanceNum(insConfResKey) + if err != nil { + return + } + limitForLabel := int(gi.insConfig[insConfResKey].InstanceMetaData.MaxInstance) + limitForGlobal := int(gi.insConfig[gi.defaultResKey].InstanceMetaData.MaxInstance) + exceedLimitForLabel := sumForCurrentLabel > limitForLabel + exceedLimitForGlobal := sumForGlobal > limitForGlobal + if !exceedLimitForGlobal && !exceedLimitForLabel { + return + } + instanceQueue, exist := gi.scaledInstanceQueue[resKey] + if !exist { + logger.Warnf("scaled instance queue not exist, delete default scale queue") + return + } + scaleDiff := utils.IntMax(sumForCurrentLabel-limitForLabel, sumForGlobal-limitForGlobal) + logger.Infof("start scale down exceed instance, %d", scaleDiff) + instanceQueue.ScaleDownHandler(scaleDiff, func(i int) { + logger.Infof("scale down exceed instance %d succeed", i) + }) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/instancepool_test.go b/yuanrong/pkg/functionscaler/instancepool/instancepool_test.go new file mode 100644 index 0000000..ad0cf5e --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/instancepool_test.go @@ -0,0 +1,1351 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "sync" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + mockUtils "yuanrong/pkg/common/faas_common/utils" + wisecloudTypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/instancequeue" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/types" +) + +var ( + pool InstancePool +) + +func initRegistry() { + defer ApplyFunc((*registry.FaasSchedulerRegistry).WaitForETCDList, func() {}).Reset() + _ = registry.InitRegistry(make(chan struct{})) + registry.GlobalRegistry.FaaSSchedulerRegistry = registry.NewFaasSchedulerRegistry(make(chan struct{})) + selfregister.SelfInstanceID = "schedulerID-1" + selfregister.GlobalSchedulerProxy.Add(&commonTypes.InstanceInfo{ + TenantID: "123456789", + FunctionName: "faasscheduler", + Version: "lastest", + InstanceName: "schedulerID-1", + }, "") +} + +func CreateTestInstancePool() InstancePool { + funcSpec := &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + FuncMetaData: commonTypes.FuncMetaData{ + Handler: "myHandler", + }, + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 100, + MinInstance: 1, + ConcurrentNum: 1, + }, + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Handler: "myInitializer", + }, + }, + } + initRegistry() + insPool, _ := NewGenericInstancePool(funcSpec, faasManagerInfo{}) + return insPool +} + +func CreateTestInstancePoolWithVPC() InstancePool { + funcSpec := &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + FuncMetaData: commonTypes.FuncMetaData{ + Handler: "myHandler", + }, + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 100, + ConcurrentNum: 100, + }, + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Handler: "myInitializer", + }, + VpcConfig: &commonTypes.VpcConfig{}, + }, + } + initRegistry() + insPool, _ := NewGenericInstancePool(funcSpec, faasManagerInfo{}) + insPool.HandleFaaSManagerUpdate(faasManagerInfo{ + funcKey: "faasManager", + instanceID: "faasManagerInstance", + }) + return insPool +} + +func TestNewGenericInstancePool(t *testing.T) { + funcSpec := &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + } + initRegistry() + pool, _ = NewGenericInstancePool(funcSpec, faasManagerInfo{}) + assert.Equal(t, true, pool != nil) +} + +func TestHandleFuncEvent(t *testing.T) { + config.GlobalConfig.Scenario = "" + insPool := CreateTestInstancePool() + funcSpecOld := &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 100, + MaxInstance: 3, + MinInstance: 1, + }, + } + insPool.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, &instanceconfig.Configuration{ + FuncKey: "testFunction", + InstanceLabel: "", + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 100, + MaxInstance: 3, + MinInstance: 1, + }, + NuwaRuntimeInfo: wisecloudTypes.NuwaRuntimeInfo{}, + }) + + insPool.HandleFunctionEvent(registry.SubEventTypeUpdate, funcSpecOld) + funcSpecNew := &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 1000, + Memory: 1000, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 100, + MaxInstance: 3, + MinInstance: 1, + }, + } + insPool.HandleFunctionEvent(registry.SubEventTypeUpdate, funcSpecNew) + genInsPool := insPool.(*GenericInstancePool) + oldResKey := resspeckey.ConvertToResSpecKey(resspeckey.ConvertResourceMetaDataToResSpec(funcSpecOld.ResourceMetaData)) + newResKey := resspeckey.ConvertToResSpecKey(resspeckey.ConvertResourceMetaDataToResSpec(funcSpecNew.ResourceMetaData)) + assert.Equal(t, 1, len(genInsPool.reservedInstanceQueue)) + assert.Equal(t, 2, len(genInsPool.scaledInstanceQueue)) + assert.Equal(t, true, genInsPool.reservedInstanceQueue[oldResKey] == nil) + assert.Equal(t, true, genInsPool.reservedInstanceQueue[newResKey] != nil) +} + +func TestHandleInsConfigEvent(t *testing.T) { + insPool := CreateTestInstancePool() + insConfig := &instanceconfig.Configuration{ + FuncKey: "testFunction", + InstanceLabel: "aaaaa", + } + insPool.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, insConfig) + insPool.HandleInstanceConfigEvent(registry.SubEventTypeDelete, insConfig) + insAcqReq := &types.InstanceAcquireRequest{ + ResSpec: &resspeckey.ResourceSpecification{ + CPU: 600, + Memory: 512, + InvokeLabel: "aaaaa", + }, + } + + _, err := insPool.AcquireInstance(insAcqReq) + assert.NotNil(t, err) +} + +func TestGetFuncSpec(t *testing.T) { + insPool := CreateTestInstancePool() + funSpec := insPool.GetFuncSpec() + assert.Equal(t, "testFunction", funSpec.FuncKey) +} + +func TestGenericInstancePool_CreateInstance(t *testing.T) { + insPool := CreateTestInstancePool() + convey.Convey("Test CreateInstance", t, func() { + defer ApplyFunc((*instancequeue.OnDemandInstanceQueue).CreateInstance, func(_ *instancequeue.OnDemandInstanceQueue, + insCrtReq *types.InstanceCreateRequest) (*types.Instance, snerror.SNError) { + return &types.Instance{}, nil + }).Reset() + instance, err := insPool.CreateInstance(&types.InstanceCreateRequest{}) + convey.So(err, convey.ShouldBeNil) + convey.So(instance, convey.ShouldNotBeNil) + }) +} + +func TestGenericInstancePool_DeleteInstance(t *testing.T) { + insPool := CreateTestInstancePool().(*GenericInstancePool) + convey.Convey("Test CreateInstance", t, func() { + defer ApplyFunc((*instancequeue.OnDemandInstanceQueue).DeleteInstance, func(_ *instancequeue.OnDemandInstanceQueue, + instance *types.Instance) snerror.SNError { + return nil + }).Reset() + resKey := resspeckey.ResSpecKey{ + CPU: 100, + Memory: 100, + } + instance := &types.Instance{ + ResKey: resKey, + } + err := insPool.DeleteInstance(instance) + convey.So(err, convey.ShouldNotBeNil) + insPool.onDemandInstanceQueue[resKey] = &instancequeue.OnDemandInstanceQueue{} + err = insPool.DeleteInstance(instance) + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestAcquireInstance(t *testing.T) { + patches := mockInstanceOperation() + defer unMockInstanceOperation(patches) + convey.Convey("Test AcquireInstance", t, func() { + convey.Convey("AcquireInstance success", func() { + defer ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + insPool := CreateTestInstancePool() + insPool.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, &instanceconfig.Configuration{ + FuncKey: "test", + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 100, + ConcurrentNum: 1, + }, + }) + insAcqReq := &types.InstanceAcquireRequest{ + ResSpec: &resspeckey.ResourceSpecification{ + CPU: 600, + Memory: 512, + }, + } + insAlloc, err := insPool.AcquireInstance(insAcqReq) + assert.Equal(t, nil, err) + assert.Equal(t, true, insAlloc != nil) + }) + convey.Convey("InstanceQueue AcquireInstance is not nil", func() { + patchGet := ApplyMethod(reflect.TypeOf(&instancequeue.ScaledInstanceQueue{}), + "AcquireInstance", func(_ *instancequeue.ScaledInstanceQueue) (thread *types.InstanceAllocation, + acquireErr snerror.SNError) { + return nil, snerror.New(0, "AcquireInstance error") + }) + defer patchGet.Reset() + insPool := CreateTestInstancePool() + insThdApp := &types.InstanceAcquireRequest{} + thd, err := insPool.AcquireInstance(insThdApp) + convey.So(thd, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("acquire on-demand instance", func() { + patchGet := ApplyMethod(reflect.TypeOf(&instancequeue.OnDemandInstanceQueue{}), + "AcquireInstance", func(_ *instancequeue.OnDemandInstanceQueue) (thread *types.InstanceAllocation, + acquireErr snerror.SNError) { + return nil, snerror.New(0, "AcquireInstance error") + }) + defer patchGet.Reset() + insPool := CreateTestInstancePool() + insThdApp := &types.InstanceAcquireRequest{InstanceName: "testInstance"} + thd, err := insPool.AcquireInstance(insThdApp) + convey.So(thd, convey.ShouldBeNil) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("acquire state instance", func() { + patch := ApplyGlobalVar(&config.GlobalConfig, types.Configuration{ + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 1000, + ScaleDownTime: 100, + BurstScaleNum: 100000, + }, + }) + defer patch.Reset() + insPool := CreateTestInstancePool().(*GenericInstancePool) + insPool.FuncSpec.FuncMetaData.IsStatefulFunction = true + var createErr snerror.SNError + insPool.stateRoute.createInstanceFunc = func(resSpec *resspeckey.ResourceSpecification, instanceType types.InstanceType, + callerPodName string) (*types.Instance, error) { + if createErr != nil { + return nil, createErr + } + return &types.Instance{InstanceID: "instance1"}, nil + } + insThdApp := &types.InstanceAcquireRequest{StateID: "aaa"} + createErr = snerror.New(statuscode.StatusInternalServerError, "internal error") + insAlloc, err := insPool.AcquireInstance(insThdApp) + convey.So(insAlloc, convey.ShouldBeNil) + convey.So(err.Code(), convey.ShouldEqual, statuscode.NoInstanceAvailableErrCode) + time.Sleep(200 * time.Millisecond) + createErr = nil + insAlloc, err = insPool.AcquireInstance(insThdApp) + convey.So(err, convey.ShouldBeNil) + convey.So(insAlloc.Instance.InstanceID, convey.ShouldEqual, "instance1") + time.Sleep(200 * time.Millisecond) + insAlloc, err = insPool.AcquireInstance(insThdApp) + convey.So(err, convey.ShouldBeNil) + convey.So(insAlloc.Instance.InstanceID, convey.ShouldEqual, "instance1") + }) + convey.Convey("acquire session instance", func() { + insPool := CreateTestInstancePool().(*GenericInstancePool) + insPool.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, &instanceconfig.Configuration{ + InstanceMetaData: commonTypes.InstanceMetaData{ + MinInstance: 1, + MaxInstance: 10, + }, + }) + insAcqReq1 := &types.InstanceAcquireRequest{ + InstanceSession: commonTypes.InstanceSessionConfig{ + SessionID: "123", + SessionTTL: 10, + Concurrency: 1, + }, + ResSpec: &resspeckey.ResourceSpecification{ + CPU: 500, + Memory: 500, + }, + } + insAlloc1, err := insPool.AcquireInstance(insAcqReq1) + convey.So(err, convey.ShouldBeNil) + record, exist := insPool.sessionRecordMap["123"] + convey.So(exist, convey.ShouldBeTrue) + convey.So(record.instance, convey.ShouldEqual, insAlloc1.Instance) + insAlloc2, err := insPool.AcquireInstance(insAcqReq1) + convey.So(err, convey.ShouldBeNil) + convey.So(insAlloc2.Instance.InstanceID, convey.ShouldEqual, insAlloc1.Instance.InstanceID) + }) + }) +} + +func TestReleaseInstanceThread(t *testing.T) { + patches := mockInstanceOperation() + defer unMockInstanceOperation(patches) + defer ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + initRegistry() + insPool := CreateTestInstancePool() + insPool.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, &instanceconfig.Configuration{ + FuncKey: "test", + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 100, + ConcurrentNum: 1, + }, + }) + insThdApp := &types.InstanceAcquireRequest{ + ResSpec: &resspeckey.ResourceSpecification{ + CPU: 600, + Memory: 512, + }, + } + thd, err := insPool.AcquireInstance(insThdApp) + assert.Equal(t, nil, err) + insPool.ReleaseInstance(thd) + + thd.Instance.InstanceType = types.InstanceTypeReserved + insPool.ReleaseInstance(thd) + + thd.Instance.InstanceType = types.InstanceTypeScaled + insPool.ReleaseInstance(thd) +} + +func TestReleaseAbnormalInstance(t *testing.T) { + patches := mockInstanceOperation() + defer unMockInstanceOperation(patches) + convey.Convey("test ReleaseAbnormalInstance", t, func() { + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&instancequeue.ScaledInstanceQueue{}), "HandleFaultyInstance", + func(_ *instancequeue.ScaledInstanceQueue, instance *types.Instance) { + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + registry.GlobalRegistry = ®istry.Registry{FaaSSchedulerRegistry: registry.NewFaasSchedulerRegistry(make(chan struct{}))} + insPool := CreateTestInstancePool() + instance := &types.Instance{ + InstanceType: "reserved", + InstanceID: "instanceID", + ResKey: resspeckey.ResSpecKey{ + CPU: 888, + Memory: 888, + }, + } + insPool.HandleInstanceEvent(registry.SubEventTypeRemove, instance) + instance.InstanceType = "scaled" + insPool.HandleInstanceEvent(registry.SubEventTypeRemove, instance) + instance.InstanceType = "default" + insPool.HandleInstanceEvent(registry.SubEventTypeRemove, instance) + }) +} + +func BenchmarkAcquireInstanceThread(b *testing.B) { + patches := mockInstanceOperation() + defer unMockInstanceOperation(patches) + defer ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + insPool := CreateTestInstancePool() + insThdApp := &types.InstanceAcquireRequest{ + ResSpec: &resspeckey.ResourceSpecification{ + CPU: 600, + Memory: 512, + }, + } + for i := 0; i < b.N; i++ { + _, err := insPool.AcquireInstance(insThdApp) + if err != nil { + b.Errorf("acquire instance thread error %s", err.Error()) + } + } +} + +func TestAcquireInstanceThreadWithVPC(t *testing.T) { + defer ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + insPool := CreateTestInstancePoolWithVPC().(*GenericInstancePool) + insPool.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, &instanceconfig.Configuration{ + FuncKey: "test", + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 100, + ConcurrentNum: 1, + }, + }) + patches := []*Patches{ + ApplyFunc(CreateInstance, + func(request createInstanceRequest) (instance *types.Instance, createErr error) { + createErr = errors.New("failed to create") + return + }), + } + + // test create instance fail + instance, err := insPool.createInstance("", types.InstanceTypeScaled, resspeckey.ResSpecKey{}, nil) + time.Sleep(10 * time.Millisecond) + assert.NotNil(t, err) + + // test create instance fail + instance, err = insPool.createInstance("", types.InstanceTypeScaled, resspeckey.ResSpecKey{CustomResources: `{"npu":1}`}, nil) + time.Sleep(10 * time.Millisecond) + assert.NotNil(t, err) + + // test create instance success + patches[0].Reset() + patche := mockInstanceOperation() + defer unMockInstanceOperation(patche) + instance, err = insPool.createInstance("", types.InstanceTypeScaled, resspeckey.ResSpecKey{}, nil) + time.Sleep(10 * time.Millisecond) + assert.Nil(t, err) + assert.NotNil(t, instance) + + // test createInstanceAndAddCallerPodName success + instance, err = insPool.createInstanceAndAddCallerPodName(nil, types.InstanceTypeScaled, "callerPodName") + time.Sleep(10 * time.Millisecond) + assert.Nil(t, err) + assert.NotNil(t, instance) +} + +func TestHandleInsEvent(t *testing.T) { + patches := mockInstanceOperation() + defer unMockInstanceOperation(patches) + insPool := CreateTestInstancePool() + resKey := resspeckey.ResSpecKey{ + CPU: 500, + Memory: 500, + } + instance := &types.Instance{ + InstanceID: "instanceID", + InstanceType: types.InstanceTypeReserved, + ResKey: resKey, + FuncKey: "testFunction", + ConcurrentNum: 1, + } + convey.Convey("HandleInsEvent", t, func() { + convey.Convey("wait instance config", func() { + start := time.Now() + go func() { + time.Sleep(2 * time.Second) + insPool.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, &instanceconfig.Configuration{ + FuncKey: "testFunc", + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 0, + MinInstance: 0, + ConcurrentNum: 0, + InstanceType: "reserved", + IdleMode: false, + }, + }) + }() + insPool.HandleInstanceEvent(registry.SubEventTypeUpdate, &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: 3}, + InstanceType: "reserved", + InstanceID: "123456789", + FuncKey: "testFunc", + FuncSig: "", + ConcurrentNum: 1, + }) + insPool.HandleInstanceEvent(registry.SubEventTypeUpdate, &types.Instance{ + ResKey: resspeckey.ResSpecKey{CPU: 500, Memory: 500}, + InstanceStatus: commonTypes.InstanceStatus{Code: 3}, + InstanceType: "scaled", + InstanceID: "123456789-1", + FuncKey: "testFunc", + FuncSig: "", + ConcurrentNum: 1, + }) + end := time.Since(start) + convey.So(end, convey.ShouldBeLessThan, 2*time.Second) + insPool.HandleInstanceEvent(registry.SubEventTypeUpdate, &types.Instance{ + ResKey: resspeckey.ResSpecKey{CPU: 500, Memory: 500}, + InstanceStatus: commonTypes.InstanceStatus{Code: 3}, + InstanceType: "scaled", + InstanceID: "123456789-2", + FuncKey: "testFunc", + FuncSig: "", + ConcurrentNum: 1, + }) + + insPool.HandleInstanceEvent(registry.SubEventTypeUpdate, instance) + insPool.HandleInstanceEvent(registry.SubEventTypeDelete, instance) + insPool.HandleInstanceEvent(registry.SubEventTypeUpdate, &types.Instance{ + ResKey: resspeckey.ResSpecKey{CPU: 500, Memory: 500}, + InstanceStatus: commonTypes.InstanceStatus{Code: 7}, + InstanceType: "scaled", + InstanceID: "123456789-1", + FuncKey: "testFunc", + FuncSig: "", + ConcurrentNum: 1, + }) + insPool.HandleInstanceEvent(registry.SubEventTypeUpdate, &types.Instance{ + ResKey: resspeckey.ResSpecKey{CPU: 500, Memory: 500}, + InstanceStatus: commonTypes.InstanceStatus{Code: 10}, + InstanceType: "scaled", + InstanceID: "123456789-2", + FuncKey: "testFunc", + FuncSig: "", + ConcurrentNum: 1, + }) + }) + convey.Convey("delete invalid instance", func() { + var deleteTime int + defer ApplyMethod(reflect.TypeOf(&instancequeue.ScaledInstanceQueue{}), "HandleFaultyInstance", + func(_ *instancequeue.ScaledInstanceQueue, instance *types.Instance) { + deleteTime++ + }).Reset() + insPool.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, &instanceconfig.Configuration{ + FuncKey: "testFunc", + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 10, + MinInstance: 0, + ConcurrentNum: 0, + InstanceType: "reserved", + IdleMode: false, + }, + }) + insPool.HandleInstanceEvent(registry.SubEventTypeUpdate, &types.Instance{ + ResKey: resspeckey.ResSpecKey{CPU: 500, Memory: 500}, + InstanceStatus: commonTypes.InstanceStatus{Code: 5}, + InstanceType: types.InstanceTypeReserved, + InstanceID: "123456789-1", + FuncKey: "testFunc", + FuncSig: "", + ConcurrentNum: 1, + }) + convey.So(deleteTime, convey.ShouldEqual, 1) + }) + convey.Convey("handle on-demand instance", func() { + var ( + updateTime int + deleteTime int + ) + defer ApplyMethod(reflect.TypeOf(&instancequeue.OnDemandInstanceQueue{}), "HandleInstanceUpdate", + func(_ *instancequeue.OnDemandInstanceQueue, instance *types.Instance) { + updateTime++ + }).Reset() + defer ApplyMethod(reflect.TypeOf(&instancequeue.OnDemandInstanceQueue{}), "HandleInstanceDelete", + func(_ *instancequeue.OnDemandInstanceQueue, instance *types.Instance) { + deleteTime++ + }).Reset() + insPool.HandleInstanceEvent(registry.SubEventTypeUpdate, &types.Instance{ + ResKey: resspeckey.ResSpecKey{CPU: 500, Memory: 500}, + InstanceStatus: commonTypes.InstanceStatus{Code: 3}, + InstanceType: types.InstanceTypeOnDemand, + InstanceID: "instance1", + FuncKey: "function1", + FuncSig: "", + ConcurrentNum: 1, + }) + convey.So(updateTime, convey.ShouldEqual, 1) + insPool.HandleInstanceEvent(registry.SubEventTypeDelete, &types.Instance{ + ResKey: resspeckey.ResSpecKey{CPU: 500, Memory: 500}, + InstanceStatus: commonTypes.InstanceStatus{Code: 3}, + InstanceType: types.InstanceTypeOnDemand, + InstanceID: "instance1", + FuncKey: "function1", + FuncSig: "", + ConcurrentNum: 1, + }) + convey.So(deleteTime, convey.ShouldEqual, 1) + }) + }) +} + +func TestFilterInstanceIDMap(t *testing.T) { + patches := mockInstanceOperation() + defer unMockInstanceOperation(patches) + filterMap := map[string]*types.Instance{} + existsMap := map[string]map[string]*commonTypes.InstanceSpecification{} + existsMap["321"] = map[string]*commonTypes.InstanceSpecification{"123": {CreateOptions: map[string]string{types.FunctionKeyNote: "321", types.InstanceTypeNote: "scaled"}}} + gi := &GenericInstancePool{ + FuncSpec: &types.FunctionSpecification{FuncKey: "321"}, + } + convey.Convey("map is nil", t, func() { + gi.filterInstanceIDMap(filterMap, existsMap, "scaled") + convey.So(len(existsMap), convey.ShouldEqual, 1) + }) + filterMap["123"] = &types.Instance{FuncSig: "aaa"} + filterMap["456"] = &types.Instance{FuncSig: "bbb"} + convey.Convey("delete success", t, func() { + gi.filterInstanceIDMap(filterMap, existsMap, "scaled") + convey.So(filterMap, convey.ShouldContainKey, "123") + convey.So(filterMap, convey.ShouldNotContainKey, "456") + }) +} + +func TestMain(m *testing.M) { + patches := []*Patches{ + ApplyFunc((*etcd3.EtcdWatcher).StartList, func(_ *etcd3.EtcdWatcher) {}), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc(etcd3.GetCAEMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc((*etcd3.EtcdClient).AttachAZPrefix, func(_ *etcd3.EtcdClient, key string) string { return key }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + config.GlobalConfig = types.Configuration{} + config.GlobalConfig.AutoScaleConfig = types.AutoScaleConfig{ + SLAQuota: 1000, + ScaleDownTime: 1000, + BurstScaleNum: 100000, + } + config.GlobalConfig.LeaseSpan = 500 + config.GlobalConfig.LocalAuth = localauth.AuthConfig{ + AKey: "ENC(key=servicekek, value=6B6D73763030000101D615B6381ED56AF68123844D047428BDCCBF19957866" + + "CD0D7F53C29438337667A93FB9A06C5ED4A3D925C87655E4C734)", + SKey: "ENC(key=servicekek, value=6B6D73763030000101139308ABBC0C4120F949AC833416D5E6D8CA18D8C69E" + + "4C5E03E553E18733B4119C4B716FF2C8265336BB2979545A24FDC07CDD6A6A02F412D0DE83BD43F2A07DDBC78EB2)", + Duration: 0, + } + initRegistry() + m.Run() +} + +func mockInstanceOperation() []*Patches { + SetGlobalSdkClient(&mockUtils.FakeLibruntimeSdkClient{}) + patches := []*Patches{ + ApplyFunc(CreateInstance, func(request createInstanceRequest) (instance *types.Instance, createErr error) { + instance = &types.Instance{ + InstanceID: uuid.New().String(), + ConcurrentNum: 1, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + ResKey: resspeckey.ResSpecKey{InvokeLabel: ""}, + } + return + }), + ApplyFunc(DeleteInstance, func(funcSpec *types.FunctionSpecification, faasManagerInfo faasManagerInfo, + instance *types.Instance) error { + return nil + }), + ApplyFunc(SignalInstance, func(instance *types.Instance, signal int) { + return + }), + } + return patches +} + +func unMockInstanceOperation(patches []*Patches) { + for _, patch := range patches { + patch.Reset() + } +} + +type fakeInstanceQueue struct { + instanceID map[string]struct{} +} + +func (fq *fakeInstanceQueue) HandleFuncManagedChange() { +} + +func (fq *fakeInstanceQueue) SetInstanceScheduler(instanceScheduler scheduler.InstanceScheduler) { +} + +func (fq *fakeInstanceQueue) SetInstanceScaler(instanceScaler scaler.InstanceScaler) { +} + +func (fq *fakeInstanceQueue) ScaleUpHandler(insNum int, callback scaler.ScaleUpCallback) { +} + +func (fq *fakeInstanceQueue) ScaleDownHandler(insNum int, callback scaler.ScaleDownCallback) { +} + +func (fq *fakeInstanceQueue) AcquireInstanceThread(DesignateInstanceID string, isLimiting bool) (*types.InstanceAllocation, snerror.SNError) { + return nil, nil +} +func (fq *fakeInstanceQueue) ReleaseInstance(insThd *types.InstanceAllocation) snerror.SNError { + return nil +} +func (fq *fakeInstanceQueue) HandleInstanceUpdate(instance *types.Instance) {} +func (fq *fakeInstanceQueue) HandleInstanceDelete(instance *types.Instance) {} +func (fq *fakeInstanceQueue) HandleFaultyInstance(instance *types.Instance) {} +func (fq *fakeInstanceQueue) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) {} +func (fq *fakeInstanceQueue) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) {} +func (fq *fakeInstanceQueue) Destroy() {} +func (fq *fakeInstanceQueue) RecoverInstance(input map[string]*types.Instance) { + for key, _ := range input { + if _, ok := fq.instanceID[key]; !ok { + fq.instanceID[key] = struct{}{} + } + } +} +func (fq *fakeInstanceQueue) HandleAliasUpdate() {} + +func (fq *fakeInstanceQueue) GetInstanceNumber(onlySelf bool) int { return 0 } +func (fq *fakeInstanceQueue) HandleInstanceCreating(instance *types.Instance) { + return +} + +func CreateScaledInstanceQueue(instanceType types.InstanceType) *instancequeue.ScaledInstanceQueue { + basicInsQueConfig := &instancequeue.InsQueConfig{ + InstanceType: instanceType, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + queue := instancequeue.NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 1000*time.Millisecond) + queue.SetInstanceScheduler(concurrencyscheduler.NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 1}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue)) + return queue +} + +func TestRecoverInstance(t *testing.T) { + reservedQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{}: CreateScaledInstanceQueue(types.InstanceTypeReserved), + } + scaledQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{CPU: 300, Memory: 128}: CreateScaledInstanceQueue(types.InstanceTypeScaled), + } + registry.GlobalRegistry = ®istry.Registry{InstanceRegistry: registry.NewInstanceRegistry(make(chan struct{}))} + gi := &GenericInstancePool{ + faasManagerInfo: faasManagerInfo{ + funcKey: "faasManager-0123", + instanceID: "01234567", + }, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "mock-funcKey-123", + }, + reservedInstanceQueue: reservedQueue, + scaledInstanceQueue: scaledQueue, + stateRoute: StateRoute{stateRoute: map[string]*StateInstance{}, logger: log.GetLogger()}, + } + funcSpec := &types.FunctionSpecification{ + FuncKey: "testFuncKey", + } + instancePoolState := &types.InstancePoolState{ + StateInstance: map[string]*types.Instance{ + "StateInstanceID123": &types.Instance{FuncSig: "22222"}, + }, + } + var instanceRecord map[string]*types.Instance + defer ApplyMethod(reflect.TypeOf(&instancequeue.ScaledInstanceQueue{}), "RecoverInstance", + func(sq *instancequeue.ScaledInstanceQueue, instanceMap map[string]*types.Instance) { + for k, v := range instanceMap { + instanceRecord[k] = v + } + }).Reset() + initRegistry() + convey.Convey("recover success", t, func() { + instanceRecord = make(map[string]*types.Instance) + defer ApplyMethod(reflect.TypeOf(registry.GlobalRegistry.InstanceRegistry), "GetFunctionInstanceIDMap", + func(ir *registry.InstanceRegistry) map[string]map[string]*commonTypes.InstanceSpecification { + functionInstanceIDMap := make(map[string]map[string]*commonTypes.InstanceSpecification) + functionInstanceIDMap["mock-funcKey-123"] = map[string]*commonTypes.InstanceSpecification{ + "ReservedInstanceID123": {CreateOptions: map[string]string{types.FunctionKeyNote: "22222"}}, + "ScaledInstanceID123": {CreateOptions: map[string]string{types.FunctionKeyNote: "11111"}}, + } + return functionInstanceIDMap + }).Reset() + var wg sync.WaitGroup + wg.Add(1) + gi.RecoverInstance(funcSpec, instancePoolState, false, &wg) + wg.Wait() + convey.So(gi.stateRoute.stateRoute["StateInstanceID123"], convey.ShouldNotBeNil) + }) +} + +func Test_generateScheduleAffinity(t *testing.T) { + type args struct { + label string + } + tests := []struct { + name string + args args + want []api.Affinity + }{ + { + name: "label1", + args: args{"label1"}, + want: []api.Affinity{api.Affinity{ + PreferredPriority: true, + PreferredAntiOtherLabels: true, + LabelOps: []api.LabelOperator{ + api.LabelOperator{ + Type: 2, + LabelKey: "label1", + }, + }, + }}, + }, + { + name: "label1, unUseAntiOtherLabels", + args: args{" label1, unUseAntiOtherLabels"}, + want: []api.Affinity{api.Affinity{ + PreferredPriority: true, + PreferredAntiOtherLabels: false, + LabelOps: []api.LabelOperator{ + api.LabelOperator{ + Type: 2, + LabelKey: "label1", + }, + }, + }}, + }, + { + name: "label1,label2", + args: args{label: " label1, label2 "}, + want: []api.Affinity{ + { + PreferredPriority: true, + PreferredAntiOtherLabels: true, + LabelOps: []api.LabelOperator{ + { + Type: 2, + LabelKey: "label1", + }, + }}, + { + PreferredPriority: true, + PreferredAntiOtherLabels: true, + LabelOps: []api.LabelOperator{ + { + Type: 2, + LabelKey: "label2", + }, + }}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, generateScheduleAffinity([]api.Affinity{}, tt.args.label), "generateScheduleAffinity(%v)", tt.args.label) + }) + } +} + +func Test_filterStateInstanceMap(t *testing.T) { + convey.Convey("filterStateInstanceMap", t, func() { + convey.Convey("success", func() { + stateMap := map[string]*types.Instance{ + "state-1": &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + InstanceType: "", + InstanceID: "1", + FuncKey: "test-function", + FuncSig: "", + ConcurrentNum: 2, + }, + "state-2": &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusExited), + }, + InstanceType: "", + InstanceID: "2", + FuncKey: "test-function", + FuncSig: "", + ConcurrentNum: 2, + }, + } + etcdMap := map[string]map[string]*commonTypes.InstanceSpecification{} + filterStateInstanceMap(stateMap, etcdMap, "test-function") + convey.So(len(stateMap), convey.ShouldEqual, 2) + convey.So(stateMap["state-2"].InstanceStatus.Code, convey.ShouldEqual, constant.KernelInstanceStatusExited) + }) + }) +} + +func Test_makeRaspContainer(t *testing.T) { + convey.Convey("test makeRaspContainer", t, func() { + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + RaspImage: "test-image", + Envs: []commonTypes.KV{ + { + Name: "1", + Value: "11", + }, { + Name: "2", + Value: "22", + }, + }, + }, + }, + } + sideCardConfig := makeRaspContainer(funcSpec) + convey.So(sideCardConfig.Image, convey.ShouldEqual, "test-image") + env1 := false + env2 := false + + for _, env := range sideCardConfig.Env { + if env.Name == "1" && env.Value == "11" { + env1 = true + } + if env.Name == "2" && env.Value == "22" { + env2 = true + } + } + convey.So(env1, convey.ShouldBeTrue) + convey.So(env2, convey.ShouldBeTrue) + }) +} + +func Test_initContainerAdd(t *testing.T) { + convey.Convey("test initContainerAdd", t, func() { + convey.Convey("add success", func() { + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspImage: "test-image", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + }, + } + _, err := initContainerAdd(funcSpec) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("parse ip error ", func() { + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspImage: "test-image", + RaspServerIP: "", + RaspServerPort: "8080", + }, + }, + } + configData, err := initContainerAdd(funcSpec) + convey.So(err, convey.ShouldBeNil) + convey.So(string(configData), convey.ShouldEqual, "null") + }) + convey.Convey("invalid port ", func() { + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspImage: "test-image", + RaspServerIP: "127.0.0.1", + RaspServerPort: "a", + }, + }, + } + configData, err := initContainerAdd(funcSpec) + convey.So(err, convey.ShouldBeNil) + convey.So(string(configData), convey.ShouldEqual, "null") + }) + convey.Convey("json Marsha data error", func() { + defer ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + }).Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + }, + } + _, err := initContainerAdd(funcSpec) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func Test_SideCarAdd(t *testing.T) { + convey.Convey("test SideCarAdd", t, func() { + convey.Convey("add success", func() { + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspImage: "test-image", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + CustomFilebeatConfig: commonTypes.CustomFilebeatConfig{ + ImageAddress: "test-initImage", + }, + }, + } + _, err := sideCarAdd(funcSpec) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("json Marsha data error", func() { + defer ApplyFunc(json.Marshal, func(v interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + }).Reset() + funcSpec := &types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "test-initImage", + RaspImage: "test-image", + RaspServerIP: "127.0.0.1", + RaspServerPort: "8080", + }, + CustomFilebeatConfig: commonTypes.CustomFilebeatConfig{ + ImageAddress: "test-initImage", + }, + }, + } + _, err := sideCarAdd(funcSpec) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestGenericInstancePool_handleManagedChange(t *testing.T) { + convey.Convey("HandleSchedulerManaged", t, func() { + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).HandleInsConfigUpdate, func(_ *instancequeue.ScaledInstanceQueue, + insConfig *instanceconfig.Configuration) { + return + }).Reset() + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).EnableInstanceScale, func(_ *instancequeue.ScaledInstanceQueue) { + }).Reset() + reservedQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{}: CreateScaledInstanceQueue(types.InstanceTypeReserved), + } + scaledQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{CPU: 300, Memory: 128}: CreateScaledInstanceQueue(types.InstanceTypeScaled), + } + gi := &GenericInstancePool{ + faasManagerInfo: faasManagerInfo{ + funcKey: "faasManager-0123", + instanceID: "01234567", + }, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "mock-funcKey-123", + }, + reservedInstanceQueue: reservedQueue, + scaledInstanceQueue: scaledQueue, + stateRoute: StateRoute{stateRoute: map[string]*StateInstance{}, logger: log.GetLogger()}, + } + gi.handleManagedChange() + }) +} + +func TestGenericInstancePool_handleRatioChange(t *testing.T) { + var patches []*Patches + expectRatio := 0 + patches = append(patches, ApplyFunc( + (*concurrencyscheduler.ScaledConcurrencyScheduler).ReassignInstanceWhenGray, + func(s *concurrencyscheduler.ScaledConcurrencyScheduler, ratio int) { + expectRatio = ratio + }, + )) + defer func() { + for _, p := range patches { + p.Reset() + } + }() + + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).HandleInsConfigUpdate, func(_ *instancequeue.ScaledInstanceQueue, + insConfig *instanceconfig.Configuration) { + return + }).Reset() + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).EnableInstanceScale, func(_ *instancequeue.ScaledInstanceQueue) { + }).Reset() + queue := CreateScaledInstanceQueue(types.InstanceTypeReserved) + reservedQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{}: queue, + } + + scaledQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{CPU: 300, Memory: 128}: CreateScaledInstanceQueue(types.InstanceTypeScaled), + } + + gi := &GenericInstancePool{ + faasManagerInfo: faasManagerInfo{ + funcKey: "faasManager-0123", + instanceID: "01234567", + }, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "mock-funcKey-123", + }, + reservedInstanceQueue: reservedQueue, + scaledInstanceQueue: scaledQueue, + stateRoute: StateRoute{stateRoute: map[string]*StateInstance{}, logger: log.GetLogger()}, + } + gi.handleRatioChange(50) + assert.Equal(t, expectRatio, 50) +} + +func TestGenericInstancePool_HandleInstanceConfigEvent(t *testing.T) { + convey.Convey("HandleInstanceConfigEvent", t, func() { + reservedQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{CPU: 300, Memory: 128}: CreateScaledInstanceQueue(types.InstanceTypeReserved), + resspeckey.ResSpecKey{CPU: 300, Memory: 128, InvokeLabel: "label1"}: CreateScaledInstanceQueue(types.InstanceTypeReserved), + } + scaledQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{CPU: 300, Memory: 128}: CreateScaledInstanceQueue(types.InstanceTypeScaled), + resspeckey.ResSpecKey{CPU: 300, Memory: 128, InvokeLabel: "label1"}: CreateScaledInstanceQueue(types.InstanceTypeScaled), + } + gi := &GenericInstancePool{ + defaultResSpec: &resspeckey.ResourceSpecification{CPU: 300, Memory: 128}, + faasManagerInfo: faasManagerInfo{ + funcKey: "faasManager-0123", + instanceID: "01234567", + }, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "mock-funcKey-123", + }, + reservedInstanceQueue: reservedQueue, + scaledInstanceQueue: scaledQueue, + } + convey.Convey("delete insConfig without label", func() { + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).HandleInsConfigUpdate, func(_ *instancequeue.ScaledInstanceQueue, + insConfig *instanceconfig.Configuration) { + }).Reset() + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).EnableInstanceScale, func(_ *instancequeue.ScaledInstanceQueue) { + }).Reset() + callDestroy := 0 + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).Destroy, func(_ *instancequeue.ScaledInstanceQueue) { + callDestroy++ + }).Reset() + gi.HandleInstanceConfigEvent(registry.SubEventTypeDelete, &instanceconfig.Configuration{}) + convey.So(callDestroy, convey.ShouldEqual, 0) + }) + }) +} + +func TestGenericInstancePool_CleanOrphansInstanceQueue(t *testing.T) { + callDestroy := 0 + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).Destroy, func(_ *instancequeue.ScaledInstanceQueue) { + callDestroy++ + }).Reset() + reservedQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{}: CreateScaledInstanceQueue(types.InstanceTypeReserved), + } + + scaledQueue := map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{CPU: 300, Memory: 128}: CreateScaledInstanceQueue(types.InstanceTypeScaled), + } + onDemandInstanceQueue := map[resspeckey.ResSpecKey]*instancequeue.OnDemandInstanceQueue{ + resspeckey.ResSpecKey{CPU: 300, Memory: 128}: &instancequeue.OnDemandInstanceQueue{}, + } + gi := &GenericInstancePool{ + faasManagerInfo: faasManagerInfo{ + funcKey: "faasManager-0123", + instanceID: "01234567", + }, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "mock-funcKey-123", + }, + reservedInstanceQueue: reservedQueue, + scaledInstanceQueue: scaledQueue, + onDemandInstanceQueue: onDemandInstanceQueue, + defaultPoolLabel: "", + stateRoute: StateRoute{stateRoute: map[string]*StateInstance{}, logger: log.GetLogger()}, + } + gi.CleanOrphansInstanceQueue() + assert.Equal(t, 1, callDestroy) +} + +func TestGenericInstancePool_sessionOperations(t *testing.T) { + convey.Convey("test session operations", t, func() { + funcCtx, funcCancel := context.WithCancel(context.TODO()) + gi := &GenericInstancePool{ + FuncSpec: &types.FunctionSpecification{ + FuncKey: "mock-funcKey-123", + FuncCtx: funcCtx, + }, + sessionRecordMap: make(map[string]sessionRecord), + instanceSessionMap: make(map[string]map[string]struct{}), + sessionReaperInterval: 100 * time.Millisecond, + } + convey.Convey("record and process", func() { + ctx, cancel := context.WithCancel(context.TODO()) + insAlloc := &types.InstanceAllocation{ + Instance: &types.Instance{InstanceID: "instance1"}, + SessionInfo: types.SessionInfo{SessionID: "session1", SessionCtx: ctx}, + } + gi.recordInstanceSession(insAlloc) + convey.So(gi.sessionRecordMap["session1"], convey.ShouldNotBeEmpty) + convey.So(len(gi.instanceSessionMap["instance1"]), convey.ShouldEqual, 1) + insAcqReq := &types.InstanceAcquireRequest{ + InstanceSession: commonTypes.InstanceSessionConfig{SessionID: "session1"}, + } + gi.processInstanceSession(insAcqReq) + convey.So(insAcqReq.DesignateInstanceID, convey.ShouldEqual, "instance1") + cancel() + gi.processInstanceSession(insAcqReq) + convey.So(gi.sessionRecordMap, convey.ShouldBeEmpty) + convey.So(len(gi.instanceSessionMap["instance1"]), convey.ShouldEqual, 0) + }) + convey.Convey("record and clean", func() { + insAlloc1 := &types.InstanceAllocation{ + Instance: &types.Instance{InstanceID: "instance1"}, + SessionInfo: types.SessionInfo{SessionID: "session1"}, + } + insAlloc2 := &types.InstanceAllocation{ + Instance: &types.Instance{InstanceID: "instance1"}, + SessionInfo: types.SessionInfo{SessionID: "session2"}, + } + gi.recordInstanceSession(insAlloc1) + gi.recordInstanceSession(insAlloc2) + gi.cleanInstanceSession("instance1") + convey.So(len(gi.instanceSessionMap["instance1"]), convey.ShouldEqual, 0) + }) + convey.Convey("reaper", func() { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + insAlloc := &types.InstanceAllocation{ + Instance: &types.Instance{InstanceID: "instance1"}, + SessionInfo: types.SessionInfo{SessionID: "session1", SessionCtx: ctx}, + } + gi.recordInstanceSession(insAlloc) + go gi.cleanInstanceSession("instance1") + time.Sleep(200 * time.Millisecond) + funcCancel() + convey.So(gi.sessionRecordMap, convey.ShouldBeEmpty) + convey.So(len(gi.instanceSessionMap["instance1"]), convey.ShouldEqual, 0) + }) + }) +} + +func TestGenericInstancePool_judgeExceedInstance(t *testing.T) { + funcCtx, _ := context.WithCancel(context.TODO()) + logger := log.GetLogger() + resKey := resspeckey.ResSpecKey{ + CPU: 500, + Memory: 600, + EphemeralStorage: 0, + CustomResources: "", + CustomResourcesSpec: "", + InvokeLabel: "", + } + gi := &GenericInstancePool{ + FuncSpec: &types.FunctionSpecification{ + FuncKey: "mock-funcKey-123", + FuncCtx: funcCtx, + }, + defaultResKey: resKey, + insConfig: map[resspeckey.ResSpecKey]*instanceconfig.Configuration{resKey: &instanceconfig.Configuration{ + FuncKey: "mock-funcKey-123", + InstanceLabel: "", + InstanceMetaData: commonTypes.InstanceMetaData{ + MaxInstance: 1, + MinInstance: 0, + }, + }}, + scaledInstanceQueue: map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{resKey: {}}, + } + defer ApplyFunc((*GenericInstancePool).getCurrentInstanceNum, func(_ *GenericInstancePool, resKey resspeckey.ResSpecKey) (int, int, snerror.SNError) { + return 2, 2, nil + }).Reset() + i := 0 + defer ApplyFunc((*instancequeue.ScaledInstanceQueue).ScaleDownHandler, func(_ *instancequeue.ScaledInstanceQueue, insNum int, callback scaler.ScaleDownCallback) { + i++ + }).Reset() + gi.judgeExceedInstance(resKey, logger) + assert.Equal(t, i, 1) + gi.judgeExceedInstance(resspeckey.ResSpecKey{ + CPU: 300, + Memory: 128, + EphemeralStorage: 0, + CustomResources: "", + CustomResourcesSpec: "", + InvokeLabel: "", + }, logger) + assert.Equal(t, i, 1) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/log.go b/yuanrong/pkg/functionscaler/instancepool/log.go new file mode 100644 index 0000000..1681346 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/log.go @@ -0,0 +1,106 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "fmt" + + "k8s.io/api/core/v1" +) + +const ( + aiOpsVolume = "aiops-logs" + biLogsVolume = "bi-logs" +) + +var ( + raspLogVolumeMountPath = "/opt/logs" + raspLogVolumeMountSubPathExpr = fmt.Sprintf("$(%s)_$(%s)/rasp", podNameEnvNew, podIPEnv) + + caasUserVolumeMountPath = "/opt/logs/caas/user" + caasUserVolumeMountSubPathExpr = fmt.Sprintf("$(%s)_$(%s)/caasUser", podNameEnvNew, podIPEnv) + + dataSystemVolumeMountPath = "/opt/logs/caas/dataSystem" + dataSystemVolumeMountSubPathExpr = fmt.Sprintf("$(%s)_$(%s)/dataSystem", podNameEnvNew, podIPEnv) + + customAIOpsMountPath = "/opt/logs" + customAIOpsMountSubPathExpr = fmt.Sprintf("$(%s)_$(%s)/custom", podNameEnvNew, podIPEnv) + + biLogVolumeMountPath = "/opt/logs/caas/bi" + biLogVolumeMountSubPathExpr = fmt.Sprintf("$(%s)_$(%s)/caasUser", podNameEnvNew, podIPEnv) + + aiOpsVolumeHostPath = "/mnt/daemonset/aiops/%s/%s" + biLogVolumeHostPath = "/mnt/daemonset/bi/%s/%s" +) + +type aiHostPathConfig struct { + WorkloadName string + Namespace string +} + +func (ai *aiHostPathConfig) configVolume(vb *volumeBuilder) { + hostPathDirectoryOrCreate := v1.HostPathDirectoryOrCreate + vb.addVolume(v1.Volume{ + Name: aiOpsVolume, + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: fmt.Sprintf(aiOpsVolumeHostPath, ai.Namespace, ai.WorkloadName), + Type: &hostPathDirectoryOrCreate, + }, + }, + }) + + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: aiOpsVolume, + MountPath: caasUserVolumeMountPath, + SubPathExpr: caasUserVolumeMountSubPathExpr, + }) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: aiOpsVolume, + MountPath: dataSystemVolumeMountPath, + SubPathExpr: dataSystemVolumeMountSubPathExpr, + }) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: aiOpsVolume, + MountPath: customAIOpsMountPath, + SubPathExpr: customAIOpsMountSubPathExpr, + }) +} + +type biHostPathConfig struct { + WorkloadName string + Namespace string +} + +func (bc *biHostPathConfig) configVolume(vb *volumeBuilder) { + hostPathDirectoryOrCreate := v1.HostPathDirectoryOrCreate + vb.addVolume(v1.Volume{ + Name: biLogsVolume, + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: fmt.Sprintf(biLogVolumeHostPath, bc.Namespace, bc.WorkloadName), + Type: &hostPathDirectoryOrCreate, + }, + }, + }) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: biLogsVolume, + MountPath: biLogVolumeMountPath, + SubPathExpr: biLogVolumeMountSubPathExpr, + }) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/log_test.go b/yuanrong/pkg/functionscaler/instancepool/log_test.go new file mode 100644 index 0000000..7e870cf --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/log_test.go @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" +) + +func TestAiHostPathConfig(t *testing.T) { + ah := aiHostPathConfig{ + WorkloadName: "testWorkload", + Namespace: "testNamespace", + } + vb := &volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + ah.configVolume(vb) + found := false + for _, v := range vb.volumes { + if v.Name == aiOpsVolume && v.VolumeSource.HostPath.Path == fmt.Sprintf(aiOpsVolumeHostPath, ah.Namespace, + ah.WorkloadName) { + found = true + } + } + assert.Equal(t, true, found) +} + +func TestBiHostPathConfig(t *testing.T) { + bh := biHostPathConfig{ + WorkloadName: "testWorkload", + Namespace: "testNamespace", + } + vb := &volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + bh.configVolume(vb) + found := false + for _, v := range vb.volumes { + if v.Name == biLogsVolume && v.VolumeSource.HostPath.Path == fmt.Sprintf(biLogVolumeHostPath, bh.Namespace, + bh.WorkloadName) { + found = true + } + } + assert.Equal(t, true, found) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/min_instance_alarm.go b/yuanrong/pkg/functionscaler/instancepool/min_instance_alarm.go new file mode 100644 index 0000000..a10ded5 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/min_instance_alarm.go @@ -0,0 +1,197 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "fmt" + "os" + "strconv" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +var ( + checkInterval time.Duration + defaultStartInterval = 15 + defaultMinInsCheckInterval = 15 + insufficientNumMap = make(map[int64]int64) +) + +const uint64Base = 10 + +func (pm *PoolManager) checkMinInsRegularly(stopCh <-chan struct{}) { + startIntervalNum := config.GlobalConfig.AlarmConfig.MinInsStartInterval + if startIntervalNum == 0 { + startIntervalNum = defaultStartInterval + } + startInterval := time.Duration(startIntervalNum) * time.Minute + checkIntervalNum := config.GlobalConfig.AlarmConfig.MinInsCheckInterval + if checkIntervalNum == 0 { + checkIntervalNum = defaultMinInsCheckInterval + } + checkInterval = time.Duration(checkIntervalNum) * time.Minute + + if stopCh == nil { + log.GetLogger().Errorf("stopCh is nil") + return + } + clusterID := config.GlobalConfig.ClusterID + if clusterID == "" { + log.GetLogger().Warnf("failed to get cluster ID") + } + + log.GetLogger().Infof("start to check min instance with after %v", startInterval) + log.GetLogger().Infof("start to check min instance with interval %v", checkInterval) + + time.AfterFunc(startInterval, func() { + ticker := time.NewTicker(checkInterval) + for { + select { + case <-stopCh: + log.GetLogger().Infof("stop check min instance") + ticker.Stop() + return + case <-ticker.C: + pm.judgeAndReport(clusterID) + } + } + }) +} + +func (pm *PoolManager) judgeAndReport(clusterID string) { + insufficientNum := pm.checkMinScale(clusterID) + if insufficientNum > 0 { + errInfo := fmt.Sprintf("Insufficient number of reserved instances: %d, cluster: %s", + insufficientNum, clusterID) + reportInsufficientAlarm(errInfo, insufficientNum) + log.GetLogger().Errorf(errInfo) + return + } else if insufficientNum == 0 { + log.GetLogger().Infof("clear ins Insufficient alarm") + clearInsufficientAlarm(clusterID) + return + } + log.GetLogger().Warnf("insufficientNum is under zero: %d", insufficientNum) +} + +func (pm *PoolManager) checkMinScale(clusterID string) int64 { + logger := log.GetLogger().With(zap.Any("clusterID", clusterID)) + logger.Infof("start to check min scale") + res := int64(0) + pm.RLock() + defer pm.RUnlock() + for funcKey, pool := range pm.instancePool { + funcLogger := logger.With(zap.Any("funcKey", funcKey)) + gi, ok := pool.(*GenericInstancePool) + if !ok { + logger.Warnf("is not a generic instance pool, skip") + continue + } + if gi.FuncSpec.InstanceMetaData.ScalePolicy == types.InstanceScalePolicyStaticFunction { + continue + } + gi.RLock() + for label, _ := range gi.insConfig { + res += pm.checkMinScaleForLabel(gi, label, funcLogger) + } + gi.RUnlock() + } + logger.Infof("finish to check min scale") + return res +} + +func (pm *PoolManager) checkMinScaleForLabel(gi *GenericInstancePool, resKey resspeckey.ResSpecKey, + logger api.FormatLogger) int64 { + res := int64(0) + if gi.reservedInstanceQueue[resKey] == nil || gi.insConfig[resKey] == nil { + logger.Warnf("resource %+v in this faas scheduler does not have reserved instance queues", resKey) + return res + } + if gi.insConfig[resKey].InstanceMetaData.MinInstance == 0 { + logger.Infof("funcKey: %s reserved ins expectNum is 0", gi.FuncSpec.FuncKey) + return res + } + if time.Now().Sub(gi.minScaleUpdatedTime) > checkInterval { + actualNum := gi.reservedInstanceQueue[resKey].GetInstanceNumber(true) + expectNum := int(gi.insConfig[resKey].InstanceMetaData.MinInstance) + logger.Infof("check min instance, currentNum: %d, expectNum: %d", actualNum, expectNum) + if actualNum < expectNum { + res += int64(expectNum) - int64(actualNum) + logger.Errorf("actual reserved instance num %d, configured minScale num %d", + actualNum, expectNum) + gi.minScaleAlarmSign[resKey.InvokeLabel] = true + } else if gi.minScaleAlarmSign[resKey.InvokeLabel] == true { + logger.Warnf("the number of reserved instances of reaches the set value %d", expectNum) + gi.minScaleAlarmSign[resKey.InvokeLabel] = false + } + } + return res +} + +func reportInsufficientAlarm(errMsg string, insufficientNum int64) { + alarmInfo := &alarm.LogAlarmInfo{ + AlarmID: alarm.InsufficientMinInstance00001, + AlarmName: "InsufficientMinInstance", + AlarmLevel: alarm.Level2, + } + alarmDetail := &alarm.Detail{ + SourceTag: os.Getenv(constant.PodNameEnvKey) + "|" + os.Getenv(constant.PodIPEnvKey) + + "|" + os.Getenv(constant.ClusterName) + "|InsufficientMinInstance" + + strconv.FormatInt(insufficientNum, uint64Base), + OpType: alarm.GenerateAlarmLog, + Details: errMsg, + StartTimestamp: int(time.Now().Unix()), + EndTimestamp: 0, + } + + if _, ok := insufficientNumMap[insufficientNum]; !ok { + insufficientNumMap[insufficientNum] = 1 + } + alarm.ReportOrClearAlarm(alarmInfo, alarmDetail) +} + +func clearInsufficientAlarm(clusterID string) { + alarmInfo := &alarm.LogAlarmInfo{ + AlarmID: alarm.InsufficientMinInstance00001, + AlarmName: "InsufficientMinInstance", + AlarmLevel: alarm.Level2, + } + alarmDetail := &alarm.Detail{ + OpType: alarm.ClearAlarmLog, + Details: fmt.Sprintf("The number of reserved instances is normal, cluster: %s", clusterID), + StartTimestamp: 0, + EndTimestamp: int(time.Now().Unix()), + } + + for k := range insufficientNumMap { + alarmDetail.SourceTag = os.Getenv(constant.PodNameEnvKey) + "|" + os.Getenv(constant.PodIPEnvKey) + + "|" + os.Getenv(constant.ClusterName) + "|InsufficientMinInstance" + strconv.FormatInt(k, uint64Base) + alarm.ReportOrClearAlarm(alarmInfo, alarmDetail) + delete(insufficientNumMap, k) + } +} diff --git a/yuanrong/pkg/functionscaler/instancepool/miscellaneous.go b/yuanrong/pkg/functionscaler/instancepool/miscellaneous.go new file mode 100644 index 0000000..1042fd0 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/miscellaneous.go @@ -0,0 +1,310 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "os" + + "k8s.io/api/core/v1" + + "yuanrong/pkg/common/faas_common/urnutils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/dynamicconfigmanager" +) + +const ( + defaultCaasSysConfigSuffix = "-sys-config" + defaultCaasSysConfigVolumeName = "caas-sys-config" + defaultCaasSysConfigVolumeMountPath = "/opt/config/caas-sys-dynamic-config" + defaultWorkerStsSuffix = "-worker-sts" + defaultWorkerStsVolumeName = "sts-config" + defaultAgentStsSuffix = "-agent-sts" + defaultAgentStsVolumeName = "agent-sts-config" + defaultWorkerStsVolumeMountPath = "/opt/certs/HMSClientCloudAccelerateService/HMSCaaSYuanRongWorker/" + defaultAgentStsVolumeMountPath = "/opt/certs/WiseCloudElasticResourceService/ERSDataSystem/" +) + +type dataSystemSocket struct{} + +func (u *dataSystemSocket) configVolume(b *volumeBuilder) { + b.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: "datasystem-socket", + MountPath: "/home/uds", + }) +} + +type cgroupMemory struct{} + +func (f *cgroupMemory) configVolume(b *volumeBuilder) { + b.addVolume(v1.Volume{ + Name: "cgroup-memory", + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: "/sys/fs/cgroup/memory/kubepods/burstable", + }, + }, + }) + + mount := v1.VolumeMount{ + Name: "cgroup-memory", + MountPath: "/runtime/memory", + SubPathExpr: "pod$(POD_ID)", + } + b.addVolumeMount(containerRuntimeManager, mount) +} + +type dockerSocket struct{} + +func (v *dockerSocket) configVolume(b *volumeBuilder) { + b.addVolume(v1.Volume{ + Name: "docker-socket", + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: "/var/run/docker.sock", + }, + }, + }) + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{ + Name: "docker-socket", + MountPath: "/var/run/docker.sock", + }) +} + +type dockerRootDir struct{} + +func (v *dockerRootDir) configVolume(b *volumeBuilder) { + dockerRootPath := os.Getenv(config.DockerRootPathEnv) + b.addVolume(v1.Volume{ + Name: "docker-rootdir", + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: dockerRootPath, + }, + }, + }) + + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{ + Name: "docker-rootdir", + MountPath: "/var/lib/docker", + }) +} + +type dataVolume struct{} + +func (c *dataVolume) configVolume(b *volumeBuilder) { + b.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: DefaultDataVolumeName, + MountPath: "/opt/data", + }) +} + +type faasAgentSts struct { + crName string +} + +func (f *faasAgentSts) configVolume(b *volumeBuilder) { + securityFileMode := int32(urnutils.OwnerReadWrite) + b.addVolume(v1.Volume{ + Name: defaultAgentStsVolumeName, + VolumeSource: v1.VolumeSource{ + Secret: &v1.SecretVolumeSource{ + DefaultMode: &securityFileMode, + SecretName: f.crName + defaultAgentStsSuffix, + Items: []v1.KeyToPath{ + { + Key: "a", + Path: "a", + }, + { + Key: "b", + Path: "b", + }, + { + Key: "c", + Path: "c", + }, + { + Key: "d", + Path: "d", + }, + { + Key: "ERSDataSystem.ini", + Path: "ERSDataSystem.ini", + }, + { + Key: "ERSDataSystem.sts.p12", + Path: "ERSDataSystem.sts.p12", + }, + }, + }, + }, + }) + b.addVolumeMount(containerFunctionAgent, v1.VolumeMount{ + Name: defaultAgentStsVolumeName, + MountPath: defaultAgentStsVolumeMountPath + "ERSDataSystem/apple/a", + SubPath: "a", + }) + b.addVolumeMount(containerFunctionAgent, v1.VolumeMount{ + Name: defaultAgentStsVolumeName, + MountPath: defaultAgentStsVolumeMountPath + "ERSDataSystem/boy/b", + SubPath: "b", + }) + b.addVolumeMount(containerFunctionAgent, v1.VolumeMount{ + Name: defaultAgentStsVolumeName, + MountPath: defaultAgentStsVolumeMountPath + "ERSDataSystem/cat/c", + SubPath: "c", + }) + b.addVolumeMount(containerFunctionAgent, v1.VolumeMount{ + Name: defaultAgentStsVolumeName, + MountPath: defaultAgentStsVolumeMountPath + "ERSDataSystem/dog/d", + SubPath: "d", + }) + b.addVolumeMount(containerFunctionAgent, v1.VolumeMount{ + Name: defaultAgentStsVolumeName, + MountPath: defaultAgentStsVolumeMountPath + "ERSDataSystem.ini", + SubPath: "ERSDataSystem.ini", + }) + b.addVolumeMount(containerFunctionAgent, v1.VolumeMount{ + Name: defaultAgentStsVolumeName, + MountPath: defaultAgentStsVolumeMountPath + "ERSDataSystem.sts.p12", + SubPath: "ERSDataSystem.sts.p12", + }) +} + +type dynamicConfig struct { + crName string + enable bool +} + +func (d *dynamicConfig) configVolume(vb *volumeBuilder) { + if !d.enable { + return + } + securityFileMode := int32(urnutils.OwnerReadWrite) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: dynamicconfigmanager.DynamicConfigMapName, + MountPath: dynamicconfigmanager.DefaultDynamicConfigPath, + }) + vb.addVolume(v1.Volume{ + Name: dynamicconfigmanager.DynamicConfigMapName, + VolumeSource: v1.VolumeSource{ + ConfigMap: &v1.ConfigMapVolumeSource{ + LocalObjectReference: v1.LocalObjectReference{ + Name: d.crName + dynamicconfigmanager.DynamicConfigSuffix, + }, + DefaultMode: &securityFileMode, + }, + }, + }) +} + +type caasSysConfig struct { + crName string +} + +func (c *caasSysConfig) configVolume(vb *volumeBuilder) { + securityFileMode := int32(urnutils.OwnerReadWrite) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: defaultCaasSysConfigVolumeName, + MountPath: defaultCaasSysConfigVolumeMountPath, + }) + vb.addVolume(v1.Volume{ + Name: defaultCaasSysConfigVolumeName, + VolumeSource: v1.VolumeSource{ + ConfigMap: &v1.ConfigMapVolumeSource{ + LocalObjectReference: v1.LocalObjectReference{ + Name: c.crName + defaultCaasSysConfigSuffix, + }, + DefaultMode: &securityFileMode, + }, + }, + }) +} + +type functionDefaultConfig struct { + configName string + mount v1.VolumeMount +} + +func (f *functionDefaultConfig) configVolume(vb *volumeBuilder) { + securityFileMode := int32(urnutils.DefaultMode) + vb.addVolumeMount(containerDelegate, f.mount) + vb.addVolume(v1.Volume{ + Name: f.mount.Name, + VolumeSource: v1.VolumeSource{ + ConfigMap: &v1.ConfigMapVolumeSource{ + LocalObjectReference: v1.LocalObjectReference{ + Name: f.configName, + }, + DefaultMode: &securityFileMode, + }, + }, + }) +} + +type npuDriver struct { +} + +func (nh *npuDriver) configVolume(vb *volumeBuilder) { + hostPathDirectoryOrCreate := v1.HostPathDirectoryOrCreate + vb.addVolume(v1.Volume{ + Name: "npu-driver", + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: "/var/queue_schedule", + Type: &hostPathDirectoryOrCreate, + }, + }, + }) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: "npu-driver", + MountPath: "/var/queue_schedule", + }) +} + +type ascendConfig struct{} + +func (ad *ascendConfig) configVolume(vb *volumeBuilder) { + var ( + ascendDriverPathVolumeName = "ascend-driver-path" + ascendConfigVolumeName = "ascend-config" + ) + hostPathFile := v1.HostPathFile + vb.addVolume(v1.Volume{ + Name: "ascend-npu-smi", + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: "/usr/local/sbin/npu-smi", + Type: &hostPathFile, + }, + }, + }) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: ascendDriverPathVolumeName, + MountPath: "/usr/local/Ascend/driver", + ReadOnly: true, + }) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: ascendConfigVolumeName, + MountPath: "/opt/config/ascend_config", + }) + vb.addVolumeMount(containerDelegate, v1.VolumeMount{ + Name: "ascend-npu-smi", + MountPath: "/usr/local/sbin/npu-smi", + }) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/miscellaneous_test.go b/yuanrong/pkg/functionscaler/instancepool/miscellaneous_test.go new file mode 100644 index 0000000..dc8f6d6 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/miscellaneous_test.go @@ -0,0 +1,175 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/dynamicconfigmanager" +) + +func TestUnixSocketMethod(t *testing.T) { + us := dataSystemSocket{} + vb := volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + us.configVolume(&vb) + found := false + for _, vm := range vb.mounts[containerDelegate] { + if vm.Name == "datasystem-socket" && vm.MountPath == "/home/uds" { + found = true + } + } + assert.Equal(t, true, found) +} + +func TestCgroupMemoryMethod(t *testing.T) { + cm := cgroupMemory{} + vb := volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + cm.configVolume(&vb) + found := false + for _, v := range vb.volumes { + if v.Name == "cgroup-memory" && v.VolumeSource.HostPath.Path == "/sys/fs/cgroup/memory/kubepods/burstable" { + found = true + } + } + assert.Equal(t, true, found) + found = false + for _, vm := range vb.mounts[containerRuntimeManager] { + if vm.Name == "cgroup-memory" && vm.MountPath == "/runtime/memory" { + found = true + } + } + assert.Equal(t, true, found) +} + +func TestDockerSocketMethod(t *testing.T) { + ds := dockerSocket{} + vb := volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + ds.configVolume(&vb) + found := false + for _, v := range vb.volumes { + if v.Name == "docker-socket" && v.VolumeSource.HostPath.Path == "/var/run/docker.sock" { + found = true + } + } + assert.Equal(t, true, found) + found = false + for _, vm := range vb.mounts[containerRuntimeManager] { + if vm.Name == "docker-socket" && vm.MountPath == "/var/run/docker.sock" { + found = true + } + } + assert.Equal(t, true, found) +} + +func TestDockerRootDirMethod(t *testing.T) { + dr := dockerRootDir{} + vb := volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + dockerRootPath := "/test/docker/root/path" + os.Setenv(config.DockerRootPathEnv, dockerRootPath) + dr.configVolume(&vb) + found := false + for _, v := range vb.volumes { + if v.Name == "docker-rootdir" && v.VolumeSource.HostPath.Path == dockerRootPath { + found = true + } + } + assert.Equal(t, true, found) + found = false + for _, vm := range vb.mounts[containerRuntimeManager] { + if vm.Name == "docker-rootdir" && vm.MountPath == "/var/lib/docker" { + found = true + } + } + assert.Equal(t, true, found) +} + +func TestDynamicConfigMethod(t *testing.T) { + dc := dynamicConfig{ + crName: "testDynamicConfig", + enable: true, + } + vb := volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + dc.configVolume(&vb) + found := false + for _, v := range vb.volumes { + if v.Name == dynamicconfigmanager.DynamicConfigMapName && + v.VolumeSource.ConfigMap.LocalObjectReference.Name == dc.crName+dynamicconfigmanager.DynamicConfigSuffix { + found = true + } + } + assert.Equal(t, true, found) + found = false + for _, vm := range vb.mounts[containerDelegate] { + if vm.Name == dynamicconfigmanager.DynamicConfigMapName && + vm.MountPath == dynamicconfigmanager.DefaultDynamicConfigPath { + found = true + } + } + assert.Equal(t, true, found) +} + +func TestAscendConfig(t *testing.T) { + sc := ascendConfig{} + vb := volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + sc.configVolume(&vb) + found := false + for _, v := range vb.volumes { + if v.Name == "ascend-npu-smi" && v.VolumeSource.HostPath.Path == "/usr/local/sbin/npu-smi" { + found = true + } + } + assert.Equal(t, true, found) +} + +func TestConfigVolume2(t *testing.T) { + vb := volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + nd := npuDriver{} + nd.configVolume(&vb) + cs := caasSysConfig{} + cs.configVolume(&vb) + dv := dataVolume{} + dv.configVolume(&vb) + + assert.Equal(t, 2, len(vb.volumes)) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/operatekube.go b/yuanrong/pkg/functionscaler/instancepool/operatekube.go new file mode 100644 index 0000000..e012ec3 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/operatekube.go @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "fmt" + + "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + + "yuanrong/pkg/functionscaler/types" +) + +const ( + volumeStsConfig = "runtime-sts-config" + volumeCertVolume = "runtime-certs-volume" + baseCertFilePath = "/opt/certs" + + // CertMode - + CertMode = 384 // 600:rw- --- --- +) + +type stsSecret struct { + param *types.FunctionSpecification + req types.PodRequest +} + +func (v *stsSecret) configVolume(b *volumeBuilder) { + serviceName := v.param.StsMetaData.ServiceName + microServiceName := v.param.StsMetaData.MicroService + certVolumeQuality := resource.MustParse("10Mi") + securityFileMode := int32(CertMode) + b.addVolume(v1.Volume{Name: volumeStsConfig, VolumeSource: v1.VolumeSource{ + Secret: &v1.SecretVolumeSource{ + SecretName: v.req.FunSvcID, + DefaultMode: &securityFileMode, + Items: []v1.KeyToPath{{Key: "a", Path: "a"}, {Key: "b", Path: "b"}, + {Key: "c", Path: "c"}, {Key: "d", Path: "d"}, + {Key: microServiceName + ".ini", Path: microServiceName + ".ini"}, + {Key: microServiceName + ".sts.p12", Path: microServiceName + ".sts.p12"}}}}}) + + b.addVolume(v1.Volume{Name: volumeCertVolume, VolumeSource: v1.VolumeSource{ + EmptyDir: &v1.EmptyDirVolumeSource{SizeLimit: &certVolumeQuality, Medium: v1.StorageMediumMemory}}}) + + // a/b/c/d - four sub-files of the rootKey + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{Name: volumeCertVolume, + MountPath: fmt.Sprintf("%s/%s/%s/", baseCertFilePath, serviceName, microServiceName)}) + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{Name: volumeStsConfig, + MountPath: fmt.Sprintf("%s/%s/%s/%s/apple/a", baseCertFilePath, serviceName, microServiceName, + microServiceName), + SubPath: "a", + }) + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{Name: volumeStsConfig, + MountPath: fmt.Sprintf("%s/%s/%s/%s/boy/b", baseCertFilePath, serviceName, microServiceName, microServiceName), + SubPath: "b", + }) + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{Name: volumeStsConfig, + MountPath: fmt.Sprintf("%s/%s/%s/%s/cat/c", baseCertFilePath, serviceName, microServiceName, microServiceName), + SubPath: "c", + }) + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{Name: volumeStsConfig, + MountPath: fmt.Sprintf("%s/%s/%s/%s/dog/d", baseCertFilePath, serviceName, microServiceName, microServiceName), + SubPath: "d", + }) + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{Name: volumeStsConfig, + MountPath: fmt.Sprintf("%s/%s/%s/%s.ini", baseCertFilePath, serviceName, microServiceName, microServiceName), + SubPath: microServiceName + ".ini", + }) + b.addVolumeMount(containerRuntimeManager, v1.VolumeMount{Name: volumeStsConfig, + MountPath: fmt.Sprintf("%s/%s/%s/%s.sts.p12", baseCertFilePath, serviceName, microServiceName, + microServiceName), + SubPath: microServiceName + ".sts.p12", + }) +} + +func configEnv(b *envBuilder, env map[string]string) { + for k, v := range env { + b.addEnvVar(containerDelegate, v1.EnvVar{ + Name: k, + Value: v, + }) + } +} diff --git a/yuanrong/pkg/functionscaler/instancepool/operatekube_test.go b/yuanrong/pkg/functionscaler/instancepool/operatekube_test.go new file mode 100644 index 0000000..6388c24 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/operatekube_test.go @@ -0,0 +1,75 @@ +package instancepool + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/types" +) + +func TestConfigVolume(t *testing.T) { + serviceName := "testService" + microServiceName := "testMicroService" + ss := stsSecret{ + param: &types.FunctionSpecification{ + StsMetaData: commonTypes.StsMetaData{ + ServiceName: serviceName, + MicroService: microServiceName, + }, + }, + req: types.PodRequest{ + FunSvcID: "testFunc", + }, + } + vb := volumeBuilder{ + volumes: make([]corev1.Volume, 0), + mounts: make(map[container][]corev1.VolumeMount), + } + ss.configVolume(&vb) + foundBits := 0 + for _, v := range vb.volumes { + if v.Name == volumeStsConfig && v.VolumeSource.Secret.SecretName == "testFunc" { + foundBits |= 0b0001 + } + if v.Name == volumeCertVolume { + foundBits |= 0b0010 + } + } + assert.Equal(t, 0b0011, foundBits) + foundBits = 0 + for _, vm := range vb.mounts[containerRuntimeManager] { + if vm.Name == volumeCertVolume && vm.MountPath == fmt.Sprintf("%s/%s/%s/", baseCertFilePath, serviceName, + microServiceName) { + foundBits |= 0b0000001 + } + if vm.Name == volumeStsConfig && vm.MountPath == fmt.Sprintf("%s/%s/%s/%s/apple/a", baseCertFilePath, + serviceName, microServiceName, microServiceName) { + foundBits |= 0b0000010 + } + if vm.Name == volumeStsConfig && vm.MountPath == fmt.Sprintf("%s/%s/%s/%s/boy/b", baseCertFilePath, + serviceName, microServiceName, microServiceName) { + foundBits |= 0b0000100 + } + if vm.Name == volumeStsConfig && vm.MountPath == fmt.Sprintf("%s/%s/%s/%s/cat/c", baseCertFilePath, + serviceName, microServiceName, microServiceName) { + foundBits |= 0b0001000 + } + if vm.Name == volumeStsConfig && vm.MountPath == fmt.Sprintf("%s/%s/%s/%s/dog/d", baseCertFilePath, + serviceName, microServiceName, microServiceName) { + foundBits |= 0b0010000 + } + if vm.Name == volumeStsConfig && vm.MountPath == fmt.Sprintf("%s/%s/%s/%s.ini", baseCertFilePath, + serviceName, microServiceName, microServiceName) { + foundBits |= 0b0100000 + } + if vm.Name == volumeStsConfig && vm.MountPath == fmt.Sprintf("%s/%s/%s/%s.sts.p12", baseCertFilePath, + serviceName, microServiceName, microServiceName) { + foundBits |= 0b1000000 + } + } + assert.Equal(t, 0b1111111, foundBits) +} diff --git a/yuanrong/pkg/functionscaler/instancepool/poolmanager.go b/yuanrong/pkg/functionscaler/instancepool/poolmanager.go new file mode 100644 index 0000000..a10a954 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/poolmanager.go @@ -0,0 +1,573 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/urnutils" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/dynamicconfigmanager" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/stateinstance" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + // SLA time should not be shorter than cold start time + minSLATime = time.Duration(500) * time.Millisecond +) + +type faasManagerInfo struct { + funcKey string + instanceID string +} + +// PoolManager manages instance pools of different functions +type PoolManager struct { + faasManagerInfo faasManagerInfo + instancePool map[string]InstancePool + // instanceRecord is used to find instancePool for a specific instance watched from etcd, because router etcd will + instanceConfigRecord map[string]map[string]*instanceconfig.Configuration + stateLeaseManager map[string]*stateinstance.Leaser // key instanceID + leaseInterval time.Duration + stopCh <-chan struct{} + sync.RWMutex +} + +// GetLeaseInterval it's for state instance +func (pm *PoolManager) GetLeaseInterval() time.Duration { + return pm.leaseInterval +} + +func getScaleDownWindow() time.Duration { + scaleUpWindow := time.Duration(config.GlobalConfig.AutoScaleConfig.SLAQuota) * time.Millisecond + if scaleUpWindow < minSLATime { + scaleUpWindow = minSLATime + } + scaleDownTime := time.Duration(config.GlobalConfig.AutoScaleConfig.ScaleDownTime) * time.Millisecond + if scaleDownTime < scaleUpWindow { + scaleDownTime = scaleUpWindow + } + return scaleDownTime +} + +// NewPoolManager creates a PoolManager +func NewPoolManager(stopCh <-chan struct{}) *PoolManager { + leaseInterval := time.Duration(config.GlobalConfig.LeaseSpan) * time.Millisecond + if leaseInterval < types.MinLeaseInterval { + leaseInterval = types.MinLeaseInterval + } + pm := &PoolManager{ + faasManagerInfo: faasManagerInfo{}, + instancePool: make(map[string]InstancePool, utils.DefaultMapSize), + instanceConfigRecord: make(map[string]map[string]*instanceconfig.Configuration, utils.DefaultMapSize), + stateLeaseManager: make(map[string]*stateinstance.Leaser), + leaseInterval: leaseInterval, + stopCh: stopCh, + } + return pm +} + +// RecoverInstancePool recover instancePool data +// if fail to recover, see faaSScheduler as restarting +// precondition: make sure that FunctionRegistry.funcSpecs is up-to-date +func (pm *PoolManager) RecoverInstancePool() { + s := state.GetState() + log.GetLogger().Infof("ready to recovery instance pool") + var wg sync.WaitGroup + registry.GlobalRegistry.InstanceRegistry.WaitForETCDList() + var deleteFunctions []*types.FunctionSpecification + for funcKey, val := range s.InstancePool { + var InstancePoolState *types.InstancePoolState + commonUtils.DeepCopyObj(val, &InstancePoolState) + funcSpec := registry.GlobalRegistry.GetFuncSpec(funcKey) + // guarantee that the recovered data is the same as etcd data + var deleteFuncFlag bool + if funcSpec == nil { + // it means function has been deleted while recovering + // faasscheduler need to receive etcd delete event to delete function + deleteFuncFlag = true + funcSpec = &types.FunctionSpecification{FuncKey: funcKey} + deleteFunctions = append(deleteFunctions, funcSpec) + } + pm.RLock() + _, exist := pm.instancePool[funcSpec.FuncKey] + pm.RUnlock() + if !exist { + if _, err := pm.processInstancePoolCreate(funcSpec); err != nil { + continue + } + } + log.GetLogger().Infof("now recover function :%s", funcSpec.FuncKey) + wg.Add(1) + go func() { + log.GetLogger().Infof("recover func: %s, isStateful: %v, deleteFuncFlag: %v", + funcSpec.FuncKey, funcSpec.FuncMetaData.IsStatefulFunction, deleteFuncFlag) + pm.instancePool[funcSpec.FuncKey].RecoverInstance(funcSpec, InstancePoolState, deleteFuncFlag, &wg) + pm.recoverStateLeaser(InstancePoolState.StateInstance, funcSpec) + }() + } + wg.Wait() + for _, function := range deleteFunctions { + pm.HandleFunctionEvent(registry.SubEventTypeDelete, function) + } +} + +func (pm *PoolManager) recoverStateLeaser(stateInstance map[string]*types.Instance, + funcSpec *types.FunctionSpecification) { + for stateID, instance := range stateInstance { + pool := pm.instancePool[funcSpec.FuncKey] + if pool == nil { + log.GetLogger().Errorf("func %s, stateid %s: pool is nil! can't create leaser manager!", + funcSpec.FuncKey, stateID) + } else { + if instance.InstanceStatus.Code == int32(-1) { + log.GetLogger().Warnf("instance of stateID %s is existed!", stateID) + continue + } + log.GetLogger().Infof("func %s, stateId %s: create leaser manager for recovery, downwin %v", + funcSpec.FuncKey, stateID, getScaleDownWindow()) + leaser := stateinstance.NewLeaser(funcSpec.InstanceMetaData.ConcurrentNum, + pool.DeleteStateInstance, stateID, instance.InstanceID, getScaleDownWindow()) + pm.Lock() + leaser.Recover() + pm.stateLeaseManager[instance.InstanceID] = leaser + pm.Unlock() + } + } +} + +// GetAndDeleteState delete state and return whether the state exists +func (pm *PoolManager) GetAndDeleteState(stateID string, funcKey string, funcSpec *types.FunctionSpecification, + logger api.FormatLogger) bool { + pm.Lock() + pool, exist := pm.instancePool[funcKey] + if !exist { + var err error + pool, err = NewGenericInstancePool(funcSpec, pm.faasManagerInfo) + if err != nil { + pm.Unlock() + return false + } + pm.instancePool[funcKey] = pool + } + pm.Unlock() + + logger.Infof("getAndDeleteStateInstance, stateKey %s", stateID) + return pool.GetAndDeleteState(stateID) +} + +// CreateInstance will create an instance of a specific function +func (pm *PoolManager) CreateInstance(insCrtReq *types.InstanceCreateRequest) (*types.Instance, snerror.SNError) { + pm.RLock() + pool, exist := pm.instancePool[insCrtReq.FuncSpec.FuncKey] + pm.RUnlock() + if !exist { + return nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, "pool not exist") + } + return pool.CreateInstance(insCrtReq) +} + +// DeleteInstance will delete an instance of a specific function +func (pm *PoolManager) DeleteInstance(instance *types.Instance) snerror.SNError { + pm.RLock() + pool, exist := pm.instancePool[instance.FuncKey] + pm.RUnlock() + if !exist { + return snerror.New(statuscode.FuncMetaNotFoundErrCode, "pool not exist") + } + return pool.DeleteInstance(instance) +} + +// AcquireInstanceThread will acquire a instance thread of a specific function +func (pm *PoolManager) AcquireInstanceThread(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + pm.RLock() + pool, exist := pm.instancePool[insAcqReq.FuncSpec.FuncKey] + pm.RUnlock() + if !exist { + return nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, "pool not exist") + } + insAlloc, acquireErr := pool.AcquireInstance(insAcqReq) + if acquireErr != nil { + return nil, acquireErr + } + return insAlloc, nil +} + +// ReleaseStateThread - +func (pm *PoolManager) ReleaseStateThread(thread *types.InstanceAllocation) error { + leaser := pm.stateLeaseManager[thread.Instance.InstanceID] + if leaser == nil { + return errors.New("leaser is nil") + } + leaseID, err := getLeaseIdFromAllocationId(thread.AllocationID) + if err != nil { + return err + } + leaser.ReleaseLease(leaseID) + return nil +} + +// RetainStateThread - +func (pm *PoolManager) RetainStateThread(thread *types.InstanceAllocation) error { + leaser := pm.stateLeaseManager[thread.Instance.InstanceID] + if leaser == nil { + return errors.New("leaser is nil") + } + leaseID, err := getLeaseIdFromAllocationId(thread.AllocationID) + if err != nil { + return err + } + return leaser.RetainLease(leaseID, pm.leaseInterval) +} + +func getLeaseIdFromAllocationId(allocationId string) (int, error) { + parts := strings.Split(allocationId, "-") + if len(parts) < 1 { // %s-stateThread%d, eg:0600a7ba-cfc0-4a00-8000-0000000004a1-stateThread1 + return 0, fmt.Errorf("thread.AllocationID: %s invalid", allocationId) + } + numStr := strings.Replace(parts[len(parts)-1], "stateThread", "", 1) + leaseID, err := strconv.Atoi(numStr) + if err != nil { + return 0, fmt.Errorf("thread.AllocationID: %s invalid, err %v", allocationId, err) + } + return leaseID, nil +} + +// ReleaseInstanceThread will release a instance thread of a specific function +func (pm *PoolManager) ReleaseInstanceThread(insAlloc *types.InstanceAllocation) { + instance := insAlloc.Instance + pm.RLock() + pool, exist := pm.instancePool[instance.FuncKey] + pm.RUnlock() + if !exist { + log.GetLogger().Errorf("instance pool for function %s doesn't exist", instance.FuncKey) + return + } + pool.ReleaseInstance(insAlloc) +} + +// ReleaseAbnormalInstance will release an abnormal instance of a specific function +func (pm *PoolManager) ReleaseAbnormalInstance(instance *types.Instance) { + pm.RLock() + pool, exist := pm.instancePool[instance.FuncKey] + pm.RUnlock() + if !exist { + log.GetLogger().Warnf("instance pool for function %s doesn't exist", instance.FuncKey) + return + } + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + pool.HandleInstanceEvent(registry.SubEventTypeRemove, instance) + } +} + +// HandleFunctionEvent handles function event +func (pm *PoolManager) HandleFunctionEvent(eventType registry.EventType, funcSpec *types.FunctionSpecification) { + log.GetLogger().Infof("handling function event type %s for function %s", eventType, funcSpec.FuncKey) + if eventType == registry.SubEventTypeUpdate { + pm.RLock() + pool, poolExist := pm.instancePool[funcSpec.FuncKey] + pm.RUnlock() + var err error + if !poolExist { + if pool, err = pm.processInstancePoolCreate(funcSpec); err != nil { + return + } + } + go handleK8sResourceUpdate(funcSpec) + if poolExist { + pool.HandleFunctionEvent(eventType, funcSpec) + } + // 注意:这里的读锁需要把HandleInstanceConfigEvent包含在内 + pm.RLock() + insConfigs, insConfigExist := pm.instanceConfigRecord[funcSpec.FuncKey] + if insConfigExist { + for _, insConfig := range insConfigs { + pool.HandleInstanceConfigEvent(eventType, insConfig) + } + } + pm.RUnlock() + } + if eventType == registry.SubEventTypeDelete { + pm.RLock() + pool, poolExist := pm.instancePool[funcSpec.FuncKey] + pm.RUnlock() + if poolExist { + pm.processInstancePoolDelete(funcSpec) + pool.HandleFunctionEvent(eventType, funcSpec) + } + go handleK8sResourceDelete(funcSpec) + metrics.ClearMetricsForFunction(funcSpec) + } + + if eventType == registry.SubEventTypeSynced { + log.GetLogger().Infof("send function synced event") + registry.GlobalRegistry.FunctionRegistry.FinishEtcdList() + } +} + +func (pm *PoolManager) processInstancePoolCreate(funcSpec *types.FunctionSpecification) (InstancePool, error) { + if funcSpec.FuncMetaData.VPCTriggerImage != "" && + config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeKernel { + pm.RLock() + faasMgrInfo := pm.faasManagerInfo + pm.RUnlock() + go handlePullTriggerCreate(faasMgrInfo, funcSpec) + } + var pool InstancePool + var err error + pool, err = NewGenericInstancePool(funcSpec, pm.faasManagerInfo) + if err != nil { + log.GetLogger().Errorf("failed to create instance pool of function %s error %s", funcSpec.FuncKey, err.Error()) + return nil, err + } + pm.Lock() + pm.instancePool[funcSpec.FuncKey] = pool + pm.Unlock() + return pool, nil +} + +func (pm *PoolManager) processInstancePoolDelete(funcSpec *types.FunctionSpecification) { + if funcSpec.FuncMetaData.VPCTriggerImage != "" && + config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeKernel { + pm.RLock() + faasMgrInfo := pm.faasManagerInfo + pm.RUnlock() + go handlePullTriggerDelete(faasMgrInfo, funcSpec) + } + pm.Lock() + delete(pm.instancePool, funcSpec.FuncKey) + delete(pm.instanceConfigRecord, funcSpec.FuncKey) + pm.Unlock() + state.Update(&types.InstancePoolStateInput{ + FuncKey: funcSpec.FuncKey, + }, types.StateDelete) +} + +// HandleInstanceEvent handles instance event +func (pm *PoolManager) HandleInstanceEvent(eventType registry.EventType, insSpec *commonTypes.InstanceSpecification) { + items := strings.Split(insSpec.Function, utils.FuncKeyDelimiter) + if len(items) != utils.ValidFuncKeyLen { + return + } + if utils.IsFaaSManager(items[1]) { + pm.faasManagerInfo.funcKey = insSpec.Function + pm.faasManagerInfo.instanceID = insSpec.InstanceID + log.GetLogger().Infof("set faas manager info to %v", pm.faasManagerInfo) + for _, pool := range pm.instancePool { + pool.HandleFaaSManagerUpdate(pm.faasManagerInfo) + } + return + } + logger := log.GetLogger().With(zap.Any("InstanceID", insSpec.InstanceID)) + logger.Infof("handling instance event, type %v, status %+v", eventType, insSpec.InstanceStatus) + pm.Lock() + pool, exist := pm.instancePool[insSpec.CreateOptions[types.FunctionKeyNote]] + if !exist { + logger.Warnf("instance pool for instance doesn't exist") + if eventType != registry.SubEventTypeDelete { + go DeleteUnexpectInstance(insSpec.ParentID, insSpec.InstanceID, + insSpec.CreateOptions[types.FunctionKeyNote], logger) + } + pm.Unlock() + return + } + pm.Unlock() + instance := utils.BuildInstanceFromInsSpec(insSpec, pool.GetFuncSpec()) + switch eventType { + case registry.SubEventTypeUpdate: + metrics.EnsureLeaseRequestTotal(instance.MetricLabelValues) + metrics.EnsureConcurrencyGaugeWithLabel(pool.GetFuncSpec().FuncKey, + instance.ResKey.InvokeLabel, instance.MetricLabelValues) + case registry.SubEventTypeDelete: + metrics.ClearLeaseRequestTotal(instance.MetricLabelValues) + metrics.ClearConcurrencyGaugeWithLabel(instance.MetricLabelValues) + default: + logger.Debugf("no need update prometheus metric") + } + pm.updateStateLeaseMgrForHandleInstanceEvent(eventType, instance) + pool.HandleInstanceEvent(eventType, instance) +} + +func (pm *PoolManager) updateStateLeaseMgrForHandleInstanceEvent(eventType registry.EventType, + instance *types.Instance) { + if eventType == registry.SubEventTypeDelete || + (eventType == registry.SubEventTypeUpdate && instance.InstanceStatus.Code == 6) { // 6 FATAL + if stateLeaseManager, exist := pm.stateLeaseManager[instance.InstanceID]; exist { + log.GetLogger().Infof("terminate state lease manager for instance %s, funckey %s", + instance.InstanceID, instance.FuncKey) + stateLeaseManager.Terminate() + delete(pm.stateLeaseManager, instance.InstanceID) + } + } +} + +// HandleSchedulerManaged current scheduler now is supposed to manage the scheduler's instances +func (pm *PoolManager) HandleSchedulerManaged(eventType registry.EventType, + insSpec *commonTypes.InstanceSpecification) { + log.GetLogger().Infof("handling scheduler managed event type %s, schedulerID:%s", + eventType, insSpec.InstanceID) + pm.Lock() + for _, p := range pm.instancePool { + p.handleManagedChange() + p.HandleFaaSSchedulerEvent() + } + pm.Unlock() +} + +// HandleRolloutRatioChange 监听灰度比例变化 +func (pm *PoolManager) HandleRolloutRatioChange(ratio int) { + log.GetLogger().Infof("handling scheduler ratio change %d", ratio) + pm.Lock() + defer pm.Unlock() + for _, p := range pm.instancePool { + p.handleRatioChange(ratio) + } +} + +// HandleInstanceConfigEvent handles instance config event +func (pm *PoolManager) HandleInstanceConfigEvent(eventType registry.EventType, + insConfig *instanceconfig.Configuration) { + logger := log.GetLogger().With(zap.Any("FuncKey", insConfig.FuncKey)). + With(zap.Any("eventType", eventType)). + With(zap.Any("InstanceLabel", insConfig.InstanceLabel)) + logger.Infof("handling instance config event") + if eventType == registry.SubEventTypeSynced { + for _, pool := range pm.instancePool { + pool.CleanOrphansInstanceQueue() + } + } + pm.RLock() + pool, exist := pm.instancePool[insConfig.FuncKey] + pm.RUnlock() + if eventType == registry.SubEventTypeUpdate { + pm.Lock() + if _, ok := pm.instanceConfigRecord[insConfig.FuncKey]; !ok { + pm.instanceConfigRecord[insConfig.FuncKey] = make(map[string]*instanceconfig.Configuration, + constant.DefaultMapSize) + } + pm.instanceConfigRecord[insConfig.FuncKey][insConfig.InstanceLabel] = insConfig + pm.Unlock() + if !exist { + logger.Warnf("instance pool for function doesn't exist") + return + } + pool.HandleInstanceConfigEvent(eventType, insConfig) + } + if eventType == registry.SubEventTypeDelete { + pm.Lock() + if _, ok := pm.instanceConfigRecord[insConfig.FuncKey]; !ok { + pm.Unlock() + return + } + delete(pm.instanceConfigRecord[insConfig.FuncKey], insConfig.InstanceLabel) + gi, ok := pm.instancePool[insConfig.FuncKey] + pm.Unlock() + if ok { + gi.HandleInstanceConfigEvent(eventType, insConfig) + } + } +} + +// HandleAliasEvent handles instance config event +func (pm *PoolManager) HandleAliasEvent(eventType registry.EventType, aliasUrn string) { + logger := log.GetLogger().With(zap.Any("", "HandleAliasEvent")).With(zap.Any("urn", aliasUrn)) + logger.Infof("start") + var wg sync.WaitGroup + pm.RLock() + for funcKey, pool := range pm.instancePool { + tenantID := urnutils.GetTenantFromFuncKey(funcKey) + logger.Infof("handle alias event for funcKey %s, poolName %s, tenantID %s", funcKey, + pool.GetFuncSpec().FuncKey, tenantID) + if urnutils.CheckAliasUrnTenant(tenantID, aliasUrn) { + wg.Add(1) + go func(p InstancePool) { + defer wg.Done() + p.HandleAliasEvent(eventType, aliasUrn) + }(pool) + } + } + pm.RUnlock() + wg.Wait() + logger.Infof("finish") +} + +// ReportMetrics sends invoke metrics to instance pool of a specific function +func (pm *PoolManager) ReportMetrics(funcKey string, resKey resspeckey.ResSpecKey, + insMetrics *types.InstanceThreadMetrics) { + pm.RLock() + pool, exist := pm.instancePool[funcKey] + pm.RUnlock() + if !exist { + log.GetLogger().Errorf("failed to find pool for function %s", funcKey) + return + } + pool.UpdateInvokeMetrics(resKey, insMetrics) +} + +// CheckMinInsAndReport - +func (pm *PoolManager) CheckMinInsAndReport(stopCh <-chan struct{}) { + go pm.checkMinInsRegularly(stopCh) +} + +func handleK8sResourceUpdate(funcSpec *types.FunctionSpecification) { + dynamicconfigmanager.HandleUpdateFunctionEvent(funcSpec) +} + +func handleK8sResourceDelete(funcSpec *types.FunctionSpecification) { + dynamicconfigmanager.HandleDeleteFunctionEvent(funcSpec) + utils.DeleteConfigMapByFuncInfo(funcSpec) +} + +func getInstanceType(createOptions map[string]string) types.InstanceType { + instanceType, ok := createOptions[types.InstanceTypeNote] + if !ok { + return types.InstanceTypeUnknown + } + return types.InstanceType(instanceType) +} + +func getInstanceLabel(createOptions map[string]string) string { + instanceLabel, ok := createOptions[types.InstanceLabelNode] + if !ok { + return "" + } + return instanceLabel +} diff --git a/yuanrong/pkg/functionscaler/instancepool/poolmanager_test.go b/yuanrong/pkg/functionscaler/instancepool/poolmanager_test.go new file mode 100644 index 0000000..a7056c9 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/poolmanager_test.go @@ -0,0 +1,1181 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "context" + "encoding/json" + "errors" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + commonstate "yuanrong/pkg/common/faas_common/state" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/instancequeue" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/stateinstance" + "yuanrong/pkg/functionscaler/types" +) + +func TestNewPoolManager(t *testing.T) { + stopCh := make(<-chan struct{}) + poolManager := NewPoolManager(stopCh) + assert.Equal(t, stopCh, poolManager.stopCh) +} + +func mockInsAcqReq() *types.InstanceAcquireRequest { + insAcqReq := &types.InstanceAcquireRequest{} + insAcqReq.FuncSpec = &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 500, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 100, + }, + } + return insAcqReq +} + +func TestPoolManagerAcquireInstanceThread(t *testing.T) { + initRegistry() + convey.Convey("test PoolManagerAcquireInstanceThread", t, func() { + convey.Convey("success", func() { + patches := gomonkey.ApplyFunc((*GenericInstancePool).AcquireInstance, func(_ *GenericInstancePool, + insThdApp *types.InstanceAcquireRequest) (*types.InstanceAllocation, snerror.SNError) { + mockInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{ + CPU: 1000, + Memory: 1000, + }, + InstanceType: "", + InstanceID: "", + FuncKey: "", + ConcurrentNum: 0, + } + instanceThread := &types.InstanceAllocation{ + Instance: mockInstance, + AllocationID: "mock-thread-id-123", + } + return instanceThread, nil + }) + patches.ApplyFunc((*GenericInstancePool).ReleaseInstance, func(_ *GenericInstancePool, + instance *types.InstanceAllocation) { + return + }) + defer patches.Reset() + insAcqReq := mockInsAcqReq() + poolManager := NewPoolManager(make(chan struct{})) + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncKey: "testFunction", + }) + poolManager.instanceConfigRecord["testFunction"] = map[string]*instanceconfig.Configuration{ + DefaultInstanceLabel: { + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{}, + }, + } + _, err := poolManager.AcquireInstanceThread(insAcqReq) + assert.Equal(t, nil, err) + }) + convey.Convey("state miss match", func() { + insAcqReq := mockInsAcqReq() + insAcqReq.StateID = "testStateID" + poolManager := NewPoolManager(make(chan struct{})) + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncKey: "testFunction", + }) + _, err := poolManager.AcquireInstanceThread(insAcqReq) + convey.So(err.Code(), convey.ShouldEqual, statuscode.StateMismatch) + }) + convey.Convey("newGenericInstancePool error", func() { + patch := gomonkey.ApplyFunc(NewGenericInstancePool, func(funcSpec *types.FunctionSpecification, + faasManagerInfo faasManagerInfo) (InstancePool, error) { + return nil, errors.New("new pool error") + }) + defer patch.Reset() + insAcqReq := mockInsAcqReq() + poolManager := NewPoolManager(make(chan struct{})) + _, err := poolManager.AcquireInstanceThread(insAcqReq) + convey.So(err.Code(), convey.ShouldEqual, statuscode.FuncMetaNotFoundErrCode) + }) + convey.Convey("acquireInstanceThread error", func() { + var mockInstancePool *GenericInstancePool + patches := gomonkey.ApplyMethod(reflect.TypeOf(mockInstancePool), "AcquireInstance", + func(_ *GenericInstancePool, insThdApp *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + return nil, snerror.New(statuscode.StatusInternalServerError, "acquire instance thread error") + }) + defer patches.Reset() + insAcqReq := mockInsAcqReq() + poolManager := NewPoolManager(make(chan struct{})) + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncKey: "testFunction", + }) + _, err := poolManager.AcquireInstanceThread(insAcqReq) + convey.So(err.Code(), convey.ShouldEqual, statuscode.StatusInternalServerError) + }) + }) +} + +func TestPoolManagerReleaseInstanceThread(t *testing.T) { + initRegistry() + var mockInstancePool *GenericInstancePool + patches := gomonkey.ApplyMethod(reflect.TypeOf(mockInstancePool), "AcquireInstance", + func(_ *GenericInstancePool, insThdApp *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + mockInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{ + CPU: 1000, + Memory: 1000, + }, + InstanceType: "", + InstanceID: "", + FuncKey: "", + ConcurrentNum: 0, + } + instanceThread := &types.InstanceAllocation{ + Instance: mockInstance, + AllocationID: "mock-thread-id-123", + } + return instanceThread, nil + }) + patches.ApplyMethod(reflect.TypeOf(mockInstancePool), "ReleaseInstance", + func(_ *GenericInstancePool, instance *types.InstanceAllocation) { + return + }) + + defer patches.Reset() + poolManager := NewPoolManager(make(chan struct{})) + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncKey: "testFunction", + }) + insAcqReq := mockInsAcqReq() + poolManager.instanceConfigRecord["testFunction"] = map[string]*instanceconfig.Configuration{ + DefaultInstanceLabel: { + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{}, + }, + } + insAlloc, err := poolManager.AcquireInstanceThread(insAcqReq) + assert.Equal(t, nil, err) + + poolManager.ReleaseInstanceThread(insAlloc) +} + +func TestPoolManagerReleaseAbnormalInstance(t *testing.T) { + var mockInstancePool *GenericInstancePool + patches := gomonkey.ApplyMethod(reflect.TypeOf(mockInstancePool), "AcquireInstance", + func(_ *GenericInstancePool, insThdApp *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + mockInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{ + CPU: 1000, + Memory: 1000, + }, + InstanceType: "", + InstanceID: "", + FuncKey: "", + ConcurrentNum: 0, + } + instanceThread := &types.InstanceAllocation{ + Instance: mockInstance, + AllocationID: "mock-thread-id-123", + } + return instanceThread, nil + }) + patches.ApplyMethod(reflect.TypeOf(mockInstancePool), "HandleInstanceEvent", + func(_ *GenericInstancePool, eventType registry.EventType, instance *types.Instance) { + return + }) + + defer patches.Reset() + initRegistry() + poolManager := NewPoolManager(make(chan struct{})) + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncKey: "testFunction", + }) + insAcqReq := mockInsAcqReq() + poolManager.instanceConfigRecord["testFunction"] = map[string]*instanceconfig.Configuration{ + DefaultInstanceLabel: { + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{}, + }, + } + insAlloc, err := poolManager.AcquireInstanceThread(insAcqReq) + assert.Equal(t, nil, err) + poolManager.ReleaseAbnormalInstance(insAlloc.Instance) +} + +func TestHandleFunctionUpdate(t *testing.T) { + var mockInstancePool *GenericInstancePool + patches := gomonkey.ApplyMethod(reflect.TypeOf(mockInstancePool), "AcquireInstance", + func(_ *GenericInstancePool, insThdApp *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + mockInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{ + CPU: 1000, + Memory: 1000, + }, + InstanceType: "", + InstanceID: "", + FuncKey: "", + ConcurrentNum: 0, + } + instanceThread := &types.InstanceAllocation{ + Instance: mockInstance, + AllocationID: "mock-thread-id-123", + } + return instanceThread, nil + }) + patches.ApplyMethod(reflect.TypeOf(mockInstancePool), "ReleaseInstance", + func(_ *GenericInstancePool, instance *types.InstanceAllocation) { + return + }) + + defer patches.Reset() + initRegistry() + poolManager := NewPoolManager(make(chan struct{})) + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncKey: "testFunction", + }) + insAcqReq := mockInsAcqReq() + poolManager.instanceConfigRecord["testFunction"] = map[string]*instanceconfig.Configuration{ + DefaultInstanceLabel: { + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{}, + }, + } + _, err := poolManager.AcquireInstanceThread(insAcqReq) + assert.Equal(t, err, nil) + + funcSpec1 := &types.FunctionSpecification{ + FuncKey: "testFunction123456", + FuncCtx: context.TODO(), + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 700, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 100, + }, + } + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, funcSpec1) + assert.Equal(t, funcSpec1, poolManager.instancePool[funcSpec1.FuncKey].(*GenericInstancePool).FuncSpec) +} + +func TestHandleFunctionDelete(t *testing.T) { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + var mockInstancePool *GenericInstancePool + patches := gomonkey.ApplyMethod(reflect.TypeOf(mockInstancePool), "AcquireInstance", + func(_ *GenericInstancePool, insThdApp *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + mockInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{ + CPU: 1000, + Memory: 1000, + }, + InstanceType: "scaled", + InstanceID: "test-instance", + FuncKey: "test-function", + ConcurrentNum: 0, + } + instanceThread := &types.InstanceAllocation{ + Instance: mockInstance, + AllocationID: "mock-thread-id-123", + } + return instanceThread, nil + }) + patches.ApplyMethod(reflect.TypeOf(mockInstancePool), "ReleaseInstance", + func(_ *GenericInstancePool, instance *types.InstanceAllocation) { + return + }) + + defer patches.Reset() + initRegistry() + poolManager := NewPoolManager(make(chan struct{})) + insAcqReq := mockInsAcqReq() + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncKey: "testFunction", + }) + poolManager.instanceConfigRecord["testFunction"] = map[string]*instanceconfig.Configuration{ + DefaultInstanceLabel: { + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{}, + }, + } + _, err := poolManager.AcquireInstanceThread(insAcqReq) + assert.Equal(t, err, nil) + + funcSpec1 := &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncCtx: context.TODO(), + ResourceMetaData: commonTypes.ResourceMetaData{ + CPU: 500, + Memory: 700, + }, + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 100, + }, + } + assert.NotEqual(t, nil, poolManager.instancePool[funcSpec1.FuncKey]) + poolManager.HandleFunctionEvent(registry.SubEventTypeDelete, funcSpec1) + assert.Equal(t, nil, poolManager.instancePool[funcSpec1.FuncKey]) +} + +func TestHandleInstanceEvent(t *testing.T) { + patches := mockInstanceOperation() + defer unMockInstanceOperation(patches) + var mockInstancePool *GenericInstancePool = &GenericInstancePool{} + + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "instance1": mockInstancePool, + }, + } + convey.Convey("Func is error", t, func() { + pm.HandleInstanceEvent(registry.SubEventTypeUpdate, &commonTypes.InstanceSpecification{ + Function: "123456/funcName", + }) + }) + convey.Convey("update faasManager", t, func() { + pm.HandleInstanceEvent(registry.SubEventTypeUpdate, &commonTypes.InstanceSpecification{ + Function: "123456/0-system-faasmanager/$latest", + }) + convey.So(mockInstancePool.faasManagerInfo, convey.ShouldResemble, faasManagerInfo{ + funcKey: "123456/0-system-faasmanager/$latest", + }) + }) + testFunc := &commonTypes.InstanceSpecification{ + Function: "123456/0-system-testFunc/$latest", + InstanceID: "testFunc123", + } + convey.Convey("instanceRecord not exist", t, func() { + pm.HandleInstanceEvent(registry.SubEventTypeDelete, testFunc) + convey.So(pm.instancePool, convey.ShouldNotContainKey, testFunc.InstanceID) + }) + + convey.Convey("delete instanceRecord", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&GenericInstancePool{}), "HandleInstanceEvent", + func(_ *GenericInstancePool, eventType registry.EventType, instance *types.Instance) { + return + }).Reset() + + pm.HandleInstanceEvent(registry.SubEventTypeDelete, testFunc) + convey.So(pm.instancePool, convey.ShouldNotContainKey, testFunc.InstanceID) + }) + convey.Convey("delete leaseManager", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&GenericInstancePool{}), "HandleInstanceEvent", + func(_ *GenericInstancePool, eventType registry.EventType, instance *types.Instance) { + return + }).Reset() + pm.HandleInstanceEvent(registry.SubEventTypeDelete, testFunc) + convey.So(pm.instancePool, convey.ShouldNotContainKey, testFunc.InstanceID) + }) + + convey.Convey("labels is not exist", t, func() { + testFunc := &commonTypes.InstanceSpecification{ + Function: "123456/0-system-testFunc/$latest", + InstanceID: "testFunc123", + Labels: []string{"labels1", "labels2"}, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "instance1": mockInstancePool, + }, + } + pm.HandleInstanceEvent(registry.SubEventTypeDelete, testFunc) + convey.So(pm.instancePool, convey.ShouldNotContainKey, testFunc.InstanceID) + }) + + convey.Convey("lease not exist", t, func() { + mockInstancePool := &GenericInstancePool{ + FuncSpec: &types.FunctionSpecification{ + FuncKey: "funcKey", + }, + } + testFunc := &commonTypes.InstanceSpecification{ + Function: "123456/0-system-testFunc/$latest", + InstanceID: "testFunc123", + Labels: []string{"labels1", "instance1"}, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "instance1": mockInstancePool, + }, + } + pm.HandleInstanceEvent(registry.SubEventTypeDelete, testFunc) + convey.So(pm.instancePool, convey.ShouldNotContainKey, testFunc.InstanceID) + }) + + convey.Convey("delete success", t, func() { + mockInstancePool := &GenericInstancePool{ + FuncSpec: &types.FunctionSpecification{ + FuncKey: "funcKey", + }, + } + testFunc := &commonTypes.InstanceSpecification{ + Function: "123456/0-system-testFunc/$latest", + InstanceID: "testFunc123", + Labels: []string{"labels1", "instance1"}, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "instance1": mockInstancePool, + }, + } + pm.HandleInstanceEvent(registry.SubEventTypeDelete, testFunc) + convey.So(pm.instancePool, convey.ShouldNotContainKey, testFunc.InstanceID) + }) + + convey.Convey("success HandleInstanceEvent", t, func() { + insPool := &GenericInstancePool{ + FuncSpec: &types.FunctionSpecification{ + FuncKey: "123456/0-system-testFunc/$latest", + }, + waitInsConfigChan: make(chan struct{}), + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "123456/0-system-testFunc/$latest": insPool, + }, + } + close(insPool.waitInsConfigChan) + labels := map[string]string{podLabelInstanceType: string(types.InstanceTypeReserved)} + b, _ := json.Marshal(labels) + testInsSpec := &commonTypes.InstanceSpecification{ + Function: "123456/0-system-testFunc/$latest", + InstanceID: "testIns123", + Labels: []string{"labels1", "instance1"}, + CreateOptions: map[string]string{types.FunctionKeyNote: "123456/0-system-testFunc/$latest", commonconstant.DelegatePodLabels: string(b)}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(3)}, + } + var getIns *types.Instance + defer gomonkey.ApplyMethod(reflect.TypeOf(insPool), "HandleInstanceEvent", + func(gi *GenericInstancePool, eventType registry.EventType, instance *types.Instance) { + getIns = instance + }).Reset() + pm.HandleInstanceEvent(registry.SubEventTypeUpdate, testInsSpec) + convey.So(getIns.InstanceID, convey.ShouldEqual, "testIns123") + }) +} + +func TestReportMetrics(t *testing.T) { + var mockInstancePool *GenericInstancePool + patches := gomonkey.ApplyMethod(reflect.TypeOf(mockInstancePool), "AcquireInstance", + func(_ *GenericInstancePool, insThdApp *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + mockInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{ + CPU: 1000, + Memory: 1000, + }, + InstanceType: "", + InstanceID: "", + FuncKey: "", + ConcurrentNum: 0, + } + instanceThread := &types.InstanceAllocation{ + Instance: mockInstance, + AllocationID: "mock-thread-id-123", + } + return instanceThread, nil + }) + patches.ApplyMethod(reflect.TypeOf(mockInstancePool), "ReleaseInstance", + func(_ *GenericInstancePool, instance *types.InstanceAllocation) { + return + }) + + defer patches.Reset() + initRegistry() + poolManager := NewPoolManager(make(chan struct{})) + insAcqReq := mockInsAcqReq() + poolManager.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncKey: "testFunction", + }) + poolManager.instanceConfigRecord["testFunction"] = map[string]*instanceconfig.Configuration{ + DefaultInstanceLabel: { + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{}, + }, + } + _, err := poolManager.AcquireInstanceThread(insAcqReq) + assert.Equal(t, err, nil) + funcKey := insAcqReq.FuncSpec.FuncKey + resKey := resspeckey.ResSpecKey{ + CPU: 10, + Memory: 500, + } + insMetrics := &types.InstanceThreadMetrics{ + InsThdID: "Instance-1", + ProcNumPS: 0, + ProcReqNum: 0, + AvgProcTime: 0, + MaxProcTime: 0, + } + + poolManager.ReportMetrics(funcKey, resKey, insMetrics) + + // The instance pool does not exist + delete(poolManager.instancePool, funcKey) + poolManager.ReportMetrics(funcKey, resKey, insMetrics) +} + +func Test_RecoverInstancePool(t *testing.T) { + initRegistry() + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&GenericInstancePool{}), "RecoverInstance", + func(_ *GenericInstancePool, funcSpec *types.FunctionSpecification, + instancePoolState *types.InstancePoolState, deleteFunc bool, wg *sync.WaitGroup) { + wg.Done() + return + }), + gomonkey.ApplyMethod(reflect.TypeOf(®istry.InstanceRegistry{}), "WaitForETCDList", + func(ir *registry.InstanceRegistry) { + return + }), + gomonkey.ApplyMethod(reflect.TypeOf(&commonstate.Queue{}), "SaveState", + func(q *commonstate.Queue, state []byte, key string) error { + return nil + })} + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + schedulerState := state.GetState() + schedulerState.InstancePool["functionKey1"] = &types.InstancePoolState{ + StateInstance: map[string]*types.Instance{}, + } + schedulerState.InstancePool["functionKey2"] = &types.InstancePoolState{ + StateInstance: map[string]*types.Instance{}, + } + schedulerState.InstancePool["functionKey1"].StateInstance["InstanceID1"] = &types.Instance{FuncSig: "11111"} + schedulerState.InstancePool["functionKey1"].StateInstance["InstanceID2"] = &types.Instance{FuncSig: "11111"} + + schedulerState.InstancePool["functionKey2"].StateInstance["InstanceID21"] = &types.Instance{FuncSig: "22222"} + schedulerState.InstancePool["functionKey2"].StateInstance["InstanceID22"] = &types.Instance{FuncSig: "22222"} + + poolManager := NewPoolManager(make(chan struct{})) + convey.Convey("RecoverInstancePool success", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(®istry.Registry{}), "GetFuncSpec", + func(_ *registry.Registry, funcKey string) *types.FunctionSpecification { + return &types.FunctionSpecification{ + FuncKey: funcKey, + } + }).Reset() + poolManager.RecoverInstancePool() + _, exist := poolManager.instancePool["functionKey1"] + convey.So(exist, convey.ShouldEqual, true) + _, exist = poolManager.instancePool["functionKey2"] + convey.So(exist, convey.ShouldEqual, true) + }) + convey.Convey("failed to GetFuncSpec", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(®istry.Registry{}), "GetFuncSpec", + func(_ *registry.Registry, funcKey string) *types.FunctionSpecification { + if funcKey == "functionKey1" { + return nil + } + return &types.FunctionSpecification{ + FuncKey: funcKey, + } + }).Reset() + poolManager.RecoverInstancePool() + _, exist := poolManager.instancePool["functionKey1"] + convey.So(exist, convey.ShouldEqual, false) + _, exist = poolManager.instancePool["functionKey2"] + convey.So(exist, convey.ShouldEqual, true) + }) +} + +func TestHandleInstanceConfigEvent(t *testing.T) { + var mockInstancePool *GenericInstancePool = &GenericInstancePool{} + + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "instance1": mockInstancePool, + }, + instanceConfigRecord: map[string]map[string]*instanceconfig.Configuration{}, + } + insConfig := &instanceconfig.Configuration{ + FuncKey: "testFunc", + } + convey.Convey("instance pool not exist", t, func() { + pm.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, insConfig) + }) + pm.HandleFunctionEvent(registry.SubEventTypeUpdate, &types.FunctionSpecification{ + FuncKey: "testFunc", + }) + convey.Convey("instance pool exist", t, func() { + pm.HandleInstanceConfigEvent(registry.SubEventTypeUpdate, insConfig) + }) + convey.Convey("instance pool exist", t, func() { + pm.HandleInstanceConfigEvent(registry.SubEventTypeDelete, insConfig) + }) +} + +func TestHandleAliasEvent(t *testing.T) { + convey.Convey("Test HandleAliasEvent", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&GenericInstancePool{}), "HandleAliasEvent", + func(_ *GenericInstancePool, eventType registry.EventType, aliasUrn string) { + }).Reset() + mockInstancePool := &GenericInstancePool{ + reservedInstanceQueue: map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{resspeckey.ResSpecKey{}: {}}, + FuncSpec: &types.FunctionSpecification{FuncKey: "TenantID/FunctionName/Version"}, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "TenantID/FunctionName/Version": mockInstancePool, + }, + instanceConfigRecord: map[string]map[string]*instanceconfig.Configuration{}, + } + pm.HandleAliasEvent(registry.SubEventTypeUpdate, "sn:cn:yrk:TenantID:function:FunctionName:Version") + }) +} + +func TestGetInstanceType(t *testing.T) { + convey.Convey("Test getInstanceType", t, func() { + convey.Convey("DelegatePodLabels is not exist", func() { + createOptions := map[string]string{} + instanceType := getInstanceType(createOptions) + convey.So(instanceType, convey.ShouldEqual, types.InstanceTypeUnknown) + }) + convey.Convey("unmarshal labels error", func() { + createOptions := map[string]string{commonconstant.DelegatePodLabels: ""} + instanceType := getInstanceType(createOptions) + convey.So(instanceType, convey.ShouldEqual, types.InstanceTypeUnknown) + }) + convey.Convey("podLabels is nil", func() { + podLabels := map[string]string{} + data, _ := json.Marshal(podLabels) + createOptions := map[string]string{commonconstant.DelegatePodLabels: string(data)} + instanceType := getInstanceType(createOptions) + convey.So(instanceType, convey.ShouldEqual, types.InstanceTypeUnknown) + }) + convey.Convey("podLabelInstanceType is not exist", func() { + podLabels := map[string]string{"a": "b"} + data, _ := json.Marshal(podLabels) + createOptions := map[string]string{commonconstant.DelegatePodLabels: string(data)} + instanceType := getInstanceType(createOptions) + convey.So(instanceType, convey.ShouldEqual, types.InstanceTypeUnknown) + }) + convey.Convey("get instanceType success", func() { + podLabels := map[string]string{podLabelInstanceType: "reserved"} + data, _ := json.Marshal(podLabels) + createOptions := map[string]string{commonconstant.DelegatePodLabels: string(data)} + instanceType := getInstanceType(createOptions) + convey.So(instanceType, convey.ShouldEqual, types.InstanceTypeUnknown) + }) + }) +} + +func TestCheckMinInsAndReport(t *testing.T) { + convey.Convey("Test CheckMinInsAndReport", t, func() { + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{} + stopCh := make(chan struct{}) + pm := &PoolManager{} + pm.CheckMinInsAndReport(stopCh) + time.Sleep(10 * time.Millisecond) + close(stopCh) + config.GlobalConfig = rawGConfig + convey.So(pm, convey.ShouldNotBeNil) + }) +} + +func TestJudgeAndReport(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(&instancequeue.ScaledInstanceQueue{}), "GetInstanceNumber", + func(_ *instancequeue.ScaledInstanceQueue) int { + return 0 + }).Reset() + convey.Convey("Test JudgeAndReport", t, func() { + convey.Convey("report success", func() { + mockInstancePool := &GenericInstancePool{ + insConfig: map[resspeckey.ResSpecKey]*instanceconfig.Configuration{ + resspeckey.ResSpecKey{}: { + InstanceMetaData: commonTypes.InstanceMetaData{ + MinInstance: 1, + }, + }, + }, + reservedInstanceQueue: map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{ + resspeckey.ResSpecKey{}: &instancequeue.ScaledInstanceQueue{}, + }, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "testFunc", + }, + minScaleAlarmSign: map[string]bool{}, + pendingInstanceNum: map[string]int{}, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "TenantID/FunctionName/Version": mockInstancePool, + }, + } + pm.judgeAndReport("clusterID") + convey.So(mockInstancePool.minScaleAlarmSign[DefaultInstanceLabel], convey.ShouldBeTrue) + }) + convey.Convey("minInstance is zero", func() { + mockInstancePool := &GenericInstancePool{ + insConfig: map[resspeckey.ResSpecKey]*instanceconfig.Configuration{ + resspeckey.ResSpecKey{}: { + InstanceMetaData: commonTypes.InstanceMetaData{}, + }, + }, + minScaleUpdatedTime: time.Now(), + reservedInstanceQueue: map[resspeckey.ResSpecKey]*instancequeue.ScaledInstanceQueue{resspeckey.ResSpecKey{}: {}}, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "testFunc", + }, + minScaleAlarmSign: map[string]bool{}, + pendingInstanceNum: map[string]int{}, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "TenantID/FunctionName/Version": mockInstancePool, + }, + } + pm.judgeAndReport("clusterID") + convey.So(mockInstancePool.minScaleAlarmSign[DefaultInstanceLabel], convey.ShouldBeFalse) + }) + convey.Convey("reservedInstanceQueue is nill", func() { + mockInstancePool := &GenericInstancePool{ + insConfig: map[resspeckey.ResSpecKey]*instanceconfig.Configuration{ + resspeckey.ResSpecKey{}: { + InstanceMetaData: commonTypes.InstanceMetaData{}, + }, + }, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "testFunc", + }, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "TenantID/FunctionName/Version": mockInstancePool, + }, + instanceConfigRecord: map[string]map[string]*instanceconfig.Configuration{}, + } + pm.judgeAndReport("clusterID") + convey.So(mockInstancePool.minScaleAlarmSign[DefaultInstanceLabel], convey.ShouldBeFalse) + }) + }) +} + +func TestGetAndDeleteState(t *testing.T) { + convey.Convey("Test GetAndDeleteState", t, func() { + convey.Convey("create instance pool error", func() { + patch := gomonkey.ApplyFunc(NewGenericInstancePool, + func(funcSpec *types.FunctionSpecification, + faasManagerInfo faasManagerInfo) (InstancePool, error) { + return &GenericInstancePool{}, errors.New("create instance pool error") + }) + defer patch.Reset() + mockInstancePool := &GenericInstancePool{} + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "TenantID/FunctionName/Version": mockInstancePool, + }, + } + res := pm.GetAndDeleteState("stateID", "testFunc", + &types.FunctionSpecification{}, log.GetLogger()) + convey.So(res, convey.ShouldBeFalse) + }) + convey.Convey("get state id success", func() { + patch := gomonkey.ApplyFunc(NewGenericInstancePool, + func(funcSpec *types.FunctionSpecification, + faasManagerInfo faasManagerInfo) (InstancePool, error) { + return &GenericInstancePool{}, nil + }) + defer patch.Reset() + patch1 := gomonkey.ApplyMethod(reflect.TypeOf(&StateRoute{}), + "GetAndDeleteState", func(s *StateRoute, stateID string) bool { + return true + }) + defer patch1.Reset() + mockInstancePool := &GenericInstancePool{} + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "TenantID/FunctionName/Version": mockInstancePool, + }, + } + res := pm.GetAndDeleteState("stateID", "testFunc", + &types.FunctionSpecification{}, log.GetLogger()) + convey.So(res, convey.ShouldBeTrue) + }) + }) +} + +func TestReleaseStateThread(t *testing.T) { + convey.Convey("Test ReleaseStateThread", t, func() { + convey.Convey("leaser is nil", func() { + pm := &PoolManager{ + stateLeaseManager: make(map[string]*stateinstance.Leaser), + } + thread := &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "testID", + }, + } + err := pm.ReleaseStateThread(thread) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("thread id invalid", func() { + pm := &PoolManager{ + stateLeaseManager: map[string]*stateinstance.Leaser{"testID": {}}, + } + thread := &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "testID", + }, + AllocationID: "InsAlloc", + } + err := pm.ReleaseStateThread(thread) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("thread id error", func() { + pm := &PoolManager{ + stateLeaseManager: map[string]*stateinstance.Leaser{"testID": {}}, + } + thread := &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "testID", + }, + AllocationID: "InsAlloc-stateThreada", + } + err := pm.ReleaseStateThread(thread) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("release state thread success", func() { + pm := &PoolManager{ + stateLeaseManager: map[string]*stateinstance.Leaser{ + "testID": stateinstance.NewLeaser(1, func(stateID string, instanceID string) {}, + "stateID", "instanceID", 1*time.Second), + }, + } + thread := &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "testID", + }, + AllocationID: "InsAlloc-stateThread1", + } + err := pm.ReleaseStateThread(thread) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestRetainStateThread(t *testing.T) { + convey.Convey("Test RetainStateThread", t, func() { + convey.Convey("leaser is nil", func() { + pm := &PoolManager{ + stateLeaseManager: make(map[string]*stateinstance.Leaser), + } + thread := &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "testID", + }, + } + err := pm.RetainStateThread(thread) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("thread id invalid", func() { + pm := &PoolManager{ + stateLeaseManager: map[string]*stateinstance.Leaser{"testID": {}}, + } + thread := &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "testID", + }, + AllocationID: "InsAlloc", + } + err := pm.RetainStateThread(thread) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("thread id error", func() { + pm := &PoolManager{ + stateLeaseManager: map[string]*stateinstance.Leaser{"testID": {}}, + } + thread := &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "testID", + }, + AllocationID: "InsAlloc-stateThreada", + } + err := pm.RetainStateThread(thread) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("lease not found", func() { + pm := &PoolManager{ + stateLeaseManager: map[string]*stateinstance.Leaser{ + "testID": stateinstance.NewLeaser(1, func(stateID string, instanceID string) {}, + "stateID", "instanceID", 1*time.Second), + }, + } + thread := &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "testID", + }, + AllocationID: "InsAlloc-stateThread1", + } + err := pm.RetainStateThread(thread) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestRecoverStateLeaser(t *testing.T) { + convey.Convey("Test RetainStateThread", t, func() { + convey.Convey("pool is nil", func() { + mockInstancePool := &GenericInstancePool{} + stateInstance := map[string]*types.Instance{ + "state1": { + InstanceStatus: commonTypes.InstanceStatus{Code: int32(-1)}, + }, + "state2": { + InstanceStatus: commonTypes.InstanceStatus{Code: int32(0)}, + }, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "testFunc": mockInstancePool, + }, + } + funcSpec := &types.FunctionSpecification{ + FuncKey: "testFunc1", + } + pm.recoverStateLeaser(stateInstance, funcSpec) + convey.So(len(pm.stateLeaseManager), convey.ShouldEqual, 0) + }) + convey.Convey("recover success", func() { + mockInstancePool := &GenericInstancePool{} + stateInstance := map[string]*types.Instance{ + "state1": { + InstanceStatus: commonTypes.InstanceStatus{Code: int32(-1)}, + InstanceID: "id1", + }, + "state2": { + InstanceStatus: commonTypes.InstanceStatus{Code: int32(0)}, + InstanceID: "id2", + }, + } + pm := &PoolManager{ + instancePool: map[string]InstancePool{ + "testFunc": mockInstancePool, + }, + stateLeaseManager: make(map[string]*stateinstance.Leaser), + } + funcSpec := &types.FunctionSpecification{ + FuncKey: "testFunc", + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 1, + }, + } + pm.recoverStateLeaser(stateInstance, funcSpec) + convey.So(len(pm.stateLeaseManager), convey.ShouldNotEqual, 0) + }) + }) +} + +func TestHandleSchedulerManaged(t *testing.T) { + convey.Convey("HandleSchedulerManaged", t, func() { + pm := &PoolManager{ + stateLeaseManager: map[string]*stateinstance.Leaser{ + "testID": stateinstance.NewLeaser(1, func(stateID string, instanceID string) {}, + "stateID", "instanceID", 1*time.Second), + }, + instancePool: map[string]InstancePool{}, + } + pm.HandleSchedulerManaged(registry.SubEventTypeUpdate, &commonTypes.InstanceSpecification{}) + }) +} + +func TestHandleRolloutRatioChange(t *testing.T) { + var patches []*gomonkey.Patches + expectRatio := 0 + patches = append(patches, gomonkey.ApplyFunc( + (*PoolManager).HandleRolloutRatioChange, + func(pm *PoolManager, ratio int) { + expectRatio = ratio + }, + )) + defer func() { + for _, p := range patches { + p.Reset() + } + }() + + var mockInstancePool = &GenericInstancePool{} + pm := &PoolManager{ + stateLeaseManager: map[string]*stateinstance.Leaser{ + "testID": stateinstance.NewLeaser(1, func(stateID string, instanceID string) {}, + "stateID", "instanceID", 1*time.Second), + }, + instancePool: map[string]InstancePool{ + "instance1": mockInstancePool, + }, + } + pm.HandleRolloutRatioChange(50) + assert.Equal(t, expectRatio, 50) +} + +func TestPoolManagerCreateInstance(t *testing.T) { + tests := []struct { + name string + setupPool bool + createReq *types.InstanceCreateRequest + mockInstance *types.Instance + mockError snerror.SNError + expectError bool + }{ + { + name: "pool_exists_success", + setupPool: true, + createReq: &types.InstanceCreateRequest{ + FuncSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + }, + }, + mockInstance: &types.Instance{ + InstanceID: "test-instance", + }, + mockError: nil, + expectError: false, + }, + { + name: "pool_not_exist", + setupPool: false, + createReq: &types.InstanceCreateRequest{ + FuncSpec: &types.FunctionSpecification{ + FuncKey: "test-func", + }, + }, + mockInstance: nil, + mockError: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm := &PoolManager{ + instancePool: make(map[string]InstancePool), + } + + if tt.setupPool { + pool := &GenericInstancePool{} + pm.instancePool[tt.createReq.FuncSpec.FuncKey] = pool + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(pool), "CreateInstance", + func(_ *GenericInstancePool, req *types.InstanceCreateRequest) (*types.Instance, snerror.SNError) { + assert.Equal(t, tt.createReq, req) + return tt.mockInstance, tt.mockError + }) + } + + _, err := pm.CreateInstance(tt.createReq) + + if tt.expectError { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + }) + } +} + +func TestPoolManagerDeleteInstance(t *testing.T) { + tests := []struct { + name string + setupPool bool + instance *types.Instance + mockError snerror.SNError + expectError bool + }{ + { + name: "Pool does not exist", + setupPool: false, + instance: &types.Instance{FuncKey: "test-func-key"}, + expectError: true, + }, + { + name: "Pool exists, delete succeeds", + setupPool: true, + instance: &types.Instance{FuncKey: "test-func-key"}, + mockError: nil, + expectError: false, + }, + { + name: "Pool exists, delete fails", + setupPool: true, + instance: &types.Instance{FuncKey: "test-func-key"}, + mockError: snerror.New(statuscode.InternalErrorCode, "delete error"), + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm := &PoolManager{ + instancePool: make(map[string]InstancePool), + } + + if tt.setupPool { + pool := &GenericInstancePool{} + pm.instancePool[tt.instance.FuncKey] = pool + + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyMethod(reflect.TypeOf(pool), "DeleteInstance", + func(_ *GenericInstancePool, instance *types.Instance) snerror.SNError { + assert.Equal(t, tt.instance, instance) + return tt.mockError + }) + } + + err := pm.DeleteInstance(tt.instance) + + if tt.expectError { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + }) + } +} diff --git a/yuanrong/pkg/functionscaler/instancepool/rasp_sidecar.go b/yuanrong/pkg/functionscaler/instancepool/rasp_sidecar.go new file mode 100644 index 0000000..bee7384 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/rasp_sidecar.go @@ -0,0 +1,162 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancepool - +package instancepool + +import ( + "fmt" + + "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + + "yuanrong/pkg/functionscaler/types" +) + +const ( + containerRaspInit container = "rasp-init" + containerRasp container = "rasp" + raspDataPath = "/opt/data" +) + +func makeRaspContainer(funcSpec *types.FunctionSpecification) types.DelegateContainerSideCarConfig { + raspEnv := []v1.EnvVar{ + { + Name: "RASP_SERVER_IP2", + Value: funcSpec.ExtendedMetaData.RaspConfig.RaspServerIP, + }, { + Name: "RASP_SERVER_PORT2", + Value: funcSpec.ExtendedMetaData.RaspConfig.RaspServerPort, + }, { + Name: "RASP_CONTAINER_DEPLOYTYPE", + Value: "nonnuwaruntime", + }, { + Name: "RUNTIME_HOST_IP", ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "status.hostIP", + }}, + }, { + Name: podNameEnvNew, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "metadata.name", + }}, + }, + {Name: podIPEnv, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "status.podIP", + }}, + }, + } + + for _, env := range funcSpec.ExtendedMetaData.RaspConfig.Envs { + raspEnv = append(raspEnv, v1.EnvVar{ + Name: env.Name, + Value: env.Value, + }) + } + + raspResource := v1.ResourceRequirements{} + raspResource.Requests = map[v1.ResourceName]resource.Quantity{} + raspResource.Requests["cpu"] = resource.MustParse(fmt.Sprintf("%dm", RaspDefaultCPU)) + raspResource.Requests["memory"] = resource.MustParse(fmt.Sprintf("%dMi", RaspDefaultMemory)) + raspResource.Limits = map[v1.ResourceName]resource.Quantity{} + raspResource.Limits = map[v1.ResourceName]resource.Quantity{} + raspResource.Limits["cpu"] = resource.MustParse(fmt.Sprintf("%dm", RaspDefaultCPU)) + raspResource.Limits["memory"] = resource.MustParse(fmt.Sprintf("%dMi", RaspDefaultMemory)) + + raspMount := []v1.VolumeMount{ + { + Name: DefaultDataVolumeName, MountPath: raspDataPath, + }, + { + Name: aiOpsVolume, MountPath: raspLogVolumeMountPath, SubPathExpr: raspLogVolumeMountSubPathExpr, + }, + } + + ReadinessProbe := v1.Probe{ + ProbeHandler: v1.ProbeHandler{ + Exec: &v1.ExecAction{Command: []string{"sh", "/opt/monitor/ready.sh"}}, + }, + InitialDelaySeconds: RaspDefaultInitialDelaySeconds, + TimeoutSeconds: RaspDefaultTimeoutSeconds, + PeriodSeconds: RaspDefaultPeriodSeconds, + SuccessThreshold: RaspDefaultSuccessThreshold, + FailureThreshold: RaspDefaultFailureThreshold, + } + LivenessProbe := v1.Probe{ + ProbeHandler: v1.ProbeHandler{ + Exec: &v1.ExecAction{Command: []string{"sh", "/opt/monitor/health.sh"}}, + }, + InitialDelaySeconds: RaspDefaultInitialDelaySeconds, + TimeoutSeconds: RaspDefaultTimeoutSeconds, + PeriodSeconds: RaspDefaultPeriodSeconds, + SuccessThreshold: RaspDefaultSuccessThreshold, + FailureThreshold: RaspDefaultFailureThreshold, + } + return types.DelegateContainerSideCarConfig{ + Name: string(containerRasp), + Image: funcSpec.ExtendedMetaData.RaspConfig.RaspImage, + Env: raspEnv, + ResourceRequirements: raspResource, + VolumeMounts: raspMount, + LivenessProbe: LivenessProbe, + ReadinessProbe: ReadinessProbe, + } +} + +func makeRaspInitContainer(funcSpec *types.FunctionSpecification) types.DelegateInitContainerConfig { + raspInitMount := []v1.VolumeMount{ + { + Name: DefaultDataVolumeName, MountPath: raspDataPath, + }, + { + Name: aiOpsVolume, MountPath: raspLogVolumeMountPath, SubPathExpr: raspLogVolumeMountSubPathExpr, + }, + } + + raspInitEnv := []v1.EnvVar{ + { + Name: podNameEnvNew, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "metadata.name", + }}, + }, + {Name: podIPEnv, ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + APIVersion: "v1", FieldPath: "status.podIP", + }}, + }, + } + + raspInitResource := v1.ResourceRequirements{} + raspInitResource.Requests = map[v1.ResourceName]resource.Quantity{} + raspInitResource.Requests["cpu"] = resource.MustParse(fmt.Sprintf("%dm", RaspInitDefaultCPU)) + raspInitResource.Requests["memory"] = resource.MustParse(fmt.Sprintf("%dMi", RaspInitDefaultMemory)) + raspInitResource.Limits = map[v1.ResourceName]resource.Quantity{} + raspInitResource.Limits = map[v1.ResourceName]resource.Quantity{} + raspInitResource.Limits["cpu"] = resource.MustParse(fmt.Sprintf("%dm", RaspInitDefaultCPU)) + raspInitResource.Limits["memory"] = resource.MustParse(fmt.Sprintf("%dMi", RaspInitDefaultMemory)) + + return types.DelegateInitContainerConfig{ + Name: string(containerRaspInit), + Image: funcSpec.ExtendedMetaData.RaspConfig.InitImage, + Command: []string{"sh"}, + Env: raspInitEnv, + Args: []string{"-c", "/opt/huawei/secRASP/slaveagent_entrypoint.sh"}, + VolumeMounts: raspInitMount, + ResourceRequirements: raspInitResource, + } +} diff --git a/yuanrong/pkg/functionscaler/instancepool/stateroute.go b/yuanrong/pkg/functionscaler/instancepool/stateroute.go new file mode 100644 index 0000000..d2f158a --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/stateroute.go @@ -0,0 +1,382 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package instancepool + +import ( + "fmt" + "strings" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/stateinstance" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +// StateInstance - +type StateInstance struct { + StateID string + status int // -1 init, 0 ok, 1 instance in sub health, 2 instance exit + instance *types.Instance +} + +const ( + // InstanceExit - + InstanceExit = 2 + // InstanceAbnormal - + InstanceAbnormal = 1 + // InstanceOk - + InstanceOk = 0 +) + +// StateRoute - +type StateRoute struct { + funcSpec *types.FunctionSpecification + stateRoute map[string]*StateInstance + stateLeaseManager map[string]*stateinstance.Leaser + stateConfig commonTypes.StateConfig + resSpec *resspeckey.ResourceSpecification + deleteInstanceFunc func(instance *types.Instance) error + createInstanceFunc func(resSpec *resspeckey.ResourceSpecification, instanceType types.InstanceType, + callerPodName string) (instance *types.Instance, err error) + leaseInterval time.Duration + logger api.FormatLogger + sync.RWMutex + stateLocks sync.Map // key: stateID val: *sync.RWMutex +} + +// Destroy - +func (sr *StateRoute) Destroy() { + sr.Lock() + defer sr.Unlock() + instances := make([]string, 0) + for _, v := range sr.stateRoute { + + if v.status != InstanceExit { + instances = append(instances, fmt.Sprintf("%s:%d", v.instance.InstanceID, v.status)) + go func() { + err := sr.deleteInstanceFunc(v.instance) + if err != nil { + sr.logger.Infof("delete instance:%s failed: %s", v.instance.InstanceID, err.Error()) + } + }() + } + } + sr.stateRoute = make(map[string]*StateInstance) + sr.logger.Infof("destroy, instances: %s", strings.Join(instances, ",")) +} + +// HandleInstanceUpdate - handle etcd instance event +func (sr *StateRoute) HandleInstanceUpdate(instance *types.Instance) { + if instance == nil || instance.InstanceType != types.InstanceTypeState { + return + } + sr.Lock() + + // 通常认为实例已经添加进来了,仅考虑实例更新场景 + for stateID, stateInstance := range sr.stateRoute { + if stateInstance.instance != nil && stateInstance.instance.InstanceID == instance.InstanceID { + stateInstance.instance = instance + oldStatus := stateInstance.status + switch instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusSubHealth): + stateInstance.status = InstanceAbnormal + case int32(constant.KernelInstanceStatusRunning): + stateInstance.status = InstanceOk + default: + stateInstance.status = InstanceAbnormal + } + sr.Unlock() + sr.processStateRoute(stateInstance, stateID, stateUpdate) + sr.logger.Infof("statusInstance:%s status %d->%d", instance.InstanceID, oldStatus, stateInstance.status) + return + } + } + sr.Unlock() + sr.logger.Infof("not found stateInstance: %s, instanceStatus: %d", instance.InstanceID, + instance.InstanceStatus.Code) +} + +// GetAndDeleteState delete state and instance, return whether the state exists +func (sr *StateRoute) GetAndDeleteState(stateID string) bool { + value, _ := sr.stateLocks.LoadOrStore(stateID, &sync.RWMutex{}) + stateLock, ok := value.(*sync.RWMutex) + if !ok { + sr.logger.Errorf("get statelock err, value is %v", value) + return false + } + stateLock.Lock() + stateInstance, exist := sr.stateRoute[stateID] + sr.logger.Infof("stateRoute of %s exist is %v, stateInstance is %v", stateID, exist, stateInstance) + if exist { + sr.processStateRoute(stateInstance, stateID, stateDelete) + } + stateLock.Unlock() + + if stateInstance != nil && stateInstance.status != InstanceExit { + err := sr.deleteInstanceFunc(stateInstance.instance) + if err != nil { + sr.logger.Errorf("failed to delete stateInstance %s error %s, instance status: %d", + stateInstance.instance.InstanceID, err.Error(), stateInstance.status) + } + } + return exist +} + +// DeleteStateInstance called by ReleaseLease +func (sr *StateRoute) DeleteStateInstance(stateID string, instanceID string) { + value, _ := sr.stateLocks.LoadOrStore(stateID, &sync.RWMutex{}) + stateLock, ok := value.(*sync.RWMutex) + if !ok { + sr.logger.Errorf("get statelock err, value is %v", value) + return + } + stateLock.Lock() + stateInstance, exist := sr.getStateInstance(stateID) + if exist { + if stateInstance.instance.InstanceID == instanceID { + sr.processStateRoute(stateInstance, stateID, stateInstanceDelete) + } else { + sr.logger.Warnf("stateInstance is not matched %s:%s, stateID: %s", + stateInstance.instance.InstanceID, instanceID, stateInstance.StateID) + } + + } + stateLock.Unlock() + if stateInstance != nil { + if stateInstance.instance.InstanceID != instanceID { + sr.logger.Warnf("instanceID in lease manager and in state route are different: %s, %s, stateID: %s", + instanceID, stateInstance.instance.InstanceID, stateID) + } else { + err := sr.deleteInstanceFunc(stateInstance.instance) + if err != nil { + sr.logger.Errorf("failed to delete stateInstance %s error %s", stateInstance.instance.InstanceID, + err.Error()) + } else { + sr.logger.Infof("DeleteStateInstance state %s, instance %s over", stateID, instanceID) + } + } + } +} +func (sr *StateRoute) getStateInstance(stateID string) (*StateInstance, bool) { + sr.RLock() + defer sr.RUnlock() + stateInstance, exist := sr.stateRoute[stateID] + return stateInstance, exist +} + +func (sr *StateRoute) setStateInstance(stateID string, stateInstance *StateInstance) { + sr.Lock() + defer sr.Unlock() + sr.stateRoute[stateID] = stateInstance +} + +func (sr *StateRoute) deleteStateInstance(stateID string) { + sr.Lock() + defer sr.Unlock() + delete(sr.stateRoute, stateID) +} + +func (sr *StateRoute) getStateIDByInstanceID(instanceID string) (string, bool) { + sr.RLock() + defer sr.RUnlock() + for stateID, stateInstance := range sr.stateRoute { + if stateInstance.instance != nil && stateInstance.instance.InstanceID == instanceID { + return stateID, true + } + } + return "", false +} + +// DeleteStateInstanceByInstanceID delete state route data by instance when instance is deleted +func (sr *StateRoute) DeleteStateInstanceByInstanceID(instanceID string) { + stateID, exist := sr.getStateIDByInstanceID(instanceID) + if !exist { + sr.logger.Warnf("delete state route failed because instance %s is not in stateroute!", instanceID) + return + } + stateInstance, exist := sr.getStateInstance(stateID) + if !exist { + sr.logger.Warnf("delete state route failed because stateinstance %s is not in stateroute!, stateID: %s", + instanceID, stateID) + return + } + sr.processStateRoute(stateInstance, stateID, stateInstanceDelete) +} + +func (sr *StateRoute) processStateRoute(stateInstance *StateInstance, stateID string, opType string) { + instancePoolStateInput := &types.InstancePoolStateInput{ + StateID: stateID, + } + if stateInstance.instance != nil { + instancePoolStateInput.FuncKey = stateInstance.instance.FuncKey + instancePoolStateInput.FuncSig = stateInstance.instance.FuncSig + instancePoolStateInput.InstanceType = stateInstance.instance.InstanceType + instancePoolStateInput.InstanceID = stateInstance.instance.InstanceID + instancePoolStateInput.InstanceStatusCode = stateInstance.instance.InstanceStatus.Code + + } + stateType := types.StateUpdate + + if opType == stateUpdate { + sr.setStateInstance(stateID, stateInstance) + stateType = types.StateUpdate + } else if opType == stateDelete { + sr.deleteStateInstance(stateID) + sr.stateLocks.Delete(stateID) + stateType = types.StateDelete + } else if opType == stateInstanceDelete { + if sr.stateConfig.LifeCycle == types.InstanceLifeCycleConsistentWithState { + stateType = types.StateUpdate + newInstance := *stateInstance + newInstance.status = InstanceExit + newInstance.instance = nil + sr.setStateInstance(stateID, &newInstance) + instancePoolStateInput.InstanceStatusCode = int32(constant.KernelInstanceStatusExited) + stateType = types.StateUpdate + } else { + sr.deleteStateInstance(stateID) + stateType = types.StateDelete + } + } else { + sr.logger.Warnf("unknown stateType: %s", opType) + } + + sr.logger.Infof("state update: %+v, type %v", instancePoolStateInput, stateType) + state.Update(instancePoolStateInput, stateType) +} + +func (sr *StateRoute) recover(instanceMap map[string]*types.Instance) { + for stateID, instance := range instanceMap { + statusCode := InstanceOk + switch instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusExited): + statusCode = InstanceExit + case int32(constant.KernelInstanceStatusSubHealth): + statusCode = InstanceAbnormal + case int32(constant.KernelInstanceStatusRunning): + statusCode = InstanceOk + default: + statusCode = InstanceAbnormal + } + sr.setStateInstance(stateID, &StateInstance{ + StateID: stateID, + status: statusCode, + instance: instance, + }) + sr.logger.Infof("recover stateID:%s statusCode: %d, instanceID: %s, instanceStatusCode: %d", + stateID, statusCode, instance.InstanceID, instance.InstanceStatus.Code) + } + sr.logger.Infof("recover stateRoute over") +} + +func (sr *StateRoute) acquireStateInstanceThread(insThdApp *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, snerror.SNError) { + logger := sr.logger.With(zap.Any("stateKey", insThdApp.StateID)) + value, _ := sr.stateLocks.LoadOrStore(insThdApp.StateID, &sync.RWMutex{}) + stateLock, ok := value.(*sync.RWMutex) + if !ok { + sr.logger.Errorf("get statelock err, value is %v", value) + return nil, snerror.New(statuscode.FaaSSchedulerInternalErrCode, statuscode.FaaSSchedulerInternalErrMsg) + } + stateLock.Lock() + defer stateLock.Unlock() + stateInstance, exist := sr.getStateInstance(insThdApp.StateID) + + if exist { + if stateInstance != nil && stateInstance.status == InstanceOk { + logger.Infof("state stateInstance existed in stateRoute, instanceID: %s", stateInstance.instance.InstanceID) + lease, err := sr.generateStateLease(stateInstance) + if err != nil { + return nil, err + } + return &types.InstanceAllocation{ + Instance: stateInstance.instance, + AllocationID: fmt.Sprintf("%s-stateThread%d", stateInstance.instance.InstanceID, lease.ID), + }, nil + } + if stateInstance != nil && stateInstance.status == InstanceAbnormal { + logger.Infof("state stateInstance existed in stateRoute, but abnormal, instanceID: %s", + stateInstance.instance.InstanceID) + return nil, snerror.New(statuscode.InstanceStatusAbnormalCode, statuscode.InstanceStatusAbnormalMsg) + } + // The slave function stateInstance is destroyed, no longer repulsed, and the state should be deleted + if sr.stateConfig.LifeCycle == types.InstanceLifeCycleConsistentWithState { + logger.Infof("state stateInstance is destroyed and return 4028") + return nil, snerror.New(statuscode.StateInstanceNotExistedErrCode, statuscode.StateInstanceNotExistedErrMsg) + } + } + + var resSpec *resspeckey.ResourceSpecification + if utils.IsResSpecEmpty(insThdApp.ResSpec) { + resSpec = sr.resSpec + } else { + resSpec = insThdApp.ResSpec + } + + instance, err := sr.createInstanceFunc(resSpec, types.InstanceTypeState, insThdApp.CallerPodName) + if err == nil { + stateInstance = &StateInstance{ + StateID: insThdApp.StateID, + status: InstanceOk, + instance: instance, + } + logger.Infof("new statInstance, instanceId: %s", instance.InstanceID) + sr.processStateRoute(stateInstance, insThdApp.StateID, stateUpdate) + lease, err := sr.generateStateLease(stateInstance) + if err != nil { + return nil, err + } + return &types.InstanceAllocation{ + Instance: instance, + AllocationID: fmt.Sprintf("%s-stateThread%d", stateInstance.instance.InstanceID, lease.ID), + }, nil + } + return nil, snerror.New(statuscode.NoInstanceAvailableErrCode, err.Error()) +} + +func (sr *StateRoute) generateStateLease(stateInstance *StateInstance) (*stateinstance.Lease, snerror.SNError) { + sr.Lock() + leaser := sr.stateLeaseManager[stateInstance.instance.InstanceID] + if leaser == nil { + leaser = stateinstance.NewLeaser(sr.funcSpec.InstanceMetaData.ConcurrentNum, + sr.DeleteStateInstance, stateInstance.StateID, stateInstance.instance.InstanceID, getScaleDownWindow()) + sr.stateLeaseManager[stateInstance.instance.InstanceID] = leaser + } + sr.Unlock() + lease, err := leaser.AcquireLease(sr.leaseInterval) + if err != nil { + log.GetLogger().Errorf("failed to generate state lease for instance %s error %s", + stateInstance.instance.InstanceID, err.Error()) + if snErr, ok := err.(snerror.SNError); ok { + return nil, snErr + } + return nil, snerror.New(statuscode.StatusInternalServerError, err.Error()) + } + return lease, nil +} diff --git a/yuanrong/pkg/functionscaler/instancepool/stateroute_test.go b/yuanrong/pkg/functionscaler/instancepool/stateroute_test.go new file mode 100644 index 0000000..12e009b --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancepool/stateroute_test.go @@ -0,0 +1,334 @@ +package instancepool + +import ( + "reflect" + "sync" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/stateinstance" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +func mockStateRoute() *StateRoute { + stateRoute := &StateRoute{ + stateRoute: make(map[string]*StateInstance), + stateConfig: commonTypes.StateConfig{ + LifeCycle: types.InstanceLifeCycleConsistentWithState, + }, + funcSpec: &types.FunctionSpecification{ + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 100, + }, + }, + resSpec: &resspeckey.ResourceSpecification{ + CPU: 100, + Memory: 100, + }, + deleteInstanceFunc: func(instance *types.Instance) error { return nil }, + createInstanceFunc: func(resSpec *resspeckey.ResourceSpecification, instanceType types.InstanceType, + callerPodName string) (*types.Instance, error) { + return nil, nil + }, + stateLeaseManager: make(map[string]*stateinstance.Leaser, utils.DefaultMapSize), + logger: log.NewConsoleLogger(), + RWMutex: sync.RWMutex{}, + stateLocks: sync.Map{}, + } + return stateRoute +} + +func TestNewStateRoute(t *testing.T) { + convey.Convey("testNewStateRoute", t, func() { + convey.So(mockStateRoute(), convey.ShouldNotBeNil) + }) +} + +func TestStateRoute_Destroy(t *testing.T) { + convey.Convey("testDestroy", t, func() { + stateRoute := mockStateRoute() + stateRoute.stateRoute["1"] = &StateInstance{ + StateID: "1", + status: 0, + instance: &types.Instance{}, + } + stateRoute.stateRoute["2"] = &StateInstance{ + StateID: "2", + status: 0, + instance: &types.Instance{}, + } + stateRoute.Destroy() + convey.So(len(stateRoute.stateRoute), convey.ShouldEqual, 0) + }) +} + +func TestStateRoute_HandleInstanceUpdate(t *testing.T) { + convey.Convey("TestStateRoute_HandleInstanceUpdate", t, func() { + stateRoute := mockStateRoute() + var targetStateInstance *StateInstance + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(stateRoute), "processStateRoute", + func(_ *StateRoute, instance *StateInstance, stateID string, opType string) { + targetStateInstance = instance + }) + defer patch.Reset() + testInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{}, + InstanceType: types.InstanceTypeReserved, + InstanceID: "", + FuncKey: "", + FuncSig: "", + ConcurrentNum: 0, + } + stateRoute.HandleInstanceUpdate(testInstance) + convey.So(targetStateInstance, convey.ShouldBeNil) + + testInstance.InstanceType = types.InstanceTypeState + testInstance.InstanceID = "1" + testInstance.InstanceStatus = commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + Msg: "", + } + stateRoute.stateRoute["1"] = &StateInstance{ + StateID: "1", + status: InstanceOk, + instance: testInstance, + } + testInstance.InstanceStatus.Code = int32(constant.KernelInstanceStatusSubHealth) + stateRoute.HandleInstanceUpdate(testInstance) + convey.So(targetStateInstance, convey.ShouldNotBeNil) + convey.So(targetStateInstance.status, convey.ShouldEqual, InstanceAbnormal) + }) +} + +func TestStateRoute_GetAndDeleteState(t *testing.T) { + convey.Convey("TestStateRoute_GetAndDeleteState", t, func() { + stateRoute := mockStateRoute() + var targetStateInstance *StateInstance + patch := gomonkey.ApplyPrivateMethod(reflect.TypeOf(stateRoute), "processStateRoute", + func(_ *StateRoute, instance *StateInstance, stateID string, opType string) { + if opType == stateDelete { + delete(stateRoute.stateRoute, stateID) + } + }) + defer patch.Reset() + testInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{}, + InstanceType: types.InstanceTypeReserved, + InstanceID: "", + FuncKey: "", + FuncSig: "", + ConcurrentNum: 0, + } + stateRoute.HandleInstanceUpdate(testInstance) + convey.So(targetStateInstance, convey.ShouldBeNil) + + testInstance.InstanceType = types.InstanceTypeState + testInstance.InstanceID = "1" + testInstance.InstanceStatus = commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + Msg: "", + } + stateRoute.stateRoute["1"] = &StateInstance{ + StateID: "1", + status: InstanceOk, + instance: testInstance, + } + exist := stateRoute.GetAndDeleteState("1") + convey.So(exist, convey.ShouldEqual, true) + _, exist = stateRoute.stateRoute["1"] + convey.So(exist, convey.ShouldEqual, false) + }) +} + +func TestStateRoute_DeleteStateInstance(t *testing.T) { + convey.Convey("TestStateRoute_DeleteStateInstance", t, func() { + stateRoute := mockStateRoute() + patch := gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) {}) + defer patch.Reset() + testInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{}, + InstanceType: types.InstanceTypeReserved, + InstanceID: "1111", + FuncKey: "", + FuncSig: "", + ConcurrentNum: 0, + } + + testInstance.InstanceType = types.InstanceTypeState + testInstance.InstanceID = "1111" + testInstance.InstanceStatus = commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + Msg: "", + } + stateRoute.stateRoute["1111"] = &StateInstance{ + StateID: "1111", + status: InstanceOk, + instance: testInstance, + } + + stateRoute.DeleteStateInstance("1111", "1111") + stateInstance, ok := stateRoute.stateRoute["1111"] + convey.So(ok, convey.ShouldEqual, true) + convey.So(stateInstance.status, convey.ShouldEqual, InstanceExit) + }) +} + +func TestStateRoute_DeleteStateInstanceByInstanceID(t *testing.T) { + convey.Convey("TestStateRoute_DeleteStateInstanceByInstanceID", t, func() { + stateRoute := mockStateRoute() + patch := gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) {}) + defer patch.Reset() + testInstance := &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{}, + InstanceType: types.InstanceTypeState, + InstanceID: "3", + FuncKey: "", + FuncSig: "", + ConcurrentNum: 0, + } + + testInstance.InstanceType = types.InstanceTypeState + testInstance.InstanceID = "3" + testInstance.InstanceStatus = commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + Msg: "", + } + stateRoute.stateRoute["3"] = &StateInstance{ + StateID: "3", + status: InstanceOk, + instance: testInstance, + } + + stateRoute.DeleteStateInstanceByInstanceID("3") + stateInstance, ok := stateRoute.stateRoute["3"] + convey.So(ok, convey.ShouldEqual, true) + convey.So(stateInstance.status, convey.ShouldEqual, InstanceExit) + }) +} + +func TestStateRoute_Recover(t *testing.T) { + convey.Convey("TestStateRoute_recover", t, func() { + instanceMap := map[string]*types.Instance{ + "1": &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + InstanceType: "", + InstanceID: "1", + FuncKey: "", + FuncSig: "sig", + ConcurrentNum: 0, + }, + "2": &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusSubHealth), + }, + InstanceType: "", + InstanceID: "2", + FuncKey: "", + FuncSig: "sig", + ConcurrentNum: 0, + }, + "3": &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusExited), + }, + InstanceType: "", + InstanceID: "3", + FuncKey: "", + FuncSig: "sig", + ConcurrentNum: 0, + }, + } + stateRoute := mockStateRoute() + stateRoute.recover(instanceMap) + stateInstance, ok := stateRoute.stateRoute["1"] + convey.So(ok, convey.ShouldEqual, true) + convey.So(stateInstance.status, convey.ShouldEqual, InstanceOk) + stateInstance, ok = stateRoute.stateRoute["2"] + convey.So(ok, convey.ShouldEqual, true) + convey.So(stateInstance.status, convey.ShouldEqual, InstanceAbnormal) + stateInstance, ok = stateRoute.stateRoute["3"] + convey.So(ok, convey.ShouldEqual, true) + convey.So(stateInstance.status, convey.ShouldEqual, InstanceExit) + }) +} + +func TestStateRoute_acquireStateInstanceThread(t *testing.T) { + convey.Convey("TestStateRoute_acquireStateInstanceThread_1", t, func() { + stateRoute := mockStateRoute() + stateRoute.stateRoute["1"] = &StateInstance{ + StateID: "1", + status: InstanceAbnormal, + instance: &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceID: "1", + }, + } + thread, snerr := stateRoute.acquireStateInstanceThread(&types.InstanceAcquireRequest{ + StateID: "1", + }) + convey.So(thread, convey.ShouldBeNil) + convey.So(snerr.Code(), convey.ShouldEqual, statuscode.InstanceStatusAbnormalCode) + }) + convey.Convey("TestStateRoute_acquireStateInstanceThread_2", t, func() { + stateRoute := mockStateRoute() + patch := gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) {}) + defer patch.Reset() + stateRoute.createInstanceFunc = func(resSpec *resspeckey.ResourceSpecification, instanceType types.InstanceType, + callerPod string) (*types.Instance, error) { + return &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + InstanceID: "2", + FuncKey: "", + FuncSig: "sig", + ConcurrentNum: 0, + }, nil + } + stateRoute.stateRoute["1"] = &StateInstance{ + StateID: "1", + status: InstanceAbnormal, + instance: &types.Instance{ + ResKey: resspeckey.ResSpecKey{}, + InstanceID: "1", + }, + } + thread, _ := stateRoute.acquireStateInstanceThread(&types.InstanceAcquireRequest{ + StateID: "2", + }) + convey.So(thread, convey.ShouldNotBeNil) + convey.So(thread.Instance.InstanceID, convey.ShouldEqual, "2") + }) + convey.Convey("TestStateRoute_acquireStateInstanceThread_3", t, func() { + stateRoute := mockStateRoute() + stateRoute.stateRoute["1"] = &StateInstance{ + StateID: "1", + status: InstanceExit, + instance: nil, + } + thread, snerr := stateRoute.acquireStateInstanceThread(&types.InstanceAcquireRequest{ + StateID: "1", + }) + convey.So(thread, convey.ShouldBeNil) + convey.So(snerr.Code(), convey.ShouldEqual, statuscode.StateInstanceNotExistedErrCode) + }) +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/create_request_queue.go b/yuanrong/pkg/functionscaler/instancequeue/create_request_queue.go new file mode 100644 index 0000000..e829d95 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/create_request_queue.go @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "container/list" + "sync" + + "yuanrong/pkg/functionscaler/scaler" +) + +// InstanceCreateRequest is the request to create instance +type InstanceCreateRequest struct { + callback scaler.ScaleUpCallback +} + +// InstanceCreateQueue stores instance create requests +type InstanceCreateQueue struct { + queue *list.List + createHead *list.Element + cancelHead *list.Element + cond *sync.Cond + stopped bool +} + +// NewInstanceCreateQueue creates new InstanceCreateQueue +func NewInstanceCreateQueue() *InstanceCreateQueue { + return &InstanceCreateQueue{ + queue: list.New(), + cond: sync.NewCond(new(sync.Mutex)), + } +} + +func (iq *InstanceCreateQueue) push(request *InstanceCreateRequest) { + iq.cond.L.Lock() + newElem := iq.queue.PushBack(request) + if iq.createHead == nil { + iq.createHead = newElem + } + iq.cancelHead = newElem + iq.cond.L.Unlock() + iq.cond.Signal() +} + +func (iq *InstanceCreateQueue) getForCreate() *InstanceCreateRequest { + iq.cond.L.Lock() + if iq.createHead == nil { + iq.cond.Wait() + } + if iq.stopped { + iq.cond.L.Unlock() + return nil + } + curElem := iq.createHead + if curElem == nil { + iq.cond.L.Unlock() + return nil + } + if curElem == iq.createHead { + iq.createHead = iq.createHead.Next() + } + if curElem == iq.cancelHead { + iq.cancelHead = iq.cancelHead.Prev() + } + iq.queue.Remove(curElem) + iq.cond.L.Unlock() + request, ok := curElem.Value.(*InstanceCreateRequest) + if !ok { + return nil + } + return request +} + +func (iq *InstanceCreateQueue) getForCancel() *InstanceCreateRequest { + iq.cond.L.Lock() + if iq.stopped { + iq.cond.L.Unlock() + return nil + } + curElem := iq.cancelHead + if curElem == nil { + iq.cond.L.Unlock() + return nil + } + if curElem == iq.createHead { + iq.createHead = iq.createHead.Next() + } + if curElem == iq.cancelHead { + iq.cancelHead = iq.cancelHead.Prev() + } + iq.queue.Remove(curElem) + iq.cond.L.Unlock() + request, ok := curElem.Value.(*InstanceCreateRequest) + if !ok { + return nil + } + return request +} + +func (iq *InstanceCreateQueue) getQueLen() int { + iq.cond.L.Lock() + queLen := iq.queue.Len() + iq.cond.L.Unlock() + return queLen +} + +func (iq *InstanceCreateQueue) destroy() { + iq.cond.L.Lock() + iq.stopped = true + iq.cond.L.Unlock() + iq.cond.Broadcast() +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/create_request_queue_test.go b/yuanrong/pkg/functionscaler/instancequeue/create_request_queue_test.go new file mode 100644 index 0000000..427134c --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/create_request_queue_test.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInstanceCreateQueue(t *testing.T) { + queue := NewInstanceCreateQueue() + assert.NotNil(t, queue) + queue.push(&InstanceCreateRequest{}) + req1 := queue.getForCreate() + assert.NotNil(t, req1) + req2 := queue.getForCancel() + assert.Nil(t, req2) + queue.push(&InstanceCreateRequest{}) + req3 := queue.getForCancel() + assert.NotNil(t, req3) + req4 := queue.getForCancel() + assert.Nil(t, req4) + queue.push(&InstanceCreateRequest{}) + assert.Equal(t, 1, queue.getQueLen()) +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/instance_queue.go b/yuanrong/pkg/functionscaler/instancequeue/instance_queue.go new file mode 100644 index 0000000..ddcb541 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/instance_queue.go @@ -0,0 +1,97 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "errors" + + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/types" +) + +var ( + // ErrInsNotExist is the error of instance does not exist + ErrInsNotExist = errors.New("instance does not exist in queue") + // ErrInsSubHealth is the error of instance is not normal + ErrInsSubHealth = errors.New("instance is subHealth") + // ErrNoInsThdAvailable is the error of no instance available + ErrNoInsThdAvailable = errors.New("no instance thread available now") + // ErrFuncSigMismatch is the error of "function signature mismatch" + ErrFuncSigMismatch = errors.New("function signature mismatch") + // ErrFunctionDeleted is the error of "function is deleted" + ErrFunctionDeleted = errors.New("function is deleted") +) + +// InsQueConfig - +type InsQueConfig struct { + FuncSpec *types.FunctionSpecification + InsThdReqQueue *requestqueue.InsAcqReqQueue + InstanceType types.InstanceType + ResKey resspeckey.ResSpecKey + MetricsCollector metrics.Collector + CreateInstanceFunc + DeleteInstanceFunc + SignalInstanceFunc +} + +// InstanceOperationFunc contains functions which operates instance +type InstanceOperationFunc struct{} + +// CreateInstanceFunc - +type CreateInstanceFunc func(string, types.InstanceType, resspeckey.ResSpecKey, []byte) (*types.Instance, error) + +// DeleteInstanceFunc - +type DeleteInstanceFunc func(*types.Instance) error + +// SignalInstanceFunc - +type SignalInstanceFunc func(*types.Instance, int) + +// InstanceQueue stores instances +type InstanceQueue interface { + AcquireInstance(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, snerror.SNError) + HandleInstanceUpdate(instance *types.Instance) + HandleInstanceDelete(instance *types.Instance) + HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) + Destroy() + GetInstanceNumber(onlySelf bool) int +} + +func buildSnError(err error) snerror.SNError { + if snErr, ok := err.(snerror.SNError); ok { + return snErr + } + switch err { + case nil: + return nil + case scheduler.ErrNoInsAvailable: + return snerror.New(statuscode.NoInstanceAvailableErrCode, err.Error()) + case scheduler.ErrInsNotExist, scheduler.ErrInsSubHealthy: + return snerror.New(statuscode.InstanceNotFoundErrCode, err.Error()) + case scheduler.ErrInsReqTimeout: + return snerror.New(statuscode.InsThdReqTimeoutCode, err.Error()) + case scheduler.ErrInvalidSession: + return snerror.New(statuscode.InstanceSessionInvalidErrCode, err.Error()) + default: + return snerror.New(statuscode.StatusInternalServerError, err.Error()) + } +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/instance_queue_builder.go b/yuanrong/pkg/functionscaler/instancequeue/instance_queue_builder.go new file mode 100644 index 0000000..522cd4f --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/instance_queue_builder.go @@ -0,0 +1,246 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "errors" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler" + "yuanrong/pkg/functionscaler/scheduler/microservicescheduler" + "yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +var ( + // ErrUnsupportedInstanceType is the error of unsupported instance type + ErrUnsupportedInstanceType = errors.New("unsupported instance type") +) + +// BuildInstanceQueue builds an instanceQueue +func BuildInstanceQueue(config *InsQueConfig, insAcqReqQue *requestqueue.InsAcqReqQueue, + metricsCollector metrics.Collector) (InstanceQueue, error) { + if config.InstanceType == types.InstanceTypeOnDemand { + return NewOnDemandInstanceQueue(config), nil + } + instanceQueue := NewScaledInstanceQueue(config, metricsCollector) + var err error + err = AssembleScheduler(config.FuncSpec.InstanceMetaData.SchedulePolicy, instanceQueue, insAcqReqQue) + if err != nil { + log.GetLogger().Errorf("failed to assemble instanceScheduler for function %s", config.FuncSpec.FuncKey) + return nil, err + } + switch config.FuncSpec.InstanceMetaData.ScalePolicy { + case types.InstanceScalePolicyPredict: + err = assembleScalerWithPredictPolicy(instanceQueue) + case types.InstanceScalePolicyStaticFunction: + err = assembleScalerWithStaticPolicy(instanceQueue, insAcqReqQue) + default: + err = assembleScalerWithConcurrencyPolicy(instanceQueue) + } + if err != nil { + log.GetLogger().Errorf("failed to assemble instanceScaler for function %s err %v", + config.FuncSpec.FuncKey, err) + return nil, err + } + if instanceQueue.isFuncOwner { + log.GetLogger().Infof("instance queue is funcOwner of function %s", config.FuncSpec.FuncKey) + instanceQueue.instanceScheduler.HandleFuncOwnerUpdate(true) + } + return instanceQueue, nil +} + +// AssembleScheduler assemble scheduler to queue +func AssembleScheduler(schedulePolicy string, instanceQueue *ScaledInstanceQueue, + insAcqReqQue *requestqueue.InsAcqReqQueue) error { + var err error + switch schedulePolicy { + case types.InstanceSchedulePolicyConcurrency: + err = assembleSchedulerWithConcurrencyPolicy(instanceQueue, insAcqReqQue) + case types.InstanceSchedulePolicyRoundRobin: + err = assembleSchedulerWithRoundRobinPolicy(instanceQueue, insAcqReqQue) + case types.InstanceSchedulePolicyMicroService: + err = assembleSchedulerWithMicroService(instanceQueue) + default: + err = assembleSchedulerWithConcurrencyPolicy(instanceQueue, insAcqReqQue) + } + if err != nil { + return err + } + return nil +} + +func assembleSchedulerWithConcurrencyPolicy(instanceQueue *ScaledInstanceQueue, + insThdReqQueue *requestqueue.InsAcqReqQueue) error { + requestTimeout := utils.GetRequestTimeout(instanceQueue.funcSpec) + var instanceScheduler scheduler.InstanceScheduler + if instanceQueue.instanceType == types.InstanceTypeReserved { + instanceScheduler = concurrencyscheduler.NewReservedConcurrencyScheduler(instanceQueue.funcSpec, + instanceQueue.resKey, requestTimeout, insThdReqQueue) + log.GetLogger().Infof("assemble with concurrency scheduler for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + } else if instanceQueue.instanceType == types.InstanceTypeScaled { + instanceScheduler = concurrencyscheduler.NewScaledConcurrencyScheduler(instanceQueue.funcSpec, + instanceQueue.resKey, insThdReqQueue) + log.GetLogger().Infof("assemble with concurrency scheduler for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + } else { + log.GetLogger().Errorf("unsupported instance type %s", instanceQueue.instanceType) + return ErrUnsupportedInstanceType + } + instanceQueue.SetInstanceScheduler(instanceScheduler) + return nil +} + +func assembleSchedulerWithRoundRobinPolicy(instanceQueue *ScaledInstanceQueue, + insThdReqQueue *requestqueue.InsAcqReqQueue) error { + requestTimeout := utils.GetRequestTimeout(instanceQueue.funcSpec) + var instanceScheduler scheduler.InstanceScheduler + if instanceQueue.instanceType == types.InstanceTypeReserved { + instanceScheduler = roundrobinscheduler.NewRoundRobinScheduler(instanceQueue.funcKeyWithRes, true, requestTimeout) + log.GetLogger().Infof("assemble with round-robin scheduler for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + } else if instanceQueue.instanceType == types.InstanceTypeScaled { + instanceScheduler = roundrobinscheduler.NewRoundRobinScheduler(instanceQueue.funcKeyWithRes, false, requestTimeout) + log.GetLogger().Infof("assemble with round-robin scheduler for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + } else { + log.GetLogger().Errorf("unsupported instance type %s", instanceQueue.instanceType) + return ErrUnsupportedInstanceType + } + instanceQueue.SetInstanceScheduler(instanceScheduler) + return nil + +} + +func assembleSchedulerWithMicroService(instanceQueue *ScaledInstanceQueue) error { + var instanceScheduler scheduler.InstanceScheduler + if instanceQueue.instanceType == types.InstanceTypeReserved { + instanceScheduler = microservicescheduler.NewMicroServiceScheduler(instanceQueue.funcKeyWithRes, + config.GlobalConfig.MicroServiceSchedulingPolicy) + log.GetLogger().Infof("assemble with microService for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + } else if instanceQueue.instanceType == types.InstanceTypeScaled { + instanceScheduler = microservicescheduler.NewMicroServiceScheduler(instanceQueue.funcKeyWithRes, + config.GlobalConfig.MicroServiceSchedulingPolicy) + log.GetLogger().Infof("assemble with microService for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + } else { + log.GetLogger().Errorf("unsupported instance type %s", instanceQueue.instanceType) + return ErrUnsupportedInstanceType + } + instanceQueue.SetInstanceScheduler(instanceScheduler) + return nil +} + +func assembleScalerWithStaticPolicy(instanceQueue *ScaledInstanceQueue, + insAcqReqQue *requestqueue.InsAcqReqQueue) error { + if instanceQueue.instanceScheduler == nil { + return errors.New("missing instanceScheduler in instanceQueue") + } + var instanceScaler scaler.InstanceScaler + // only reserve queue and func owner can trigger nuwa cold start + if instanceQueue.instanceType == types.InstanceTypeReserved { + instanceScaler = scaler.NewWiseCloudScaler(instanceQueue.funcKeyWithRes, instanceQueue.resKey, true, + insAcqReqQue.HandleCreateError) + } else if instanceQueue.instanceType == types.InstanceTypeScaled { + instanceScaler = scaler.NewWiseCloudScaler(instanceQueue.funcKeyWithRes, instanceQueue.resKey, false, + insAcqReqQue.HandleCreateError) + } else { + log.GetLogger().Errorf("unsupported instance type %s", instanceQueue.instanceType) + return ErrUnsupportedInstanceType + } + // set concurrentNum for first time + instanceScaler.HandleFuncSpecUpdate(instanceQueue.funcSpec) + instanceQueue.instanceScheduler.ConnectWithInstanceScaler(instanceScaler) + instanceQueue.SetInstanceScaler(instanceScaler) + return nil +} + +func assembleScalerWithConcurrencyPolicy(instanceQueue *ScaledInstanceQueue) error { + if instanceQueue.instanceScheduler == nil { + return errors.New("missing instanceScheduler in instanceQueue") + } + var instanceScaler scaler.InstanceScaler + if instanceQueue.instanceType == types.InstanceTypeReserved { + instanceScaler = scaler.NewReplicaScaler(instanceQueue.funcKeyWithRes, instanceQueue.metricsCollector, + instanceQueue.ScaleUpHandler, instanceQueue.ScaleDownHandler) + log.GetLogger().Infof("assemble with replica scaler for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + } else if instanceQueue.instanceType == types.InstanceTypeScaled { + instanceScheduler, ok := instanceQueue.instanceScheduler.(*concurrencyscheduler.ScaledConcurrencyScheduler) + if ok { + instanceScaler = scaler.NewAutoScaler(instanceQueue.funcKeyWithRes, instanceQueue.metricsCollector, + instanceScheduler.GetReqQueLen, instanceQueue.ScaleUpHandler, instanceQueue.ScaleDownHandler) + } else { + log.GetLogger().Warnf("missing concurrencyScheduler when build concurrencyScaler for scaled instance") + instanceScaler = scaler.NewAutoScaler(instanceQueue.funcKeyWithRes, instanceQueue.metricsCollector, + func() int { return 0 }, instanceQueue.ScaleUpHandler, instanceQueue.ScaleDownHandler) + } + log.GetLogger().Infof("assemble with auto scaler for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + } else { + log.GetLogger().Errorf("unsupported instance type %s", instanceQueue.instanceType) + return ErrUnsupportedInstanceType + } + // set concurrentNum for first time + instanceScaler.HandleFuncSpecUpdate(instanceQueue.funcSpec) + instanceQueue.instanceScheduler.ConnectWithInstanceScaler(instanceScaler) + instanceQueue.SetInstanceScaler(instanceScaler) + return nil +} + +func assembleScalerWithPredictPolicy(instanceQueue *ScaledInstanceQueue) error { + if instanceQueue.instanceScheduler == nil { + return errors.New("missing instanceScheduler in instanceQueue") + } + var instanceScaler scaler.InstanceScaler + switch instanceQueue.instanceType { + case types.InstanceTypeReserved: + instanceScaler = scaler.NewReplicaScaler(instanceQueue.funcKeyWithRes, instanceQueue.metricsCollector, + instanceQueue.ScaleUpHandler, instanceQueue.ScaleDownHandler) + log.GetLogger().Infof("assemble with replica scaler for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + case types.InstanceTypeScaled: + instanceScheduler, ok := instanceQueue.instanceScheduler.(*concurrencyscheduler.ScaledConcurrencyScheduler) + if ok { + instanceScaler = scaler.NewPredictScaler(instanceQueue.funcKeyWithRes, instanceQueue.metricsCollector, + instanceScheduler.GetReqQueLen, instanceQueue.ScaleUpHandler, instanceQueue.ScaleDownHandler) + } else { + log.GetLogger().Warnf("missing concurrencyScheduler when build concurrencyScaler for scaled instance") + instanceScaler = scaler.NewPredictScaler(instanceQueue.funcKeyWithRes, instanceQueue.metricsCollector, + func() int { return 0 }, instanceQueue.ScaleUpHandler, instanceQueue.ScaleDownHandler) + } + log.GetLogger().Infof("assemble with predict scaler for %s instance queue of function %s", + instanceQueue.instanceType, instanceQueue.funcKeyWithRes) + default: + log.GetLogger().Errorf("unsupported instance type %s", instanceQueue.instanceType) + return ErrUnsupportedInstanceType + } + // set concurrentNum for first time + instanceScaler.HandleFuncSpecUpdate(instanceQueue.funcSpec) + instanceQueue.instanceScheduler.ConnectWithInstanceScaler(instanceScaler) + instanceQueue.SetInstanceScaler(instanceScaler) + return nil +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/instance_queue_builder_test.go b/yuanrong/pkg/functionscaler/instancequeue/instance_queue_builder_test.go new file mode 100644 index 0000000..f84df99 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/instance_queue_builder_test.go @@ -0,0 +1,204 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/resspeckey" + commontypes "yuanrong/pkg/common/faas_common/types" + wisecloudTypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler" + "yuanrong/pkg/functionscaler/scheduler/microservicescheduler" + "yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler" + "yuanrong/pkg/functionscaler/types" +) + +var testFuncSpec = &types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + ConcurrentNum: 100, + }, +} + +func TestBuildInstanceQueue(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + MetricsCollector: &metrics.BucketCollector{}, + InsThdReqQueue: requestqueue.NewInsAcqReqQueue("", 10), + } + insAcqReqQue := &requestqueue.InsAcqReqQueue{} + metricsCollector := &metrics.BucketCollector{} + insQue, err := BuildInstanceQueue(basicInsQueConfig, insAcqReqQue, metricsCollector) + assert.Nil(t, err) + typedInsQue, ok := insQue.(*ScaledInstanceQueue) + assert.Equal(t, true, ok) + assert.IsType(t, &concurrencyscheduler.ScaledConcurrencyScheduler{}, typedInsQue.instanceScheduler) + + basicInsQueConfig = &InsQueConfig{ + InstanceType: types.InstanceTypeReserved, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + MetricsCollector: &metrics.BucketCollector{}, + } + insQue, err = BuildInstanceQueue(basicInsQueConfig, insAcqReqQue, metricsCollector) + assert.Nil(t, err) + typedInsQue, ok = insQue.(*ScaledInstanceQueue) + assert.Equal(t, true, ok) + assert.IsType(t, &concurrencyscheduler.ReservedConcurrencyScheduler{}, typedInsQue.instanceScheduler) + + basicInsQueConfig = &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + SchedulePolicy: types.InstanceSchedulePolicyRoundRobin, + }}, + ResKey: resspeckey.ResSpecKey{}, + MetricsCollector: &metrics.BucketCollector{}, + } + insQue, err = BuildInstanceQueue(basicInsQueConfig, insAcqReqQue, metricsCollector) + assert.Nil(t, err) + typedInsQue, ok = insQue.(*ScaledInstanceQueue) + assert.Equal(t, true, ok) + assert.IsType(t, &roundrobinscheduler.RoundRobinScheduler{}, typedInsQue.instanceScheduler) + + basicInsQueConfig = &InsQueConfig{ + InstanceType: types.InstanceTypeReserved, + FuncSpec: &types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + SchedulePolicy: types.InstanceSchedulePolicyRoundRobin, + }}, + ResKey: resspeckey.ResSpecKey{}, + MetricsCollector: &metrics.BucketCollector{}, + } + insQue, err = BuildInstanceQueue(basicInsQueConfig, insAcqReqQue, metricsCollector) + assert.Nil(t, err) + typedInsQue, ok = insQue.(*ScaledInstanceQueue) + assert.Equal(t, true, ok) + assert.IsType(t, &roundrobinscheduler.RoundRobinScheduler{}, typedInsQue.instanceScheduler) + + basicInsQueConfig = &InsQueConfig{ + InstanceType: types.InstanceTypeReserved, + FuncSpec: &types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + SchedulePolicy: types.InstanceSchedulePolicyRoundRobin, + ScalePolicy: types.InstanceScalePolicyPredict, + }}, + ResKey: resspeckey.ResSpecKey{}, + MetricsCollector: &metrics.BucketCollector{}, + } + insQue, err = BuildInstanceQueue(basicInsQueConfig, insAcqReqQue, metricsCollector) + assert.Nil(t, err) + typedInsQue, ok = insQue.(*ScaledInstanceQueue) + assert.Equal(t, true, ok) + assert.IsType(t, &roundrobinscheduler.RoundRobinScheduler{}, typedInsQue.instanceScheduler) + assert.IsType(t, &scaler.ReplicaScaler{}, typedInsQue.instanceScaler) + + basicInsQueConfig = &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + SchedulePolicy: types.InstanceSchedulePolicyRoundRobin, + ScalePolicy: types.InstanceScalePolicyPredict, + }}, + ResKey: resspeckey.ResSpecKey{}, + MetricsCollector: &metrics.BucketCollector{}, + } + insQue, err = BuildInstanceQueue(basicInsQueConfig, insAcqReqQue, metricsCollector) + assert.Nil(t, err) + typedInsQue, ok = insQue.(*ScaledInstanceQueue) + assert.Equal(t, true, ok) + assert.IsType(t, &roundrobinscheduler.RoundRobinScheduler{}, typedInsQue.instanceScheduler) + assert.IsType(t, &scaler.PredictScaler{}, typedInsQue.instanceScaler) + + basicInsQueConfig = &InsQueConfig{ + InstanceType: types.InstanceTypeReserved, + FuncSpec: &types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + SchedulePolicy: types.InstanceSchedulePolicyMicroService, + ScalePolicy: types.InstanceScalePolicyConcurrency, + }}, + ResKey: resspeckey.ResSpecKey{}, + MetricsCollector: &metrics.BucketCollector{}, + } + insQue, err = BuildInstanceQueue(basicInsQueConfig, insAcqReqQue, metricsCollector) + assert.Nil(t, err) + typedInsQue, ok = insQue.(*ScaledInstanceQueue) + assert.Equal(t, true, ok) + assert.IsType(t, µservicescheduler.MicroServiceScheduler{}, typedInsQue.instanceScheduler) + + basicInsQueConfig = &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + SchedulePolicy: types.InstanceSchedulePolicyMicroService, + ScalePolicy: types.InstanceScalePolicyConcurrency, + }}, + ResKey: resspeckey.ResSpecKey{}, + MetricsCollector: &metrics.BucketCollector{}, + } + insQue, err = BuildInstanceQueue(basicInsQueConfig, insAcqReqQue, metricsCollector) + assert.Nil(t, err) + typedInsQue, ok = insQue.(*ScaledInstanceQueue) + assert.Equal(t, true, ok) + assert.IsType(t, µservicescheduler.MicroServiceScheduler{}, typedInsQue.instanceScheduler) +} + +func TestAssembleWithConcurrencyScaler(t *testing.T) { + err := assembleScalerWithConcurrencyPolicy(&ScaledInstanceQueue{}) + assert.Equal(t, "missing instanceScheduler in instanceQueue", err.Error()) + err = assembleScalerWithConcurrencyPolicy(&ScaledInstanceQueue{instanceScheduler: &fakeInstanceScheduler{}}) + assert.Contains(t, "unsupported instance type", err.Error()) + err = assembleScalerWithConcurrencyPolicy(&ScaledInstanceQueue{funcSpec: testFuncSpec, + instanceType: types.InstanceTypeScaled, instanceScheduler: &fakeInstanceScheduler{}}) + assert.Equal(t, nil, err) +} + +func TestAssembleScalerWithPredictPolicy(t *testing.T) { + err := assembleScalerWithPredictPolicy(&ScaledInstanceQueue{}) + assert.Equal(t, "missing instanceScheduler in instanceQueue", err.Error()) + err = assembleScalerWithPredictPolicy(&ScaledInstanceQueue{instanceScheduler: &fakeInstanceScheduler{}}) + assert.Contains(t, "unsupported instance type", err.Error()) + err = assembleScalerWithPredictPolicy(&ScaledInstanceQueue{funcSpec: testFuncSpec, + instanceType: types.InstanceTypeScaled, instanceScheduler: &fakeInstanceScheduler{}}) + assert.Equal(t, nil, err) +} + +func TestAssembleScalerWithStaticPolicy(t *testing.T) { + config.GlobalConfig.ServiceAccountJwt = wisecloudTypes.ServiceAccountJwt{ + ServiceAccount: &wisecloudTypes.ServiceAccount{}, + TlsConfig: &wisecloudTypes.TLSConfig{}, + } + err := assembleScalerWithStaticPolicy(&ScaledInstanceQueue{}, requestqueue.NewInsAcqReqQueue("", 10)) + assert.Equal(t, "missing instanceScheduler in instanceQueue", err.Error()) + err = assembleScalerWithStaticPolicy(&ScaledInstanceQueue{instanceScheduler: &fakeInstanceScheduler{}}, requestqueue.NewInsAcqReqQueue("", 10)) + assert.Contains(t, "unsupported instance type", err.Error()) + err = assembleScalerWithStaticPolicy(&ScaledInstanceQueue{funcSpec: testFuncSpec, + instanceType: types.InstanceTypeScaled, instanceScheduler: &fakeInstanceScheduler{}}, requestqueue.NewInsAcqReqQueue("", 10)) + assert.Equal(t, nil, err) + err = assembleScalerWithStaticPolicy(&ScaledInstanceQueue{funcSpec: testFuncSpec, + instanceType: types.InstanceTypeReserved, instanceScheduler: &fakeInstanceScheduler{}}, requestqueue.NewInsAcqReqQueue("", 10)) + assert.Equal(t, nil, err) +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/instance_queue_test.go b/yuanrong/pkg/functionscaler/instancequeue/instance_queue_test.go new file mode 100644 index 0000000..13cced7 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/instance_queue_test.go @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "errors" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/statuscode" + commontypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/selfregister" +) + +func TestMain(m *testing.M) { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*etcd3.EtcdWatcher).StartList, func(_ *etcd3.EtcdWatcher) {}), + gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + gomonkey.ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + gomonkey.ApplyFunc(etcd3.GetCAEMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + gomonkey.ApplyFunc((*registry.FaasSchedulerRegistry).WaitForETCDList, func() {}), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + _ = registry.InitRegistry(make(chan struct{})) + registry.GlobalRegistry.FaaSSchedulerRegistry = registry.NewFaasSchedulerRegistry(make(chan struct{})) + selfregister.SelfInstanceID = "schedulerID-1" + selfregister.GlobalSchedulerProxy.Add(&commontypes.InstanceInfo{ + TenantID: "123456789", + FunctionName: "faasscheduler", + Version: "lastest", + InstanceName: "schedulerID-1", + }, "") + m.Run() +} + +func TestBuildSnError(t *testing.T) { + assert.Equal(t, statuscode.NoInstanceAvailableErrCode, buildSnError(scheduler.ErrNoInsAvailable).Code()) + assert.Equal(t, statuscode.InstanceNotFoundErrCode, buildSnError(scheduler.ErrInsNotExist).Code()) + assert.Equal(t, statuscode.InsThdReqTimeoutCode, buildSnError(scheduler.ErrInsReqTimeout).Code()) + assert.Equal(t, statuscode.StatusInternalServerError, buildSnError(errors.New("some error")).Code()) +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/ondemand_instance_queue.go b/yuanrong/pkg/functionscaler/instancequeue/ondemand_instance_queue.go new file mode 100644 index 0000000..7b247b6 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/ondemand_instance_queue.go @@ -0,0 +1,200 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "context" + "fmt" + "sync" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +// OnDemandInstanceQueue - +type OnDemandInstanceQueue struct { + funcSpec *types.FunctionSpecification + instanceMap map[string]*types.Instance + insNameMap map[string]*types.Instance + instanceType types.InstanceType + resKey resspeckey.ResSpecKey + funcCtx context.Context + funcKey string + funcSig string + funcKeyWithRes string + createInstanceFunc CreateInstanceFunc + deleteInstanceFunc DeleteInstanceFunc + signalInstanceFunc SignalInstanceFunc + sync.RWMutex +} + +// NewOnDemandInstanceQueue - +func NewOnDemandInstanceQueue(config *InsQueConfig) *OnDemandInstanceQueue { + return &OnDemandInstanceQueue{ + funcSpec: config.FuncSpec, + resKey: config.ResKey, + instanceMap: make(map[string]*types.Instance, utils.DefaultMapSize), + insNameMap: make(map[string]*types.Instance, utils.DefaultMapSize), + instanceType: config.InstanceType, + funcCtx: config.FuncSpec.FuncCtx, + funcKey: config.FuncSpec.FuncKey, + funcSig: config.FuncSpec.FuncMetaSignature, + funcKeyWithRes: fmt.Sprintf("%s-%s", config.FuncSpec.FuncKey, config.ResKey.String()), + createInstanceFunc: config.CreateInstanceFunc, + deleteInstanceFunc: config.DeleteInstanceFunc, + signalInstanceFunc: config.SignalInstanceFunc, + } +} + +// AcquireInstance will acquire an instance +func (oi *OnDemandInstanceQueue) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, + snerror.SNError) { + select { + case <-oi.funcCtx.Done(): + log.GetLogger().Errorf("function is deleted, can not acquire instance InsAlloc") + return nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, statuscode.FuncMetaNotFoundErrMsg) + default: + } + var instance *types.Instance + instanceByName := oi.insNameMap[insAcqReq.InstanceName] + instanceByID := oi.instanceMap[insAcqReq.DesignateInstanceID] + if instanceByName != nil { + instance = instanceByName + } else if instanceByID != nil { + instance = instanceByID + } else { + return nil, snerror.New(statuscode.InstanceNotFoundErrCode, statuscode.InstanceNotFoundErrMsg) + } + return &types.InstanceAllocation{ + AllocationID: instance.InstanceID, + Instance: instance, + }, nil +} + +// CreateInstance - +func (oi *OnDemandInstanceQueue) CreateInstance(insCrtReq *types.InstanceCreateRequest) (*types.Instance, + snerror.SNError) { + var ( + instance *types.Instance + createErr error + ) + oi.RLock() + functionSignature := oi.funcSig + oi.RUnlock() + instance, createErr = oi.createInstanceFunc(insCrtReq.InstanceName, oi.instanceType, oi.resKey, + insCrtReq.CreateEvent) + if createErr != nil { + log.GetLogger().Errorf("failed to create instance for function %s error %s", oi.funcKeyWithRes, + createErr.Error()) + } + select { + case _, ok := <-oi.funcCtx.Done(): + if !ok { + log.GetLogger().Warnf("function %s is deleted, killing instance now", oi.funcKey) + createErr = ErrFunctionDeleted + } + default: + // in case of function signature change during instance creating + oi.RLock() + checkFunctionSignature := oi.funcSig + oi.RUnlock() + if functionSignature != checkFunctionSignature { + log.GetLogger().Errorf("function signature changes while creating instance for function %s, "+ + "killing instance now", oi.funcKeyWithRes) + createErr = ErrFuncSigMismatch + } + } + if createErr != nil { + if instance != nil { + log.GetLogger().Warnf("killing failed created instance %s for function %s", instance.InstanceID, + oi.funcKeyWithRes) + go oi.DeleteInstance(instance) + } + return nil, buildSnError(createErr) + } + oi.Lock() + oi.instanceMap[instance.InstanceID] = instance + if len(insCrtReq.InstanceName) != 0 { + oi.insNameMap[insCrtReq.InstanceName] = instance + } + oi.Unlock() + return instance, nil +} + +// DeleteInstance - +func (oi *OnDemandInstanceQueue) DeleteInstance(instance *types.Instance) snerror.SNError { + oi.Lock() + delete(oi.instanceMap, instance.InstanceID) + if len(instance.InstanceName) != 0 { + delete(oi.insNameMap, instance.InstanceName) + } + oi.Unlock() + return buildSnError(oi.deleteInstanceFunc(instance)) +} + +// HandleInstanceUpdate - +func (oi *OnDemandInstanceQueue) HandleInstanceUpdate(instance *types.Instance) { + log.GetLogger().Infof("handling instance update of function %s instanceID %s instanceName %s", oi.funcKeyWithRes, + instance.InstanceID, oi.funcKeyWithRes) +} + +// HandleInstanceDelete - +func (oi *OnDemandInstanceQueue) HandleInstanceDelete(instance *types.Instance) { + log.GetLogger().Infof("handling instance delete of function %s instanceID %s instanceName %s", oi.funcKeyWithRes, + instance.InstanceID, oi.funcKeyWithRes) + oi.Lock() + delete(oi.instanceMap, instance.InstanceID) + delete(oi.insNameMap, instance.InstanceName) + oi.Unlock() +} + +// HandleFuncSpecUpdate - +func (oi *OnDemandInstanceQueue) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + log.GetLogger().Infof("handling funcSpec update of function %s", oi.funcKeyWithRes) + if oi.funcSig != funcSpec.FuncMetaSignature { + log.GetLogger().Warnf("function %s signature changes from %s to %s", oi.funcKeyWithRes, oi.funcSig, + funcSpec.FuncMetaSignature) + deleteList := make([]*types.Instance, len(oi.insNameMap)) + oi.Lock() + for _, instance := range oi.instanceMap { + deleteList = append(deleteList, instance) + } + oi.instanceMap = make(map[string]*types.Instance, utils.DefaultMapSize) + oi.insNameMap = make(map[string]*types.Instance, utils.DefaultMapSize) + oi.Unlock() + for _, instance := range deleteList { + go oi.deleteInstanceFunc(instance) + } + } +} + +// Destroy - +func (oi *OnDemandInstanceQueue) Destroy() { +} + +// GetInstanceNumber - +func (oi *OnDemandInstanceQueue) GetInstanceNumber(onlySelf bool) int { + oi.RLock() + insNum := len(oi.insNameMap) + oi.RUnlock() + return insNum +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/ondemand_instance_queue_test.go b/yuanrong/pkg/functionscaler/instancequeue/ondemand_instance_queue_test.go new file mode 100644 index 0000000..80d266c --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/ondemand_instance_queue_test.go @@ -0,0 +1,183 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/functionscaler/types" +) + +func TestNewOnDemandInstanceQueue(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeOnDemand, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + } + q := NewOnDemandInstanceQueue(basicInsQueConfig) + assert.NotNil(t, q) +} + +func TestOnDemandAcquireInstance(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + var createErr error + createInstanceFunc := func(name string, _ types.InstanceType, _ resspeckey.ResSpecKey, _ []byte) ( + *types.Instance, error) { + if createErr != nil { + return nil, createErr + } + return &types.Instance{InstanceID: "testInstance1", InstanceName: name}, nil + } + deleteInstanceFunc := func(*types.Instance) error { return nil } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeOnDemand, + FuncSpec: &types.FunctionSpecification{ + FuncCtx: ctx, + }, + ResKey: resspeckey.ResSpecKey{}, + CreateInstanceFunc: createInstanceFunc, + DeleteInstanceFunc: deleteInstanceFunc, + } + q := NewOnDemandInstanceQueue(basicInsQueConfig) + ins, err := q.CreateInstance(&types.InstanceCreateRequest{InstanceName: "testInsName1"}) + assert.Nil(t, err) + assert.Equal(t, "testInstance1", ins.InstanceID) + insAlloc1, err := q.AcquireInstance(&types.InstanceAcquireRequest{InstanceName: "testInsName1"}) + assert.Nil(t, err) + assert.Equal(t, "testInstance1", insAlloc1.AllocationID) + insAlloc2, err := q.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "testInstance1"}) + assert.Nil(t, err) + assert.Equal(t, "testInstance1", insAlloc2.AllocationID) + cancel() + _, err = q.AcquireInstance(&types.InstanceAcquireRequest{InstanceName: "testInsName1"}) + assert.NotNil(t, err) +} + +func TestCreateInstance(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + var createErr error + createInstanceFunc := func(name string, _ types.InstanceType, _ resspeckey.ResSpecKey, _ []byte) ( + *types.Instance, error) { + if createErr != nil { + return nil, createErr + } + return &types.Instance{InstanceID: "testInstance1", InstanceName: name}, nil + } + deleteInstanceFunc := func(*types.Instance) error { return nil } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeOnDemand, + FuncSpec: &types.FunctionSpecification{ + FuncCtx: ctx, + }, + ResKey: resspeckey.ResSpecKey{}, + CreateInstanceFunc: createInstanceFunc, + DeleteInstanceFunc: deleteInstanceFunc, + } + q := NewOnDemandInstanceQueue(basicInsQueConfig) + ins, err := q.CreateInstance(&types.InstanceCreateRequest{InstanceName: "testInsName1"}) + assert.Nil(t, err) + assert.Equal(t, "testInstance1", ins.InstanceID) + createErr = errors.New("some error") + ins, err = q.CreateInstance(&types.InstanceCreateRequest{InstanceName: "testInsName1"}) + assert.NotNil(t, err) + assert.Nil(t, ins) + cancel() + ins, err = q.CreateInstance(&types.InstanceCreateRequest{InstanceName: "testInsName1"}) + assert.NotNil(t, err) + assert.Nil(t, ins) +} + +func TestDeleteInstance(t *testing.T) { + createInstanceFunc := func(name string, _ types.InstanceType, _ resspeckey.ResSpecKey, _ []byte) ( + *types.Instance, error) { + return &types.Instance{InstanceID: "testInstance1", InstanceName: name}, nil + } + deleteInstanceFunc := func(*types.Instance) error { return nil } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeOnDemand, + FuncSpec: &types.FunctionSpecification{FuncCtx: context.TODO()}, + ResKey: resspeckey.ResSpecKey{}, + CreateInstanceFunc: createInstanceFunc, + DeleteInstanceFunc: deleteInstanceFunc, + } + q := NewOnDemandInstanceQueue(basicInsQueConfig) + ins, _ := q.CreateInstance(&types.InstanceCreateRequest{InstanceName: "testInsName1"}) + err := q.DeleteInstance(ins) + assert.Nil(t, err) + assert.Equal(t, 0, len(q.instanceMap)) + assert.Equal(t, 0, len(q.insNameMap)) +} + +func TestOnDemandHandleInstanceDelete(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeOnDemand, + FuncSpec: &types.FunctionSpecification{FuncCtx: context.TODO()}, + ResKey: resspeckey.ResSpecKey{}, + } + q := NewOnDemandInstanceQueue(basicInsQueConfig) + instance := &types.Instance{ + InstanceID: "testInstance1", + InstanceName: "testInsName1", + } + q.instanceMap[instance.InstanceID] = instance + q.insNameMap[instance.InstanceName] = instance + q.HandleInstanceDelete(instance) + assert.Equal(t, 0, len(q.instanceMap)) + assert.Equal(t, 0, len(q.insNameMap)) +} + +func TestOnDemandHandleFuncSpecUpdate(t *testing.T) { + deleteInstanceFunc := func(*types.Instance) error { return nil } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeOnDemand, + FuncSpec: &types.FunctionSpecification{FuncCtx: context.TODO()}, + ResKey: resspeckey.ResSpecKey{}, + DeleteInstanceFunc: deleteInstanceFunc, + } + q := NewOnDemandInstanceQueue(basicInsQueConfig) + instance := &types.Instance{ + InstanceID: "testInstance1", + InstanceName: "testInsName1", + } + q.instanceMap[instance.InstanceID] = instance + q.insNameMap[instance.InstanceName] = instance + q.HandleFuncSpecUpdate(&types.FunctionSpecification{FuncMetaSignature: "testFuncSig"}) + assert.Equal(t, 0, len(q.instanceMap)) + assert.Equal(t, 0, len(q.insNameMap)) +} + +func TestOnDemandGetInstanceNumber(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeOnDemand, + FuncSpec: &types.FunctionSpecification{FuncCtx: context.TODO()}, + ResKey: resspeckey.ResSpecKey{}, + } + q := NewOnDemandInstanceQueue(basicInsQueConfig) + instance := &types.Instance{ + InstanceID: "testInstance1", + InstanceName: "testInsName1", + } + q.instanceMap[instance.InstanceID] = instance + q.insNameMap[instance.InstanceName] = instance + assert.Equal(t, 1, q.GetInstanceNumber(true)) +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/scaled_instance_queue.go b/yuanrong/pkg/functionscaler/instancequeue/scaled_instance_queue.go new file mode 100644 index 0000000..a6da5ee --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/scaled_instance_queue.go @@ -0,0 +1,509 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "context" + "fmt" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/scheduler/microservicescheduler" + "yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +var ( + initialErrorDelay = 1 * time.Second + retryDelayLimit = 16 * time.Second + retryDelayFactor = 2 + dataSystemFeatureUsedStream = "stream" +) + +// ScaledInstanceQueue stores instances and handles scaling automatically +type ScaledInstanceQueue struct { + funcSpec *types.FunctionSpecification + metricsCollector metrics.Collector + instanceScheduler scheduler.InstanceScheduler + instanceScaler scaler.InstanceScaler + insCreateQueue *InstanceCreateQueue + instanceType types.InstanceType + resKey resspeckey.ResSpecKey + funcCtx context.Context + funcKey string + funcSig string + funcKeyWithRes string + concurrentNum int + updating bool + isFuncOwner bool + stopCh chan struct{} + createInstanceFunc CreateInstanceFunc + deleteInstanceFunc DeleteInstanceFunc + signalInstanceFunc SignalInstanceFunc + *sync.Cond +} + +// NewScaledInstanceQueue - +func NewScaledInstanceQueue(config *InsQueConfig, metricsCollector metrics.Collector) *ScaledInstanceQueue { + funcSpec := config.FuncSpec + instanceQueue := &ScaledInstanceQueue{ + funcSpec: funcSpec, + resKey: config.ResKey, + insCreateQueue: NewInstanceCreateQueue(), + metricsCollector: metricsCollector, + instanceType: config.InstanceType, + funcCtx: funcSpec.FuncCtx, + funcKey: funcSpec.FuncKey, + funcSig: funcSpec.FuncMetaSignature, + funcKeyWithRes: fmt.Sprintf("%s-%s", funcSpec.FuncKey, config.ResKey.String()), + concurrentNum: utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum), + stopCh: make(chan struct{}), + createInstanceFunc: config.CreateInstanceFunc, + deleteInstanceFunc: config.DeleteInstanceFunc, + signalInstanceFunc: config.SignalInstanceFunc, + Cond: sync.NewCond(&sync.Mutex{}), + isFuncOwner: selfregister.GlobalSchedulerProxy.CheckFuncOwner(funcSpec.FuncKey), + } + go instanceQueue.startScaleUpWorker() + return instanceQueue +} + +// DisableCreateRetry - +func DisableCreateRetry() { + initialErrorDelay = 1 * time.Second + retryDelayLimit = 1 * time.Second + retryDelayFactor = 2 +} + +// SetInstanceScheduler sets instanceScheduler +func (si *ScaledInstanceQueue) SetInstanceScheduler(instanceScheduler scheduler.InstanceScheduler) { + si.instanceScheduler = instanceScheduler +} + +// GetInstanceScheduler return queue's isntance scheduler +func (si *ScaledInstanceQueue) GetInstanceScheduler() scheduler.InstanceScheduler { + if si.instanceScheduler != nil { + return si.instanceScheduler + } + return nil +} + +// GetSchedulerPolicy get current scheduelr policy +func (si *ScaledInstanceQueue) GetSchedulerPolicy() string { + if _, ok := si.instanceScheduler.(*microservicescheduler.MicroServiceScheduler); ok { + return types.InstanceSchedulePolicyMicroService + } + if _, ok := si.instanceScheduler.(*roundrobinscheduler.RoundRobinScheduler); ok { + return types.InstanceSchedulePolicyRoundRobin + } + return types.InstanceSchedulePolicyConcurrency +} + +// ReconnectWithScaler change scheduelr +func (si *ScaledInstanceQueue) ReconnectWithScaler() { + si.instanceScheduler.ConnectWithInstanceScaler(si.instanceScaler) +} + +// SetInstanceScaler sets instanceScaler +func (si *ScaledInstanceQueue) SetInstanceScaler(instanceScaler scaler.InstanceScaler) { + si.instanceScaler = instanceScaler +} + +// AcquireInstance will acquire one instance +func (si *ScaledInstanceQueue) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, snerror.SNError) { + select { + case <-si.funcCtx.Done(): + log.GetLogger().Errorf("function is deleted, can not acquire instance InsAlloc") + return nil, snerror.New(statuscode.FuncMetaNotFoundErrCode, statuscode.FuncMetaNotFoundErrMsg) + default: + si.metricsCollector.UpdateInvokeRequests(1) + si.Cond.L.Lock() + if si.updating { + si.Wait() + } + si.Cond.L.Unlock() + insAlloc, err := si.instanceScheduler.AcquireInstance(insAcqReq) + return insAlloc, buildSnError(err) + } +} + +// ReleaseInstance will release an instance (this is not a necessary method, need to be removed in future) +func (si *ScaledInstanceQueue) ReleaseInstance(thread *types.InstanceAllocation) snerror.SNError { + return buildSnError(si.instanceScheduler.ReleaseInstance(thread)) +} + +// HandleInstanceUpdate handles instance update comes from ETCD +func (si *ScaledInstanceQueue) HandleInstanceUpdate(instance *types.Instance) { + logger := log.GetLogger().With(zap.Any("instanceID", instance.InstanceID), zap.Any("funcKey", si.funcKeyWithRes)) + logger.Infof("handling instance: %s update, pod id: %s, deployment name: %s, status code is %d, error code is %d", + instance.InstanceID, instance.PodID, instance.PodDeploymentName, instance.InstanceStatus.Code, + instance.InstanceStatus.ErrorCode) + si.instanceScheduler.HandleInstanceUpdate(instance) + wiseScale, ok := si.instanceScaler.(*scaler.WiseCloudScaler) + if ok && constant.InstanceStatus(instance.InstanceStatus.Code) == constant.KernelInstanceStatusSubHealth && + (instance.InstanceStatus.ErrorCode == constant.KernelDataSystemUnavailable || + instance.InstanceStatus.ErrorCode == constant.KernelNPUFAULTErrCode) { + go wiseScale.DelNuwaPod(instance) + } +} + +// HandleInstanceDelete handles instance delete comes from ETCD +func (si *ScaledInstanceQueue) HandleInstanceDelete(instance *types.Instance) { + logger := log.GetLogger().With(zap.Any("instanceID", instance.InstanceID), zap.Any("funcKey", si.funcKeyWithRes)) + logger.Infof("handling instance delete") + if err := si.instanceScheduler.DelInstance(instance); err != nil { + logger.Errorf("failed to delete Instance error %s", err.Error()) + } +} + +// HandleFuncSpecUpdate handles instance metadata update +// 1. if concurrentNum changes, clean all instances since we can't modify runtime concurrency here +// 2. if resourceMetadata changes, only reserved instance will handle +func (si *ScaledInstanceQueue) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + logger := log.GetLogger().With(zap.Any("funcKeyWithRes", si.funcKeyWithRes)) + logger.Infof("handling funcSpec update") + needUpdate := false + si.Cond.L.Lock() + if si.funcSig != funcSpec.FuncMetaSignature { + log.GetLogger().Warnf("signature changes from %s to %s", si.funcSig, funcSpec.FuncMetaSignature) + si.funcSig = funcSpec.FuncMetaSignature + // only reserved instance needs to handle resource update + if si.instanceType == types.InstanceTypeReserved || + funcSpec.InstanceMetaData.ScalePolicy == types.InstanceScalePolicyPredict { + si.resKey = resspeckey.ConvertToResSpecKey(resspeckey.ConvertResourceMetaDataToResSpec( + funcSpec.ResourceMetaData)) + } + // reset create error when funcSpec is updated since user code may be modified and unrecoverable error may be + // resolved + si.instanceScheduler.HandleCreateError(nil) + si.concurrentNum = utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum) + needUpdate = true + si.updating = true + } + si.Cond.L.Unlock() + if needUpdate { + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + logger.Infof("cleaning old instances") + si.cleanInstances() + } + logger.Infof("updating instanceScheduler and instanceScaler") + si.instanceScheduler.HandleFuncSpecUpdate(funcSpec) + si.instanceScaler.HandleFuncSpecUpdate(funcSpec) + // if unrecoverable createError happens, instanceScaler may stop, try again since function is updated + si.instanceScaler.SetEnable(si.isFuncOwner) + si.Cond.L.Lock() + si.updating = false + si.Cond.L.Unlock() + si.Cond.Broadcast() + } +} + +// HandleInsConfigUpdate updates instance configuration +func (si *ScaledInstanceQueue) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { + si.instanceScaler.HandleInsConfigUpdate(insConfig) +} + +// EnableInstanceScale enable scaler scale replica +func (si *ScaledInstanceQueue) EnableInstanceScale() { + si.L.Lock() + isFuncOwner := si.isFuncOwner + si.L.Unlock() + si.instanceScaler.SetEnable(isFuncOwner) +} + +// HandleFaultyInstance will remove a faulty instance +func (si *ScaledInstanceQueue) HandleFaultyInstance(instance *types.Instance) { + log.GetLogger().Warnf("handling faulty instance %s for function %s", instance.InstanceID, si.funcKeyWithRes) + if err := si.instanceScheduler.DelInstance(instance); err != nil { + log.GetLogger().Warnf("delete Instance from instance queue failed, err: %v", err) + } + si.deleteInstance(instance) +} + +// HandleAliasUpdate - +func (si *ScaledInstanceQueue) HandleAliasUpdate() { + if !selfregister.GlobalSchedulerProxy.CheckFuncOwner(si.funcKey) { + log.GetLogger().Infof("no %s funcOwner, skip update alias to instance", si.funcKey) + return + } + log.GetLogger().Infof("%s funcOwner, begin update alias to instance", si.funcKey) + si.instanceScheduler.SignalAllInstances(func(instance *types.Instance) { + si.signalInstanceFunc(instance, constant.KillSignalAliasUpdate) + }) +} + +// HandleFaaSSchedulerUpdate - +func (si *ScaledInstanceQueue) HandleFaaSSchedulerUpdate() { + // 去除对函数owner的判断,防止这样的场景,两个scheduler a和b,b由于接收事件延后,先推送了旧的事件给实例,但是在新的事件到来时,改变了scheduler对这个函数的owner,导致不推新的事件,导致脏数据残留 + log.GetLogger().Infof("begin update scheduler to %s's instance", si.funcKey) + si.instanceScheduler.SignalAllInstances(func(instance *types.Instance) { + si.signalInstanceFunc(instance, constant.KillSignalFaaSSchedulerUpdate) + }) +} + +// GetInstanceNumber will get current instance number +func (si *ScaledInstanceQueue) GetInstanceNumber(onlySelf bool) int { + return si.instanceScheduler.GetInstanceNumber(onlySelf) +} + +// RecoverInstance recover instances from scheduler state +func (si *ScaledInstanceQueue) RecoverInstance(instanceMap map[string]*types.Instance) { + si.instanceScaler.SetEnable(false) + for _, instanceState := range instanceMap { + si.instanceScheduler.HandleInstanceUpdate(instanceState) + if si.funcSig != instanceState.FuncSig { + si.HandleFaultyInstance(instanceState) + } + } +} + +// Destroy will destroy instance queue and its components +func (si *ScaledInstanceQueue) Destroy() { + log.GetLogger().Infof("destroy instance queue type %s for function %s", si.instanceType, si.funcKeyWithRes) + commonUtils.SafeCloseChannel(si.stopCh) + si.insCreateQueue.destroy() + si.instanceScheduler.Destroy() + si.instanceScaler.Destroy() + si.cleanInstances() + log.GetLogger().Debugf("destroy instance queue type %s for function %s completed", si.instanceType, + si.funcKeyWithRes) +} + +func (si *ScaledInstanceQueue) startScaleUpWorker() { + for { + select { + case _, ok := <-si.stopCh: + if !ok { + log.GetLogger().Warnf("stop scale up worker for function %s", si.funcKeyWithRes) + return + } + default: + } + // getForCreate blocks until a createReq is pushed in or insCreateQueue is destroyed + if createReq := si.insCreateQueue.getForCreate(); createReq != nil { + go si.scaleUpInstanceWithRetry(createReq.callback) + } + } +} + +// ScaleUpHandler handles instance scale up planed by instanceScaler +func (si *ScaledInstanceQueue) ScaleUpHandler(insNum int, callback scaler.ScaleUpCallback) { + for i := 0; i < insNum; i++ { + si.insCreateQueue.push(&InstanceCreateRequest{callback: callback}) + } + if insNum > 0 { + log.GetLogger().Debugf("succeed to submit %d instance to scale up for function %s", insNum, si.funcKeyWithRes) + } +} + +func (si *ScaledInstanceQueue) scaleUpInstanceWithRetry(callback scaler.ScaleUpCallback) { + var ( + instance *types.Instance + createErr error + ) + var retryDelay time.Duration + for retryDelay = initialErrorDelay; retryDelay <= retryDelayLimit; retryDelay *= time.Duration(retryDelayFactor) { + select { + case _, ok := <-si.funcCtx.Done(): + if !ok { + log.GetLogger().Warnf("function %s is deleted, stop scale up instance now", si.funcKey) + return + } + default: + } + instance, createErr = si.createInstance() + if createErr == nil && instance != nil { + break + } + if si.instanceType == types.InstanceTypeReserved && + config.GlobalConfig.ScaleRetryConfig.ReservedInstanceAlwaysRetry { + continue + } + if utils.IsUnrecoverableError(createErr) { + break + } + log.GetLogger().Warnf("failed to create type %s instance for function %s createErr %v should be retried "+ + "after %.0fs", si.instanceType, si.funcKey, createErr, retryDelay.Seconds()) + time.Sleep(retryDelay) + } + // offset pendingInsThdNum pre-increased in instanceScaler before add instance into instanceScheduler + callback(1) + if createErr != nil || instance == nil { + // retry may finish after function is updated, handle create error again to trigger scale again + si.instanceScaler.HandleCreateError(createErr) + si.instanceScheduler.HandleCreateError(createErr) + log.GetLogger().Errorf("failed to create instance after retry %s s for function %s create error %v", + retryDelay, si.funcKeyWithRes, createErr) + return + } + if si.signalInstanceFunc != nil { + si.signalInstanceFunc(instance, constant.KillSignalAliasUpdate) + si.signalInstanceFunc(instance, constant.KillSignalFaaSSchedulerUpdate) + } + // instance event may come before or after scale up process, so ErrInsAlreadyExist is ok + if err := si.instanceScheduler.AddInstance(instance); err != nil && err != scheduler.ErrInsAlreadyExist { + log.GetLogger().Errorf("failed to add instance to instanceScheduler for function %s error %s", + si.funcKeyWithRes, err.Error()) + } +} + +func (si *ScaledInstanceQueue) createInstance() (instance *types.Instance, createErr error) { + defer func() { + // nil should also be handled by instanceScaler and instanceScheduler + si.instanceScaler.HandleCreateError(createErr) + si.instanceScheduler.HandleCreateError(createErr) + }() + si.Cond.L.Lock() + functionSignature := si.funcSig + si.Cond.L.Unlock() + startTime := time.Now() + instance, createErr = si.createInstanceFunc("", si.instanceType, si.resKey, nil) + if createErr != nil { + log.GetLogger().Errorf("failed to create instance for function %s error %s", si.funcKeyWithRes, + createErr.Error()) + } else { + si.instanceScaler.UpdateCreateMetrics(time.Now().Sub(startTime)) + } + select { + case _, ok := <-si.funcCtx.Done(): + if !ok { + log.GetLogger().Warnf("function %s is deleted, killing instance now", si.funcKey) + createErr = ErrFunctionDeleted + } + default: + // in case of function signature change during instance creating + si.Cond.L.Lock() + checkFunctionSignature := si.funcSig + si.Cond.L.Unlock() + if functionSignature != checkFunctionSignature { + log.GetLogger().Errorf("function signature changes while creating instance for function %s, "+ + "killing instance now", si.funcKeyWithRes) + createErr = ErrFuncSigMismatch + } + } + if createErr != nil && instance != nil { + log.GetLogger().Warnf("killing failed created instance %s for function %s", instance.InstanceID, + si.funcKeyWithRes) + go si.deleteInstance(instance) + } + return instance, createErr +} + +// ScaleDownHandler handles instance scale down planed by instanceScaler, consider to handle delete retry in future +func (si *ScaledInstanceQueue) ScaleDownHandler(insNum int, callback scaler.ScaleDownCallback) { + for i := 0; i < insNum; i++ { + go func() { + si.scaleDownInstance(callback) + }() + } +} + +// consider to add retry with backoff process in future +func (si *ScaledInstanceQueue) scaleDownInstance(callback scaler.ScaleDownCallback) { + select { + case _, ok := <-si.funcCtx.Done(): + if !ok { + log.GetLogger().Warnf("function %s is deleted, stop scale down instance now", si.funcKey) + return + } + default: + } + // offset pendingInsThdNum pre-decreased in instanceScaler after pop instance from instanceScheduler + defer callback(1) + if createReq := si.insCreateQueue.getForCancel(); createReq != nil { + return + } + if si.funcSpec.InstanceMetaData.ScalePolicy == types.InstanceScalePolicyPredict { + if instance := si.instanceScheduler.PopInstance(true); instance != nil { + si.deleteInstance(instance) + } + } else { + if instance := si.instanceScheduler.PopInstance(false); instance != nil { + si.deleteInstance(instance) + } + } +} + +func (si *ScaledInstanceQueue) deleteInstance(instance *types.Instance) { + log.GetLogger().Infof("deleting instance %s for function %s", instance.InstanceID, si.funcKeyWithRes) + if _, isWiseCloudScaler := si.instanceScaler.(*scaler.WiseCloudScaler); isWiseCloudScaler { + log.GetLogger().Warnf("skipping deleting instance %s for function %s", instance.InstanceID, + si.funcKeyWithRes) + return + } + if err := si.deleteInstanceFunc(instance); err != nil { + log.GetLogger().Errorf("failed to delete instance %s function %s error %s", instance.InstanceID, + si.funcKeyWithRes, err.Error()) + } +} + +func (si *ScaledInstanceQueue) cleanInstances() { + log.GetLogger().Infof("cleaning all instances for function %s", si.funcKeyWithRes) + var instance *types.Instance + for { + if instance = si.instanceScheduler.PopInstance(true); instance != nil { + go si.deleteInstance(instance) + } else { + break + } + } +} + +// HandleFuncOwnerChange - +func (si *ScaledInstanceQueue) HandleFuncOwnerChange() { + isFuncOwner := selfregister.GlobalSchedulerProxy.CheckFuncOwner(si.funcKey) + si.Cond.L.Lock() + if si.isFuncOwner == isFuncOwner { + si.Cond.L.Unlock() + log.GetLogger().Infof("function owner change of %s doesn't affect this instance queue", si.funcKey) + return + } + log.GetLogger().Infof("function owner of %s changes from %t to %t", si.funcKey, si.isFuncOwner, isFuncOwner) + si.isFuncOwner = isFuncOwner + si.Cond.L.Unlock() + si.instanceScaler.SetEnable(false) + // reassign in instanceScheduler first then reset instanceScaler + si.instanceScheduler.HandleFuncOwnerUpdate(isFuncOwner) + si.instanceScaler.SetEnable(isFuncOwner) +} + +// HandleRatioUpdate - +func (si *ScaledInstanceQueue) HandleRatioUpdate(ratio int) { + si.Cond.L.Lock() + isFuncOwner := si.isFuncOwner + si.Cond.L.Unlock() + if isFuncOwner { + si.instanceScheduler.ReassignInstanceWhenGray(ratio) + } +} diff --git a/yuanrong/pkg/functionscaler/instancequeue/scaled_instance_queue_test.go b/yuanrong/pkg/functionscaler/instancequeue/scaled_instance_queue_test.go new file mode 100644 index 0000000..c627142 --- /dev/null +++ b/yuanrong/pkg/functionscaler/instancequeue/scaled_instance_queue_test.go @@ -0,0 +1,755 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancequeue - +package instancequeue + +import ( + "context" + "fmt" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey" + . "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commontypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" +) + +type fakeInstanceScheduler struct { + insQue []*types.Instance + index int + acquireErr error + releaseErr error +} + +func (f *fakeInstanceScheduler) ReassignInstanceWhenGray(ratio int) { +} + +func (f *fakeInstanceScheduler) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, error) { + if f.acquireErr != nil { + return nil, f.acquireErr + } + instance := &types.Instance{ + InstanceID: fmt.Sprintf("instance%d", f.index), + } + thread := &types.InstanceAllocation{ + Instance: instance, + AllocationID: fmt.Sprintf("%s-thd1", instance.InstanceID), + } + f.index++ + return thread, nil +} + +func (f *fakeInstanceScheduler) HandleFuncOwnerUpdate(isManaged bool) { + // TODO implement me +} + +func (f *fakeInstanceScheduler) GetInstanceNumber(onlySelf bool) int { + return len(f.insQue) +} + +func (f *fakeInstanceScheduler) ReleaseInstance(thread *types.InstanceAllocation) error { + if f.releaseErr != nil { + return f.releaseErr + } + return nil +} + +func (f *fakeInstanceScheduler) AddInstance(instance *types.Instance) error { + f.insQue = append(f.insQue, instance) + return nil +} + +func (f *fakeInstanceScheduler) PopInstance(force bool) *types.Instance { + if len(f.insQue) == 0 { + return nil + } + instance := f.insQue[len(f.insQue)-1] + f.insQue = f.insQue[:len(f.insQue)-1] + return instance +} + +func (f *fakeInstanceScheduler) DelInstance(instance *types.Instance) error { + index := -1 + for i, item := range f.insQue { + if item.InstanceID == instance.InstanceID { + index = i + break + } + } + if index == -1 { + return scheduler.ErrInsNotExist + } + f.insQue = append(f.insQue[:index], f.insQue[index+1:]...) + return nil +} + +func (f *fakeInstanceScheduler) ConnectWithInstanceScaler(instanceScaler scaler.InstanceScaler) { + return +} + +func (f *fakeInstanceScheduler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { +} + +func (f *fakeInstanceScheduler) HandleInstanceUpdate(instance *types.Instance) { + f.insQue = append(f.insQue, instance) +} + +func (f *fakeInstanceScheduler) HandleCreateError(createErr error) { +} + +func (f *fakeInstanceScheduler) SignalAllInstances(signalFunc scheduler.SignalInstanceFunc) { + for _, item := range f.insQue { + signalFunc(item) + } +} + +func (f *fakeInstanceScheduler) Destroy() { +} + +type fakeInstanceScaler struct { + enable bool + insNum int +} + +func (f *fakeInstanceScaler) SetFuncOwner(isManaged bool) { +} + +func (f *fakeInstanceScaler) SetEnable(enable bool) { + f.enable = enable +} + +func (f *fakeInstanceScaler) TriggerScale() { +} + +func (f *fakeInstanceScaler) CheckScaling() bool { + return false +} + +func (f *fakeInstanceScaler) UpdateCreateMetrics(coldStartTime time.Duration) { +} + +func (f *fakeInstanceScaler) HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) { +} + +func (f *fakeInstanceScaler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { +} + +func (f *fakeInstanceScaler) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { +} + +func (f *fakeInstanceScaler) HandleCreateError(createError error) { +} + +func (f *fakeInstanceScaler) GetExpectInstanceNumber() int { + return f.insNum +} + +func (f *fakeInstanceScaler) Destroy() { +} + +func TestNewScaledInstanceQueue(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + assert.NotNil(t, q) +} + +func TestAcquireInstance(t *testing.T) { + funcCtx, cancelFunc := context.WithCancel(context.TODO()) + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncCtx: funcCtx, + }, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := metrics.NewBucketMetricsCollector("testFunction", "500-500") + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + q.SetInstanceScheduler(&fakeInstanceScheduler{index: 1}) + insThd, err := q.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", insThd.Instance.InstanceID) + someErr := snerror.New(1234, "some error") + q.SetInstanceScheduler(&fakeInstanceScheduler{acquireErr: someErr}) + _, err = q.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, someErr, err) + q.updating = true + acqChan := make(chan error, 1) + go func() { + _, err = q.AcquireInstance(&types.InstanceAcquireRequest{}) + acqChan <- err + }() + time.Sleep(1 * time.Millisecond) + select { + case <-acqChan: + t.Errorf("should not acquire instance thread") + default: + } + q.Cond.Broadcast() + time.Sleep(1 * time.Millisecond) + select { + case <-acqChan: + default: + t.Errorf("should acquire instance thread") + } + cancelFunc() + _, err = q.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, statuscode.FuncMetaNotFoundErrCode, err.Code()) +} + +func TestReleaseInstance(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + q.SetInstanceScheduler(&fakeInstanceScheduler{}) + err := q.ReleaseInstance(&types.InstanceAllocation{AllocationID: "testThread"}) + assert.Nil(t, err) + someErr := snerror.New(1234, "some error") + q.SetInstanceScheduler(&fakeInstanceScheduler{releaseErr: someErr}) + err = q.ReleaseInstance(&types.InstanceAllocation{AllocationID: "testThread"}) + assert.Equal(t, someErr, err) +} + +func TestHandleInstanceDelete(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{ + insQue: []*types.Instance{ + { + InstanceID: "instance1", + }, + }, + } + q.SetInstanceScheduler(instanceScheduler) + q.HandleInstanceDelete(&types.Instance{ + InstanceID: "instance1", + }) + assert.Equal(t, 0, len(instanceScheduler.insQue)) +} + +func TestHandleFuncSpecUpdate(t *testing.T) { + defer ApplyFunc((*etcd3.EtcdWatcher).StartList, func(_ *etcd3.EtcdWatcher) {}).Reset() + defer ApplyFunc((*registry.FaasSchedulerRegistry).WaitForETCDList, func() {}).Reset() + deleteFunc := func(ins *types.Instance) error { + return nil + } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncKey: "testFunction", + FuncMetaSignature: "funcSig1", + InstanceMetaData: commontypes.InstanceMetaData{ + ConcurrentNum: 1, + }, + }, + ResKey: resspeckey.ResSpecKey{ + CPU: 500, + Memory: 500, + }, + DeleteInstanceFunc: deleteFunc, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{ + insQue: []*types.Instance{ + { + InstanceID: "instance1", + }, + }, + } + stop := make(chan struct{}) + registry.InitRegistry(stop) + instanceScaler := &fakeInstanceScaler{} + q.SetInstanceScheduler(instanceScheduler) + q.SetInstanceScaler(instanceScaler) + q.HandleFuncSpecUpdate(&types.FunctionSpecification{ + FuncMetaSignature: "funcSig2", + ResourceMetaData: commontypes.ResourceMetaData{ + CPU: 300, + Memory: 300, + }, + InstanceMetaData: commontypes.InstanceMetaData{ + ConcurrentNum: 2, + }, + }) + assert.Equal(t, "funcSig2", q.funcSig) + assert.Equal(t, "testFunction-cpu-500-mem-500-storage-0-cstRes--cstResSpec--invokeLabel-", q.funcKeyWithRes) + assert.Equal(t, 2, q.concurrentNum) + assert.Equal(t, 0, len(instanceScheduler.insQue)) + assert.Equal(t, true, instanceScaler.enable) + q.instanceType = types.InstanceTypeReserved + q.HandleFuncSpecUpdate(&types.FunctionSpecification{ + FuncMetaSignature: "funcSig3", + ResourceMetaData: commontypes.ResourceMetaData{ + CPU: 300, + Memory: 300, + }, + InstanceMetaData: commontypes.InstanceMetaData{ + ConcurrentNum: 4, + }, + }) + assert.Equal(t, "funcSig3", q.funcSig) + assert.Equal(t, 4, q.concurrentNum) + assert.Equal(t, 0, len(instanceScheduler.insQue)) + assert.Equal(t, true, instanceScaler.enable) + close(stop) +} + +func TestHandleFaultyInstance(t *testing.T) { + var delIns *types.Instance + deleteFunc := func(ins *types.Instance) error { + delIns = ins + return nil + } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + DeleteInstanceFunc: deleteFunc, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{ + insQue: []*types.Instance{ + { + InstanceID: "instance1", + }, + }, + } + q.SetInstanceScheduler(instanceScheduler) + q.HandleFaultyInstance(&types.Instance{ + InstanceID: "instance1", + }) + time.Sleep(1 * time.Millisecond) + assert.Equal(t, "instance1", delIns.InstanceID) + + q.HandleFaultyInstance(&types.Instance{InstanceID: "instance2"}) // 是为了构造在scheduler刚启动后,收到被状态为fatal的函数实例更新事件 + time.Sleep(1 * time.Millisecond) + // 目的是为了测试,即使本地缓存没有,也应该要调用deleteFunc来删除etcd里残留的数据 + assert.Equal(t, "instance2", delIns.InstanceID) +} + +func TestHandleAliasUpdate(t *testing.T) { + var signalIns *types.Instance + var signalNum int + signalFunc := func(ins *types.Instance, sig int) { + signalIns = ins + signalNum = sig + } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + SignalInstanceFunc: signalFunc, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{ + insQue: []*types.Instance{ + { + InstanceID: "instance1", + }, + }, + } + q.SetInstanceScheduler(instanceScheduler) + + defer ApplyMethod(reflect.TypeOf(&selfregister.SchedulerProxy{}), "CheckFuncOwner", func( + *selfregister.SchedulerProxy, string) bool { + return true + }).Reset() + + q.HandleAliasUpdate() + assert.Equal(t, "instance1", signalIns.InstanceID) + assert.Equal(t, constant.KillSignalAliasUpdate, signalNum) +} + +func TestGetInstanceNumber(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{}, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + q.SetInstanceScheduler(&fakeInstanceScheduler{ + insQue: []*types.Instance{ + { + InstanceID: "instance1", + }, + { + InstanceID: "instance2", + }, + { + InstanceID: "instance3", + }, + }, + }) + assert.Equal(t, 3, q.GetInstanceNumber(true)) +} + +func TestRecoverInstance(t *testing.T) { + var delIns *types.Instance + deleteFunc := func(ins *types.Instance) error { + delIns = ins + return nil + } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + DeleteInstanceFunc: deleteFunc, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{} + q.SetInstanceScheduler(instanceScheduler) + instanceScaler := &fakeInstanceScaler{enable: false} + q.SetInstanceScaler(instanceScaler) + q.RecoverInstance(map[string]*types.Instance{ + "instance1": { + InstanceID: "instance1", + FuncSig: "funcSig1", + InstanceStatus: commontypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusSubHealth), + }, + }, + "instance2": { + InstanceID: "instance2", + FuncSig: "funcSig2", + InstanceStatus: commontypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + }, + }) + time.Sleep(1 * time.Millisecond) + assert.Equal(t, "instance2", delIns.InstanceID) + assert.Equal(t, false, instanceScaler.enable) +} + +func TestDestroy(t *testing.T) { + var delIns *types.Instance + deleteFunc := func(ins *types.Instance) error { + delIns = ins + return nil + } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + DeleteInstanceFunc: deleteFunc, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{ + insQue: []*types.Instance{ + { + InstanceID: "instance1", + }, + }, + } + instanceScaler := &fakeInstanceScaler{} + q.SetInstanceScheduler(instanceScheduler) + q.SetInstanceScaler(instanceScaler) + q.Destroy() + time.Sleep(1 * time.Millisecond) + assert.Equal(t, "instance1", delIns.InstanceID) +} + +func TestHandleInstanceSync(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{ + insQue: []*types.Instance{}, + } + instanceScaler := &fakeInstanceScaler{} + q.SetInstanceScheduler(instanceScheduler) + q.SetInstanceScaler(instanceScaler) + cnt := 0 + assert.Equal(t, cnt, 0) + gomonkey.ApplyFunc((*concurrencyscheduler.ScaledConcurrencyScheduler).HandleInstanceUpdate, func(_ *concurrencyscheduler.ScaledConcurrencyScheduler, instance *types.Instance) { cnt++ }) + q.HandleInstanceUpdate(&types.Instance{InstanceID: "aaa"}) + DisableCreateRetry() + q.EnableInstanceScale() + assert.Equal(t, retryDelayFactor, 2) +} + +func TestStartScaleUpWorker(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + q.SetInstanceScheduler(&fakeInstanceScheduler{}) + q.SetInstanceScaler(&fakeInstanceScaler{}) + q.Destroy() + time.Sleep(1 * time.Millisecond) + assert.Equal(t, true, q.insCreateQueue.stopped) +} + +func TestScaleUpProcess(t *testing.T) { + createFunc := func(string, types.InstanceType, resspeckey.ResSpecKey, []byte) (*types.Instance, error) { + return &types.Instance{}, nil + } + var delIns *types.Instance + deleteFunc := func(ins *types.Instance) error { + delIns = ins + return nil + } + retryDelayLimit = 1 * time.Second + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + CreateInstanceFunc: createFunc, + DeleteInstanceFunc: deleteFunc, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{} + instanceScaler := &fakeInstanceScaler{} + q.SetInstanceScheduler(instanceScheduler) + q.SetInstanceScaler(instanceScaler) + // no error + callCount := 0 + callCountMutex := new(sync.Mutex) + callback := func(i int) { + callCountMutex.Lock() + callCount++ + callCountMutex.Unlock() + } + q.ScaleUpHandler(3, callback) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 3, callCount) + // instance nil & error not nil + patchCreateFunc := ApplyFunc(createFunc, func(string, types.InstanceType, resspeckey.ResSpecKey, []byte) (*types.Instance, error) { + return nil, snerror.New(4001, "user error") + }) + q.ScaleUpHandler(1, callback) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 4, callCount) + patchCreateFunc.Reset() + // instance not nil & error not nil + patchCreateFunc = ApplyFunc(createFunc, func(string, types.InstanceType, resspeckey.ResSpecKey, []byte) (*types.Instance, error) { + return &types.Instance{InstanceID: "instance1"}, snerror.New(4001, "user error") + }) + q.ScaleUpHandler(1, callback) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 5, callCount) + assert.Equal(t, "instance1", delIns.InstanceID) + patchCreateFunc.Reset() +} + +func TestScaleDownProcess(t *testing.T) { + var delIns *types.Instance + deleteFunc := func(ins *types.Instance) error { + delIns = ins + return nil + } + createFunc := func(string, types.InstanceType, resspeckey.ResSpecKey, []byte) ( + *types.Instance, error) { + return nil, nil + } + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + DeleteInstanceFunc: deleteFunc, + CreateInstanceFunc: createFunc, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + instanceScheduler := &fakeInstanceScheduler{} + instanceScaler := &fakeInstanceScaler{} + q.SetInstanceScheduler(instanceScheduler) + q.SetInstanceScaler(instanceScaler) + callCount := 0 + callback := func(i int) { + callCount++ + } + q.ScaleDownHandler(1, callback) + time.Sleep(1 * time.Millisecond) + assert.Equal(t, 1, callCount) + assert.Nil(t, delIns) + instanceScheduler.insQue = append(instanceScheduler.insQue, &types.Instance{InstanceID: "instance1"}) + q.scaleDownInstance(callback) + time.Sleep(1 * time.Millisecond) + assert.Equal(t, 2, callCount) + assert.Equal(t, "instance1", delIns.InstanceID) + assert.Equal(t, 0, instanceScheduler.GetInstanceNumber(true)) + delIns = nil + q.insCreateQueue.push(&InstanceCreateRequest{callback: callback}) + q.ScaleDownHandler(1, callback) + time.Sleep(1 * time.Millisecond) + assert.Equal(t, 3, callCount) + assert.Nil(t, delIns) +} + +func TestHandleFuncOwnerChange(t *testing.T) { + setFuncOwner := false + patches := []*Patches{ + ApplyFunc((*selfregister.SchedulerProxy).CheckFuncOwner, func(_ *selfregister.SchedulerProxy, _ string) bool { + return setFuncOwner + }), + ApplyFunc((*concurrencyscheduler.ScaledConcurrencyScheduler).HandleFuncOwnerUpdate, func( + _ *concurrencyscheduler.ScaledConcurrencyScheduler, _ bool) { + return + }), + ApplyFunc((*scaler.AutoScaler).SetEnable, func(_ *scaler.AutoScaler, _ bool) { + return + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + q.instanceScheduler = &concurrencyscheduler.ScaledConcurrencyScheduler{} + q.instanceScaler = &scaler.AutoScaler{} + q.HandleFuncOwnerChange() + setFuncOwner = true + q.HandleFuncOwnerChange() + assert.Equal(t, true, q.isFuncOwner) +} +func TestHandleRatioUpdate(t *testing.T) { + patches := []*Patches{} + expectRatio := 0 + patches = append(patches, ApplyFunc( + (*concurrencyscheduler.ScaledConcurrencyScheduler).ReassignInstanceWhenGray, + func(s *concurrencyscheduler.ScaledConcurrencyScheduler, ratio int) { + expectRatio = ratio + }, + )) + defer func() { + for _, p := range patches { + p.Reset() + } + }() + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 1000*time.Millisecond) + q.instanceScheduler = concurrencyscheduler.NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 1}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue) + q.instanceScaler = &scaler.AutoScaler{} + + q.HandleRatioUpdate(50) + assert.Equal(t, 50, expectRatio) +} + +func TestScaledInstanceQueue_HandleFaaSSchedulerUpdate(t *testing.T) { + basicInsQueConfig := &InsQueConfig{ + InstanceType: types.InstanceTypeScaled, + FuncSpec: &types.FunctionSpecification{ + FuncCtx: context.TODO(), + FuncMetaSignature: "funcSig1", + }, + ResKey: resspeckey.ResSpecKey{}, + } + metricsCollector := &metrics.BucketCollector{} + q := NewScaledInstanceQueue(basicInsQueConfig, metricsCollector) + scheduler := &fakeInstanceScheduler{index: 1, insQue: make([]*types.Instance, 0)} + scheduler.insQue = append(scheduler.insQue, &types.Instance{InstanceID: "instance1"}) + q.SetInstanceScheduler(scheduler) + result := false + defer gomonkey.ApplyMethod(reflect.TypeOf(&selfregister.SchedulerProxy{}), "CheckFuncOwner", func( + *selfregister.SchedulerProxy, string) bool { + return result + }).Reset() + + flag := false + q.signalInstanceFunc = func(instance *types.Instance, i int) { + flag = true + } + q.HandleFaaSSchedulerUpdate() + assert.Equal(t, flag, true) + + result = true + q.HandleFaaSSchedulerUpdate() + assert.Equal(t, flag, true) +} diff --git a/yuanrong/pkg/functionscaler/lease/generic_lease_manager.go b/yuanrong/pkg/functionscaler/lease/generic_lease_manager.go new file mode 100644 index 0000000..1054784 --- /dev/null +++ b/yuanrong/pkg/functionscaler/lease/generic_lease_manager.go @@ -0,0 +1,271 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package lease - +package lease + +import ( + "errors" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/timewheel" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + defaultTimeWheelPace = 5 * time.Millisecond + defaultTimeWheelSlots = 100 +) + +var ( + errInstanceLeaseHolderNotEnable = errors.New("instance lease holder not enable") + // ErrInstanceNotFound instance not found err + ErrInstanceNotFound = errors.New("instance doesn't exist") +) + +type instanceLeaseHolder struct { + instance *types.Instance + timeWheel timewheel.TimeWheel + intervalMap map[string]time.Duration + callbackMap map[string]func() + enable bool + sync.RWMutex +} + +func newInstanceLeaseHolder(instance *types.Instance) *instanceLeaseHolder { + holder := &instanceLeaseHolder{ + instance: instance, + timeWheel: timewheel.NewSimpleTimeWheel(defaultTimeWheelPace, defaultTimeWheelSlots), + intervalMap: make(map[string]time.Duration, instance.ConcurrentNum), + callbackMap: make(map[string]func(), instance.ConcurrentNum), + enable: true, + } + go holder.pollLease() + return holder +} + +func (il *instanceLeaseHolder) stop() { + il.Lock() + enable := il.enable + il.enable = false + il.Unlock() + // make sure just stop once + if enable { + il.timeWheel.Stop() + log.GetLogger().Warnf("stopping timeWheel for instance %s function %s", + il.instance.InstanceID, il.instance.FuncKey) + } +} + +func (il *instanceLeaseHolder) pollLease() { + for { + il.RLock() + // if instanceLeaseHolder stopped, break loop + if !il.enable { + log.GetLogger().Warnf("stopping leases for instance %s function %s", + il.instance.InstanceID, il.instance.FuncKey) + il.RUnlock() + return + } + il.RUnlock() + readyList := il.timeWheel.Wait() + for _, allocationID := range readyList { + log.GetLogger().Warnf("lease %s expires, now release", allocationID) + il.Lock() + if err := il.timeWheel.DelTask(allocationID); err != nil { + log.GetLogger().Errorf("failed to delete task %s in time wheel", allocationID) + } + callback, exist := il.callbackMap[allocationID] + delete(il.intervalMap, allocationID) + delete(il.callbackMap, allocationID) + il.Unlock() + if exist { + callback() + } + } + } +} + +func (il *instanceLeaseHolder) createLease(insAlloc *types.InstanceAllocation, interval time.Duration, + callback func()) error { + il.Lock() + defer il.Unlock() + if !il.enable { + return errInstanceLeaseHolderNotEnable + } + _, err := il.timeWheel.AddTask(insAlloc.AllocationID, interval, 1) + if err != nil { + return err + } + il.intervalMap[insAlloc.AllocationID] = interval + il.callbackMap[insAlloc.AllocationID] = callback + return nil +} + +func (il *instanceLeaseHolder) extendLease(insAlloc *types.InstanceAllocation) error { + il.Lock() + defer il.Unlock() + if !il.enable { + return errInstanceLeaseHolderNotEnable + } + interval, exist := il.intervalMap[insAlloc.AllocationID] + if !exist { + return errors.New("lease doesn't exist or released") + } + err := il.timeWheel.DelTask(insAlloc.AllocationID) + if err != nil { + return err + } + _, err = il.timeWheel.AddTask(insAlloc.AllocationID, interval, 1) + if err != nil { + return err + } + return nil +} + +func (il *instanceLeaseHolder) releaseLease(insAlloc *types.InstanceAllocation) error { + il.Lock() + if !il.enable { + il.Unlock() + return errInstanceLeaseHolderNotEnable + } + callback, exist := il.callbackMap[insAlloc.AllocationID] + delete(il.intervalMap, insAlloc.AllocationID) + delete(il.callbackMap, insAlloc.AllocationID) + err := il.timeWheel.DelTask(insAlloc.AllocationID) + il.Unlock() + if err != nil { + return err + } + if exist { + callback() + } + return nil +} + +// GenericInstanceLeaseManager manages insAlloc leases of instances of a specific function +type GenericInstanceLeaseManager struct { + leaseHolders map[string]*instanceLeaseHolder + funcKey string + sync.RWMutex +} + +// NewGenericLeaseManager creates a GenericInstanceLeaseManager +func NewGenericLeaseManager(funcKey string) InstanceLeaseManager { + return &GenericInstanceLeaseManager{ + leaseHolders: make(map[string]*instanceLeaseHolder, utils.DefaultMapSize), + funcKey: funcKey, + } +} + +// CreateInstanceLease creates a lease for an instance insAlloc +func (gm *GenericInstanceLeaseManager) CreateInstanceLease(insAlloc *types.InstanceAllocation, interval time.Duration, + callback func()) (types.InstanceLease, error) { + if insAlloc == nil || insAlloc.Instance == nil { + log.GetLogger().Errorf("invalid instance insAlloc") + return nil, errors.New("invalid instance insAlloc") + } + instance := insAlloc.Instance + gm.Lock() + leaseHolder, exist := gm.leaseHolders[instance.InstanceID] + if !exist { + leaseHolder = newInstanceLeaseHolder(insAlloc.Instance) + gm.leaseHolders[instance.InstanceID] = leaseHolder + } + gm.Unlock() + if err := leaseHolder.createLease(insAlloc, interval, callback); err != nil { + return nil, err + } + return &GenericInstanceLease{ + insAlloc: insAlloc, + manager: gm, + interval: interval, + }, nil +} + +// HandleInstanceDelete handles instance delete +func (gm *GenericInstanceLeaseManager) HandleInstanceDelete(instance *types.Instance) { + gm.Lock() + leaseHolder, exist := gm.leaseHolders[instance.InstanceID] + delete(gm.leaseHolders, instance.InstanceID) + gm.Unlock() + if !exist { + return + } + leaseHolder.stop() +} + +// CleanAllLeases cleans all leases +func (gm *GenericInstanceLeaseManager) CleanAllLeases() { + gm.Lock() + for instanceID, leaseHolder := range gm.leaseHolders { + leaseHolder.stop() + log.GetLogger().Infof("leases for instance %s function %s stopped", instanceID, gm.funcKey) + } + gm.leaseHolders = map[string]*instanceLeaseHolder{} + gm.Unlock() +} + +func (gm *GenericInstanceLeaseManager) extendLease(insAlloc *types.InstanceAllocation) error { + gm.Lock() + leaseHolder, exist := gm.leaseHolders[insAlloc.Instance.InstanceID] + gm.Unlock() + if !exist { + return ErrInstanceNotFound + } + return leaseHolder.extendLease(insAlloc) +} + +func (gm *GenericInstanceLeaseManager) releaseLease(insAlloc *types.InstanceAllocation) error { + gm.Lock() + leaseHolder, exist := gm.leaseHolders[insAlloc.Instance.InstanceID] + gm.Unlock() + if !exist { + return ErrInstanceNotFound + } + return leaseHolder.releaseLease(insAlloc) +} + +// GenericInstanceLease provides lease operations of a specified instance allocation +type GenericInstanceLease struct { + insAlloc *types.InstanceAllocation + manager *GenericInstanceLeaseManager + interval time.Duration +} + +// Extend will extend lease +func (gl *GenericInstanceLease) Extend() error { + if gl.manager == nil { + return errors.New("lease manager doesn't exist") + } + return gl.manager.extendLease(gl.insAlloc) +} + +// Release will release lease +func (gl *GenericInstanceLease) Release() error { + if gl.manager == nil { + return errors.New("lease manager doesn't exist") + } + return gl.manager.releaseLease(gl.insAlloc) +} + +// GetInterval will return interval of lease +func (gl *GenericInstanceLease) GetInterval() time.Duration { + return gl.interval +} diff --git a/yuanrong/pkg/functionscaler/lease/generic_lease_manager_test.go b/yuanrong/pkg/functionscaler/lease/generic_lease_manager_test.go new file mode 100644 index 0000000..2567a74 --- /dev/null +++ b/yuanrong/pkg/functionscaler/lease/generic_lease_manager_test.go @@ -0,0 +1,128 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package lease - +package lease + +import ( + "fmt" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/types" +) + +func TestInstanceLeaseHolder(t *testing.T) { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + instance := &types.Instance{ + InstanceID: "test-instance", + ConcurrentNum: 100, + } + leaseHolder := newInstanceLeaseHolder(instance) + assert.Equal(t, true, leaseHolder.enable) + thread1 := &types.InstanceAllocation{ + Instance: instance, + AllocationID: fmt.Sprintf("%s-1", instance.InstanceID), + } + thread1Released := false + err := leaseHolder.createLease(thread1, 500*time.Millisecond, func() { + thread1Released = true + }) + assert.Nil(t, err) + time.Sleep(300 * time.Millisecond) + err = leaseHolder.extendLease(thread1) + time.Sleep(300 * time.Millisecond) + assert.Equal(t, false, thread1Released) + time.Sleep(600 * time.Millisecond) + assert.Equal(t, true, thread1Released) + err = leaseHolder.releaseLease(thread1) + assert.Nil(t, err) +} + +func TestHandleInstanceDelete(t *testing.T) { + convey.Convey("HandleInstanceDelete", t, func() { + manager := NewGenericLeaseManager("funcKey") + instance := &types.Instance{InstanceID: "instanceID"} + thread := &types.InstanceAllocation{Instance: instance} + lease, err := manager.CreateInstanceLease(nil, 5*time.Second, func() { + }) + convey.So(err, convey.ShouldNotBeNil) + lease, err = manager.CreateInstanceLease(thread, 5*time.Second, func() { + }) + convey.So(err, convey.ShouldBeNil) + convey.So(lease, convey.ShouldNotBeNil) + manager.HandleInstanceDelete(instance) + }) +} + +func TestCleanAllLeases(t *testing.T) { + convey.Convey("CleanAllLeases", t, func() { + manager := NewGenericLeaseManager("funcKey") + instance := &types.Instance{InstanceID: "instanceID"} + thread := &types.InstanceAllocation{Instance: instance} + lease, err := manager.CreateInstanceLease(thread, 5*time.Second, func() { + }) + convey.So(err, convey.ShouldBeNil) + convey.So(lease, convey.ShouldNotBeNil) + manager.CleanAllLeases() + }) +} + +func TestGenericInstanceLeaseManager(t *testing.T) { + lm := &GenericInstanceLeaseManager{leaseHolders: make(map[string]*instanceLeaseHolder)} + instance := &types.Instance{InstanceID: "instanceID"} + insAlloc := &types.InstanceAllocation{Instance: instance} + lm.leaseHolders["instance1"] = newInstanceLeaseHolder(instance) + convey.Convey("Test GenericInstanceLeaseManager", t, func() { + convey.Convey("test extendLease", func() { + lm.CreateInstanceLease(insAlloc, 1*time.Second, func() {}) + err := lm.extendLease(insAlloc) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("test releaseLease", func() { + lm.CreateInstanceLease(insAlloc, 1*time.Second, func() {}) + err := lm.releaseLease(insAlloc) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestGenericInstanceLease(t *testing.T) { + manager := NewGenericLeaseManager("funcKey1") + instance := &types.Instance{InstanceID: "instance1"} + insAlloc := &types.InstanceAllocation{AllocationID: "allocation1", Instance: instance} + lease, _ := manager.CreateInstanceLease(insAlloc, 1*time.Second, func() {}) + convey.Convey("Test GenericInstanceLease", t, func() { + convey.Convey("test extendLease", func() { + interval := lease.GetInterval() + convey.So(interval, convey.ShouldEqual, 1*time.Second) + }) + convey.Convey("test Extend", func() { + err := lease.Extend() + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("test Release", func() { + err := lease.Release() + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/functionscaler/lease/lease.go b/yuanrong/pkg/functionscaler/lease/lease.go new file mode 100644 index 0000000..a6113d5 --- /dev/null +++ b/yuanrong/pkg/functionscaler/lease/lease.go @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package lease - +package lease + +import ( + "time" + + "yuanrong/pkg/functionscaler/types" +) + +// InstanceLeaseManager manages leases of a specific function +type InstanceLeaseManager interface { + CreateInstanceLease(insAlloc *types.InstanceAllocation, interval time.Duration, callback func()) ( + types.InstanceLease, error) + HandleInstanceDelete(instance *types.Instance) + CleanAllLeases() +} diff --git a/yuanrong/pkg/functionscaler/registry/agentregistry.go b/yuanrong/pkg/functionscaler/registry/agentregistry.go new file mode 100644 index 0000000..f08a3d6 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/agentregistry.go @@ -0,0 +1,195 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "context" + "errors" + "os" + "sync" + "time" + + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/dynamic/dynamicinformer" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" + + "yuanrong/pkg/common/faas_common/k8sclient" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + // CRD-related configurations are consistent with the AI cloud platform + agentGroup = "yr.cap.io" + agentVersion = "v1" + agentResource = "yrtasks" + cedEventsQueue = "crdEventsQueue" + + // consistency with 1.x + defaultMaxRetryTimes = 12 +) + +// AgentRegistry watches agent event of CR +type AgentRegistry struct { + dynamicClient dynamic.Interface + informerFactory dynamicinformer.DynamicSharedInformerFactory + workQueue workqueue.RateLimitingInterface + funcSpecs map[string]*types.FunctionSpecification + stopCh <-chan struct{} + sync.RWMutex +} + +// crdEvent include eventType and obj +type crdEvent struct { + eventType EventType + obj *unstructured.Unstructured +} + +// NewAgentRegistry will create AgentRegistry +func NewAgentRegistry(stopCh <-chan struct{}) *AgentRegistry { + // prevent component startup exceptions when the YAML file for deployment permissions is not configured + if os.Getenv("ENABLE_AGENT_CRD_REGISTRY") == "" { + return nil + } + dynamicClient := k8sclient.NewDynamicClient() + // Different CR events share the same rate-limiting queue + workQueue := workqueue.NewNamedRateLimitingQueue( + workqueue.DefaultControllerRateLimiter(), + cedEventsQueue, + ) + agentRegistry := &AgentRegistry{ + dynamicClient: dynamicClient, + informerFactory: dynamicinformer.NewDynamicSharedInformerFactory(dynamicClient, time.Minute), + workQueue: workQueue, + funcSpecs: make(map[string]*types.FunctionSpecification, utils.DefaultMapSize), + stopCh: stopCh, + } + return agentRegistry +} + +// RunWatcher will start CR watch process +func (ar *AgentRegistry) RunWatcher() { + crdGVR := schema.GroupVersionResource{ + Group: agentGroup, + Version: agentVersion, + Resource: agentResource, + } + crdInformer := ar.informerFactory.ForResource(crdGVR).Informer() + ar.setupEventHandlers(crdInformer) + ar.startController(crdInformer) +} + +// setupEventHandlers setup CRD Event Handlers +func (ar *AgentRegistry) setupEventHandlers(informer cache.SharedInformer) { + informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + ar.enqueueEvent(SubEventTypeAdd, obj) + }, + UpdateFunc: func(oldObj, newObj interface{}) { + if oldObj.(*unstructured.Unstructured).GetResourceVersion() != + newObj.(*unstructured.Unstructured).GetResourceVersion() { + ar.enqueueEvent(SubEventTypeUpdate, newObj) + } + }, + DeleteFunc: func(obj interface{}) { + ar.enqueueEvent(SubEventTypeDelete, obj) + }, + }) +} + +// enqueueEvent handle crd event enqueue +func (ar *AgentRegistry) enqueueEvent(eventType EventType, obj interface{}) { + unstructObj, ok := obj.(*unstructured.Unstructured) + if !ok { + log.GetLogger().Errorf("failed to assert crd event") + return + } + ar.workQueue.Add(&crdEvent{ + eventType: eventType, + obj: unstructObj, + }) +} + +// startController start crd Controller +func (ar *AgentRegistry) startController(informer cache.SharedInformer) { + ctx, _ := context.WithCancel(context.Background()) + go ar.informerFactory.Start(ctx.Done()) + go ar.processQueue() + if !cache.WaitForCacheSync(ctx.Done(), informer.HasSynced) { + log.GetLogger().Warnf("failed to sync crd cache") + } +} + +// processQueue process crd event Queue +func (ar *AgentRegistry) processQueue() { + for { + item, shutdown := ar.workQueue.Get() + if shutdown { + return + } + event, ok := item.(*crdEvent) + if !ok { + log.GetLogger().Warnf("invalid crd event") + ar.workQueue.Forget(item) + continue + } + if err := ar.processEvent(event); err != nil { + // Limited number of retries + if ar.workQueue.NumRequeues(item) < defaultMaxRetryTimes { + log.GetLogger().Warnf("process crd event error: %s, retry", err.Error()) + ar.workQueue.AddRateLimited(item) + } + } else { + ar.workQueue.Forget(item) + } + ar.workQueue.Done(item) + } +} + +// processEvent process cr add update delete Event +func (ar *AgentRegistry) processEvent(event *crdEvent) error { + spec, ok := event.obj.UnstructuredContent()["spec"].(map[string]interface{}) + if !ok { + return errors.New("crd has no spec key") + } + + var info types.FunctionSpecification + if err := runtime.DefaultUnstructuredConverter.FromUnstructured(spec, &info); err != nil { + log.GetLogger().Errorf("failed to convert crd spec: %s", err.Error()) + return err + } + + switch event.eventType { + case SubEventTypeAdd: + log.GetLogger().Infof("[ADD] FunctionURN: %s, Image: %s", info.FuncMetaData.FunctionURN, + info.ExtendedMetaData.CustomContainerConfig.Image) + case SubEventTypeUpdate: + log.GetLogger().Infof("[UPDATE] FunctionURN: %s, Image: %s", info.FuncMetaData.FunctionURN, + info.ExtendedMetaData.CustomContainerConfig.Image) + case SubEventTypeDelete: + log.GetLogger().Infof("[DELETE] FunctionURN: %s", info.FuncMetaData.FunctionURN) + default: + log.GetLogger().Warnf("invalid event type") + } + return nil +} diff --git a/yuanrong/pkg/functionscaler/registry/agentregistry_test.go b/yuanrong/pkg/functionscaler/registry/agentregistry_test.go new file mode 100644 index 0000000..e0889dc --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/agentregistry_test.go @@ -0,0 +1,167 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "os" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" +) + +func TestNewAgentRegistry_BasicInitialization(t *testing.T) { + os.Setenv("ENABLE_AGENT_CRD_REGISTRY", "true") + stopCh := make(chan struct{}) + registry := NewAgentRegistry(stopCh) + assert.NotNil(t, registry.dynamicClient, "dynamic client initialization failed") + assert.NotEqual(t, stopCh, registry.stopCh, "stop channel transfer error") + os.Setenv("ENABLE_AGENT_CRD_REGISTRY", "") +} + +func TestNewAgentRegistry_WorkQueueSetup(t *testing.T) { + os.Setenv("ENABLE_AGENT_CRD_REGISTRY", "true") + registry := NewAgentRegistry(make(chan struct{})) + assert.NotNil(t, registry.workQueue, "work queue not initialized") + os.Setenv("ENABLE_AGENT_CRD_REGISTRY", "") +} + +func TestNewAgentRegistry_ConcurrentAccess(t *testing.T) { + os.Setenv("ENABLE_AGENT_CRD_REGISTRY", "true") + callCount := 0 + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + NewAgentRegistry(make(chan struct{})) + }() + } + wg.Wait() + assert.Equal(t, 0, callCount, "dynamic client should be initialized only once") + os.Setenv("ENABLE_AGENT_CRD_REGISTRY", "") +} + +func TestProcessEvent_SuccessfulCases(t *testing.T) { + tests := []struct { + name string + eventType EventType + specData map[string]interface{} + expected string + }{ + { + "ADD event", + SubEventTypeAdd, + map[string]interface{}{ + "funcMetaData": map[string]interface{}{ + "functionURN": "test-urn", + }, + "extendedMetaData": map[string]interface{}{ + "customContainerConfig": map[string]interface{}{ + "image": "nginx:latest", + }, + }, + }, + "[ADD]", + }, + { + "UPDATE event", + SubEventTypeUpdate, + map[string]interface{}{ + "funcMetaData": map[string]interface{}{ + "functionURN": "test-urn", + }, + "extendedMetaData": map[string]interface{}{ + "customContainerConfig": map[string]interface{}{ + "image": "nginx:latest", + }, + }, + }, + "[UPDATE]", + }, + { + "DELETE event", + SubEventTypeDelete, + map[string]interface{}{ + "funcMetaData": map[string]interface{}{ + "functionURN": "test-urn", + }, + "extendedMetaData": map[string]interface{}{ + "customContainerConfig": map[string]interface{}{ + "image": "nginx:latest", + }, + }, + }, + "[DELETE]", + }, + } + + registry := &AgentRegistry{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := &crdEvent{ + eventType: tt.eventType, + obj: &unstructured.Unstructured{ + Object: map[string]interface{}{ + "spec": tt.specData, + }, + }, + } + err := registry.processEvent(event) + assert.NoError(t, err) + }) + } +} + +func TestProcessEvent_ErrorCases(t *testing.T) { + tests := []struct { + name string + event *crdEvent + expectErrMsg string + }{ + { + "The spec field is missing", + &crdEvent{ + obj: &unstructured.Unstructured{ + Object: map[string]interface{}{}, + }, + }, + "crd has no spec key", + }, + { + "Invalid spec format", + &crdEvent{ + obj: &unstructured.Unstructured{ + Object: map[string]interface{}{ + "spec": "invalid-string", + }, + }, + }, + "crd has no spec key", + }, + } + registry := &AgentRegistry{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := registry.processEvent(tt.event) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErrMsg) + }) + } +} diff --git a/yuanrong/pkg/functionscaler/registry/aliasregistry.go b/yuanrong/pkg/functionscaler/registry/aliasregistry.go new file mode 100644 index 0000000..5e080ff --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/aliasregistry.go @@ -0,0 +1,111 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/aliasroute" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" +) + +// AliasRegistry watches instance event of etcd +type AliasRegistry struct { + watcher etcd3.Watcher + subscriberChans []chan SubEvent + stopCh <-chan struct{} + sync.RWMutex +} + +// NewAliasRegistry will create InstanceRegistry +func NewAliasRegistry(stopCh <-chan struct{}) *AliasRegistry { + aliasRegistry := &AliasRegistry{ + stopCh: stopCh, + } + return aliasRegistry +} + +func (ar *AliasRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + ar.watcher = etcd3.NewEtcdWatcher( + constant.AliasPrefix, + ar.watcherFilter, + ar.watcherHandler, + ar.stopCh, + etcdClient) + ar.watcher.StartList() +} + +// RunWatcher will start etcd watch process for instance event +func (ar *AliasRegistry) RunWatcher() { + go ar.watcher.StartWatch() +} + +// watcherFilter will filter alias event from etcd event +func (ar *AliasRegistry) watcherFilter(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForAlias { + return true + } + if items[aliasKeyIndex] != "aliases" || items[tenantKeyIndex] != "tenant" || + items[functionKeyIndex] != "function" { + return true + } + return false +} + +// watcherHandler will handle instance event from etcd +func (ar *AliasRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling alias event type %s key %s", event.Type, event.Key) + switch event.Type { + case etcd3.PUT: + aliasURN, err := aliasroute.ProcessUpdate(event) + if err != nil { + return + } + ar.publishEvent(SubEventTypeUpdate, aliasURN) + case etcd3.DELETE: + aliasURN := aliasroute.ProcessDelete(event) + ar.publishEvent(SubEventTypeDelete, aliasURN) + case etcd3.ERROR: + log.GetLogger().Warnf("etcd error event: %s", event.Value) + default: + log.GetLogger().Warnf("unsupported event, key: %s", event.Key) + } +} + +// addSubscriberChan will add channel, subscribed by FaaSScheduler +func (ar *AliasRegistry) addSubscriberChan(subChan chan SubEvent) { + ar.Lock() + ar.subscriberChans = append(ar.subscriberChans, subChan) + ar.Unlock() +} + +// publishEvent will publish instance event via channel +func (ar *AliasRegistry) publishEvent(eventType EventType, aliasUrn string) { + for _, subChan := range ar.subscriberChans { + if subChan != nil { + subChan <- SubEvent{ + EventType: eventType, + EventMsg: aliasUrn, + } + } + } +} diff --git a/yuanrong/pkg/functionscaler/registry/faasfrontendregistry.go b/yuanrong/pkg/functionscaler/registry/faasfrontendregistry.go new file mode 100644 index 0000000..0cb8f5f --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/faasfrontendregistry.go @@ -0,0 +1,157 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/utils" +) + +// FaaSFrontendRegistry watches frontend instance info event of etcd +type FaaSFrontendRegistry struct { + watcher etcd3.Watcher + ClusterFrontends map[string][]string // cluster frontend instance + stopCh <-chan struct{} + sync.RWMutex +} + +// NewFaaSFrontendRegistry will create FaaSFrontendRegistry +func NewFaaSFrontendRegistry(stopCh <-chan struct{}) *FaaSFrontendRegistry { + faaSFrontendRegistry := &FaaSFrontendRegistry{ + ClusterFrontends: make(map[string][]string, utils.DefaultMapSize), + stopCh: stopCh, + } + return faaSFrontendRegistry +} + +func (ffr *FaaSFrontendRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + ffr.watcher = etcd3.NewEtcdWatcher( + constant.FrontendInstancePrefix, + ffr.watcherFilter, + ffr.watcherHandler, + ffr.stopCh, + etcdClient) + ffr.watcher.StartList() +} + +// RunWatcher will start etcd watch process for frontend instance info event +func (ffr *FaaSFrontendRegistry) RunWatcher() { + go ffr.watcher.StartWatch() +} + +// GetFrontends - +func (ffr *FaaSFrontendRegistry) GetFrontends(cluster string) []string { + ffr.RLock() + frontends := ffr.ClusterFrontends[cluster] + ffr.RUnlock() + return frontends +} + +// watcherFilter will filter frontend instance info event from etcd event +func (ffr *FaaSFrontendRegistry) watcherFilter(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForFrontend { + return true + } + return false +} + +// watcherHandler will handle frontend instance info event from etcd +func (ffr *FaaSFrontendRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling alias event type %s key %s", event.Type, event.Key) + + switch event.Type { + case etcd3.PUT: + ffr.updateFrontendInstances(event) + case etcd3.DELETE: + ffr.deleteFrontendInstances(event) + case etcd3.ERROR: + log.GetLogger().Warnf("etcd error event: %s", event.Value) + default: + log.GetLogger().Warnf("unsupported event, key: %s", event.Key) + } +} + +func (ffr *FaaSFrontendRegistry) updateFrontendInstances(ev *etcd3.Event) { + key := ev.Key + parts := strings.Split(key, keySeparator) + if len(parts) != validEtcdKeyLenForFrontend { + log.GetLogger().Warnf("invalied cluster frontend key:%s, ignore it", key) + return + } + status := string(ev.Value) + if status != "active" { + ffr.Lock() + cluster := parts[clusterFrontendClusterIndex] + ips := ffr.ClusterFrontends[cluster] + deleteIndex := -1 + for i, val := range ips { + if val == parts[clusterFrontendIPIndex] { + deleteIndex = i + } + } + if deleteIndex != -1 { + ips = append(ips[:deleteIndex], ips[deleteIndex+1:]...) + } + if len(ips) == 0 { + delete(ffr.ClusterFrontends, cluster) + } else { + ffr.ClusterFrontends[cluster] = ips + } + ffr.Unlock() + return + } + ffr.Lock() + cluster := parts[clusterFrontendClusterIndex] + ips := ffr.ClusterFrontends[cluster] + if ips == nil { + ips = make([]string, 0) + ffr.ClusterFrontends[cluster] = ips + } + ffr.ClusterFrontends[cluster] = append(ffr.ClusterFrontends[cluster], parts[clusterFrontendIPIndex]) + ffr.Unlock() +} + +func (ffr *FaaSFrontendRegistry) deleteFrontendInstances(ev *etcd3.Event) { + key := ev.Key + parts := strings.Split(key, keySeparator) + if len(parts) != validEtcdKeyLenForFrontend { + log.GetLogger().Warnf("invalied cluster frontend key:%s, ignore it", key) + return + } + ffr.Lock() + cluster := parts[clusterFrontendClusterIndex] + newIPs := make([]string, 0) + ips := ffr.ClusterFrontends[cluster] + for _, val := range ips { + if val != parts[clusterFrontendIPIndex] { + newIPs = append(newIPs, val) + } + } + if len(newIPs) == 0 { + delete(ffr.ClusterFrontends, cluster) + } else { + ffr.ClusterFrontends[cluster] = newIPs + } + ffr.Unlock() +} diff --git a/yuanrong/pkg/functionscaler/registry/faasmanagerregistry.go b/yuanrong/pkg/functionscaler/registry/faasmanagerregistry.go new file mode 100644 index 0000000..c54b7ce --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/faasmanagerregistry.go @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instance" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/utils" +) + +// FaaSManagerRegistry watches faas manager instance event of etcd +type FaaSManagerRegistry struct { + watcher etcd3.Watcher + subscriberChans []chan SubEvent + stopCh <-chan struct{} + sync.RWMutex +} + +// NewFaaSManagerRegistry will create FaaSManagerRegistry +func NewFaaSManagerRegistry(stopCh <-chan struct{}) *FaaSManagerRegistry { + faasManagerRegistry := &FaaSManagerRegistry{ + stopCh: stopCh, + } + return faasManagerRegistry +} + +func (fr *FaaSManagerRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + fr.watcher = etcd3.NewEtcdWatcher( + instanceEtcdPrefix, + fr.watcherFilter, + fr.watcherHandler, + fr.stopCh, + etcdClient) + fr.watcher.StartList() +} + +// RunWatcher will start etcd watch process for instance event +func (fr *FaaSManagerRegistry) RunWatcher() { + go fr.watcher.StartWatch() +} + +// watcherFilter will filter instance event from etcd event +func (fr *FaaSManagerRegistry) watcherFilter(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForInstance { + return true + } + if items[instanceKeyIndex] != "instance" || items[tenantKeyIndex] != "tenant" || + items[functionKeyIndex] != "function" { + return true + } + // also check tenantID + return !utils.IsFaaSManager(items[functionKeyIndex+1]) +} + +// watcherHandler will handle instance event from etcd +func (fr *FaaSManagerRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling instance event type %s key %s", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("manager registry ready to receive etcd kv") + return + } + instanceID := instance.GetInstanceIDFromEtcdKey(event.Key) + if len(instanceID) == 0 { + log.GetLogger().Warnf("ignoring invalid etcd key of key %s", event.Key) + return + } + fr.Lock() + defer fr.Unlock() + switch event.Type { + case etcd3.PUT: + insSpec := instance.GetInsSpecFromEtcdValue(event.Key, event.Value) + if insSpec == nil { + log.GetLogger().Warnf("ignoring invalid etcd value of key %s", event.Key) + return + } + insSpec.InstanceID = instanceID + fr.publishEvent(SubEventTypeUpdate, insSpec) + case etcd3.DELETE: + insSpec := instance.GetInsSpecFromEtcdValue(event.Key, event.PrevValue) + if insSpec == nil { + log.GetLogger().Warnf("ignoring invalid etcd value of key %s", event.Key) + return + } + insSpec.InstanceID = instanceID + fr.publishEvent(SubEventTypeDelete, insSpec) + default: + log.GetLogger().Warnf("unsupported event: %#v", event) + } +} + +// addSubscriberChan will add channel, subscribed by FaaSScheduler +func (fr *FaaSManagerRegistry) addSubscriberChan(subChan chan SubEvent) { + fr.Lock() + fr.subscriberChans = append(fr.subscriberChans, subChan) + fr.Unlock() +} + +// publishEvent will publish instance event via channel +func (fr *FaaSManagerRegistry) publishEvent(eventType EventType, insSpec *types.InstanceSpecification) { + for _, subChan := range fr.subscriberChans { + if subChan != nil { + subChan <- SubEvent{ + EventType: eventType, + EventMsg: insSpec, + } + } + } +} diff --git a/yuanrong/pkg/functionscaler/registry/faasschedulerregistry.go b/yuanrong/pkg/functionscaler/registry/faasschedulerregistry.go new file mode 100644 index 0000000..0861429 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/faasschedulerregistry.go @@ -0,0 +1,415 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "os" + "strings" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instance" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/urnutils" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/selfregister" +) + +// FaasSchedulerRegistry watches faasscheduler instance event of etcd +type FaasSchedulerRegistry struct { + subscriberChans []chan SubEvent + stopDelayRemove map[string]chan struct{} + FunctionScheduler map[string]*types.InstanceSpecification + ModuleScheduler map[string]*types.InstanceSpecification + functionSchedulerWatcher etcd3.Watcher + moduleSchedulerWatcher etcd3.Watcher + discoveryKeyType string + functionListDoneCh chan struct{} + moduleListDoneCh chan struct{} + stopCh <-chan struct{} + sync.RWMutex +} + +// SchedulerInfo scheduler info +type SchedulerInfo struct { + SchedulerFuncKey string `json:"schedulerFuncKey"` + SchedulerIDList []string `json:"schedulerIDList"` + SchedulerInstanceList []*types.InstanceInfo `json:"schedulerInstanceList"` +} + +// NewFaasSchedulerRegistry will create FaasSchedulerRegistry +func NewFaasSchedulerRegistry(stopCh <-chan struct{}) *FaasSchedulerRegistry { + discoveryKeyType := constant.SchedulerKeyTypeFunction + if config.GlobalConfig.SchedulerDiscovery != nil { + discoveryKeyType = config.GlobalConfig.SchedulerDiscovery.KeyPrefixType + } + faasSchedulerRegistry := &FaasSchedulerRegistry{ + stopDelayRemove: make(map[string]chan struct{}, constant.DefaultMapSize), + FunctionScheduler: make(map[string]*types.InstanceSpecification, constant.DefaultMapSize), + ModuleScheduler: make(map[string]*types.InstanceSpecification, constant.DefaultMapSize), + discoveryKeyType: discoveryKeyType, + functionListDoneCh: make(chan struct{}, 1), + moduleListDoneCh: make(chan struct{}, 1), + stopCh: stopCh, + } + return faasSchedulerRegistry +} + +func (fsr *FaasSchedulerRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + fsr.initFunctionSchedulerWatcher(etcdClient) + fsr.initModuleSchedulerWatcher(etcdClient) + fsr.WaitForETCDList() +} + +func (fsr *FaasSchedulerRegistry) initFunctionSchedulerWatcher(etcdClient *etcd3.EtcdClient) { + fsr.functionSchedulerWatcher = etcd3.NewEtcdWatcher( + instanceEtcdPrefix, + fsr.functionSchedulerFilter, + fsr.functionSchedulerHandler, + fsr.stopCh, + etcdClient) + fsr.functionSchedulerWatcher.StartList() +} + +func (fsr *FaasSchedulerRegistry) initModuleSchedulerWatcher(etcdClient *etcd3.EtcdClient) { + fsr.moduleSchedulerWatcher = etcd3.NewEtcdWatcher( + constant.ModuleSchedulerPrefix, + fsr.moduleSchedulerFilter, + fsr.moduleSchedulerHandler, + fsr.stopCh, + etcdClient) + fsr.moduleSchedulerWatcher.StartList() +} + +// GetFaaSScheduler - +func (fsr *FaasSchedulerRegistry) GetFaaSScheduler(instanceID, keyType string) *types.InstanceSpecification { + fsr.RLock() + var insSpec *types.InstanceSpecification + if keyType == constant.SchedulerKeyTypeFunction { + insSpec = fsr.FunctionScheduler[instanceID] + } else if keyType == constant.SchedulerKeyTypeModule { + insSpec = fsr.ModuleScheduler[instanceID] + } + fsr.RUnlock() + return insSpec +} + +// WaitForETCDList - +func (fsr *FaasSchedulerRegistry) WaitForETCDList() { + log.GetLogger().Infof("start to wait faasscheduler ETCD list") + if fsr.discoveryKeyType == constant.SchedulerKeyTypeModule { + var ( + functionListDone bool + moduleListDone bool + ) + for !functionListDone || !moduleListDone { + select { + case <-fsr.functionListDoneCh: + log.GetLogger().Infof("receive function scheduler list done, stop waiting ETCD list") + functionListDone = true + case <-fsr.moduleListDoneCh: + log.GetLogger().Infof("receive module scheduler list done, stop waiting ETCD list") + moduleListDone = true + case <-fsr.stopCh: + log.GetLogger().Warnf("registry is stopped, stop waiting ETCD list") + return + } + } + } else { + select { + case <-fsr.functionListDoneCh: + log.GetLogger().Infof("receive function scheduler list done, stop waiting ETCD list") + case <-fsr.stopCh: + log.GetLogger().Warnf("registry is stopped, stop waiting ETCD list") + } + } +} + +// RunWatcher - +func (fsr *FaasSchedulerRegistry) RunWatcher() { + go fsr.functionSchedulerWatcher.StartWatch() + go fsr.moduleSchedulerWatcher.StartWatch() +} + +func (fsr *FaasSchedulerRegistry) functionSchedulerFilter(event *etcd3.Event) bool { + return !isFaaSScheduler(event.Key) +} + +func (fsr *FaasSchedulerRegistry) moduleSchedulerFilter(event *etcd3.Event) bool { + return !strings.Contains(event.Key, constant.ModuleSchedulerPrefix) +} + +func (fsr *FaasSchedulerRegistry) moduleSchedulerHandler(event *etcd3.Event) { + log.GetLogger().Infof("module scheduler event type %d received: %+v", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("received module faasscheduler synced event") + fsr.moduleListDoneCh <- struct{}{} + return + } + fsr.handleEvent(event, constant.SchedulerKeyTypeModule) +} + +func (fsr *FaasSchedulerRegistry) functionSchedulerHandler(event *etcd3.Event) { + log.GetLogger().Infof("function scheduler event type %d received: %+v", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("received function faasscheduler synced event") + fsr.functionListDoneCh <- struct{}{} + return + } + fsr.handleEvent(event, constant.SchedulerKeyTypeFunction) +} + +func (fsr *FaasSchedulerRegistry) handleEvent(event *etcd3.Event, keyType string) { + switch event.Type { + case etcd3.PUT: + if keyType == constant.SchedulerKeyTypeModule { + fsr.handleModuleSchedulerUpdate(event) + } + if keyType == constant.SchedulerKeyTypeFunction { + fsr.handleFunctionSchedulerUpdate(event) + } + case etcd3.DELETE: + if keyType == constant.SchedulerKeyTypeModule { + fsr.handleModuleSchedulerRemove(event) + } + if keyType == constant.SchedulerKeyTypeFunction { + fsr.handleFunctionSchedulerRemove(event) + } + default: + log.GetLogger().Warnf("unsupported event type d% for key %s", event.Type, event.Key) + } +} + +// when registerMode is set to registerByContend, the etcd value of module scheduler may be empty if no scheduler locks +// this key +func (fsr *FaasSchedulerRegistry) handleModuleSchedulerUpdate(event *etcd3.Event) { + insSpec := instance.GetInsSpecFromEtcdValue(event.Key, event.Value) + insInfo, err := utils.GetModuleSchedulerInfoFromEtcdKey(event.Key) + if err != nil { + log.GetLogger().Errorf("failed to parse instanceInfo from key %s error %s", event.Key, err.Error()) + return + } + if fsr.discoveryKeyType == constant.SchedulerKeyTypeModule { + exclusivity := "" + if insSpec != nil { + insInfo.InstanceID = insSpec.InstanceID + if insSpec.CreateOptions != nil { + exclusivity = insSpec.CreateOptions[constant.SchedulerExclusivityKey] + } + } + selfregister.GlobalSchedulerProxy.Add(insInfo, exclusivity) + } + fsr.Lock() + fsr.ModuleScheduler[insInfo.InstanceName] = insSpec + fsr.Unlock() + fsr.publishEvent(SubEventTypeUpdate, insSpec) + // 目标的效果是,老版本scheduler退出后, rolloutObject置为false,新的scheduler抢锁,registered置为true + if !selfregister.Registered && insInfo.InstanceName == selfregister.SelfInstanceName && + len(insSpec.InstanceID) == 0 { + selfregister.ReplaceRolloutSubject(fsr.stopCh) + } +} + +func (fsr *FaasSchedulerRegistry) handleModuleSchedulerRemove(event *etcd3.Event) { + insSpec := instance.GetInsSpecFromEtcdValue(event.Key, event.Value) + insInfo, err := utils.GetModuleSchedulerInfoFromEtcdKey(event.Key) + if err != nil { + log.GetLogger().Errorf("failed to parse instanceInfo from key %s error %s", event.Key, err.Error()) + return + } + if fsr.discoveryKeyType == constant.SchedulerKeyTypeModule { + selfregister.GlobalSchedulerProxy.Remove(insInfo) + } + fsr.Lock() + delete(fsr.ModuleScheduler, insInfo.InstanceName) + fsr.Unlock() + fsr.publishEvent(SubEventTypeRemove, insSpec) +} + +func (fsr *FaasSchedulerRegistry) handleFunctionSchedulerUpdate(event *etcd3.Event) { + insSpec := instance.GetInsSpecFromEtcdValue(event.Key, event.Value) + insInfo, err := utils.GetFunctionInstanceInfoFromEtcdKey(event.Key) + if err != nil { + log.GetLogger().Errorf("failed to parse instanceInfo from key %s error %s", event.Key, err.Error()) + return + } + if fsr.discoveryKeyType == constant.SchedulerKeyTypeFunction { + if utils.CheckFaaSSchedulerInstanceFault(insSpec.InstanceStatus) { + go fsr.delFunctionSchedulerFromProxy(insSpec, insInfo) + } else if insSpec.InstanceStatus.Code == int32(constant.KernelInstanceStatusRunning) || + insSpec.InstanceStatus.Code == int32(constant.KernelInstanceStatusCreating) { + fsr.addFunctionSchedulerToProxy(insSpec, insInfo) + } + } + fsr.Lock() + fsr.FunctionScheduler[insInfo.InstanceName] = insSpec + fsr.Unlock() + if insSpec.InstanceID == selfregister.SelfInstanceID { + selfregister.SetSelfInstanceSpec(insSpec) + } +} + +func (fsr *FaasSchedulerRegistry) handleFunctionSchedulerRemove(event *etcd3.Event) { + insSpec := instance.GetInsSpecFromEtcdValue(event.Key, event.Value) + insInfo, err := utils.GetFunctionInstanceInfoFromEtcdKey(event.Key) + if err != nil { + log.GetLogger().Errorf("failed to parse instanceInfo from key %s error %s", event.Key, err.Error()) + return + } + if fsr.discoveryKeyType == constant.SchedulerKeyTypeFunction { + selfregister.GlobalSchedulerProxy.Remove(insInfo) + } + fsr.Lock() + delete(fsr.FunctionScheduler, insInfo.InstanceName) + fsr.Unlock() + fsr.publishEvent(SubEventTypeRemove, insSpec) +} + +func (fsr *FaasSchedulerRegistry) addFunctionSchedulerToProxy(insSpec *types.InstanceSpecification, + info *types.InstanceInfo) { + exclusivity := "" + if insSpec.CreateOptions != nil { + exclusivity = insSpec.CreateOptions[constant.SchedulerExclusivityKey] + } + if !selfregister.GlobalSchedulerProxy.Contains(info.InstanceName) { + selfregister.GlobalSchedulerProxy.Add(info, exclusivity) + } + fsr.Lock() + if stopRemoveCh, ok := fsr.stopDelayRemove[info.InstanceName]; ok { + // 这里表示有scheduler从故障恢复到正常了 + utils.SafeCloseChannel(stopRemoveCh) + delete(fsr.stopDelayRemove, info.InstanceName) + fsr.Unlock() + return + } + selfregister.GlobalSchedulerProxy.Reset() + fsr.Unlock() + fsr.publishEvent(SubEventTypeUpdate, insSpec) +} + +func (fsr *FaasSchedulerRegistry) delFunctionSchedulerFromProxy(insSpec *types.InstanceSpecification, + info *types.InstanceInfo) { + fsr.Lock() + if !selfregister.GlobalSchedulerProxy.Contains(info.InstanceName) { + fsr.Unlock() + return + } + if _, exist := fsr.stopDelayRemove[info.InstanceName]; exist { + fsr.Unlock() + return + } + ch := make(chan struct{}) + fsr.stopDelayRemove[info.InstanceName] = ch + fsr.Unlock() + go func(info *types.InstanceInfo, ch chan struct{}) { + log.GetLogger().Infof("start to delay delete faasscheduler %s", info.InstanceName) + delayRemoveTimer := time.NewTimer(constant.SchedulerRecoverTime) + defer delayRemoveTimer.Stop() + select { + case <-ch: + log.GetLogger().Infof("faasscheduler %s recovered, won't delete hash node", info.InstanceName) + return + case <-delayRemoveTimer.C: + log.GetLogger().Infof("delay timer triggers, deleting faasscheduler %s now", info.InstanceName) + fsr.Lock() + if _, exist := fsr.stopDelayRemove[info.InstanceName]; exist { + delete(fsr.stopDelayRemove, info.InstanceName) + selfregister.GlobalSchedulerProxy.Remove(info) + fsr.Unlock() + fsr.publishEvent(SubEventTypeRemove, insSpec) + } else { + fsr.Unlock() + } + reportOrClearAlarm() + } + }(info, ch) +} + +func reportOrClearAlarm() { + if config.GlobalConfig.AlarmConfig.EnableAlarm { + alarmDetail := &alarm.Detail{ + SourceTag: os.Getenv(constant.PodNameEnvKey) + "|" + os.Getenv(constant.PodIPEnvKey) + + "|" + os.Getenv(constant.ClusterName) + "|FaaSSchedulerHashRingRemoved", + OpType: alarm.GenerateAlarmLog, + Details: "faasscheduler has removed from hash ring", + StartTimestamp: int(time.Now().Unix()), + EndTimestamp: 0, + } + + alarmInfo := &alarm.LogAlarmInfo{ + AlarmID: alarm.FaaSSchedulerRemovedFromHashRing00001, + AlarmName: "FaaSSchedulerRemoved", + AlarmLevel: alarm.Level2, + } + alarm.ReportOrClearAlarm(alarmInfo, alarmDetail) + } +} + +// isFaaSScheduler used to filter the etcd event which stands for a faas scheduler +func isFaaSScheduler(etcdPath string) bool { + info, err := utils.GetFunctionInstanceInfoFromEtcdKey(etcdPath) + if err != nil { + return false + } + return strings.Contains(info.FunctionName, "faasscheduler") +} + +// GetSchedulerInfo return scheduler info +func (fsr *FaasSchedulerRegistry) GetSchedulerInfo() *SchedulerInfo { + schedulerInfo := &SchedulerInfo{} + selfregister.GlobalSchedulerProxy.FaaSSchedulers.Range(func(key, value any) bool { + faasSchedulerID, ok := key.(string) + if !ok { + return true + } + faaSScheduler, ok := value.(*types.InstanceInfo) + if !ok { + return true + } + schedulerInfo.SchedulerIDList = append(schedulerInfo.SchedulerIDList, faasSchedulerID) + schedulerInfo.SchedulerFuncKey = urnutils.CombineFunctionKey(faaSScheduler.TenantID, + faaSScheduler.FunctionName, faaSScheduler.Version) + schedulerInfo.SchedulerInstanceList = append(schedulerInfo.SchedulerInstanceList, faaSScheduler) + return true + }) + return schedulerInfo +} + +// addSubscriberChan will add channel, subscribed by FaaSScheduler +func (fsr *FaasSchedulerRegistry) addSubscriberChan(subChan chan SubEvent) { + fsr.Lock() + fsr.subscriberChans = append(fsr.subscriberChans, subChan) + fsr.Unlock() +} + +// publishEvent will publish instance event via channel +func (fsr *FaasSchedulerRegistry) publishEvent(eventType EventType, insSpec *types.InstanceSpecification) { + for _, subChan := range fsr.subscriberChans { + if subChan != nil { + subChan <- SubEvent{ + EventType: eventType, + EventMsg: insSpec, + } + } + } +} diff --git a/yuanrong/pkg/functionscaler/registry/faasschedulerregistry_test.go b/yuanrong/pkg/functionscaler/registry/faasschedulerregistry_test.go new file mode 100644 index 0000000..2ee3db1 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/faasschedulerregistry_test.go @@ -0,0 +1,252 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "testing" + "time" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/selfregister" +) + +func TestFaasSchedulerRegistryWatcherHandler(t *testing.T) { + fsr := &FaasSchedulerRegistry{ + FunctionScheduler: make(map[string]*commonTypes.InstanceSpecification), + ModuleScheduler: make(map[string]*commonTypes.InstanceSpecification), + functionListDoneCh: make(chan struct{}, 1), + moduleListDoneCh: make(chan struct{}, 1), + stopDelayRemove: make(map[string]chan struct{}), + } + + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasExecutorJava8/version/$latest/defaultaz/task-b23aa1c4-2084-42b8-99b2-8907fa5ae6f4/f71875b1-3c20-4827-8600-0000000005d5", + Value: []byte("123"), + PrevValue: []byte("123"), + Rev: 1, + } + convey.Convey("test discoveryKeyType function", t, func() { + fsr.discoveryKeyType = constant.SchedulerKeyTypeFunction + convey.Convey("etcd opt error", func() { + event.Type = 999 + fsr.functionSchedulerHandler(event) + }) + convey.Convey("etcd put value success", func() { + event.Type = etcd3.PUT + event.Value = []byte(`{ + "instanceID": "1f060613-68af-4a02-8000-000000e077ce", + "instanceStatus": { + "code": 3, + "msg": "running" + }}`) + fsr.functionSchedulerHandler(event) + _, ok := selfregister.GlobalSchedulerProxy.FaaSSchedulers.Load("f71875b1-3c20-4827-8600-0000000005d5") + convey.So(ok, convey.ShouldBeTrue) + }) + convey.Convey("etcd delete value success", func() { + timer := time.NewTimer(1 * time.Second) + defer gomonkey.ApplyFunc(time.NewTimer, func(d time.Duration) *time.Timer { + return timer + }).Reset() + event.Type = etcd3.DELETE + fsr.functionSchedulerHandler(event) + time.Sleep(2 * time.Second) + _, ok := selfregister.GlobalSchedulerProxy.FaaSSchedulers.Load("f71875b1-3c20-4827-8600-0000000005d5") + convey.So(ok, convey.ShouldBeFalse) + }) + convey.Convey("etcd put invalid funcKey", func() { + event.Type = etcd3.PUT + event.Key = "/sn/instance/business/yrk/tenant//function" + fsr.functionSchedulerHandler(event) + _, ok := selfregister.GlobalSchedulerProxy.FaaSSchedulers.Load("f71875b1-3c20-4827-8600-0000000005d5") + convey.So(ok, convey.ShouldBeFalse) + }) + convey.Convey("etcd SYNCED", func() { + event.Type = etcd3.SYNCED + fsr.functionSchedulerHandler(event) + _, ok := selfregister.GlobalSchedulerProxy.FaaSSchedulers.Load("f71875b1-3c20-4827-8600-0000000005d5") + convey.So(ok, convey.ShouldBeFalse) + }) + }) + + convey.Convey("test discoveryKeyType module", t, func() { + defer func() { + selfregister.GlobalSchedulerProxy.FaaSSchedulers.Delete("faas-scheduler-59ddbc4b75-8xdjf") + }() + fsr.discoveryKeyType = constant.SchedulerKeyTypeModule + convey.Convey("etcd put valid funcKey for module scheduler", func() { + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/faas-scheduler/instances/cluster001/7.218.100.25", + } + fsr.moduleSchedulerHandler(event) + _, ok := selfregister.GlobalSchedulerProxy.FaaSSchedulers.Load("faas-scheduler-59ddbc4b75-8xdjf") + convey.So(ok, convey.ShouldBeFalse) + }) + convey.Convey("etcd SYNCED for module scheduler", func() { + event.Type = etcd3.SYNCED + fsr.moduleSchedulerHandler(event) + _, ok := selfregister.GlobalSchedulerProxy.FaaSSchedulers.Load("faas-scheduler-59ddbc4b75-8xdjf") + convey.So(ok, convey.ShouldBeFalse) + }) + convey.Convey("etcd put invalid funcKey for module scheduler", func() { + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/faas-scheduler/instances/cluster001/7.218.100.25/faas-scheduler-59ddbc4b75-8xdjf", + Value: []byte(`{ + "instanceID": "1f060613-68af-4a02-8000-000000e077ce", + "instanceStatus": { + "code": 3, + "msg": "running" + }}`), + PrevValue: []byte("123"), + Rev: 1, + } + fsr.moduleSchedulerHandler(event) + _, ok := selfregister.GlobalSchedulerProxy.FaaSSchedulers.Load("faas-scheduler-59ddbc4b75-8xdjf") + convey.So(ok, convey.ShouldBeTrue) + }) + }) +} + +func TestFaaSSchedulerRegistryWatcherHandlerDelayDelete(t *testing.T) { + fsr := &FaasSchedulerRegistry{ + FunctionScheduler: make(map[string]*commonTypes.InstanceSpecification), + ModuleScheduler: make(map[string]*commonTypes.InstanceSpecification), + functionListDoneCh: make(chan struct{}, 1), + moduleListDoneCh: make(chan struct{}, 1), + stopDelayRemove: make(map[string]chan struct{}), + } + event := &etcd3.Event{ + Type: etcd3.PUT, + PrevValue: []byte("123"), + Rev: 1, + } + + convey.Convey("test delete discoveryKeyType function", t, func() { + fsr.discoveryKeyType = constant.SchedulerKeyTypeFunction + instanceID := "1f060613-68af-4a02-8000-000000e077ce" + functionSchedulerKey := "/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasExecutorJava8/version/$latest/defaultaz/task-b23aa1c4-2084-42b8-99b2-8907fa5ae6f4/1f060613-68af-4a02-8000-000000e077ce" + event.Type = etcd3.PUT + event.Key = functionSchedulerKey + event.Value = []byte(`{ + "instanceID": "1f060613-68af-4a02-8000-000000e077ce", + "instanceStatus": { + "code": 3, + "msg": "running" + }}`) + fsr.functionSchedulerHandler(event) + convey.So(fsr.FunctionScheduler[instanceID], convey.ShouldNotBeNil) + event.Type = etcd3.DELETE + event.Key = "invalid key" + fsr.functionSchedulerHandler(event) + convey.So(fsr.FunctionScheduler[instanceID], convey.ShouldNotBeNil) + event.Type = etcd3.DELETE + event.Key = functionSchedulerKey + fsr.functionSchedulerHandler(event) + convey.So(fsr.FunctionScheduler[instanceID], convey.ShouldBeNil) + }) + convey.Convey("test delete discoveryKeyType module", t, func() { + fsr.discoveryKeyType = constant.SchedulerKeyTypeModule + instanceID := "1f060613-68af-4a02-8000-000000e077ce" + ModuleSchedulerKey := "/sn/faas-scheduler/instances/cluster001/node001/1f060613-68af-4a02-8000-000000e077ce" + event.Type = etcd3.PUT + event.Key = ModuleSchedulerKey + event.Value = []byte(`{ + "instanceID": "1f060613-68af-4a02-8000-000000e077ce", + "instanceStatus": { + "code": 3, + "msg": "running" + }}`) + fsr.moduleSchedulerHandler(event) + convey.So(fsr.ModuleScheduler[instanceID], convey.ShouldNotBeNil) + event.Type = etcd3.DELETE + event.Key = "invalid key" + fsr.moduleSchedulerHandler(event) + convey.So(fsr.ModuleScheduler[instanceID], convey.ShouldNotBeNil) + event.Type = etcd3.DELETE + event.Key = ModuleSchedulerKey + fsr.moduleSchedulerHandler(event) + convey.So(fsr.ModuleScheduler[instanceID], convey.ShouldBeNil) + }) +} + +func TestDelFunctionSchedulerFromProxy(t *testing.T) { + fsr := &FaasSchedulerRegistry{ + FunctionScheduler: make(map[string]*commonTypes.InstanceSpecification), + ModuleScheduler: make(map[string]*commonTypes.InstanceSpecification), + functionListDoneCh: make(chan struct{}, 1), + moduleListDoneCh: make(chan struct{}, 1), + stopDelayRemove: make(map[string]chan struct{}), + } + insSpec := &commonTypes.InstanceSpecification{InstanceID: "instance1"} + insInfo := &commonTypes.InstanceInfo{InstanceID: "instance1", InstanceName: "insName1"} + config.GlobalConfig.AlarmConfig.EnableAlarm = true + defer func() { + selfregister.GlobalSchedulerProxy.FaaSSchedulers.Delete("insName1") + config.GlobalConfig.AlarmConfig.EnableAlarm = false + }() + mockTimer := time.NewTimer(200 * time.Millisecond) + defer gomonkey.ApplyFunc(time.NewTimer, func(d time.Duration) *time.Timer { + return mockTimer + }).Reset() + convey.Convey("Test DelFunctionSchedulerFromProxy", t, func() { + fsr.delFunctionSchedulerFromProxy(insSpec, insInfo) + convey.So(fsr.stopDelayRemove["insName1"], convey.ShouldBeNil) + selfregister.GlobalSchedulerProxy.FaaSSchedulers.Store("insName1", insInfo) + fsr.delFunctionSchedulerFromProxy(insSpec, insInfo) + convey.So(fsr.stopDelayRemove["insName1"], convey.ShouldNotBeNil) + time.Sleep(500 * time.Millisecond) + convey.So(fsr.stopDelayRemove["insName1"], convey.ShouldBeNil) + }) +} + +func TestWaitForETCDList(t *testing.T) { + fsr := &FaasSchedulerRegistry{ + functionListDoneCh: make(chan struct{}, 1), + moduleListDoneCh: make(chan struct{}, 1), + stopCh: make(chan struct{}, 1), + } + convey.Convey("Test WaitForETCDList", t, func() { + fsr.discoveryKeyType = constant.SchedulerKeyTypeFunction + go func() { + time.Sleep(100 * time.Millisecond) + close(fsr.functionListDoneCh) + }() + fsr.WaitForETCDList() + _, ok := <-fsr.functionListDoneCh + convey.So(ok, convey.ShouldBeFalse) + + fsr.discoveryKeyType = constant.SchedulerKeyTypeModule + fsr.functionListDoneCh = make(chan struct{}) + go func() { + time.Sleep(100 * time.Millisecond) + close(fsr.functionListDoneCh) + time.Sleep(100 * time.Millisecond) + close(fsr.moduleListDoneCh) + }() + fsr.WaitForETCDList() + _, ok = <-fsr.moduleListDoneCh + convey.So(ok, convey.ShouldBeFalse) + }) +} diff --git a/yuanrong/pkg/functionscaler/registry/functionavailableregistry.go b/yuanrong/pkg/functionscaler/registry/functionavailableregistry.go new file mode 100644 index 0000000..54fe5ac --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/functionavailableregistry.go @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "encoding/json" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/utils" +) + +// FunctionAvailableRegistry watches instance event of etcd +type FunctionAvailableRegistry struct { + watcher etcd3.Watcher + FuncAvailableClusters map[string][]string // function available clusters + stopCh <-chan struct{} + sync.RWMutex +} + +// NewFunctionAvailableRegistry will create FunctionAvailableRegistry +func NewFunctionAvailableRegistry(stopCh <-chan struct{}) *FunctionAvailableRegistry { + functionAvailableRegistry := &FunctionAvailableRegistry{ + FuncAvailableClusters: make(map[string][]string, utils.DefaultMapSize), + stopCh: stopCh, + } + return functionAvailableRegistry +} + +func (cr *FunctionAvailableRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + cr.watcher = etcd3.NewEtcdWatcher( + constant.FunctionAvailClusterPrefix, + cr.watcherFilter, + cr.watcherHandler, + cr.stopCh, + etcdClient) + cr.watcher.StartList() +} + +// RunWatcher will start etcd watch process for function available clusters event +func (cr *FunctionAvailableRegistry) RunWatcher() { + go cr.watcher.StartWatch() +} + +// GeClusters - +func (cr *FunctionAvailableRegistry) GeClusters(cluster string) []string { + cr.RLock() + clusters := cr.FuncAvailableClusters[cluster] + cr.RUnlock() + return clusters +} + +// watcherFilter will filter function available clusters event from etcd event +func (cr *FunctionAvailableRegistry) watcherFilter(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForCluster { + return true + } + return false +} + +// watcherHandler will handle function available clusters event from etcd +func (cr *FunctionAvailableRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling function available clusters event type %s key %s", event.Type, event.Key) + clusters, funcURN := GetFuncAvailableClusterFromEtcd(event) + if clusters == nil { + log.GetLogger().Warnf("ignoring invalid etcd value of key %s", event.Key) + return + } + switch event.Type { + case etcd3.PUT: + cr.Lock() + cr.FuncAvailableClusters[funcURN] = clusters + cr.Unlock() + case etcd3.DELETE: + cr.Lock() + delete(cr.FuncAvailableClusters, funcURN) + cr.Unlock() + case etcd3.ERROR: + log.GetLogger().Warnf("etcd error event: %s", event.Value) + default: + log.GetLogger().Warnf("unsupported event, key: %s", event.Key) + } +} + +// GetFuncAvailableClusterFromEtcd get function available clusters from etcd +func GetFuncAvailableClusterFromEtcd(ev *etcd3.Event) ([]string, string) { + clusters := make([]string, 0) + err := json.Unmarshal(ev.Value, &clusters) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal function cluster update event: %s, reason: %s", ev.Value, err) + return nil, "" + } + + key := ev.Key + parts := strings.Split(key, constant.ETCDEventKeySeparator) + if len(parts) != validEtcdKeyLenForCluster { + log.GetLogger().Warnf("invalied function cluster key:%s, ignore it", key) + return nil, "" + } + + return clusters, parts[functionClusterKeyIdx] +} diff --git a/yuanrong/pkg/functionscaler/registry/functionregistry.go b/yuanrong/pkg/functionscaler/registry/functionregistry.go new file mode 100644 index 0000000..303a7a6 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/functionregistry.go @@ -0,0 +1,274 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "context" + "fmt" + "strings" + "sync" + + "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + commonTypes "yuanrong/pkg/common/faas_common/types" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +// FunctionRegistry watches function event of etcd +type FunctionRegistry struct { + userAgencyRegistry *UserAgencyRegistry + watcher etcd3.Watcher + funcSpecs map[string]*types.FunctionSpecification + subscriberChans []chan SubEvent + listDoneCh chan struct{} + stopCh <-chan struct{} + sync.RWMutex +} + +// NewFunctionRegistry will create FunctionRegistry +func NewFunctionRegistry(stopCh <-chan struct{}) *FunctionRegistry { + functionRegistry := &FunctionRegistry{ + userAgencyRegistry: NewUserAgencyRegistry(stopCh), + funcSpecs: make(map[string]*types.FunctionSpecification, utils.DefaultMapSize), + listDoneCh: make(chan struct{}, 1), + stopCh: stopCh, + } + return functionRegistry +} + +func (fr *FunctionRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + fr.watcher = etcd3.NewEtcdWatcher( + functionEtcdPrefix, + fr.watcherFilter, + fr.watcherHandler, + fr.stopCh, + etcdClient) + fr.watcher.StartList() + fr.WaitForETCDList() + fr.userAgencyRegistry.initWatcher(etcdClient) +} + +// WaitForETCDList - +func (fr *FunctionRegistry) WaitForETCDList() { + log.GetLogger().Infof("start to wait function ETCD list") + select { + case <-fr.listDoneCh: + log.GetLogger().Infof("receive list done, stop waiting ETCD list") + case <-fr.stopCh: + log.GetLogger().Warnf("registry is stopped, stop waiting ETCD list") + } +} + +// RunWatcher will start etcd watch process +func (fr *FunctionRegistry) RunWatcher() { + go fr.watcher.StartWatch() + fr.userAgencyRegistry.RunWatcher() +} + +func (fr *FunctionRegistry) getFuncSpec(funcKey string) *types.FunctionSpecification { + fr.RLock() + funcSpec := fr.funcSpecs[funcKey] + fr.RUnlock() + return funcSpec +} +func (fr *FunctionRegistry) fetchSilentFuncSpec(funcKey string) *types.FunctionSpecification { + tenantID, funcName, funcVersion := commonUtils.ParseFuncKey(funcKey) + silentEtcdKey := fmt.Sprintf(constant.SilentFuncKey, tenantID, funcName, funcVersion) + etcdValue, err := etcd3.GetValueFromEtcdWithRetry(silentEtcdKey, etcd3.GetMetaEtcdClient()) + if err != nil { + log.GetLogger().Errorf("failed to get silent function, error: %s", err.Error()) + return nil + } + metaEtcdKey := fmt.Sprintf(constant.MetaFuncKey, tenantID, funcName, funcVersion) + fr.Lock() + defer fr.Unlock() + funcSpec := fr.buildFuncSpec(metaEtcdKey, etcdValue, funcKey) + if funcSpec == nil { + return nil + } + fr.funcSpecs[funcKey] = funcSpec + log.GetLogger().Infof("get silent function success,funcKey :%s", funcKey) + return funcSpec +} + +func (fr *FunctionRegistry) watcherFilter(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForFunction { + return true + } + if items[instanceKeyIndex] != "functions" || items[tenantKeyIndex] != "tenant" || + items[functionKeyIndex] != "function" { + return true + } + + return false +} + +func (fr *FunctionRegistry) watcherHandler(event *etcd3.Event) { + fr.Lock() + defer fr.Unlock() + log.GetLogger().Infof("handling function event type %s key %s", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("publish function synced event") + fr.publishEvent(SubEventTypeSynced, &types.FunctionSpecification{}) + return + } + funcKey := GetFuncKeyFromEtcdKey(event.Key) + if len(funcKey) == 0 { + log.GetLogger().Warnf("ignoring invalid etcd key of key %s", event.Key) + return + } + switch event.Type { + case etcd3.PUT: + funcSpec := fr.buildFuncSpec(event.Key, event.Value, funcKey) + if funcSpec == nil { + return + } + fr.funcSpecs[funcKey] = funcSpec + log.GetLogger().Infof("get function key :%s", funcKey) + fr.publishEvent(SubEventTypeUpdate, funcSpec) + case etcd3.DELETE: + funcSpec, exist := fr.funcSpecs[funcKey] + if !exist { + log.GetLogger().Errorf("function %s doesn't exist in registry", funcKey) + return + } + funcSpec.CancelFunc() + delete(fr.funcSpecs, funcKey) + fr.publishEvent(SubEventTypeDelete, funcSpec) + default: + log.GetLogger().Warnf("unsupported event: %v", event.Type) + } +} + +// buildFuncSpec without lock should lock outside +func (fr *FunctionRegistry) buildFuncSpec(etcdKey string, etcdValue []byte, + funcKey string) *types.FunctionSpecification { + funcMetaInfo := GetFuncMetaInfoFromEtcdValue(etcdValue) + if funcMetaInfo == nil { + log.GetLogger().Errorf("ignoring invalid etcd value of key %s", etcdKey) + return nil + } + funcMetaInfo.ExtendedMetaData.UserAgency = fr.userAgencyRegistry.GetUserAgencyByFuncMeta(funcMetaInfo) + funcSpec := fr.createOrUpdateFuncSpec(funcKey, funcMetaInfo) + funcSpec.FuncSecretName = utils.GenerateStsSecretName(etcdKey) + return funcSpec +} + +func (fr *FunctionRegistry) createOrUpdateFuncSpec(funcKey string, + funcMetaInfo *commonTypes.FunctionMetaInfo) *types.FunctionSpecification { + commonUtils.SetFuncMetaDynamicConfEnable(funcMetaInfo) + funcSpec, exist := fr.funcSpecs[funcKey] + if !exist { + funcCtx, cancelFunc := context.WithCancel(context.TODO()) + funcSpec = &types.FunctionSpecification{ + FuncCtx: funcCtx, + CancelFunc: cancelFunc, + FuncKey: funcKey, + FuncMetaSignature: commonUtils.GetFuncMetaSignature(funcMetaInfo, + config.GlobalConfig.RawStsConfig.StsEnable), + FuncMetaData: funcMetaInfo.FuncMetaData, + S3MetaData: funcMetaInfo.S3MetaData, + CodeMetaData: funcMetaInfo.CodeMetaData, + EnvMetaData: funcMetaInfo.EnvMetaData, + StsMetaData: funcMetaInfo.StsMetaData, + ResourceMetaData: funcMetaInfo.ResourceMetaData, + InstanceMetaData: funcMetaInfo.InstanceMetaData, + ExtendedMetaData: funcMetaInfo.ExtendedMetaData, + } + } else { + funcSpec.FuncMetaSignature = commonUtils.GetFuncMetaSignature(funcMetaInfo, + config.GlobalConfig.RawStsConfig.StsEnable) + funcSpec.FuncMetaData = funcMetaInfo.FuncMetaData + funcSpec.S3MetaData = funcMetaInfo.S3MetaData + funcSpec.CodeMetaData = funcMetaInfo.CodeMetaData + funcSpec.EnvMetaData = funcMetaInfo.EnvMetaData + funcSpec.StsMetaData = funcMetaInfo.StsMetaData + funcSpec.ResourceMetaData = funcMetaInfo.ResourceMetaData + funcSpec.InstanceMetaData = funcMetaInfo.InstanceMetaData + funcSpec.ExtendedMetaData = funcMetaInfo.ExtendedMetaData + } + return funcSpec +} + +func (fr *FunctionRegistry) addSubscriberChan(subChan chan SubEvent) { + fr.Lock() + fr.subscriberChans = append(fr.subscriberChans, subChan) + fr.Unlock() +} + +func (fr *FunctionRegistry) publishEvent(eventType EventType, funcSpec *types.FunctionSpecification) { + for _, subChan := range fr.subscriberChans { + if subChan != nil { + subChan <- SubEvent{ + EventType: eventType, + EventMsg: funcSpec, + } + } + } +} + +// FinishEtcdList - +func (fr *FunctionRegistry) FinishEtcdList() { + log.GetLogger().Infof("received function synced event") + fr.listDoneCh <- struct{}{} + return +} + +// EtcdList - +func (fr *FunctionRegistry) EtcdList() []*types.FunctionSpecification { + client := etcd3.GetMetaEtcdClient() + if client == nil { + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), etcd3.DurationContextTimeout) + defer cancel() + res, err := client.Client.Get(ctx, functionEtcdPrefix, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("get function meta failed, error: %v", err) + return nil + } + var result []*types.FunctionSpecification + for _, kv := range res.Kvs { + e := &etcd3.Event{ + Key: string(kv.Key), + Value: kv.Value, + } + if fr.watcherFilter(e) { + continue + } + funcKey := GetFuncKeyFromEtcdKey(e.Key) + if len(funcKey) == 0 { + log.GetLogger().Warnf("ignoring invalid etcd key of key %s", e.Key) + continue + } + funcSpec := fr.buildFuncSpec(e.Key, e.Value, funcKey) + if funcSpec == nil { + continue + } + fr.funcSpecs[funcKey] = funcSpec + result = append(result, funcSpec) + } + return result +} diff --git a/yuanrong/pkg/functionscaler/registry/instanceconfigregistry.go b/yuanrong/pkg/functionscaler/registry/instanceconfigregistry.go new file mode 100644 index 0000000..42e0240 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/instanceconfigregistry.go @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "sync" + + "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/config" +) + +// InstanceConfigRegistry watches /instances event of etcd +type InstanceConfigRegistry struct { + watcher etcd3.Watcher + subscriberChans []chan SubEvent + listDoneCh chan struct{} + stopCh <-chan struct{} + sync.RWMutex +} + +// NewInstanceConfigRegistry will create InstanceConfigRegistry +func NewInstanceConfigRegistry(stopCh <-chan struct{}) *InstanceConfigRegistry { + instanceConfigRegistry := &InstanceConfigRegistry{ + listDoneCh: make(chan struct{}, 1), + stopCh: stopCh, + } + return instanceConfigRegistry +} + +func (ifr *InstanceConfigRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + ifr.watcher = etcd3.NewEtcdWatcher( + instanceconfig.InsConfigEtcdPrefix, + instanceconfig.GetWatcherFilter(config.GlobalConfig.ClusterID), + ifr.watcherHandler, + ifr.stopCh, + etcdClient) + ifr.watcher.StartList() + ifr.WaitForETCDList() +} + +// WaitForETCDList - +func (ifr *InstanceConfigRegistry) WaitForETCDList() { + log.GetLogger().Infof("start to wait instance config ETCD list") + select { + case <-ifr.listDoneCh: + log.GetLogger().Infof("receive list done, stop waiting ETCD list") + case <-ifr.stopCh: + log.GetLogger().Warnf("registry is stopped, stop waiting ETCD list") + } +} + +// RunWatcher will start etcd watch process for instance event +func (ifr *InstanceConfigRegistry) RunWatcher() { + go ifr.watcher.StartWatch() +} + +func parseInstanceConfig(event *etcd3.Event) (*instanceconfig.Configuration, error) { + value := event.Value + if event.Type == etcd3.DELETE || event.Type == etcd3.HISTORYDELETE { + value = event.PrevValue + } + return instanceconfig.ParseInstanceConfigFromEtcdEvent(event.Key, value) +} + +// watcherHandler will handle instance event from etcd +func (ifr *InstanceConfigRegistry) watcherHandler(event *etcd3.Event) { + logger := log.GetLogger().With(zap.Any("eventType", event.Type), zap.Any("eventKey", event.Key)) + logger.Infof("handling instances info") + if event.Type == etcd3.SYNCED { + logger.Infof("received instance config synced event") + ifr.listDoneCh <- struct{}{} + ifr.publishEvent(SubEventTypeSynced, &instanceconfig.Configuration{}) + return + } + + insSpec, err := parseInstanceConfig(event) + if err != nil { + log.GetLogger().Warnf("ParseInstanceConfigFromEtcdEvent failed, eventvalue is %s, err: %s", + string(event.Value), err.Error()) + return + } + + switch event.Type { + case etcd3.PUT, etcd3.HISTORYUPDATE: + ifr.publishEvent(SubEventTypeUpdate, insSpec) + case etcd3.DELETE, etcd3.HISTORYDELETE: + ifr.publishEvent(SubEventTypeDelete, insSpec) + default: + logger.Warnf("unsupported event") + } +} + +// addSubscriberChan will add channel, subscribed by FaaSScheduler +func (ifr *InstanceConfigRegistry) addSubscriberChan(subChan chan SubEvent) { + ifr.Lock() + ifr.subscriberChans = append(ifr.subscriberChans, subChan) + ifr.Unlock() +} + +// publishEvent will publish instance event via channel +func (ifr *InstanceConfigRegistry) publishEvent(eventType EventType, insConfig *instanceconfig.Configuration) { + for _, subChan := range ifr.subscriberChans { + if subChan != nil { + subChan <- SubEvent{ + EventType: eventType, + EventMsg: insConfig, + } + } + } +} diff --git a/yuanrong/pkg/functionscaler/registry/instanceregistry.go b/yuanrong/pkg/functionscaler/registry/instanceregistry.go new file mode 100644 index 0000000..bed2e13 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/instanceregistry.go @@ -0,0 +1,284 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "context" + "strings" + "sync" + + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instance" + "yuanrong/pkg/common/faas_common/logger/log" + commonTypes "yuanrong/pkg/common/faas_common/types" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/state" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +// InstanceRegistry watches instance event of etcd +type InstanceRegistry struct { + watcher etcd3.Watcher + fgWatcher etcd3.Watcher + InstanceIDMap map[string]*types.Instance + functionInstanceIDMap map[string]map[string]*commonTypes.InstanceSpecification + subscriberChans []chan SubEvent + listDoneCh chan struct{} + fgListDoneCh chan struct{} + stopCh <-chan struct{} + once sync.Once + sync.RWMutex +} + +type getInstanceSpecFunc func(etcdKey string, etcdValue []byte) *commonTypes.InstanceSpecification + +// NewInstanceRegistry will create InstanceRegistry +func NewInstanceRegistry(stopCh <-chan struct{}) *InstanceRegistry { + instanceRegistry := &InstanceRegistry{ + InstanceIDMap: make(map[string]*types.Instance, utils.DefaultMapSize), + functionInstanceIDMap: make(map[string]map[string]*commonTypes.InstanceSpecification, utils.DefaultMapSize), + listDoneCh: make(chan struct{}, 1), + fgListDoneCh: make(chan struct{}, 1), + stopCh: stopCh, + } + return instanceRegistry +} + +// GetInstance - +func (ir *InstanceRegistry) GetInstance(instanceID string) *types.Instance { + ir.RLock() + instance := ir.InstanceIDMap[instanceID] + ir.RUnlock() + return instance +} + +// GetFunctionInstanceIDMap - +func (ir *InstanceRegistry) GetFunctionInstanceIDMap() map[string]map[string]*commonTypes.InstanceSpecification { + ir.Lock() + defer ir.Unlock() + var idMap map[string]map[string]*commonTypes.InstanceSpecification + commonUtils.DeepCopyObj(ir.functionInstanceIDMap, &idMap) + return idMap +} + +func (ir *InstanceRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + ir.watcher = etcd3.NewEtcdWatcher( + instanceEtcdPrefix, + ir.watcherFilter, + ir.watcherHandler, + ir.stopCh, + etcdClient) + ir.watcher.StartList() + ir.fgWatcher = etcd3.NewEtcdWatcher( + workersEtcdPrefix, + ir.watcherFGFilter, + ir.watcherFGHandler, + ir.stopCh, + etcdClient) + ir.fgWatcher.StartList() + ir.WaitForETCDList() +} + +// WaitForETCDList while recovering, must get all instance including running and creating +func (ir *InstanceRegistry) WaitForETCDList() { + log.GetLogger().Infof("start to wait instance ETCD list") + select { + case <-ir.listDoneCh: + log.GetLogger().Infof("receive list done, stop waiting ETCD list") + case <-ir.fgListDoneCh: + log.GetLogger().Infof("receive fg list done, stop waiting ETCD list") + case <-ir.stopCh: + log.GetLogger().Warnf("registry is stopped, stop waiting ETCD list") + } +} + +// RunWatcher will start etcd watch process for instance event +func (ir *InstanceRegistry) RunWatcher() { + go ir.watcher.StartWatch() + go ir.fgWatcher.StartWatch() +} + +// watcherFilter will filter instance event from etcd event +func (ir *InstanceRegistry) watcherFilter(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForInstance { + return true + } + if items[instanceKeyIndex] != "instance" || items[tenantKeyIndex] != "tenant" || + items[functionKeyIndex] != "function" || + !strings.HasPrefix(items[executorKeyIndex], "0-system-faasExecutor") && + !strings.HasPrefix(items[executorKeyIndex], "0-system-serveExecutor") { + return true + } + return false +} + +// watcherHandler will handle instance event from etcd +func (ir *InstanceRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling instance event type %s key %s", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("received instance synced event") + ir.listDoneCh <- struct{}{} + return + } + instanceID := instance.GetInstanceIDFromEtcdKey(event.Key) + if len(instanceID) == 0 { + log.GetLogger().Warnf("ignoring invalid etcd key of key %s", event.Key) + return + } + ir.handleEtcdEvent(event, instanceID, instance.GetInsSpecFromEtcdValue) +} + +func (ir *InstanceRegistry) handleEtcdEvent(event *etcd3.Event, instanceID string, + getInstanceSpec getInstanceSpecFunc) { + ir.Lock() + defer ir.Unlock() + switch event.Type { + case etcd3.PUT, etcd3.HISTORYUPDATE: + insSpec := getInstanceSpec(event.Key, event.Value) + if insSpec == nil { + log.GetLogger().Warnf("ignoring invalid etcd value of key %s", event.Key) + return + } + ir.InstanceIDMap[insSpec.InstanceID] = utils.BuildInstanceFromInsSpec(insSpec, nil) + functionKey := insSpec.CreateOptions[types.FunctionKeyNote] + if functionKey == "" { + log.GetLogger().Warnf("ignoring invalid instance meta data, function is empty") + return + } + if _, ok := ir.functionInstanceIDMap[functionKey]; !ok { + ir.functionInstanceIDMap[functionKey] = make(map[string]*commonTypes.InstanceSpecification, + utils.DefaultMapSize) + } + ir.functionInstanceIDMap[functionKey][instanceID] = insSpec + if insSpec.CreateOptions[types.SchedulerIDNote] != selfregister.SelfInstanceID { + log.GetLogger().Warnf( + "carefully, instance[%s][%s] dose not created by this faaSScheduler[%s]", instanceID, + insSpec.CreateOptions[types.SchedulerIDNote], selfregister.SelfInstanceID) + } + insSpec.InstanceID = instanceID + if event.Type == etcd3.HISTORYUPDATE { + ir.publishEvent(SubEventTypeUpdate, insSpec) + return + } + log.GetLogger().Infof("put instance event, instanceId: %s, instanceStatus: %v", instanceID, + insSpec.InstanceStatus) + ir.publishEvent(SubEventTypeUpdate, insSpec) + state.Update(int32(event.Rev)) + case etcd3.DELETE, etcd3.HISTORYDELETE: + insSpec := getInstanceSpec(event.Key, event.PrevValue) + if insSpec == nil { + log.GetLogger().Warnf("ignoring invalid etcd value of key %s", event.Key) + return + } + delete(ir.InstanceIDMap, insSpec.InstanceID) + functionKey := insSpec.CreateOptions[types.FunctionKeyNote] + if functionKey == "" { + log.GetLogger().Warnf("ignoring invalid instance meta data, function is empty") + return + } + if instanceIDMap, ok := ir.functionInstanceIDMap[functionKey]; ok { + delete(instanceIDMap, instanceID) + if len(instanceIDMap) == 0 { + delete(ir.functionInstanceIDMap, functionKey) + } + } else { + log.GetLogger().Warnf("no instances of function %s exist", functionKey) + return + } + if insSpec.CreateOptions[types.SchedulerIDNote] != selfregister.SelfInstanceID { + log.GetLogger().Warnf( + "carefully, instance[%s] dose not created by this faaSScheduler[%s]", instanceID, + selfregister.SelfInstanceID) + } + insSpec.InstanceID = instanceID + ir.publishEvent(SubEventTypeDelete, insSpec) + state.Update(int32(event.Rev)) + default: + log.GetLogger().Warnf("unsupported event: %#v", event) + } +} + +// addSubscriberChan will add channel, subscribed by FaaSScheduler +func (ir *InstanceRegistry) addSubscriberChan(subChan chan SubEvent) { + ir.Lock() + ir.subscriberChans = append(ir.subscriberChans, subChan) + ir.Unlock() +} + +// publishEvent will publish instance event via channel +func (ir *InstanceRegistry) publishEvent(eventType EventType, insSpec *commonTypes.InstanceSpecification) { + for _, subChan := range ir.subscriberChans { + if subChan != nil { + subChan <- SubEvent{ + EventType: eventType, + EventMsg: insSpec, + } + } + } +} + +// EtcdList - +func (ir *InstanceRegistry) EtcdList() []*commonTypes.InstanceSpecification { + client := etcd3.GetRouterEtcdClient() + if client == nil { + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), etcd3.DurationContextTimeout) + etcdCtx := etcd3.EtcdCtxInfo{ + Ctx: ctx, + Cancel: cancel, + } + res, err := client.Get(etcdCtx, instanceEtcdPrefix, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("get function meta failed, error: %v", err) + return nil + } + var result []*commonTypes.InstanceSpecification + for _, kv := range res.Kvs { + e := &etcd3.Event{ + Key: string(kv.Key), + Value: kv.Value, + } + if ir.watcherFilter(e) { + continue + } + instanceID := instance.GetInstanceIDFromEtcdKey(e.Key) + if len(instanceID) == 0 { + log.GetLogger().Warnf("ignoring invalid etcd key of key %s", e.Key) + continue + } + insSpec := instance.GetInsSpecFromEtcdValue(e.Key, e.Value) + if insSpec == nil { + log.GetLogger().Warnf("ignoring invalid etcd value of key %s", e.Key) + continue + } + functionKey := insSpec.CreateOptions[types.FunctionKeyNote] + if functionKey == "" { + log.GetLogger().Warnf("ignoring invalid instance meta data, function is empty") + continue + } + insSpec.InstanceID = instanceID + result = append(result, insSpec) + } + return result +} diff --git a/yuanrong/pkg/functionscaler/registry/instanceregistry_fg.go b/yuanrong/pkg/functionscaler/registry/instanceregistry_fg.go new file mode 100644 index 0000000..8ca26af --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/instanceregistry_fg.go @@ -0,0 +1,172 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "encoding/json" + "strconv" + "strings" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/types" +) + +const ( + workersEtcdPrefix = "/sn/workers" + validEtcdKeyLenForWorkers = 13 + versionKeyIndex = 9 + instanceIDFGValueIndex = 12 + workersKeyIndex = 2 +) + +// watcherFGFilter will filter version FG instance event from etcd event +func (ir *InstanceRegistry) watcherFGFilter(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForWorkers { + return true + } + if items[workersKeyIndex] != "workers" || items[tenantKeyIndex] != "tenant" || + items[functionKeyIndex] != "function" || items[versionKeyIndex] != "version" || + !strings.HasPrefix(items[executorKeyIndex], "0@") { + return true + } + return false +} + +// watcherFGHandler will handle FG instance event from etcd +func (ir *InstanceRegistry) watcherFGHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling instance event type %s key %s", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("instance registry ready to receive etcd kv") + ir.fgListDoneCh <- struct{}{} + return + } + instanceID := GetInstanceIDFromFGEtcdKey(event.Key) + if len(instanceID) == 0 { + log.GetLogger().Warnf("ignoring invalid etcd key of key %s", event.Key) + return + } + ir.handleEtcdEvent(event, instanceID, GetInsSpecFromEtcdKVForFG) +} + +// GetInsSpecFromEtcdKVForFG gets InstanceSpecification from etcd key and value of instance +func GetInsSpecFromEtcdKVForFG(etcdKey string, etcdValue []byte) *commonTypes.InstanceSpecification { + insSpecFG := &commonTypes.InstanceSpecificationFG{} + insSpec := &commonTypes.InstanceSpecification{} + // extract fields from etcd key + keyFeilds := strings.Split(etcdKey, keySeparator) + if len(keyFeilds) != validEtcdKeyLenForWorkers { + log.GetLogger().Errorf("the etcdKey length doesn't match FG vesrsion!") + return nil + } + tenantID := keyFeilds[6] + functionName := keyFeilds[8] + version := keyFeilds[10] + insSpec.Function = tenantID + "/" + functionName + "/" + version + if tenantID == "" || functionName == "" { + log.GetLogger().Errorf("failed to get tenatID or functionName!,etcdKey %s", etcdKey) + return nil + } + insSpec.Function = tenantID + "/" + functionName + "/" + version + // extract fields from etcd value + err := json.Unmarshal(etcdValue, insSpecFG) + if err != nil { + log.GetLogger().Errorf("funcKey %s,failed to unmarshal etcd value to instance specification %s", + insSpec.Function, err.Error()) + return nil + } + if !buildCreateOptions(insSpecFG, insSpec) { + log.GetLogger().Warnf("funcKey %s,applier %s is invalid,ignore this instance", + insSpec.Function, insSpecFG.Applier) + return nil + } + insSpec.StartTime = strconv.Itoa(insSpecFG.CreationTime) + insSpec.RuntimeAddress = insSpecFG.NodeIP + ":" + insSpecFG.NodePort + insSpec.FunctionProxyID = insSpecFG.InstanceIP + ":" + insSpecFG.InstancePort + insSpec.ParentID = insSpecFG.Applier + insSpec.InstanceStatus = commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + Msg: "", + } + buildResource(insSpec, insSpecFG) + return insSpec +} + +func buildResource(insSpec *commonTypes.InstanceSpecification, insSpecFG *commonTypes.InstanceSpecificationFG) { + var ephemeralStorage int + funcSpec := GlobalRegistry.GetFuncSpec(insSpec.Function) + if funcSpec != nil && funcSpec.ResourceMetaData.EphemeralStorage != 0 { + ephemeralStorage = funcSpec.ResourceMetaData.EphemeralStorage + } + insSpec.Resources = commonTypes.Resources{ + Resources: map[string]commonTypes.Resource{ + constant.ResourceCPUName: { + Name: constant.ResourceCPUName, + Scalar: commonTypes.ValueScalar{ + Value: float64(insSpecFG.Resource.Runtime.CPULimit), + }, + }, + constant.ResourceMemoryName: { + Name: constant.ResourceMemoryName, + Scalar: commonTypes.ValueScalar{ + Value: float64(insSpecFG.Resource.Runtime.MemoryLimit), + }, + }, + constant.ResourceEphemeralStorage: { + Name: constant.ResourceEphemeralStorage, + Scalar: commonTypes.ValueScalar{ + Value: float64(ephemeralStorage), + }, + }, + }} +} + +func buildCreateOptions(insSpecFG *commonTypes.InstanceSpecificationFG, + insSpec *commonTypes.InstanceSpecification) bool { + if strings.HasPrefix(insSpecFG.Applier, constant.FaasSchedulerApplier) { + insSpec.CreateOptions = map[string]string{ + types.InstanceTypeNote: string(types.InstanceTypeScaled), + types.FunctionKeyNote: insSpec.Function, + types.SchedulerIDNote: insSpecFG.Applier, + } + return true + } + if insSpecFG.Applier == constant.WorkerManagerApplier || + insSpecFG.BusinessType == constant.BusinessTypeCAE { + insSpec.CreateOptions = map[string]string{ + types.InstanceTypeNote: string(types.InstanceTypeReserved), + types.FunctionKeyNote: insSpec.Function, + types.SchedulerIDNote: insSpecFG.Applier, + } + return true + } + return false +} + +// GetInstanceIDFromFGEtcdKey gets instance id from etcd key of instance +func GetInstanceIDFromFGEtcdKey(etcdKey string) string { + items := strings.Split(etcdKey, keySeparator) + if len(items) != validEtcdKeyLenForWorkers { + return "" + } + instanceID := items[instanceIDFGValueIndex] + return instanceID +} diff --git a/yuanrong/pkg/functionscaler/registry/instanceregistry_fg_test.go b/yuanrong/pkg/functionscaler/registry/instanceregistry_fg_test.go new file mode 100644 index 0000000..a6c02f9 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/instanceregistry_fg_test.go @@ -0,0 +1,130 @@ +package registry + +import ( + "reflect" + "testing" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/types" +) + +func TestGetInstanceIDFromFGEtcdKey(t *testing.T) { + type args struct { + etcdKey string + } + tests := []struct { + name string + args args + want string + }{ + {"test1", + args{etcdKey: "/sn/workers/business/yrk/tenant/6d5b16f6ef0e4b7d938d5035356aa378/function/0@default@app1/version/latest/defaultaz" + + "/defaultaz-#-custom-600-512-cbf49869-a2e7-46e0-b9bc-c11533f38db5"}, + "defaultaz-#-custom-600-512-cbf49869-a2e7-46e0-b9bc-c11533f38db5"}, + {"test2", args{etcdKey: ""}, ""}, + {"test3", + args{etcdKey: "/sn/workers/business/yrk/tenant/6d5b16f6ef0e4b7d938d5035356aa378/function/0@default@app1/version/latest/defaultaz" + + "/defaultaz"}, "defaultaz"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetInstanceIDFromFGEtcdKey(tt.args.etcdKey) + assert.Equal(t, tt.want, got) + }) + } + +} + +func TestInstanceWatcherFGFilter(t *testing.T) { + stopCh := make(chan struct{}) + instanceRegistry := NewInstanceRegistry(stopCh) + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/workers/business/yrk/tenant/6d5b16f6ef0e4b7d938d5035356aa378/function/0@default@app1/version/latest/defaultaz" + + "/defaultaz-#-custom-600-512-cbf49869-a2e7-46e0-b9bc-c11533f38db5", + Value: []byte("value"), + Rev: 1, + } + assert.Equal(t, false, instanceRegistry.watcherFGFilter(event)) + + event.Key = "/sn/workers/business/yrk/tenant/123/instance/faasscheduler/version/$latest/defaultaz/requestID/abc" + assert.Equal(t, true, instanceRegistry.watcherFGFilter(event)) +} + +func TestGetInsSpecFromEtcdKV(t *testing.T) { + etcdKey := "/sn/workers/business/yrk/tenant/6d5b16f6ef0e4b7d938d5035356aa378/function/0@default@app1/" + + "version/latest/defaultaz/defaultaz-#-custom-600-512-cbf49869-a2e7-46e0-b9bc-c11533f38db5" + etcdValue := []byte("{\"ip\":\"192.168.0.97\",\"port\":\"8080\",\"cluster\":\"cluster001\",\"status\"" + + ":\"ready\",\"p2pPort\":\"22668\",\"nodeIP\":\"10.29.111.186\",\"nodePort\":\"22423\"," + + "\"applier\":\"worker-manager\",\"ownerIP\":\"10.29.111.186\",\"cpu\":600,\"memory\":512,\"businessType\":\"CAE\"," + + "\"hasInitializer\":true,\"creationTime\":1719788553,\"resource\":{\"worker\":{\"cpuLimit\":1000,\"cpuRequest\":100," + + "\"memoryLimit\":200,\"memoryRequest\":100}," + + "\"runtime\":{\"cpuLimit\":400,\"cpuRequest\":60,\"memoryLimit\":256,\"memoryRequest\":256}}}") + defer ApplyMethod(reflect.TypeOf(GlobalRegistry), "GetFuncSpec", + func(_ *Registry, funcKey string) *types.FunctionSpecification { + return &types.FunctionSpecification{ + ResourceMetaData: commonTypes.ResourceMetaData{ + EphemeralStorage: 512, + }, + } + }).Reset() + insSpecTrans := GetInsSpecFromEtcdKVForFG(etcdKey, etcdValue) + insSpecExpected := &commonTypes.InstanceSpecification{ + InstanceID: "", + RequestID: "", + RuntimeID: "", + RuntimeAddress: "10.29.111.186:22423", + FunctionAgentID: "", + FunctionProxyID: "192.168.0.97:8080", + Function: "6d5b16f6ef0e4b7d938d5035356aa378/0@default@app1/latest", + RestartPolicy: "", + Resources: commonTypes.Resources{ + Resources: map[string]commonTypes.Resource{ + constant.ResourceCPUName: { + Name: constant.ResourceCPUName, + Scalar: commonTypes.ValueScalar{ + Value: float64(400), + }, + }, + constant.ResourceMemoryName: { + Name: constant.ResourceMemoryName, + Scalar: commonTypes.ValueScalar{ + Value: float64(256), + }, + }, + constant.ResourceEphemeralStorage: { + Name: constant.ResourceEphemeralStorage, + Scalar: commonTypes.ValueScalar{ + Value: float64(512), + }, + }, + }, + }, + ActualUse: commonTypes.Resources{}, + ScheduleOption: commonTypes.ScheduleOption{ + Affinity: commonTypes.Affinity{ + InstanceAffinity: commonTypes.InstanceAffinity{}, + }, + }, + CreateOptions: map[string]string{ + types.InstanceTypeNote: "reserved", + types.FunctionKeyNote: "6d5b16f6ef0e4b7d938d5035356aa378/0@default@app1/latest", + types.SchedulerIDNote: "worker-manager", + }, + Labels: nil, + StartTime: "1719788553", + InstanceStatus: commonTypes.InstanceStatus{ + Code: 3, + Msg: "", + }, + JobID: "", + SchedulerChain: nil, + ParentID: "worker-manager", + } + assert.Equal(t, insSpecExpected, insSpecTrans) +} diff --git a/yuanrong/pkg/functionscaler/registry/instanceregistry_test.go b/yuanrong/pkg/functionscaler/registry/instanceregistry_test.go new file mode 100644 index 0000000..43c9aae --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/instanceregistry_test.go @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package registry + +import ( + "testing" + "yuanrong/pkg/common/faas_common/types" + + "github.com/smartystreets/goconvey/convey" +) + +func TestInstanceRegistry_GetFunctionInstanceIDMap(t *testing.T) { + convey.Convey("test GetFunctionInstanceIDMap", t, func() { + ir := &InstanceRegistry{ + functionInstanceIDMap: map[string]map[string]*types.InstanceSpecification{ + "aaa": {"aa": &types.InstanceSpecification{}}, + }, + } + idMap := ir.GetFunctionInstanceIDMap() + convey.So(len(idMap), convey.ShouldEqual, 1) + }) +} diff --git a/yuanrong/pkg/functionscaler/registry/registry.go b/yuanrong/pkg/functionscaler/registry/registry.go new file mode 100644 index 0000000..3ef3f08 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/registry.go @@ -0,0 +1,311 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "encoding/json" + "fmt" + "strings" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + commonTypes "yuanrong/pkg/common/faas_common/types" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + functionEtcdPrefix = "/sn/function" + instanceEtcdPrefix = "/sn/instance" + userAgencyEtcdPrefix = "/sn/agency/" + instancesInfoEtcdPrefix = "/instances" + keySeparator = "/" + validEtcdKeyLenForCluster = 6 + validEtcdKeyLenForFrontend = 7 + validEtcdKeyLenForAlias = 10 + validEtcdKeyLenForFunction = 11 + validEtcdKeyLenForInstance = 14 + validEtcdKeyLenForAgency = 11 + validEtcdKeyLenForInsConfig = 12 + validEtcdKeyLenForInsWithLabelConf = 14 + validEtcdKeyLenForQuota1 = 7 + validEtcdKeyLenForQuota2 = 8 + quotaKeyIndex = 2 + instanceMetadataKeyIndex1 = 6 + instanceMetadataKeyIndex2 = 7 + clusterFrontendClusterIndex = 4 + clusterFrontendIPIndex = 5 + tenantValueIndex = 6 + funcNameValueIndex = 8 + versionValueIndex = 10 + instanceIDValueIndex = 13 + instanceKeyIndex = 2 + aliasKeyIndex = 2 + tenantKeyIndex = 5 + functionKeyIndex = 7 + executorKeyIndex = 8 + agencyKeyIndex1 = 2 + agencyKeyIndex2 = 9 + agencyBusinessKeyIndex = 3 + agencyTenantKeyIndex = 5 + agencyDomainKeyIndex = 7 + insInfoKeyIndex = 1 + insInfoClusterKeyIndex = 4 + insInfoClusterValueIndex = 5 + insInfoTenantKeyIndex = 6 + insInfoFunctionKeyIndex = 8 + insInfoTenantValueIndex = 7 + insInfoFuncNameValueIndex = 9 + insInfoVersionValueIndex = 11 + insInfoLabelKeyIndex = 12 + insInfoLabelValueIndex = 13 + functionClusterKeyIdx = 5 +) + +const ( + // SubEventTypeUpdate is update type of subscribe event + SubEventTypeUpdate EventType = "update" + // SubEventTypeDelete is delete type of subscribe event + SubEventTypeDelete EventType = "delete" + // SubEventTypeAdd is add type of subscribe event + SubEventTypeAdd EventType = "add" + // SubEventTypeSynced is synced type of subscribe event + SubEventTypeSynced EventType = "synced" + // SubEventTypeRemove is remove type of instance event + SubEventTypeRemove EventType = "remove" + defaultEphemeralStorage = 512 +) + +var ( + // GlobalRegistry is the global registry + GlobalRegistry *Registry +) + +// EventType defines registry event type +type EventType string + +// SubEvent contains event published to subscribers +type SubEvent struct { + EventType + EventMsg interface{} +} + +// Registry watches etcd and builds registry cache based on etcd watch +type Registry struct { + FaaSSchedulerRegistry *FaasSchedulerRegistry + FunctionRegistry *FunctionRegistry + AgentRegistry *AgentRegistry + InstanceRegistry *InstanceRegistry + FaaSManagerRegistry *FaaSManagerRegistry + InstanceConfigRegistry *InstanceConfigRegistry + AliasRegistry *AliasRegistry + FunctionAvailableRegistry *FunctionAvailableRegistry + FaaSFrontendRegistry *FaaSFrontendRegistry + TenantQuotaRegistry *TenantQuotaRegistry + RolloutRegistry *RolloutRegistry + stopCh <-chan struct{} +} + +// InitRegistry will initialize registry +func InitRegistry(stopCh <-chan struct{}) error { + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + AgentRegistry: NewAgentRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: NewInstanceConfigRegistry(stopCh), + AliasRegistry: NewAliasRegistry(stopCh), + FunctionAvailableRegistry: NewFunctionAvailableRegistry(stopCh), + FaaSFrontendRegistry: NewFaaSFrontendRegistry(stopCh), + TenantQuotaRegistry: NewTenantQuotaRegistry(stopCh), + RolloutRegistry: NewRolloutRegistry(stopCh), + } + GlobalRegistry.FaaSSchedulerRegistry.initWatcher(etcd3.GetRouterEtcdClient()) + GlobalRegistry.RolloutRegistry.initWatcher(etcd3.GetMetaEtcdClient()) + return nil +} + +// ProcessETCDList before watch etcd event, list etcd kv first +func ProcessETCDList() { + routerEtcdClient := etcd3.GetRouterEtcdClient() + metaEtcdClient := etcd3.GetMetaEtcdClient() + // Serial Execution + if GlobalRegistry != nil { + GlobalRegistry.FunctionRegistry.initWatcher(metaEtcdClient) + GlobalRegistry.AliasRegistry.initWatcher(metaEtcdClient) + GlobalRegistry.InstanceRegistry.initWatcher(routerEtcdClient) + GlobalRegistry.FaaSManagerRegistry.initWatcher(routerEtcdClient) + GlobalRegistry.InstanceConfigRegistry.initWatcher(metaEtcdClient) + GlobalRegistry.FunctionAvailableRegistry.initWatcher(metaEtcdClient) + GlobalRegistry.FaaSFrontendRegistry.initWatcher(metaEtcdClient) + GlobalRegistry.TenantQuotaRegistry.initWatcher(metaEtcdClient) + } +} + +// StartRegistry will start registry +func StartRegistry() { + if GlobalRegistry == nil { + log.GetLogger().Errorf("faaSScheduler registry is nil") + return + } + if GlobalRegistry.FaaSSchedulerRegistry != nil { + GlobalRegistry.FaaSSchedulerRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("faaSScheduler registry is nil") + } + if GlobalRegistry.FunctionRegistry != nil { + GlobalRegistry.FunctionRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("function registry is nil") + } + if GlobalRegistry.AgentRegistry != nil { + GlobalRegistry.AgentRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("agent registry is nil") + } + if GlobalRegistry.InstanceRegistry != nil { + GlobalRegistry.InstanceRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("instance registry is nil") + } + if GlobalRegistry.FaaSManagerRegistry != nil { + GlobalRegistry.FaaSManagerRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("faas manager registry is nil") + } + if GlobalRegistry.InstanceConfigRegistry != nil { + GlobalRegistry.InstanceConfigRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("instances info registry is nil") + } + if GlobalRegistry.AliasRegistry != nil { + GlobalRegistry.AliasRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("instances info registry is nil") + } + if GlobalRegistry.FunctionAvailableRegistry != nil { + GlobalRegistry.FunctionAvailableRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("function available clusters registry is nil") + } + if GlobalRegistry.FaaSFrontendRegistry != nil { + GlobalRegistry.FaaSFrontendRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("frontend instance registry is nil") + } + if GlobalRegistry.TenantQuotaRegistry != nil { + GlobalRegistry.TenantQuotaRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("tenant instance registry is nil") + } + if GlobalRegistry.RolloutRegistry != nil { + GlobalRegistry.RolloutRegistry.RunWatcher() + } else { + log.GetLogger().Errorf("rollout registry is nil") + } + commonUtils.ClearStringMemory(config.GlobalConfig.MetaETCDConfig.Password) + commonUtils.ClearStringMemory(config.GlobalConfig.RouterETCDConfig.Password) +} + +// GetFuncSpec will get function specification +func (r *Registry) GetFuncSpec(funcKey string) *types.FunctionSpecification { + return r.FunctionRegistry.getFuncSpec(funcKey) +} + +// GetInstance will get instance +func (r *Registry) GetInstance(instanceID string) *types.Instance { + return r.InstanceRegistry.GetInstance(instanceID) +} + +// FetchSilentFuncSpec will get silent function specification +func (r *Registry) FetchSilentFuncSpec(funcKey string) *types.FunctionSpecification { + return r.FunctionRegistry.fetchSilentFuncSpec(funcKey) +} + +// SubscribeFuncSpec will add subscriber for function registry +func (r *Registry) SubscribeFuncSpec(subChan chan SubEvent) { + r.FunctionRegistry.addSubscriberChan(subChan) +} + +// SubscribeInsSpec will add subscriber for instance registry +func (r *Registry) SubscribeInsSpec(subChan chan SubEvent) { + r.InstanceRegistry.addSubscriberChan(subChan) + r.FaaSManagerRegistry.addSubscriberChan(subChan) +} + +// SubscribeInsConfig will add subscriber for instanceConfig registry +func (r *Registry) SubscribeInsConfig(subChan chan SubEvent) { + r.InstanceConfigRegistry.addSubscriberChan(subChan) +} + +// SubscribeAliasSpec will add subscriber for instanceConfig registry +func (r *Registry) SubscribeAliasSpec(subChan chan SubEvent) { + r.AliasRegistry.addSubscriberChan(subChan) +} + +// SubscribeSchedulerProxy will add subscriber for scheduler registry +func (r *Registry) SubscribeSchedulerProxy(subChan chan SubEvent) { + r.FaaSSchedulerRegistry.addSubscriberChan(subChan) +} + +// SubscribeRolloutConfig will add subscriber for rollout config +func (r *Registry) SubscribeRolloutConfig(subChan chan SubEvent) { + r.RolloutRegistry.addSubscriberChan(subChan) +} + +// GetFuncKeyFromEtcdKey will get funcKey from etcd key +func GetFuncKeyFromEtcdKey(etcdKey string) string { + items := strings.Split(etcdKey, keySeparator) + if len(items) != validEtcdKeyLenForFunction { + return "" + } + return fmt.Sprintf("%s/%s/%s", items[tenantValueIndex], items[funcNameValueIndex], items[versionValueIndex]) +} + +// GetFuncMetaInfoFromEtcdValue will get FunctionMetaInfo from etcd value +func GetFuncMetaInfoFromEtcdValue(etcdValue []byte) *commonTypes.FunctionMetaInfo { + funcMetaInfo := &commonTypes.FunctionMetaInfo{} + err := json.Unmarshal(etcdValue, funcMetaInfo) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal etcd value to function meta info %s", err.Error()) + return nil + } + // set default value + if config.GlobalConfig.Scenario == types.ScenarioFunctionGraph && strings.Contains(funcMetaInfo.FuncMetaData.Runtime, + types.CustomContainerRuntimeType) && funcMetaInfo.ResourceMetaData.EphemeralStorage == 0 { + funcMetaInfo.ResourceMetaData.EphemeralStorage = int(config.GlobalConfig.EphemeralStorage) + if npu, _ := utils.GetNpuTypeAndInstanceTypeFromStr(funcMetaInfo.ResourceMetaData.CustomResources, + funcMetaInfo.ResourceMetaData.CustomResourcesSpec); npu != "" && + config.GlobalConfig.NpuEphemeralStorage != 0 { + funcMetaInfo.ResourceMetaData.EphemeralStorage = int(config.GlobalConfig.NpuEphemeralStorage) + } + if funcMetaInfo.ResourceMetaData.EphemeralStorage == 0 { + funcMetaInfo.ResourceMetaData.EphemeralStorage = defaultEphemeralStorage + } + } + // currently for instances of algorithm function in CaaS which utilizes NPU need to be scheduled with round-robin + // policy + if funcMetaInfo.FuncMetaData.BusinessType == constant.BusinessTypeCAE { + funcMetaInfo.InstanceMetaData.SchedulePolicy = types.InstanceSchedulePolicyMicroService + } + return funcMetaInfo +} diff --git a/yuanrong/pkg/functionscaler/registry/registry_test.go b/yuanrong/pkg/functionscaler/registry/registry_test.go new file mode 100644 index 0000000..e1673d7 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/registry_test.go @@ -0,0 +1,1792 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/aliasroute" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instanceconfig" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +func TestInitRegistry(t *testing.T) { + config.GlobalConfig = types.Configuration{} + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&etcd3.EtcdInitParam{}), "InitClient", + func(_ *etcd3.EtcdInitParam) error { + return nil + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartList", func(_ *etcd3.EtcdWatcher) { + return + }), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }), + ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }), + ApplyFunc(etcd3.GetCAEMetaEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }), + ApplyFunc((*FaasSchedulerRegistry).WaitForETCDList, func() {}), + } + defer func() { + time.Sleep(1000 * time.Millisecond) + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + stopCh := make(chan struct{}) + err := InitRegistry(stopCh) + assert.Equal(t, true, err == nil) + instance := GlobalRegistry.GetInstance("instance1") + assert.Equal(t, true, instance == nil) + function := GlobalRegistry.FetchSilentFuncSpec("function1") + assert.Equal(t, true, function == nil) + subChan := make(chan SubEvent, 1) + GlobalRegistry.SubscribeFuncSpec(subChan) + assert.Equal(t, 1, len(GlobalRegistry.FunctionRegistry.subscriberChans)) + GlobalRegistry.SubscribeInsSpec(subChan) + assert.Equal(t, 1, len(GlobalRegistry.InstanceRegistry.subscriberChans)) + GlobalRegistry.SubscribeInsConfig(subChan) + assert.Equal(t, 1, len(GlobalRegistry.InstanceConfigRegistry.subscriberChans)) + GlobalRegistry.SubscribeAliasSpec(subChan) + assert.Equal(t, 1, len(GlobalRegistry.AliasRegistry.subscriberChans)) + GlobalRegistry.SubscribeSchedulerProxy(subChan) + assert.Equal(t, 1, len(GlobalRegistry.FaaSSchedulerRegistry.subscriberChans)) +} + +func TestNewInstanceRegistry(t *testing.T) { + stopCh := make(chan struct{}) + instanceRegistry := NewInstanceRegistry(stopCh) + assert.Equal(t, true, instanceRegistry != nil) +} + +func TestInstanceWatcherFilter(t *testing.T) { + stopCh := make(chan struct{}) + ir := NewInstanceRegistry(stopCh) + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/instance/business/yrk/tenant", + Value: []byte("value"), + Rev: 1, + } + assert.Equal(t, true, ir.watcherFilter(event)) + + event.Key = "/sn/instance/business/yrk/tenant/123/instance/faasscheduler/version/$latest/defaultaz/requestID/abc" + assert.Equal(t, true, ir.watcherFilter(event)) + + event.Key = "/sn/instance/business/yrk/tenant/123/function/faasscheduler/version/$latest/defaultaz/requestID/abc" + assert.Equal(t, true, ir.watcherFilter(event)) + + convey.Convey("EtcdList", t, func() { + defer ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "Get", func(_ *etcd3.EtcdClient, ctx etcd3.EtcdCtxInfo, + key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + getRsp := &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + { + Key: []byte("/sn/instance/business/yrk/tenant/0/function/0-system-faasscheduler/version/$latest/defaultaz/8c1e66d5f21be4fc00/45a6e8e0-d99a-46ec-afb1-feb6f640f37d"), + Value: []byte("{}"), + }, + }, + } + return getRsp, nil + }).Reset() + insList := ir.EtcdList() + convey.So(len(insList), convey.ShouldEqual, 0) + }) +} + +func TestGetFuncSpecFromEtcdValue(t *testing.T) { + type args struct { + etcdValue []byte + } + tests := []struct { + name string + args args + IsNil bool + }{ + {"test1", + args{etcdValue: []byte("{\"funcMetaData\":{\"layers\":[],\"name\":\"0-system-hello\",\"description\"" + + ":\"\",\"functionUrn\":\"sn:cn:yrk:12345678901234561234567890123456:function:0-system-hello\",\"rever" + + "sedConcurrency\":0,\"tags\":null,\"functionUpdateTime\":\"\",\"functionVersionUrn\":\"sn:cn:yrk:12345" + + "678901234561234567890123456:function:0-system-hello:$latest\",\"codeSize\":1619264,\"codeSha256\":\"c" + + "e9f7446a54331137c8386cedc38eec942f33bab0575c81d5f3b5633caff2596\",\"handler\":\"\",\"runtime\":\"go1" + + ".13\",\"timeout\":900,\"version\":\"$latest\",\"versionDescription\":\"$latest\",\"deadLetterConfi" + + "g\":\"\",\"latestVersionUpdateTime\":\"\",\"publishTime\":\"\",\"businessId\":\"yrk\",\"tenantId\":\"1" + + "2345678901234561234567890123456\",\"domain_id\":\"\",\"project_name\":\"\",\"revisionId\":\"202212150" + + "92604748\",\"created\":\"2022-12-13 13:01:44.376 UTC\",\"statefulFlag\":false,\"hookHandler\":{\"cal" + + "l\":\"main.CallHandler\",\"init\":\"main.InitHandler\"}},\"codeMetaData\":{\"storage_type\":\"s3\",\"a" + + "ppId\":\"61022\",\"bucketId\":\"bucket-test-log1\",\"objectId\":\"hello-1671096364751\",\"bucketUr" + + "l\":\"http://10.244.162.129:19002\",\"sha256\":\"\",\"code_type\":\"\",\"code_url\":\"\",\"code_filen" + + "ame\":\"\",\"func_code\":{\"file\":\"\",\"link\":\"\"},\"code_path\":\"\"},\"envMetaData\":{\"envKe" + + "y\":\"85c545e0e31241d681031542:8231fc7f6dd9f6411d03bb5cf751a398bcf1d3d4fa1098022228c75cdb7420116807" + + "2edc1bb265f53bc8b4fee10e757693935bd8d412e292ac2349207c52311b9cef460a65c91a4103b9aed5dc920b49\",\"env" + + "ironment\":\"9cce4db16c95d4a215999e26:010456a679a83eefda685c2eff8330c69285a1196e53afaca04fd0f0bef5b87" + + "e369f603794c9942a1c38e9b8ef0e49286f8b2a06aebc007f90ebf11f97eeb16f2668eb66e551a23206896df0391f6e16536d" + + "8141d6f4f94ce75ad4125e5c9fba83bd594cda705beb9b215846e580f2594930c7d61f9f2f2ce6c14de68d5a44369e7e51aea" + + "3b8d60d44f7673bd143ac688b1e5530a9714083aac51d0d6a776ed9d72da2960972f37e48\",\"encrypted_user_data" + + "\":\"\"},\"resourceMetaData\":{\"cpu\":500,\"memory\":500,\"customResources\":\"\"},\"extendedMetaDa" + + "ta\":{\"image_name\":\"\",\"role\":{\"xrole\":\"\",\"app_xrole\":\"\"},\"mount_config\":{\"mount_use" + + "r\":{\"user_id\":0,\"user_group_id\":0},\"func_mounts\":null},\"strategy_config\":{\"concurrenc" + + "y\":0},\"extend_config\":\"\",\"initializer\":{\"initializer_handler\":\"\",\"initializer_timeou" + + "t\":0},\"enterprise_project_id\":\"\",\"log_tank_service\":{\"logGroupId\":\"\",\"logStrea" + + "mId\":\"\"},\"tracing_config\":{\"tracing_ak\":\"\",\"tracing_sk\":\"\",\"project_name\":\"\"},\"us" + + "er_type\":\"\",\"instance_meta_data\":{\"maxInstance\":100,\"minInstance\":0,\"concurrentN" + + "um\":100,\"cacheInstance\":0},\"extended_handler\":null,\"extended_timeout\":null}}")}, false}, + {"test2", args{etcdValue: []byte("")}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetFuncMetaInfoFromEtcdValue(tt.args.etcdValue) + assert.Equal(t, tt.IsNil, got == nil) + }) + } +} + +func TestGetUserAgencyByFuncMeta(t *testing.T) { + type args struct { + funcMetaInfo *commonTypes.FunctionMetaInfo + } + tests := []struct { + name string + args args + want commonTypes.UserAgency + }{ + { + name: "test funcMeta nil", + args: args{funcMetaInfo: nil}, + want: commonTypes.UserAgency{}, + }, + { + name: "test AppXRole not empty", + args: args{funcMetaInfo: &commonTypes.FunctionMetaInfo{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Role: commonTypes.Role{ + AppXRole: "", + }, + }, + }}, + want: commonTypes.UserAgency{}, + }, + { + name: "test AppXRole empty but XRole not empty", + args: args{funcMetaInfo: &commonTypes.FunctionMetaInfo{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Role: commonTypes.Role{ + XRole: "test", + AppXRole: "", + }, + }, + }}, + want: commonTypes.UserAgency{}, + }, + } + userAgencyRegistry := NewUserAgencyRegistry(make(chan struct{})) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, userAgencyRegistry.GetUserAgencyByFuncMeta(tt.args.funcMetaInfo), + "GetUserAgencyByFuncMeta(%v)", tt.args.funcMetaInfo) + }) + } +} + +func TestAgencyWatcherFilter(t *testing.T) { + userAgencyRegistry := NewUserAgencyRegistry(make(chan struct{})) + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/agency/business/yrk/tenant/123/domain/123/agency/123", + Value: []byte("value"), + Rev: 1, + } + assert.Equal(t, false, userAgencyRegistry.watcherFilter(event)) + + event.Key = "/sn/functions/business/yrk/tenant/123/domain/123/agency/123" + assert.Equal(t, true, userAgencyRegistry.watcherFilter(event)) + + event.Key = "/sn/agency/busi/yrk/tenant/123/domain/123/agency/123" + assert.Equal(t, true, userAgencyRegistry.watcherFilter(event)) + + event.Key = "/sn/agency/business/yrk/ten/123/domain/123/agency/123" + assert.Equal(t, true, userAgencyRegistry.watcherFilter(event)) + + event.Key = "/sn/agency/business/yrk/tenant/123/dom/123/agency/123" + assert.Equal(t, true, userAgencyRegistry.watcherFilter(event)) + + event.Key = "/sn/agency/business/yrk/tenant/123/domain/123/agen/123" + assert.Equal(t, true, userAgencyRegistry.watcherFilter(event)) + + event.Key = "/sn/agency/business/yrk/tenant/123/instance/faasscheduler/version/$latest/defaultaz/abc" + assert.Equal(t, true, userAgencyRegistry.watcherFilter(event)) + + event.Key = "/sn/agency/business/yrk/tenant/123/function/faasscheduler/version/$latest/defaultaz/abc" + assert.Equal(t, true, userAgencyRegistry.watcherFilter(event)) +} + +func TestAgencyWatcherHandler(t *testing.T) { + agency := `{"accessKey": "ak","secretKey": "sk","token": "token","akSkExpireTime": "2021-06-11T23:00:00.101073Z","tokenExpireTime": "2021-06-12T13:23:00.483000Z"}` + convey.Convey("agencyWatcherHandler", t, func() { + stopCh := make(chan struct{}) + userAgencyRegistry := NewUserAgencyRegistry(stopCh) + funcMetaInfo := &commonTypes.FunctionMetaInfo{ + FuncMetaData: commonTypes.FuncMetaData{ + TenantID: "123", + DomainID: "123", + }, + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Role: commonTypes.Role{ + XRole: "123", + }, + }, + } + convey.Convey("agencyWatcherHandler put unmarshal failed", func() { + defer ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error { + return fmt.Errorf("unmarshal failed") + }).Reset() + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/agency/business/yrk/tenant/123/domain/123/agency/123", + Value: []byte(agency), + Rev: 1, + } + userAgencyRegistry.watcherHandler(event) + getAgency := userAgencyRegistry.GetUserAgencyByFuncMeta(funcMetaInfo) + convey.So(getAgency.AccessKey, convey.ShouldBeZeroValue) + convey.So(getAgency.SecretKey, convey.ShouldBeZeroValue) + convey.So(getAgency.Token, convey.ShouldBeZeroValue) + convey.So(getAgency.SecurityAk, convey.ShouldBeZeroValue) + convey.So(getAgency.SecuritySk, convey.ShouldBeZeroValue) + convey.So(getAgency.SecurityToken, convey.ShouldBeZeroValue) + }) + convey.Convey("agencyWatcherHandler put", func() { + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/agency/business/yrk/tenant/123/domain/123/agency/123", + Value: []byte(agency), + Rev: 1, + } + userAgencyRegistry.watcherHandler(event) + getAgency := userAgencyRegistry.GetUserAgencyByFuncMeta(funcMetaInfo) + convey.So(getAgency.AccessKey, convey.ShouldNotBeZeroValue) + convey.So(getAgency.SecretKey, convey.ShouldNotBeZeroValue) + convey.So(getAgency.Token, convey.ShouldNotBeZeroValue) + convey.So(getAgency.SecurityAk, convey.ShouldBeZeroValue) + convey.So(getAgency.SecuritySk, convey.ShouldBeZeroValue) + convey.So(getAgency.SecurityToken, convey.ShouldBeZeroValue) + }) + convey.Convey("agencyWatcherHandler delete", func() { + event := &etcd3.Event{ + Type: etcd3.DELETE, + Key: "/sn/agency/business/yrk/tenant/123/domain/123/agency/123", + Value: []byte(agency), + Rev: 1, + } + userAgencyRegistry.watcherHandler(event) + getAgency := userAgencyRegistry.GetUserAgencyByFuncMeta(funcMetaInfo) + convey.So(getAgency.AccessKey, convey.ShouldBeZeroValue) + convey.So(getAgency.SecretKey, convey.ShouldBeZeroValue) + convey.So(getAgency.Token, convey.ShouldBeZeroValue) + convey.So(getAgency.SecurityAk, convey.ShouldBeZeroValue) + convey.So(getAgency.SecuritySk, convey.ShouldBeZeroValue) + convey.So(getAgency.SecurityToken, convey.ShouldBeZeroValue) + }) + + convey.Convey("etcd3.SYNCED & unknow etcd opt", func() { + event := &etcd3.Event{ + Type: etcd3.SYNCED, + } + userAgencyRegistry.watcherHandler(event) + event.Type = 4 + userAgencyRegistry.watcherHandler(event) + }) + close(stopCh) + }) +} + +func TestInstanceRegistryWatcherHandler(t *testing.T) { + config.GlobalConfig.Scenario = types.ScenarioWiseCloud + defer func() { + config.GlobalConfig.Scenario = "faas" + }() + stopCh := make(chan struct{}) + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + } + + recvMsg := make(chan SubEvent, 1) + ir := &InstanceRegistry{ + InstanceIDMap: make(map[string]*types.Instance, utils.DefaultMapSize), + functionInstanceIDMap: make(map[string]map[string]*commonTypes.InstanceSpecification, utils.DefaultMapSize), + listDoneCh: make(chan struct{}, 1), + } + ir.addSubscriberChan(recvMsg) + + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasExecutorJava8/version/$latest/defaultaz/task-b23aa1c4-2084-42b8-99b2-8907fa5ae6f4/f71875b1-3c20-4827-8600-0000000005d5", + Value: []byte("123"), + PrevValue: []byte("123"), + Rev: 1, + } + + convey.Convey("etcd put value error", t, func() { + event.Type = etcd3.PUT + ir.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd put value success", t, func() { + selfregister.SelfInstanceID = "" + instanceSpecByte, _ := json.Marshal(&commonTypes.InstanceSpecification{Labels: []string{""}, + CreateOptions: map[string]string{types.FunctionKeyNote: "test-function"}}) + event.Value = instanceSpecByte + event.Type = etcd3.PUT + ir.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeUpdate) + convey.So(msg.EventMsg, convey.ShouldResemble, &commonTypes.InstanceSpecification{ + InstanceID: "f71875b1-3c20-4827-8600-0000000005d5", + Labels: []string{""}, + CreateOptions: map[string]string{types.FunctionKeyNote: "test-function"}, + }) + }) + convey.Convey("etcd delete value error", t, func() { + event.Type = etcd3.DELETE + ir.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd delete value success", t, func() { + instanceSpecByte, _ := json.Marshal(&commonTypes.InstanceSpecification{Labels: []string{""}, + CreateOptions: map[string]string{types.FunctionKeyNote: "test-function"}}) + event.PrevValue = instanceSpecByte + event.Type = etcd3.DELETE + ir.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeDelete) + convey.So(msg.EventMsg, convey.ShouldResemble, &commonTypes.InstanceSpecification{ + InstanceID: "f71875b1-3c20-4827-8600-0000000005d5", + Labels: []string{""}, + CreateOptions: map[string]string{types.FunctionKeyNote: "test-function"}, + }) + }) + convey.Convey("etcd put invalid funcKey", t, func() { + event.Type = etcd3.PUT + event.Key = "/sn/instance/business/yrk/tenant//function" + ir.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd put invalid instanceID", t, func() { + event.Type = etcd3.PUT + event.Key = "/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasExecutorJava8/version/$latest/defaultaz//" + ir.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd SYNCED", t, func() { + event.Type = etcd3.SYNCED + ir.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd put invalid instanceID", t, func() { + event.Type = etcd3.PUT + event.Key = "/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasExecutorJava8/version/$latest/defaultaz//" + ir.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + close(stopCh) +} + +func TestInstanceRegistryWatcherFGHandler(t *testing.T) { + config.GlobalConfig.Scenario = types.ScenarioWiseCloud + defer func() { + config.GlobalConfig.Scenario = "faas" + }() + stopCh := make(chan struct{}) + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + } + + recvMsg := make(chan SubEvent, 1) + ir := &InstanceRegistry{ + InstanceIDMap: make(map[string]*types.Instance, utils.DefaultMapSize), + functionInstanceIDMap: make(map[string]map[string]*commonTypes.InstanceSpecification, utils.DefaultMapSize), + fgListDoneCh: make(chan struct{}, 1), + } + ir.addSubscriberChan(recvMsg) + + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/workers/business/yrk/tenant/c53626012ba84727b938ca8bf03108ef/function/0@default@zscaetest/" + + "version/latest/defaultaz/defaultaz-#-custom-1000-1024-935e9454-93fa-43f1-b5e4-7cd82737dd62", + Value: []byte("{\"ip\":\"192.168.0.154\",\"port\":\"8080\",\"cluster\":\"cluster001\",\"status\":\"ready\",\"p2pPort\":\"22668\",\"nodeIP\":\"10.29.111.186\",\"nodePort\":\"22423\",\"applier\":\"worker-manager\",\"ownerIP\":\"10.29.111.186\",\"businessType\":\"CAE\",\"hasInitializer\":true,\"creationTime\":1724393756,\"podUID\":\"c00e66f0-a4b1-46db-8d9c-61d7ab8c2405\",\"containerIDs\":{\"worker\":\"58dd3e59f60f7533cfe5604a076470d51b1e9b5f3e87a70c0954937fddfa7280\",\"runtime\":\"b616e60c55adab60271507c8df3aefec5923ee02fc851689294d04325ef522d1\"},\"resource\":{\"worker\":{\"cpuLimit\":1000,\"cpuRequest\":100,\"memoryLimit\":3686,\"memoryRequest\":100},\"runtime\":{\"cpuLimit\":1000,\"cpuRequest\":250,\"memoryLimit\":1024,\"memoryRequest\":1024}}}"), + PrevValue: []byte("{\"ip\":\"192.168.0.154\",\"port\":\"8080\",\"cluster\":\"cluster001\",\"status\":\"ready\",\"p2pPort\":\"22668\",\"nodeIP\":\"10.29.111.186\",\"nodePort\":\"22423\",\"applier\":\"worker-manager\",\"ownerIP\":\"10.29.111.186\",\"businessType\":\"CAE\",\"hasInitializer\":true,\"creationTime\":1724393756,\"podUID\":\"c00e66f0-a4b1-46db-8d9c-61d7ab8c2405\",\"containerIDs\":{\"worker\":\"58dd3e59f60f7533cfe5604a076470d51b1e9b5f3e87a70c0954937fddfa7280\",\"runtime\":\"b616e60c55adab60271507c8df3aefec5923ee02fc851689294d04325ef522d1\"},\"resource\":{\"worker\":{\"cpuLimit\":1000,\"cpuRequest\":100,\"memoryLimit\":3686,\"memoryRequest\":100},\"runtime\":{\"cpuLimit\":1000,\"cpuRequest\":250,\"memoryLimit\":1024,\"memoryRequest\":1024}}}"), + Rev: 1, + } + + convey.Convey("etcd put value success", t, func() { + event.Type = etcd3.PUT + ir.watcherFGHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeUpdate) + convey.So(msg.EventMsg, convey.ShouldResemble, &commonTypes.InstanceSpecification{ + InstanceID: "defaultaz-#-custom-1000-1024-935e9454-93fa-43f1-b5e4-7cd82737dd62", + DataSystemHost: "", + RequestID: "", + RuntimeID: "", + RuntimeAddress: "10.29.111.186:22423", + FunctionAgentID: "", + FunctionProxyID: "192.168.0.154:8080", + Function: "c53626012ba84727b938ca8bf03108ef/0@default@zscaetest/latest", + RestartPolicy: "", + Resources: commonTypes.Resources{ + Resources: map[string]commonTypes.Resource{ + constant.ResourceCPUName: { + Name: constant.ResourceCPUName, + Scalar: commonTypes.ValueScalar{ + Value: float64(1000), + }, + }, + constant.ResourceMemoryName: { + Name: constant.ResourceMemoryName, + Scalar: commonTypes.ValueScalar{ + Value: float64(1024), + }, + }, + constant.ResourceEphemeralStorage: { + Name: constant.ResourceEphemeralStorage, + Scalar: commonTypes.ValueScalar{ + Value: float64(0), + }, + }, + }, + }, + ActualUse: commonTypes.Resources{ + Resources: map[string]commonTypes.Resource(nil), + }, + ScheduleOption: commonTypes.ScheduleOption{ + SchedPolicyName: "", + Priority: 0, + Affinity: commonTypes.Affinity{ + NodeAffinity: commonTypes.NodeAffinity{ + Affinity: map[string]string(nil), + }, + InstanceAffinity: commonTypes.InstanceAffinity{ + Affinity: map[string]string(nil), + }, + InstanceAntiAffinity: commonTypes.InstanceAffinity{ + Affinity: map[string]string(nil), + }, + }, + }, + CreateOptions: map[string]string{ + "FUNCTION_KEY_NOTE": "c53626012ba84727b938ca8bf03108ef/0@default@zscaetest/latest", + "INSTANCE_TYPE_NOTE": "reserved", + "SCHEDULER_ID_NOTE": "worker-manager", + }, + Labels: []string(nil), + StartTime: "1724393756", + InstanceStatus: commonTypes.InstanceStatus{ + Code: 3, + Msg: "", + }, + JobID: "", + SchedulerChain: []string(nil), + ParentID: "worker-manager", + }) + }) + convey.Convey("etcd delete value failed", t, func() { + event.Type = etcd3.DELETE + event.Key = "/sn/workers/business/yrk/tenant/c53626012ba84727b938ca8bf03108ef/function//" + + "version/latest/defaultaz/defaultaz-#-TEST" + ir.watcherFGHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(string(msg.EventType), convey.ShouldEqual, "") + convey.So(msg.EventMsg, convey.ShouldResemble, nil) + }) + convey.Convey("etcd delete value success", t, func() { + event.Key = "/sn/workers/business/yrk/tenant/c53626012ba84727b938ca8bf03108ef/function/0@default@zscaetest/" + + "version/latest/defaultaz/defaultaz-#-custom-1000-1024-935e9454-93fa-43f1-b5e4-7cd82737dd62" + event.Type = etcd3.DELETE + ir.watcherFGHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeDelete) + convey.So(msg.EventMsg, convey.ShouldResemble, &commonTypes.InstanceSpecification{ + InstanceID: "defaultaz-#-custom-1000-1024-935e9454-93fa-43f1-b5e4-7cd82737dd62", + DataSystemHost: "", + RequestID: "", + RuntimeID: "", + RuntimeAddress: "10.29.111.186:22423", + FunctionAgentID: "", + FunctionProxyID: "192.168.0.154:8080", + Function: "c53626012ba84727b938ca8bf03108ef/0@default@zscaetest/latest", + RestartPolicy: "", + Resources: commonTypes.Resources{ + Resources: map[string]commonTypes.Resource{ + constant.ResourceCPUName: { + Name: constant.ResourceCPUName, + Scalar: commonTypes.ValueScalar{ + Value: float64(1000), + }, + }, + constant.ResourceMemoryName: { + Name: constant.ResourceMemoryName, + Scalar: commonTypes.ValueScalar{ + Value: float64(1024), + }, + }, + constant.ResourceEphemeralStorage: { + Name: constant.ResourceEphemeralStorage, + Scalar: commonTypes.ValueScalar{ + Value: float64(0), + }, + }, + }, + }, + ActualUse: commonTypes.Resources{ + Resources: map[string]commonTypes.Resource(nil), + }, + ScheduleOption: commonTypes.ScheduleOption{ + SchedPolicyName: "", + Priority: 0, + Affinity: commonTypes.Affinity{ + NodeAffinity: commonTypes.NodeAffinity{ + Affinity: map[string]string(nil), + }, + InstanceAffinity: commonTypes.InstanceAffinity{ + Affinity: map[string]string(nil), + }, + InstanceAntiAffinity: commonTypes.InstanceAffinity{ + Affinity: map[string]string(nil), + }, + }, + }, + CreateOptions: map[string]string{ + "FUNCTION_KEY_NOTE": "c53626012ba84727b938ca8bf03108ef/0@default@zscaetest/latest", + "INSTANCE_TYPE_NOTE": "reserved", + "SCHEDULER_ID_NOTE": "worker-manager", + }, + Labels: []string(nil), + StartTime: "1724393756", + InstanceStatus: commonTypes.InstanceStatus{ + Code: 3, + Msg: "", + }, + JobID: "", + SchedulerChain: []string(nil), + ParentID: "worker-manager", + }) + }) + convey.Convey("etcd put invalid instanceID", t, func() { + event.Type = etcd3.PUT + event.Key = "/sn/workers/business/yrk/tenant/c53626012ba84727b938ca8bf03108ef/function/0@default@zscaetest/" + + "version/latest/defaultaz//" + ir.watcherFGHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd synced test", t, func() { + event.Type = etcd3.SYNCED + ir.watcherFGHandler(event) + msg := SubEvent{} + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + close(stopCh) +} + +func TestFunctionRegistryWatcherHandler(t *testing.T) { + fr := &FunctionRegistry{ + userAgencyRegistry: &UserAgencyRegistry{}, + funcSpecs: make(map[string]*types.FunctionSpecification), + listDoneCh: make(chan struct{}, 1), + } + recvMsg := make(chan SubEvent, 1) + fr.addSubscriberChan(recvMsg) + + stopCh := make(chan struct{}) + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + } + + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/functions/business/yrk/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/function/0@base@testresourcejava11128/version/latest", + Value: []byte("123"), + PrevValue: []byte("123"), + Rev: 1, + } + + convey.Convey("etcd put value error", t, func() { + event.Type = etcd3.PUT + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd put value success", t, func() { + instanceSpecByte, _ := json.Marshal(&commonTypes.FunctionMetaInfo{}) + event.Value = instanceSpecByte + event.Type = etcd3.PUT + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeUpdate) + convey.So(msg.EventMsg, convey.ShouldNotResemble, &commonTypes.FunctionMetaInfo{}) + }) + convey.Convey("etcd put exist value success", t, func() { + instanceSpecByte, _ := json.Marshal(&commonTypes.FunctionMetaInfo{}) + event.Value = instanceSpecByte + event.Type = etcd3.PUT + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeUpdate) + convey.So(msg.EventMsg, convey.ShouldNotResemble, &commonTypes.FunctionMetaInfo{}) + }) + + convey.Convey("etcd delete value success", t, func() { + event.Type = etcd3.DELETE + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeDelete) + convey.So(msg.EventMsg, convey.ShouldNotResemble, &commonTypes.FunctionMetaInfo{}) + }) + convey.Convey("etcd delete doesn't exist value error", t, func() { + event.Type = etcd3.DELETE + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + convey.So(fr.funcSpecs, convey.ShouldHaveLength, 0) + }) + + convey.Convey("etcd put invalid funcKey", t, func() { + event.Type = etcd3.PUT + event.Key = "/sn/functions/busi" + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd SYNCED", t, func() { + event.Type = etcd3.SYNCED + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeSynced) + }) + convey.Convey("getFuncSpec", t, func() { + res := fr.getFuncSpec("") + convey.So(res, convey.ShouldBeNil) + }) + convey.Convey("fetchSilentFuncSpec", t, func() { + defer ApplyFunc(etcd3.GetValueFromEtcdWithRetry, + func(key string, etcdClient *etcd3.EtcdClient) ([]byte, error) { + funcSpec := types.FunctionSpecification{} + spec, _ := json.Marshal(funcSpec) + return spec, nil + }).Reset() + defer ApplyPrivateMethod(reflect.TypeOf(&FunctionRegistry{}), "buildFuncSpec", + func(_ *FunctionRegistry, etcdKey string, etcdValue []byte, + funcKey string) *types.FunctionSpecification { + return &types.FunctionSpecification{ + FuncKey: "1234/test-func/latest", + } + }).Reset() + res := fr.fetchSilentFuncSpec("1234/test-func/latest") + convey.So(res.FuncKey, convey.ShouldEqual, "1234/test-func/latest") + }) + convey.Convey("EtcdList", t, func() { + defer ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "Get", func(_ *etcd3.EtcdClient, ctx etcd3.EtcdCtxInfo, + key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + getRsp := &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + { + Key: []byte("/sn/functions/business/yrk/tenant/0/function/0-system-faasscheduler/version/$latest"), + Value: []byte("{}"), + }, + }, + } + return getRsp, nil + }).Reset() + funcList := fr.EtcdList() + convey.So(len(funcList), convey.ShouldEqual, 0) + }) +} + +func TestFaaSManagerRegistryWatcherHandler(t *testing.T) { + fr := &FaaSManagerRegistry{} + recvMsg := make(chan SubEvent, 1) + fr.addSubscriberChan(recvMsg) + + stopCh := make(chan struct{}) + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + } + + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasExecutorJava8/version/$latest/defaultaz/task-b23aa1c4-2084-42b8-99b2-8907fa5ae6f4/f71875b1-3c20-4827-8600-0000000005d5", + Value: []byte("123"), + PrevValue: []byte("123"), + Rev: 1, + } + convey.Convey("etcd put value error", t, func() { + event.Type = etcd3.PUT + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd put value success", t, func() { + instanceSpecByte, _ := json.Marshal(&commonTypes.InstanceSpecification{Labels: []string{""}}) + event.Value = instanceSpecByte + event.Type = etcd3.PUT + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeUpdate) + convey.So(msg.EventMsg, convey.ShouldResemble, &commonTypes.InstanceSpecification{ + InstanceID: "f71875b1-3c20-4827-8600-0000000005d5", + Labels: []string{""}, + }) + }) + convey.Convey("etcd delete value error", t, func() { + event.Type = etcd3.DELETE + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd delete value success", t, func() { + instanceSpecByte, _ := json.Marshal(&commonTypes.InstanceSpecification{Labels: []string{""}}) + event.PrevValue = instanceSpecByte + event.Type = etcd3.DELETE + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeDelete) + convey.So(msg.EventMsg, convey.ShouldResemble, &commonTypes.InstanceSpecification{ + InstanceID: "f71875b1-3c20-4827-8600-0000000005d5", + Labels: []string{""}, + }) + }) + convey.Convey("etcd put invalid funcKey", t, func() { + event.Type = etcd3.PUT + event.Key = "/sn/instance/business/yrk/tenant//function" + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd SYNCED", t, func() { + event.Type = etcd3.SYNCED + fr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) +} + +func TestStartRegistry(t *testing.T) { + stopCh := make(chan struct{}) + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + } + ew := &etcd3.EtcdWatcher{} + GlobalRegistry.FaaSSchedulerRegistry.functionSchedulerWatcher = ew + GlobalRegistry.FaaSSchedulerRegistry.moduleSchedulerWatcher = ew + GlobalRegistry.FunctionRegistry.watcher = ew + GlobalRegistry.FunctionRegistry.userAgencyRegistry = &UserAgencyRegistry{watcher: ew} + GlobalRegistry.InstanceRegistry.watcher = ew + GlobalRegistry.InstanceRegistry.fgWatcher = ew + GlobalRegistry.FaaSManagerRegistry.watcher = ew + config.GlobalConfig = types.Configuration{} + + watch := false + defer ApplyMethod(reflect.TypeOf(ew), "StartWatch", func(_ *etcd3.EtcdWatcher) { + watch = true + }).Reset() + + convey.Convey("start success", t, func() { + StartRegistry() + time.Sleep(100 * time.Millisecond) + convey.So(watch, convey.ShouldBeTrue) + }) + convey.Convey("start success", t, func() { + GlobalRegistry.FunctionRegistry = nil + GlobalRegistry.InstanceRegistry = nil + GlobalRegistry.FaaSManagerRegistry = nil + watch = false + StartRegistry() + time.Sleep(100 * time.Millisecond) + convey.So(watch, convey.ShouldBeTrue) + }) + close(stopCh) +} + +func TestMiscellaneous(t *testing.T) { + /*convey.Convey("ProcessETCDList", t, func() { + count := 0 + defer ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartList", func(_ *etcd3.EtcdWatcher) { + count++ + }).Reset() + stopCh := make(chan struct{}) + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: NewInstanceConfigRegistry(stopCh), + AliasRegistry: NewAliasRegistry(stopCh), + FunctionAvailableRegistry: NewFunctionAvailableRegistry(stopCh), + FaaSFrontendRegistry: NewFaaSFrontendRegistry(stopCh), + TenantQuotaRegistry: NewTenantQuotaRegistry(stopCh), + } + ProcessETCDList() + convey.So(count, convey.ShouldEqual, 8) + })*/ + convey.Convey("FunctionRegistry watcherFilterForConfig", t, func() { + stopCh := make(chan struct{}) + fr := NewFunctionRegistry(stopCh) + convey.Convey("success", func() { + res := fr.watcherFilter(&etcd3.Event{ + Key: "/sn/functions/business/yrk/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/function/0@base@testresourcejava11128/version/latest", + }) + convey.So(res, convey.ShouldBeFalse) + }) + convey.Convey("failed1", func() { + res := fr.watcherFilter(&etcd3.Event{ + Key: "/sn/functions/business/yrk/tenant/7e1ad6a6-cc5c-44fa-bd5", + }) + convey.So(res, convey.ShouldBeTrue) + }) + convey.Convey("failed2", func() { + res := fr.watcherFilter(&etcd3.Event{ + Key: "/sn/instance/business/yrk/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/function/0@base@testresourcejava11128/version/latest", + }) + convey.So(res, convey.ShouldBeTrue) + }) + }) + convey.Convey("FaaSManagerRegistry watcherFilterForConfig", t, func() { + stopCh := make(chan struct{}) + fm := NewFaaSManagerRegistry(stopCh) + convey.Convey("success", func() { + res := fm.watcherFilter(&etcd3.Event{ + Key: "/sn/instance/business/yrk/tenant/1234567890123456/function/0-system-faasmanager/version/$latest/defaultaz/task-29eea890-fd17/a16e7302-0000-4000-80de-84e02e5d6717", + }) + convey.So(res, convey.ShouldBeFalse) + }) + convey.Convey("failed1", func() { + res := fm.watcherFilter(&etcd3.Event{ + Key: "/sn/instance/business/yrk/302-0000-4000-80de-84e02e5d6717", + }) + convey.So(res, convey.ShouldBeTrue) + }) + convey.Convey("failed2", func() { + res := fm.watcherFilter(&etcd3.Event{ + Key: "/sn/function/business/yrk/tenant/1234567890123456/function/0-system-faasmanager/version/$latest/defaultaz/task-29eea890-fd17/a16e7302-0000-4000-80de-84e02e5d6717", + }) + convey.So(res, convey.ShouldBeTrue) + }) + }) + + convey.Convey("*FaaSSchedulerRegistry watcherFilterForConfig", t, func() { + stopCh := make(chan struct{}) + fsr := NewFaasSchedulerRegistry(stopCh) + convey.Convey("instance success", func() { + res := fsr.functionSchedulerFilter(&etcd3.Event{ + Key: "/sn/instance/business/yrk/tenant/1234567890123456/function/0-system-faasscheduler/version/$latest/defaultaz/task-29eea890-fd17/a16e7302-0000-4000-80de-84e02e5d6717", + }) + convey.So(res, convey.ShouldBeFalse) + }) + convey.Convey("instance failed1", func() { + res := fsr.functionSchedulerFilter(&etcd3.Event{ + Key: "/sn/instance/business/yrk/302-0000-4000-80de-84e02e5d6717", + }) + convey.So(res, convey.ShouldBeTrue) + }) + convey.Convey("module success", func() { + res := fsr.moduleSchedulerFilter(&etcd3.Event{ + Key: "/sn/faas-scheduler/instances/cluster001/7.218.100.25/faas-scheduler-59ddbc4b75-8xdjf", + }) + convey.So(res, convey.ShouldBeFalse) + }) + convey.Convey("module failed1", func() { + res := fsr.moduleSchedulerFilter(&etcd3.Event{ + Key: "/sn/instance/business/yrk/302-0000-4000-80de-84e02e5d6717", + }) + convey.So(res, convey.ShouldBeTrue) + }) + }) + + convey.Convey("InstancesInfoRegistry watcherFilterForConfig", t, func() { + convey.Convey("success", func() { + config.GlobalConfig = types.Configuration{ClusterID: "cluster001"} + res := instanceconfig.GetWatcherFilter("cluster001")(&etcd3.Event{ + Key: "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456/function/0@yrservice@test-faas-scheduler-reserved-exist/version/$latest", + }) + convey.So(res, convey.ShouldBeFalse) + }) + convey.Convey("failed1", func() { + config.GlobalConfig = types.Configuration{ClusterID: "cluster001"} + res := instanceconfig.GetWatcherFilter("cluster001")(&etcd3.Event{ + Key: "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456", + }) + convey.So(res, convey.ShouldBeTrue) + }) + }) +} + +func TestInstancesInfoRegistryWatcherHandler(t *testing.T) { + stopCh := make(chan struct{}) + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: NewInstanceConfigRegistry(stopCh), + } + + recvMsg := make(chan SubEvent, 1) + ifr := &InstanceConfigRegistry{ + listDoneCh: make(chan struct{}, 1), + } + ifr.addSubscriberChan(recvMsg) + + event := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456/function/0@yrservice@test-faas-scheduler-reserved-exist/version/$latest", + Value: []byte("123"), + PrevValue: []byte("123"), + Rev: 1, + } + + Patches := ApplyMethod(reflect.TypeOf(&selfregister.SchedulerProxy{}), "CheckFuncOwner", + func(_ *selfregister.SchedulerProxy, funcKey string) bool { + return true + }) + + convey.Convey("etcd put value error", t, func() { + event.Type = etcd3.PUT + ifr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd put value success", t, func() { + instanceSpecByte, _ := json.Marshal(&instanceconfig.Configuration{}) + event.Value = instanceSpecByte + event.Type = etcd3.PUT + ifr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeUpdate) + convey.So(msg.EventMsg, convey.ShouldResemble, &instanceconfig.Configuration{ + FuncKey: "12345678901234561234567890123456/0@yrservice@test-faas-scheduler-reserved-exist/$latest", + }) + }) + convey.Convey("etcd delete value success", t, func() { + instanceSpecByte, _ := json.Marshal(&instanceconfig.Configuration{}) + event.PrevValue = instanceSpecByte + event.Type = etcd3.DELETE + ifr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + + convey.So(msg.EventType, convey.ShouldEqual, SubEventTypeDelete) + convey.So(msg.EventMsg, convey.ShouldResemble, &instanceconfig.Configuration{ + FuncKey: "12345678901234561234567890123456/0@yrservice@test-faas-scheduler-reserved-exist/$latest", + }) + }) + convey.Convey("CheckFuncOwner does not allow", t, func() { + Patches.Reset() + event.Type = etcd3.PUT + ifr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg.EventMsg.(*instanceconfig.Configuration).InstanceMetaData.MinInstance, convey.ShouldEqual, 0) + }) + convey.Convey("etcd put invalid funcKey", t, func() { + event.Type = etcd3.PUT + event.Key = "/instances/business/yrk/cluster/cluster001/tenant/12345678901234561234567890123456" + ifr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{}) + }) + convey.Convey("etcd SYNCED", t, func() { + event.Type = etcd3.SYNCED + ifr.watcherHandler(event) + msg := SubEvent{} + select { + case msg = <-recvMsg: + default: + } + convey.So(msg, convey.ShouldResemble, SubEvent{SubEventTypeSynced, &instanceconfig.Configuration{}}) + }) +} + +func TestWaitForHash(t *testing.T) { + sp := &selfregister.SchedulerProxy{} + sp.FaaSSchedulers.Store("instance121", "scheduler1") + sp.FaaSSchedulers.Store("instance122", "scheduler1") + sp.FaaSSchedulers.Store("instance123", "scheduler1") + cnt := 0 * selfregister.GetHashLenInternal + defer ApplyFunc(time.Sleep, func(d time.Duration) { + cnt += d + }).Reset() + convey.Convey("WaitForHash", t, func() { + sp.WaitForHash(0) + convey.So(cnt, convey.ShouldEqual, 0) + }) + convey.Convey("WaitForHash", t, func() { + go sp.WaitForHash(4) + time.Sleep(selfregister.GetHashLenInternal * 2) + sp.FaaSSchedulers.Store("instance124", "scheduler1") + convey.So(cnt, convey.ShouldBeGreaterThan, 0*selfregister.GetHashLenInternal) + }) +} + +func TestAliasRegistry_RunWatcher(t *testing.T) { + convey.Convey("AliasRegistry", t, func() { + GlobalRegistry = nil + stopCh := make(chan struct{}) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartList", func(ew *etcd3.EtcdWatcher) { + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.SYNCED, + Key: "", + Value: nil, + PrevValue: nil, + Rev: 0, + ETCDType: "", + } + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "EtcdHistory", func(ew *etcd3.EtcdWatcher, + revision int64) { + }), + ApplyFunc((*FunctionRegistry).WaitForETCDList, func() {}), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", func(ew *etcd3.EtcdWatcher) {}), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: NewInstanceConfigRegistry(stopCh), + AliasRegistry: NewAliasRegistry(stopCh), + FunctionAvailableRegistry: NewFunctionAvailableRegistry(stopCh), + FaaSFrontendRegistry: NewFaaSFrontendRegistry(stopCh), + TenantQuotaRegistry: NewTenantQuotaRegistry(stopCh), + RolloutRegistry: NewRolloutRegistry(stopCh), + } + ProcessETCDList() + convey.Convey("update", func() { + defer ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", + func(ew *etcd3.EtcdWatcher) { + alias := &aliasroute.AliasElement{AliasURN: "123456"} + bytes, _ := json.Marshal(alias) + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/aliases/xxx/xxx/tenant/1234567890/function/functionName/requestID", + Value: bytes, + PrevValue: nil, + Rev: 0, + } + }).Reset() + GlobalRegistry.AliasRegistry.RunWatcher() + ch := make(chan SubEvent, 1) + GlobalRegistry.AliasRegistry.addSubscriberChan(ch) + envet := <-ch + convey.So(envet.EventType, convey.ShouldEqual, SubEventTypeUpdate) + convey.So(envet.EventMsg.(string), convey.ShouldEqual, "123456") + }) + + convey.Convey("delete", func() { + defer ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", + func(ew *etcd3.EtcdWatcher) { + alias := &aliasroute.AliasElement{AliasURN: "123456"} + bytes, _ := json.Marshal(alias) + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.DELETE, + Key: "/sn/aliases/xxx/xxx/tenant/1234567890/function/functionName/requestID", + Value: bytes, + PrevValue: nil, + Rev: 0, + } + }).Reset() + GlobalRegistry.AliasRegistry.RunWatcher() + ch := make(chan SubEvent, 1) + GlobalRegistry.AliasRegistry.addSubscriberChan(ch) + envet := <-ch + convey.So(envet.EventType, convey.ShouldEqual, SubEventTypeDelete) + convey.So(envet.EventMsg.(string), convey.ShouldEqual, + "sn:cn:xxx:1234567890:function:functionName:requestID") + }) + + convey.Convey("error", func() { + defer ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", + func(ew *etcd3.EtcdWatcher) { + alias := &aliasroute.AliasElement{AliasURN: "123456"} + bytes, _ := json.Marshal(alias) + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.ERROR, + Key: "/sn/aliases/xxx/xxx/tenant/1234567890/function/functionName/requestID", + Value: bytes, + PrevValue: nil, + Rev: 0, + } + ew.ResultChan <- &etcd3.Event{ + Type: 4, + Key: "/sn/aliases/xxx/xxx/tenant/1234567890/function/functionName/requestID", + Value: bytes, + PrevValue: nil, + Rev: 0, + } + }).Reset() + GlobalRegistry.AliasRegistry.RunWatcher() + ch := make(chan SubEvent, 1) + GlobalRegistry.AliasRegistry.addSubscriberChan(ch) + convey.So(len(ch), convey.ShouldEqual, 0) + }) + }) +} + +func TestGetSchedulerInfo(t *testing.T) { + convey.Convey("GetSchedulerInfo", t, func() { + schedulerRegistry := NewFaasSchedulerRegistry(make(chan struct{})) + selfregister.GlobalSchedulerProxy.Add(&commonTypes.InstanceInfo{InstanceName: "scheduler1", + InstanceID: "scheduler1-id"}, "") + info := schedulerRegistry.GetSchedulerInfo() + convey.So(info.SchedulerIDList[0], convey.ShouldEqual, "scheduler1") + convey.So(info.SchedulerInstanceList[0].InstanceID, convey.ShouldEqual, "scheduler1-id") + convey.So(info.SchedulerInstanceList[0].InstanceName, convey.ShouldEqual, "scheduler1") + }) +} + +func TestFaaSFrontendRegistry_RunWatcher(t *testing.T) { + convey.Convey("FaaSFrontendRegistry", t, func() { + GlobalRegistry = nil + stopCh := make(chan struct{}) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartList", func(ew *etcd3.EtcdWatcher) { + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.SYNCED, + Key: "", + Value: nil, + PrevValue: nil, + Rev: 0, + ETCDType: "", + } + }), + ApplyFunc((*FunctionRegistry).WaitForETCDList, func() {}), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "EtcdHistory", func(ew *etcd3.EtcdWatcher, + revision int64) { + }), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", func(ew *etcd3.EtcdWatcher) {}), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: NewInstanceConfigRegistry(stopCh), + AliasRegistry: NewAliasRegistry(stopCh), + FunctionAvailableRegistry: NewFunctionAvailableRegistry(stopCh), + FaaSFrontendRegistry: NewFaaSFrontendRegistry(stopCh), + TenantQuotaRegistry: NewTenantQuotaRegistry(stopCh), + RolloutRegistry: NewRolloutRegistry(stopCh), + } + ProcessETCDList() + convey.Convey("update", func() { + errEv := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/frontend/instances/cluster001/7.218.74.120/frontend-768df8f66b-gvz4z", + Value: []byte(`aaa`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FaaSFrontendRegistry.watcherHandler(errEv) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends), 0) + + errEv1 := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/frontend/instances/cluster001/7.218.74.120/frontend-768df8f66b-gvz4z/latest", + Value: []byte(`active`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FaaSFrontendRegistry.watcherHandler(errEv1) + GlobalRegistry.FaaSFrontendRegistry.watcherFilter(errEv1) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends), 0) + + ev := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/frontend/instances/cluster001/7.218.74.120/frontend-768df8f66b-gvz4z", + Value: []byte(`active`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FaaSFrontendRegistry.watcherHandler(ev) + ev1 := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/frontend/instances/cluster001/7.218.74.220/frontend-768df8f66b-gvz4z", + Value: []byte(`active`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FaaSFrontendRegistry.watcherHandler(ev1) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends["cluster001"]), 2) + assert.Equal(t, GlobalRegistry.FaaSFrontendRegistry.GetFrontends("cluster001"), + []string{"7.218.74.120", "7.218.74.220"}) + + ev2 := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/frontend/instances/cluster001/7.218.74.120/frontend-768df8f66b-gvz4z", + Value: []byte(``), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FaaSFrontendRegistry.watcherHandler(ev2) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends["cluster001"]), 1) + assert.Equal(t, GlobalRegistry.FaaSFrontendRegistry.GetFrontends("cluster001"), []string{"7.218.74.220"}) + ev3 := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/frontend/instances/cluster001/7.218.74.220/frontend-768df8f66b-gvz4z", + Value: []byte(``), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FaaSFrontendRegistry.watcherHandler(ev3) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends), 0) + assert.Equal(t, GlobalRegistry.FaaSFrontendRegistry.GetFrontends("cluster001"), []string(nil)) + }) + + convey.Convey("delete", func() { + ev := &etcd3.Event{ + Key: "/sn/frontend/instances/cluster001/7.218.74.120/frontend-768df8f66b-gvz4z", + Value: []byte(`active`), + } + GlobalRegistry.FaaSFrontendRegistry.updateFrontendInstances(ev) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends), 1) + + deleteErrEv := &etcd3.Event{ + Type: etcd3.DELETE, + Key: "/sn/frontend/instances/cluster001/7.218.74.121/frontend-768df8f66b-gvz4z", + Value: []byte(`active`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FaaSFrontendRegistry.watcherHandler(deleteErrEv) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends), 1) + + deleteErrEv1 := &etcd3.Event{ + Key: "/sn/frontend/instances/cluster001/frontend-768df8f66b-gvz4z", + Value: []byte(`active`), + } + GlobalRegistry.FaaSFrontendRegistry.deleteFrontendInstances(deleteErrEv1) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends), 1) + + deleteEv := &etcd3.Event{ + Type: etcd3.DELETE, + Key: "/sn/frontend/instances/cluster001/7.218.74.120/frontend-768df8f66b-gvz4z", + Value: []byte(`active`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FaaSFrontendRegistry.watcherHandler(deleteEv) + assert.Equal(t, len(GlobalRegistry.FaaSFrontendRegistry.ClusterFrontends), 0) + }) + }) +} + +func TestFunctionAvailableRegistry_RunWatcher(t *testing.T) { + convey.Convey("FunctionAvailableRegistry", t, func() { + GlobalRegistry = nil + stopCh := make(chan struct{}) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartList", func(ew *etcd3.EtcdWatcher) { + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.SYNCED, + Key: "", + Value: nil, + PrevValue: nil, + Rev: 0, + ETCDType: "", + } + }), + ApplyFunc((*FunctionRegistry).WaitForETCDList, func() {}), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "EtcdHistory", func(ew *etcd3.EtcdWatcher, + revision int64) { + }), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", func(ew *etcd3.EtcdWatcher) {}), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: NewInstanceConfigRegistry(stopCh), + AliasRegistry: NewAliasRegistry(stopCh), + FunctionAvailableRegistry: NewFunctionAvailableRegistry(stopCh), + FaaSFrontendRegistry: NewFaaSFrontendRegistry(stopCh), + TenantQuotaRegistry: NewTenantQuotaRegistry(stopCh), + RolloutRegistry: NewRolloutRegistry(stopCh), + } + ProcessETCDList() + convey.Convey("update and delete", func() { + errEv := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/function/available/clusters/sn:cn:yrk:580943580943580943:function:0@debugservice@hello-world/latest", + Value: []byte(`["cluster1", "cluster2"]`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FunctionAvailableRegistry.watcherHandler(errEv) + assert.Equal(t, len(GlobalRegistry.FunctionAvailableRegistry.FuncAvailableClusters), 0) + + errEv1 := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/function/available/clusters/sn:cn:yrk:580943580943580943:function:0@debugservice@hello-world", + Value: []byte("aaa"), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FunctionAvailableRegistry.watcherHandler(errEv1) + assert.Equal(t, len(GlobalRegistry.FunctionAvailableRegistry.FuncAvailableClusters), 0) + + ev := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/function/available/clusters/sn:cn:yrk:580943580943580943:function:0@debugservice@hello-world", + Value: []byte(`["cluster1", "cluster2"]`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FunctionAvailableRegistry.watcherHandler(ev) + clusters := []string{"cluster1", "cluster2"} + assert.Equal(t, + GlobalRegistry.FunctionAvailableRegistry.GeClusters("sn:cn:yrk:580943580943580943:function:0@debugservice@hello-world"), + clusters) + + deleteEv := &etcd3.Event{ + Type: etcd3.DELETE, + Key: "/sn/function/available/clusters/sn:cn:yrk:580943580943580943:function:0@debugservice@hello-world", + Value: []byte(`["cluster1", "cluster2"]`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.FunctionAvailableRegistry.watcherHandler(deleteEv) + assert.Equal(t, len(GlobalRegistry.FunctionAvailableRegistry.FuncAvailableClusters), 0) + }) + }) +} + +func Test_handleDefaultQuotaEvent(t *testing.T) { + tenantMetaValueForTest := `{"tenantInstanceMetaData": {"maxOnDemandInstance": 1000,"maxReversedInstance": 1000}}` + type args struct { + event *etcd3.Event + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "case_01 invalid event", + args: args{ + event: &etcd3.Event{}, + }, + wantErr: true, + }, { + name: "case_02 unmarshal failed", + args: args{ + event: &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/quota/cluster/cluster001/default/instancemetadata", + Value: []byte(""), + PrevValue: nil, + Rev: 0, + }, + }, + wantErr: true, + }, { + name: "case_03 update event", + args: args{ + event: &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/quota/cluster/cluster001/default/instancemetadata", + Value: []byte(tenantMetaValueForTest), + PrevValue: nil, + Rev: 0, + }, + }, + wantErr: false, + }, { + name: "case_04 delete event", + args: args{ + event: &etcd3.Event{ + Type: etcd3.DELETE, + Key: "/sn/quota/cluster/cluster001/default/instancemetadata", + Value: nil, + PrevValue: []byte(tenantMetaValueForTest), + Rev: 0, + }, + }, + wantErr: false, + }, { + name: "case_05 unkown event", + args: args{ + event: &etcd3.Event{ + Type: etcd3.ERROR, + Key: "/sn/quota/cluster/cluster001/default/instancemetadata", + Value: []byte(tenantMetaValueForTest), + PrevValue: nil, + Rev: 0, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := handleDefaultQuotaEvent(tt.args.event); (err != nil) != tt.wantErr { + t.Errorf("handleDefaultQuotaEvent() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_handleTenantQuotaEvent(t *testing.T) { + tenantMetaValueForTest := `{"tenantInstanceMetaData": {"maxOnDemandInstance": 1000,"maxReversedInstance": 1000}}` + type args struct { + event *etcd3.Event + tenantID string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "case_01 invalid event", + args: args{ + event: &etcd3.Event{}, + tenantID: "test", + }, + wantErr: true, + }, { + name: "case_02 unmarshal failed", + args: args{ + event: &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/quota/cluster/cluster001/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/instancemetadata", + Value: []byte(""), + PrevValue: nil, + Rev: 0, + }, + tenantID: "test", + }, + wantErr: true, + }, { + name: "case_04 update event", + args: args{ + event: &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/quota/cluster/cluster001/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/instancemetadata", + Value: []byte(tenantMetaValueForTest), + PrevValue: nil, + Rev: 0, + }, + tenantID: "test", + }, + wantErr: false, + }, { + name: "case_05 delete event", + args: args{ + event: &etcd3.Event{ + Type: etcd3.DELETE, + Key: "/sn/quota/cluster/cluster001/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/instancemetadata", + Value: nil, + PrevValue: []byte(tenantMetaValueForTest), + Rev: 0, + }, + tenantID: "test", + }, + wantErr: false, + }, { + name: "case_06 unkown event", + args: args{ + event: &etcd3.Event{ + Type: etcd3.ERROR, + Key: "/sn/quota/cluster/cluster001/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/instancemetadata", + Value: []byte(tenantMetaValueForTest), + PrevValue: nil, + Rev: 0, + }, + tenantID: "test", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := handleTenantQuotaEvent(tt.args.event, tt.args.tenantID); (err != nil) != tt.wantErr { + t.Errorf("handleTenantQuotaEvent() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestTenantInstanceRegistry_RunWatcher(t *testing.T) { + convey.Convey("TenantQuotaRegistry", t, func() { + GlobalRegistry = nil + stopCh := make(chan struct{}) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartList", func(ew *etcd3.EtcdWatcher) { + ew.ResultChan <- &etcd3.Event{ + Type: etcd3.SYNCED, + Key: "", + Value: nil, + PrevValue: nil, + Rev: 0, + ETCDType: "", + } + }), + ApplyFunc((*FunctionRegistry).WaitForETCDList, func() {}), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "EtcdHistory", func(ew *etcd3.EtcdWatcher, + revision int64) { + }), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", func(ew *etcd3.EtcdWatcher) {}), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + GlobalRegistry = &Registry{ + FaaSSchedulerRegistry: NewFaasSchedulerRegistry(stopCh), + FunctionRegistry: NewFunctionRegistry(stopCh), + InstanceRegistry: NewInstanceRegistry(stopCh), + FaaSManagerRegistry: NewFaaSManagerRegistry(stopCh), + InstanceConfigRegistry: NewInstanceConfigRegistry(stopCh), + AliasRegistry: NewAliasRegistry(stopCh), + FunctionAvailableRegistry: NewFunctionAvailableRegistry(stopCh), + FaaSFrontendRegistry: NewFaaSFrontendRegistry(stopCh), + TenantQuotaRegistry: NewTenantQuotaRegistry(stopCh), + RolloutRegistry: NewRolloutRegistry(stopCh), + } + ProcessETCDList() + convey.Convey("update and delete", func() { + errEv := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/quota/cluster/cluster001/tenant", + Value: []byte(`["cluster1", "cluster2"]`), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.TenantQuotaRegistry.watcherHandler(errEv) + assert.Equal(t, GlobalRegistry.TenantQuotaRegistry.watcherFilter(errEv), true) + + ev := &etcd3.Event{ + Type: etcd3.PUT, + Key: "/sn/quota/cluster/cluster001/tenant/7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/instancemetadata", + Value: []byte(""), + PrevValue: nil, + Rev: 0, + } + GlobalRegistry.TenantQuotaRegistry.watcherHandler(ev) + assert.Equal(t, GlobalRegistry.TenantQuotaRegistry.watcherFilter(ev), false) + + deleteEv := &etcd3.Event{ + Type: etcd3.DELETE, + Key: "/sn/quota/cluster/cluster001/default/instancemetadata", + Value: nil, + PrevValue: []byte(`{"tenantInstanceMetaData": {"maxOnDemandInstance": 1000,"maxReversedInstance": 1000}}`), + Rev: 0, + } + GlobalRegistry.TenantQuotaRegistry.watcherHandler(deleteEv) + assert.Equal(t, GlobalRegistry.TenantQuotaRegistry.watcherFilter(deleteEv), false) + }) + }) +} diff --git a/yuanrong/pkg/functionscaler/registry/rolloutregistry.go b/yuanrong/pkg/functionscaler/registry/rolloutregistry.go new file mode 100644 index 0000000..f592768 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/rolloutregistry.go @@ -0,0 +1,189 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "encoding/json" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" +) + +const ( + validRolloutKeyLen = 7 + rolloutConfigKeyLen = 5 + clusterIndex = 4 +) + +// RolloutRegistry watches Rollout event of etcd +type RolloutRegistry struct { + subscriberChans []chan SubEvent + configWatcher etcd3.Watcher + rolloutWatcher etcd3.Watcher + configDone chan struct{} + stopCh <-chan struct{} + sync.RWMutex +} + +// NewRolloutRegistry will create RolloutRegistry +func NewRolloutRegistry(stopCh <-chan struct{}) *RolloutRegistry { + rolloutRegistry := &RolloutRegistry{ + configDone: make(chan struct{}, 1), + stopCh: stopCh, + } + return rolloutRegistry +} + +func (rr *RolloutRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + if !config.GlobalConfig.EnableRollout { + return + } + rr.configWatcher = etcd3.NewEtcdWatcher( + constant.RolloutConfigPrefix, + rr.watcherFilterForConfig, + rr.watcherHandlerForConfig, + rr.stopCh, + etcdClient) + rr.configWatcher.StartList() + rr.rolloutWatcher = etcd3.NewEtcdWatcher( + constant.SchedulerRolloutPrefix, + rr.watcherFilterForRollout, + rr.watcherHandlerForRollout, + rr.stopCh, + etcdClient) + rr.rolloutWatcher.StartList() + rr.WaitForETCDList() +} + +// WaitForETCDList - +func (rr *RolloutRegistry) WaitForETCDList() { + select { + case <-rr.configDone: + log.GetLogger().Infof("receive rollout config list done, stop waiting ETCD list") + return + case <-rr.stopCh: + log.GetLogger().Warnf("registry is stopped, stop waiting ETCD list") + return + } +} + +// RunWatcher will start etcd watch process for instance event +func (rr *RolloutRegistry) RunWatcher() { + if !config.GlobalConfig.EnableRollout { + return + } + go rr.configWatcher.StartWatch() + go rr.rolloutWatcher.StartWatch() +} + +// watcherFilterForConfig will filter alias event from etcd event eg:/sn/faas-scheduler/rolloutConfig/cluster1 +func (rr *RolloutRegistry) watcherFilterForConfig(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != rolloutConfigKeyLen { + return true + } + if items[clusterIndex] != config.GlobalConfig.ClusterID { + return true + } + return false +} + +// watcherFilterForConfig will filter alias event from etcd event eg:/sn/faas-scheduler/rollout/aaa/bbb/ccc +func (rr *RolloutRegistry) watcherFilterForRollout(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validRolloutKeyLen { + return true + } + if items[validRolloutKeyLen-1] != selfregister.SelfInstanceID { + return true + } + return false +} + +// watcherHandlerForConfig will handle instance event from etcd +func (rr *RolloutRegistry) watcherHandlerForConfig(event *etcd3.Event) { + log.GetLogger().Infof("handling rollout config event type %s key %s", event.Type, event.Key) + switch event.Type { + case etcd3.SYNCED: + log.GetLogger().Infof("received rollout config synced event") + rr.configDone <- struct{}{} + case etcd3.PUT: + err := rollout.GetGlobalRolloutHandler().ProcessRatioUpdate(event.Value) + if err != nil { + log.GetLogger().Errorf("process ratio update error: %s", err.Error()) + return + } + rr.publishEvent(SubEventTypeUpdate, rollout.GetGlobalRolloutHandler().GetCurrentRatio()) + if selfregister.IsRolloutObject { + rollout.GetGlobalRolloutHandler().ProcessAllocRecordSync(selfregister.SelfInstanceID, + selfregister.RolloutSubjectID) + } + case etcd3.DELETE: + rollout.GetGlobalRolloutHandler().ProcessRatioDelete() + rr.publishEvent(SubEventTypeUpdate, 0) + case etcd3.ERROR: + log.GetLogger().Warnf("etcd error event: %s", event.Value) + default: + log.GetLogger().Warnf("unsupported event, key: %s", event.Key) + } +} + +// watcherHandlerForConfig will handle instance event from etcd +func (rr *RolloutRegistry) watcherHandlerForRollout(event *etcd3.Event) { + log.GetLogger().Infof("handling rollout object event type %s key %s", event.Type, event.Key) + insSpec := &types.RolloutInstanceSpecification{} + err := json.Unmarshal(event.Value, insSpec) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal rollout insSpec from key %s error %s", event.Type, err.Error()) + return + } + switch event.Type { + case etcd3.PUT: + rollout.GetGlobalRolloutHandler().UpdateForwardInstance(insSpec.InstanceID) + case etcd3.DELETE: + rollout.GetGlobalRolloutHandler().UpdateForwardInstance("") + default: + log.GetLogger().Warnf("unexpected event type %d", event.Type) + } +} + +// addSubscriberChan will add channel, subscribed by FaaSScheduler +func (rr *RolloutRegistry) addSubscriberChan(subChan chan SubEvent) { + rr.Lock() + rr.subscriberChans = append(rr.subscriberChans, subChan) + rr.Unlock() +} + +// publishEvent will publish instance event via channel +func (rr *RolloutRegistry) publishEvent(eventType EventType, ratio int) { + for _, subChan := range rr.subscriberChans { + if subChan != nil { + subChan <- SubEvent{ + EventType: eventType, + EventMsg: ratio, + } + } + } +} diff --git a/yuanrong/pkg/functionscaler/registry/rolloutregistry_test.go b/yuanrong/pkg/functionscaler/registry/rolloutregistry_test.go new file mode 100644 index 0000000..547f740 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/rolloutregistry_test.go @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package registry + +import ( + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/selfregister" +) + +func TestWatcherFilter(t *testing.T) { + config.GlobalConfig.ClusterID = "cluster1" + registry := RolloutRegistry{} + event := &etcd3.Event{ + Key: "/sn/faas-scheduler/rolloutConfig/cluster1", + } + ignore := registry.watcherFilterForConfig(event) + assert.False(t, ignore) + event = &etcd3.Event{ + Key: "/sn/faas-scheduler/rolloutConfig", + } + ignore = registry.watcherFilterForConfig(event) + assert.True(t, ignore) + event = &etcd3.Event{ + Key: "/sn/faas-scheduler/rolloutConfig/cluster2", + } + ignore = registry.watcherFilterForConfig(event) + assert.True(t, ignore) + selfregister.SelfInstanceID = "instance1" + event = &etcd3.Event{ + Key: "/sn/faas-scheduler/rollout/cluster1/node1", + } + ignore = registry.watcherFilterForRollout(event) + assert.True(t, ignore) + event = &etcd3.Event{ + Key: "/sn/faas-scheduler/rollout/cluster1/node1/instance2", + } + ignore = registry.watcherFilterForRollout(event) + assert.True(t, ignore) + event = &etcd3.Event{ + Key: "/sn/faas-scheduler/rollout/cluster1/node1/instance1", + } + ignore = registry.watcherFilterForRollout(event) + assert.False(t, ignore) + selfregister.SelfInstanceID = "" +} + +func TestInitWatch(t *testing.T) { + config.GlobalConfig.EnableRollout = true + defer func() { + config.GlobalConfig.EnableRollout = false + }() + stopCh := make(chan struct{}) + once := sync.Once{} + rr := NewRolloutRegistry(stopCh) + listCalled := 0 + watchCalled := 0 + defer gomonkey.ApplyFunc((*etcd3.EtcdWatcher).StartList, func(_ *etcd3.EtcdWatcher) { + once.Do(func() { + rr.configDone <- struct{}{} + }) + listCalled++ + }).Reset() + defer gomonkey.ApplyFunc((*etcd3.EtcdWatcher).StartWatch, func(_ *etcd3.EtcdWatcher) { + watchCalled++ + }).Reset() + convey.Convey("test init watch", t, func() { + rr.initWatcher(&etcd3.EtcdClient{}) + convey.So(listCalled, convey.ShouldEqual, 2) + rr.RunWatcher() + time.Sleep(100 * time.Millisecond) + convey.So(watchCalled, convey.ShouldEqual, 2) + }) +} + +func TestWatchHandlerForConfig(t *testing.T) { + config.GlobalConfig.EnableRollout = true + selfregister.IsRolloutObject = true + defer func() { + config.GlobalConfig.EnableRollout = false + selfregister.IsRolloutObject = false + }() + stopCh := make(chan struct{}) + rr := NewRolloutRegistry(stopCh) + event := &etcd3.Event{ + Rev: 1, + } + subChan := make(chan SubEvent, 1) + rr.addSubscriberChan(subChan) + processCalled := 0 + defer gomonkey.ApplyFunc((*rollout.RFHandler).ProcessAllocRecordSync, func(_ *rollout.RFHandler, selfInsID, + targetInsID string) { + processCalled++ + }).Reset() + convey.Convey("test watch process", t, func() { + event.Type = etcd3.PUT + event.Key = "/sn/faas-scheduler/rolloutConfig/cluster1" + rr.watcherHandlerForConfig(event) + convey.So(len(subChan), convey.ShouldEqual, 0) + event.Value = []byte(`{"rolloutRatio":"100%"}`) + rr.watcherHandlerForConfig(event) + convey.So(len(subChan), convey.ShouldEqual, 1) + e := <-subChan + ratio := e.EventMsg.(int) + convey.So(ratio, convey.ShouldEqual, 100) + convey.So(processCalled, convey.ShouldEqual, 1) + event.Type = etcd3.DELETE + rr.watcherHandlerForConfig(event) + convey.So(len(subChan), convey.ShouldEqual, 1) + e = <-subChan + ratio = e.EventMsg.(int) + convey.So(ratio, convey.ShouldEqual, 0) + }) +} + +func TestWatchHandlerForRollout(t *testing.T) { + config.GlobalConfig.EnableRollout = true + defer func() { + config.GlobalConfig.EnableRollout = false + }() + stopCh := make(chan struct{}) + rr := NewRolloutRegistry(stopCh) + event := &etcd3.Event{ + Rev: 1, + } + convey.Convey("test watch process", t, func() { + event.Type = etcd3.PUT + event.Key = "/sn/faas-scheduler/rollout/cluster1/node1/instance1" + rr.watcherHandlerForRollout(event) + convey.So(len(rollout.GetGlobalRolloutHandler().ForwardInstance), convey.ShouldEqual, 0) + event.Value = []byte(`{"instanceID":"aaa"}`) + rr.watcherHandlerForRollout(event) + convey.So(rollout.GetGlobalRolloutHandler().ForwardInstance, convey.ShouldEqual, "aaa") + event.Type = etcd3.DELETE + rr.watcherHandlerForRollout(event) + convey.So(len(rollout.GetGlobalRolloutHandler().ForwardInstance), convey.ShouldEqual, 0) + }) +} diff --git a/yuanrong/pkg/functionscaler/registry/tenantquotaregistry.go b/yuanrong/pkg/functionscaler/registry/tenantquotaregistry.go new file mode 100644 index 0000000..591877e --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/tenantquotaregistry.go @@ -0,0 +1,188 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "encoding/json" + "errors" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/urnutils" + "yuanrong/pkg/functionscaler/tenantquota" + "yuanrong/pkg/functionscaler/types" +) + +// TenantQuotaRegistry watches tenant instance event of etcd +type TenantQuotaRegistry struct { + watcher etcd3.Watcher + stopCh <-chan struct{} + sync.RWMutex +} + +// NewTenantQuotaRegistry will create TenantQuotaRegistry +func NewTenantQuotaRegistry(stopCh <-chan struct{}) *TenantQuotaRegistry { + tenantQuotaRegistry := &TenantQuotaRegistry{ + stopCh: stopCh, + } + return tenantQuotaRegistry +} + +func (tr *TenantQuotaRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + tr.watcher = etcd3.NewEtcdWatcher( + constant.TenantQuotaPrefix, + tr.watcherFilter, + tr.watcherHandler, + tr.stopCh, + etcdClient) + tr.watcher.StartList() +} + +// RunWatcher will start etcd watch process for tenant instance event +func (tr *TenantQuotaRegistry) RunWatcher() { + go tr.watcher.StartWatch() +} + +// watcherFilter will filter tenant instance event from etcd event +func (tr *TenantQuotaRegistry) watcherFilter(event *etcd3.Event) bool { + if !isTenantQuota(event.Key) && !isDefaultTenantQuota(event.Key) { + return true + } + return false +} + +func isTenantQuota(etcdRef string) bool { + // An example of a tenant quota key: + // /sn/quota/cluster//tenant//instancemetadata + strs := strings.Split(etcdRef, keySeparator) + if len(strs) != validEtcdKeyLenForQuota2 { + return false + } + if strs[quotaKeyIndex] != "quota" || strs[instanceMetadataKeyIndex2] != "instancemetadata" { + return false + } + return true +} + +func isDefaultTenantQuota(etcdRef string) bool { + // An example of default tenant quota key: + // /sn/quota/cluster//default/instancemetadata + strs := strings.Split(etcdRef, keySeparator) + if len(strs) != validEtcdKeyLenForQuota1 { + return false + } + if strs[quotaKeyIndex] != "quota" || strs[instanceMetadataKeyIndex1] != "instancemetadata" { + return false + } + return true +} + +// watcherHandler will handle tenant instance event from etcd +func (tr *TenantQuotaRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling TenantQuota event type %d key %s", event.Type, event.Key) + var err error + elements := strings.Split(event.Key, keySeparator) + if isDefaultTenantQuota(event.Key) { + log.GetLogger().Infof("default tenant quota key changed: %s", event.Key) + err = handleDefaultQuotaEvent(event) + } else if len(elements) == validEtcdKeyLenForQuota2 { + log.GetLogger().Infof("different tenant quota key changed: %s", event.Key) + tenantID := elements[tenantValueIndex] + err = handleTenantQuotaEvent(event, tenantID) + } else { + log.GetLogger().Errorf("invalid tenant quota key: %s", event.Key) + return + } + if err != nil { + log.GetLogger().Errorf("failed to process event type: %s and key: %s, error %s", + event.Type, event.Key, err.Error()) + return + } + return +} + +func handleDefaultQuotaEvent(event *etcd3.Event) error { + var ( + tenantMetaInfo = types.TenantMetaInfo{} + err error + ) + + // check value + if event.Type == etcd3.DELETE { + err = json.Unmarshal(event.PrevValue, &tenantMetaInfo) + } else { + err = json.Unmarshal(event.Value, &tenantMetaInfo) + } + if err != nil { + log.GetLogger().Errorf("unmarshal default tenant quota info failed, key: %s and error: %s", + event.Key, err.Error()) + return err + } + log.GetLogger().Infof("default tenant quota info: %+v", tenantMetaInfo) + + switch event.Type { + case etcd3.PUT: + log.GetLogger().Infof("default tenant quota update event type %s", event.Type) + tenantquota.GetTenantCache().UpdateDefaultQuota(tenantMetaInfo) + case etcd3.DELETE: + log.GetLogger().Infof("default tenant quota delete event type %s, ignore", event.Type) + default: + log.GetLogger().Errorf("default tenant quota unsupported event type %s", event.Type) + return errors.New("unsupported event type for tenant quota") + } + log.GetLogger().Infof("finished to process default tenant quota event, resource key: %s, type %d", + event.Key, event.Type) + return nil +} + +func handleTenantQuotaEvent(event *etcd3.Event, tenantID string) error { + var ( + tenantMetaInfo = types.TenantMetaInfo{} + err error + ) + + // check value + if event.Type == etcd3.DELETE { + err = json.Unmarshal(event.PrevValue, &tenantMetaInfo) + } else { + err = json.Unmarshal(event.Value, &tenantMetaInfo) + } + if err != nil { + log.GetLogger().Errorf("unmarshal tenant quota info failed, key: %s and error: %s", + urnutils.AnonymizeTenantMetadataEtcdKey(event.Key), err.Error()) + return err + } + + switch event.Type { + case etcd3.PUT: + log.GetLogger().Infof("tenant quota update event type %s", event.Type) + tenantquota.GetTenantCache().UpdateOrAddTenantQuota(tenantID, tenantMetaInfo) + case etcd3.DELETE: + log.GetLogger().Infof("tenant quota delete event type %s", event.Type) + tenantquota.GetTenantCache().DeleteTenantQuota(tenantID) + default: + log.GetLogger().Errorf("tenant quota unsupported event type %s", event.Type) + return errors.New("unsupported event type for tenant quota") + } + log.GetLogger().Infof("finished to process tenant quota event, resource key: %s, type %d", + event.Key, event.Type) + return nil +} diff --git a/yuanrong/pkg/functionscaler/registry/useragencyregistry.go b/yuanrong/pkg/functionscaler/registry/useragencyregistry.go new file mode 100644 index 0000000..4392605 --- /dev/null +++ b/yuanrong/pkg/functionscaler/registry/useragencyregistry.go @@ -0,0 +1,126 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "encoding/json" + "fmt" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" +) + +// UserAgencyRegistry watches user agency event of etcd +type UserAgencyRegistry struct { + userAgencyCacheMap sync.Map + watcher etcd3.Watcher + stopCh <-chan struct{} + sync.RWMutex +} + +// NewUserAgencyRegistry will create UserAgencyRegistry +func NewUserAgencyRegistry(stopCh <-chan struct{}) *UserAgencyRegistry { + userAgencyRegistry := &UserAgencyRegistry{ + stopCh: stopCh, + } + return userAgencyRegistry +} + +func (ur *UserAgencyRegistry) initWatcher(etcdClient *etcd3.EtcdClient) { + ur.watcher = etcd3.NewEtcdWatcher( + userAgencyEtcdPrefix, + ur.watcherFilter, + ur.watcherHandler, + ur.stopCh, + etcdClient) + ur.watcher.StartList() +} + +// RunWatcher will start etcd watch process for instance event +func (ur *UserAgencyRegistry) RunWatcher() { + go ur.watcher.StartWatch() +} + +// GetUserAgencyByFuncMeta get user agency by function meta +func (ur *UserAgencyRegistry) GetUserAgencyByFuncMeta( + funcMetaInfo *types.FunctionMetaInfo) types.UserAgency { + userAgency := types.UserAgency{} + if funcMetaInfo == nil { + log.GetLogger().Errorf("funcMeta is nil and agency is empty") + return userAgency + } + agencyID := "" + if funcMetaInfo.ExtendedMetaData.Role.AppXRole != "" { + agencyID = funcMetaInfo.ExtendedMetaData.Role.AppXRole + } else if funcMetaInfo.ExtendedMetaData.Role.XRole != "" { + agencyID = funcMetaInfo.ExtendedMetaData.Role.XRole + } + if agencyID != "" { + tenantID := funcMetaInfo.FuncMetaData.TenantID + domainID := funcMetaInfo.FuncMetaData.DomainID + key := fmt.Sprintf("/sn/agency/business/yrk/tenant/%s/domain/%s/agency/%s", tenantID, domainID, agencyID) + if value, ok := ur.userAgencyCacheMap.Load(key); ok { + userAgency, ok = value.(types.UserAgency) + if !ok { + log.GetLogger().Errorf("not a valid userAgency cache") + } + } + return userAgency + } + return userAgency +} + +// watcherFilter will filter instance event from etcd event +func (ur *UserAgencyRegistry) watcherFilter(event *etcd3.Event) bool { + items := strings.Split(event.Key, keySeparator) + if len(items) != validEtcdKeyLenForAgency { + return true + } + if items[agencyKeyIndex1] != "agency" || items[agencyBusinessKeyIndex] != "business" || + items[agencyTenantKeyIndex] != "tenant" || items[agencyDomainKeyIndex] != "domain" || + items[agencyKeyIndex2] != "agency" { + return true + } + return false +} + +func (ur *UserAgencyRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling user agency event type %s key %s", event.Type, event.Key) + switch event.Type { + case etcd3.PUT: + var agency = types.UserAgency{} + if err := json.Unmarshal(event.Value, &agency); err != nil { + log.GetLogger().Errorf("failed to unmarshal the json, error: %s", err.Error()) + return + } + ur.userAgencyCacheMap.Delete(event.Key) + ur.userAgencyCacheMap.Store(event.Key, agency) + case etcd3.DELETE: + if _, ok := ur.userAgencyCacheMap.Load(event.Key); ok { + ur.userAgencyCacheMap.Delete(event.Key) + } + return + case etcd3.SYNCED: + log.GetLogger().Infof("userAgency registry ready to receive etcd kv") + default: + log.GetLogger().Warnf("unknown event type: %d", event.Type) + } +} diff --git a/yuanrong/pkg/functionscaler/requestqueue/instance_request_queue.go b/yuanrong/pkg/functionscaler/requestqueue/instance_request_queue.go new file mode 100644 index 0000000..179d609 --- /dev/null +++ b/yuanrong/pkg/functionscaler/requestqueue/instance_request_queue.go @@ -0,0 +1,372 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package requestqueue - +package requestqueue + +import ( + "errors" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/queue" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" + "yuanrong/pkg/functionscaler/workermanager" +) + +const ( + defaultTriggerChSize = 10000 + timeoutAccuracy = time.Duration(500) * time.Millisecond +) + +var ( + // DefaultRequestTimeout is defined here for the convenience of mocking + DefaultRequestTimeout = time.Duration(30) * time.Second + errInsThdReqReachMaxNum = errors.New("instance thread request reach max number") +) + +// ScheduleFunction - +type ScheduleFunction func(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) + +// PendingInsAcqReq is the instance acquire request +type PendingInsAcqReq struct { + ResultChan chan *PendingInsAcqRsp + CreatedTime time.Time + InsAcqReq *types.InstanceAcquireRequest +} + +// PendingInsAcqRsp is the instance acquire response +type PendingInsAcqRsp struct { + InsAlloc *types.InstanceAllocation + Error error +} + +// InsAcqReqQueue is queue to store instance thread requests and also responsible for scheduling them +type InsAcqReqQueue struct { + queue *queue.FifoQueue + maxQueueLen int + insNum int + funcKeyWithRes string + recoverableError error + unrecoverableError error + requestTimeout time.Duration + triggerCh chan struct{} + stopCh chan struct{} + *sync.RWMutex + + schFuncMap sync.Map +} + +// NewInsAcqReqQueue creates a InsAcqReqQueue +func NewInsAcqReqQueue(funcKeyWithRes string, requestTimeout time.Duration) *InsAcqReqQueue { + if requestTimeout < DefaultRequestTimeout { + requestTimeout = DefaultRequestTimeout + } + insThdReqQue := &InsAcqReqQueue{ + queue: queue.NewFifoQueue(nil), + maxQueueLen: config.GlobalConfig.AutoScaleConfig.BurstScaleNum, + funcKeyWithRes: funcKeyWithRes, + requestTimeout: requestTimeout, + triggerCh: make(chan struct{}, defaultTriggerChSize), + stopCh: make(chan struct{}), + RWMutex: &sync.RWMutex{}, + } + + go insThdReqQue.TimeoutReqHandleLoop() + return insThdReqQue +} + +// HandleInsNumUpdate - +func (iq *InsAcqReqQueue) HandleInsNumUpdate(InsNumDiff int) { + iq.Lock() + iq.insNum += InsNumDiff + iq.Unlock() +} + +// Len returns length of queue +func (iq *InsAcqReqQueue) Len() int { + iq.RLock() + l := iq.queue.Len() + iq.RUnlock() + return l +} + +// AddRequest adds request into queue +func (iq *InsAcqReqQueue) AddRequest(insThdReq *PendingInsAcqReq) error { + needTriggerSchedule := false + iq.Lock() + defer iq.Unlock() + if iq.unrecoverableError != nil { + return iq.unrecoverableError + } + if iq.queue.Len() == 0 { + needTriggerSchedule = true + } + if iq.queue.Len() >= iq.maxQueueLen { + return errInsThdReqReachMaxNum + } + err := iq.queue.PushBack(insThdReq) + if err != nil { + return err + } + metrics.OnPendingRequestAdd(insThdReq.InsAcqReq) + if needTriggerSchedule { + iq.TriggerSchedule() + } + return nil +} + +// RegisterSchFunc register schFunc for schedule instance +func (iq *InsAcqReqQueue) RegisterSchFunc(schFuncKey string, schFunc ScheduleFunction) { + scheCh := make(chan struct{}, defaultTriggerChSize) + iq.schFuncMap.Store(schFuncKey, scheCh) + go func(schFunc ScheduleFunction) { + for { + select { + case _, ok := <-scheCh: + if !ok { + return + } + iq.realScheduleRequest(schFunc) + case <-iq.stopCh: + return + } + } + }(schFunc) +} + +// ScheduleRequest schedules requests to instance threads +func (iq *InsAcqReqQueue) ScheduleRequest(schFuncKey string) { + ch, loaded := iq.schFuncMap.Load(schFuncKey) + if !loaded { + log.GetLogger().Errorf("schFunc has not register, skip") + return + } + scheCh, ok := ch.(chan struct{}) + if !ok { + log.GetLogger().Errorf("schFunc type error, skip") + return + } + if len(scheCh) == 0 { + scheCh <- struct{}{} + } +} + +// realScheduleRequest schedules requests to instance threads +func (iq *InsAcqReqQueue) realScheduleRequest(schFunc ScheduleFunction) { + for { + iq.Lock() + obj := iq.queue.PopFront() + if obj == nil { + iq.Unlock() + return + } + pendingRequest, ok := obj.(*PendingInsAcqReq) + if !ok { + iq.Unlock() + continue + } + iq.Unlock() + insAlloc, err := schFunc(pendingRequest.InsAcqReq) + if err != nil { + iq.Lock() + // PushBack 不会返回err 不需要处理 + _ = iq.queue.PushBack(pendingRequest) + iq.Unlock() + return + } + metrics.OnPendingRequestRelease(pendingRequest.InsAcqReq) + pendingRequest.ResultChan <- &PendingInsAcqRsp{ + InsAlloc: insAlloc, + Error: nil, + } + } +} + +// TriggerSchedule triggers request scheduling +func (iq *InsAcqReqQueue) TriggerSchedule() { + select { + case iq.triggerCh <- struct{}{}: + default: + log.GetLogger().Warnf("trigger channel is blocked in request queue") + } +} + +// TimeoutReqHandleLoop is a loop to continually handle timeout request in queue +func (iq *InsAcqReqQueue) TimeoutReqHandleLoop() { + timeoutCh := make(<-chan time.Time, 1) + var timer *time.Timer + defer func() { + if timer != nil { + timer.Stop() + } + }() + for { + select { + case <-iq.stopCh: + log.GetLogger().Warnf("stop schedule instance thread for function %s now", iq.funcKeyWithRes) + err := snerror.New(statuscode.FuncMetaNotFoundErrCode, statuscode.FuncMetaNotFoundErrMsg) + iq.Lock() + iq.ClearReqQueueWithError(err) + iq.Unlock() + return + case _, ok := <-iq.triggerCh: + if !ok { + log.GetLogger().Warnf("trigger channel is closed, stop schedule request now") + return + } + case <-timeoutCh: + } + var nextReqTimeout time.Duration + currentTime := time.Now() + iq.Lock() + for iq.queue.Len() != 0 { + obj := iq.queue.Front() + insThdReq, ok := obj.(*PendingInsAcqReq) + if !ok { + iq.queue.PopFront() + continue + } + reqHasWaitTime := currentTime.Sub(insThdReq.CreatedTime) + if (reqHasWaitTime + timeoutAccuracy).Milliseconds() >= iq.requestTimeout.Milliseconds() { + err := scheduler.ErrInsReqTimeout + if iq.recoverableError != nil { + err = iq.recoverableError + } + metrics.OnPendingRequestRelease(insThdReq.InsAcqReq) + insThdReq.ResultChan <- &PendingInsAcqRsp{ + InsAlloc: nil, + Error: err, + } + iq.queue.PopFront() + continue + } + // request queue is a fifo queue based on enqueue time of request, if request at front can't be scheduled + // or not yet timed out then there is no need to go further + nextReqTimeout = iq.requestTimeout - reqHasWaitTime + break + } + // if queue is not empty then we start a timer to handle request timeout + if iq.queue.Len() != 0 { + timer = time.NewTimer(nextReqTimeout) + timeoutCh = timer.C + } + iq.Unlock() + } +} + +// ClearReqQueueWithError - without lock,should lock/unlock outside +func (iq *InsAcqReqQueue) ClearReqQueueWithError(err error) { + for iq.queue.Len() != 0 { + obj := iq.queue.Front() + insThdReq, ok := obj.(*PendingInsAcqReq) + if !ok { + iq.queue.PopFront() + continue + } + metrics.OnPendingRequestRelease(insThdReq.InsAcqReq) + insThdReq.ResultChan <- &PendingInsAcqRsp{ + InsAlloc: nil, + Error: err, + } + iq.queue.PopFront() + } +} + +// HandleCreateError will return unrecoverable error immediately and save recoverable error for further return +func (iq *InsAcqReqQueue) HandleCreateError(createError error) { + select { + case <-iq.stopCh: + log.GetLogger().Warnf("stop schedule instance thread for function %s now", iq.funcKeyWithRes) + return + default: + if config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeFG && createError != nil { + iq.handlerWorkMangerError(createError) + return + } + if createError == nil { + iq.Lock() + iq.unrecoverableError = nil + iq.recoverableError = nil + iq.Unlock() + return + } + log.GetLogger().Warnf("instance request queue for function %s handling create error %s", iq.funcKeyWithRes, + createError.Error()) + if utils.IsUnrecoverableError(createError) { + log.GetLogger().Warnf("set unrecoverable create error %s for function %s", createError.Error(), + iq.funcKeyWithRes) + iq.Lock() + iq.unrecoverableError = createError + if iq.insNum == 0 { + iq.ClearReqQueueWithError(createError) + } + iq.Unlock() + return + } + log.GetLogger().Warnf("set recoverable create error %s for function %s", createError.Error(), + iq.funcKeyWithRes) + iq.Lock() + iq.recoverableError = createError + iq.Unlock() + } +} + +func (iq *InsAcqReqQueue) handlerWorkMangerError(createError error) { + iq.Lock() + currentInstanceNum := iq.insNum + iq.Unlock() + if currentInstanceNum == 0 && !workermanager.NeedTryError(createError) { + log.GetLogger().Errorf("worker manager return error %s for function %s, returning all request now", + createError.Error(), iq.funcKeyWithRes) + iq.Lock() + iq.ClearReqQueueWithError(createError) + iq.Unlock() + return + } + log.GetLogger().Warnf("set delay return instance create error %v for function %s", + createError, iq.funcKeyWithRes) + iq.Lock() + iq.recoverableError = createError + iq.Unlock() +} + +// UpdateRequestTimeout handles request timeout update +func (iq *InsAcqReqQueue) UpdateRequestTimeout(requestTimeout time.Duration) { + iq.Lock() + defer iq.Unlock() + if requestTimeout == iq.requestTimeout { + return + } + if requestTimeout < DefaultRequestTimeout { + requestTimeout = DefaultRequestTimeout + } + iq.requestTimeout = requestTimeout +} + +// Stop stops InsAcqReqQueue +func (iq *InsAcqReqQueue) Stop() { + commonUtils.SafeCloseChannel(iq.stopCh) +} diff --git a/yuanrong/pkg/functionscaler/requestqueue/instance_request_queue_test.go b/yuanrong/pkg/functionscaler/requestqueue/instance_request_queue_test.go new file mode 100644 index 0000000..9a8626c --- /dev/null +++ b/yuanrong/pkg/functionscaler/requestqueue/instance_request_queue_test.go @@ -0,0 +1,291 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package concurrencyscheduler - +package requestqueue + +import ( + "context" + "errors" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/queue" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/workermanager" +) + +var testFuncSpec = &types.FunctionSpecification{ + FuncMetaSignature: "123", + InstanceMetaData: commonTypes.InstanceMetaData{ + ConcurrentNum: 2, + }, + FuncCtx: context.TODO(), +} + +func TestMain(m *testing.M) { + config.GlobalConfig = types.Configuration{ + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 500, + ScaleDownTime: 1000, + BurstScaleNum: 1000, + }, + LeaseSpan: 500, + } + m.Run() +} + +func TestNewInsAcqReqQueue(t *testing.T) { + q := NewInsAcqReqQueue("testFuncKey", 10*time.Second) + assert.Equal(t, false, q == nil) + q.Stop() +} + +func TestHandleInsNumUpdate(t *testing.T) { + q := NewInsAcqReqQueue("testFuncKey", 10*time.Second) + q.HandleInsNumUpdate(1) + assert.Equal(t, 1, q.insNum) + q.Stop() +} + +func TestLen(t *testing.T) { + q := NewInsAcqReqQueue("testFuncKey", 10*time.Second) + assert.Equal(t, 0, q.Len()) + q.Stop() +} + +func TestScheduleRequest(t *testing.T) { + q := NewInsAcqReqQueue("testFuncKey", 10*time.Second) + ch := make(chan string, 1) + q.RegisterSchFunc("scheFuncFail", func(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) { + ch <- "scheFuncFail" + return nil, errors.New("some error") + }) + q.RegisterSchFunc("scheFuncSucc", func(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) { + ch <- "scheFuncSucc" + return &types.InstanceAllocation{}, nil + }) + req := &PendingInsAcqReq{ResultChan: make(chan *PendingInsAcqRsp, 1)} + q.AddRequest(req) + q.ScheduleRequest("scheFuncFail") + failOut := <-ch + assert.Equal(t, "scheFuncFail", failOut) + q.ScheduleRequest("scheFuncSucc") + succOut := <-ch + assert.Equal(t, "scheFuncSucc", succOut) + q.Stop() +} + +func TestReturnCreateError(t *testing.T) { + stopChan := make(chan struct{}) + iq := &InsAcqReqQueue{ + stopCh: stopChan, + queue: queue.NewFifoQueue(nil), + RWMutex: &sync.RWMutex{}, + } + createErr := snerror.New(statuscode.UserFuncEntryNotFoundErrCode, "entry not found") + getRespCh := make(chan *PendingInsAcqRsp, 3) + getResp := &PendingInsAcqReq{ + ResultChan: getRespCh, + } + iq.queue.PushBack(getResp) + convey.Convey("get error success", t, func() { + iq.HandleCreateError(createErr) + var err error + select { + case itq := <-getRespCh: + err = itq.Error + default: + err = nil + } + convey.So(err.Error(), convey.ShouldEqual, "entry not found") + }) + close(stopChan) + convey.Convey("chan stop", t, func() { + iq.HandleCreateError(createErr) + var err error + select { + case itq := <-getRespCh: + err = itq.Error + default: + err = nil + } + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestTimeoutReqHandleLoop(t *testing.T) { + var insThdReq1RspTime time.Time + var insThdReq2RspTime time.Time + var insThdReq3RspTime time.Time + var insThdReq4RspTime time.Time + + insThdReqQue := &InsAcqReqQueue{ + funcKeyWithRes: "funcKeyWithRes", + queue: queue.NewFifoQueue(nil), + maxQueueLen: 1000, + requestTimeout: time.Duration(1) * time.Second, + triggerCh: make(chan struct{}, defaultTriggerChSize), + stopCh: make(chan struct{}), + RWMutex: &sync.RWMutex{}, + } + go insThdReqQue.TimeoutReqHandleLoop() + time.Sleep(10 * time.Millisecond) + insThdReq1 := &PendingInsAcqReq{ + CreatedTime: time.Now(), + ResultChan: make(chan *PendingInsAcqRsp, 1), + } + insThdReq2 := &PendingInsAcqReq{ + CreatedTime: time.Now(), + ResultChan: make(chan *PendingInsAcqRsp, 1), + } + + insThdReqQue.AddRequest(insThdReq1) + var wg sync.WaitGroup + wg.Add(1) + go func(insThdReq *PendingInsAcqReq) { + defer wg.Done() + insThdRsp := <-insThdReq.ResultChan + insThdReq1RspTime = time.Now() + assert.True(t, insThdReq1RspTime.Sub(insThdReq.CreatedTime) <= insThdReqQue.requestTimeout+50*time.Millisecond) + assert.Equal(t, insThdRsp.Error.Error(), scheduler.ErrInsReqTimeout.Error()) + }(insThdReq1) + insThdReqQue.AddRequest(insThdReq2) + wg.Add(1) + go func(insThdReq *PendingInsAcqReq) { + defer wg.Done() + insThdRsp := <-insThdReq.ResultChan + insThdReq2RspTime = time.Now() + assert.True(t, insThdReq2RspTime.Sub(insThdReq.CreatedTime) <= insThdReqQue.requestTimeout+50*time.Millisecond) + assert.Equal(t, insThdRsp.Error.Error(), scheduler.ErrInsReqTimeout.Error()) + }(insThdReq2) + + time.Sleep(450 * time.Millisecond) + insThdReq3 := &PendingInsAcqReq{ + CreatedTime: time.Now(), + ResultChan: make(chan *PendingInsAcqRsp, 1), + } + insThdReqQue.AddRequest(insThdReq3) + wg.Add(1) + go func(insThdReq *PendingInsAcqReq) { + defer wg.Done() + insThdRsp := <-insThdReq.ResultChan + insThdReq3RspTime = time.Now() + assert.True(t, insThdReq3RspTime.Sub(insThdReq.CreatedTime) <= insThdReqQue.requestTimeout+50*time.Millisecond) + assert.Equal(t, insThdRsp.Error.Error(), scheduler.ErrInsReqTimeout.Error()) + }(insThdReq3) + + time.Sleep(100 * time.Millisecond) + insThdReq4 := &PendingInsAcqReq{ + CreatedTime: time.Now(), + ResultChan: make(chan *PendingInsAcqRsp, 1), + } + insThdReqQue.AddRequest(insThdReq4) + wg.Add(1) + go func(insThdReq *PendingInsAcqReq) { + defer wg.Done() + insThdRsp := <-insThdReq.ResultChan + insThdReq4RspTime = time.Now() + assert.True(t, insThdReq3RspTime.Sub(insThdReq.CreatedTime) <= insThdReqQue.requestTimeout+50*time.Millisecond) + assert.Equal(t, insThdRsp.Error.Error(), scheduler.ErrInsReqTimeout.Error()) + }(insThdReq4) + wg.Wait() + assert.True(t, insThdReq1RspTime.Sub(insThdReq2RspTime) <= 10*time.Millisecond) + assert.True(t, insThdReq2RspTime.Sub(insThdReq3RspTime) <= 10*time.Millisecond) + assert.True(t, insThdReq4RspTime.Sub(insThdReq3RspTime) >= 500*time.Millisecond) +} + +func TestHandlerWorkMangerError(t *testing.T) { + tests := []struct { + name string + insNum int + createError error + needTryError bool + expectClearCalled bool + expectRecoverable bool + }{ + { + name: "Current instance number is 0, error not retryable", + insNum: 0, + createError: errors.New("non-retryable error"), + needTryError: false, + expectClearCalled: true, + expectRecoverable: false, + }, + { + name: "Current instance number is not 0, error is retryable", + insNum: 1, + createError: errors.New("retryable error"), + needTryError: true, + expectClearCalled: false, + expectRecoverable: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + iq := &InsAcqReqQueue{ + insNum: tt.insNum, + funcKeyWithRes: "test-func-key", + recoverableError: nil, + RWMutex: &sync.RWMutex{}, + } + + patches := gomonkey.NewPatches() + defer patches.Reset() + + clearCalled := false + + patches.ApplyMethod(reflect.TypeOf(iq), "ClearReqQueueWithError", + func(_ *InsAcqReqQueue, err error) { + clearCalled = true + }) + + patches.ApplyFunc(workermanager.NeedTryError, + func(err error) bool { + return tt.needTryError + }) + + iq.handlerWorkMangerError(tt.createError) + + assert.Equal(t, tt.expectClearCalled, clearCalled, "ClearReqQueueWithError should be called accordingly") + if tt.expectRecoverable { + assert.Equal(t, tt.createError, iq.recoverableError, "recoverableError should be set accordingly") + } else { + assert.Nil(t, iq.recoverableError, "recoverableError should be nil") + } + }) + } +} + +func TestUpdateRequestTimeout(t *testing.T) { + q := NewInsAcqReqQueue("testFuncKey", 10*time.Second) + q.UpdateRequestTimeout(20 * time.Second) + assert.Equal(t, DefaultRequestTimeout, q.requestTimeout) + q.Stop() +} diff --git a/yuanrong/pkg/functionscaler/rollout/rollouthandler.go b/yuanrong/pkg/functionscaler/rollout/rollouthandler.go new file mode 100644 index 0000000..3bb1c29 --- /dev/null +++ b/yuanrong/pkg/functionscaler/rollout/rollouthandler.go @@ -0,0 +1,218 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rollout - +package rollout + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/loadbalance" + "yuanrong/pkg/common/faas_common/logger/log" + commonTypes "yuanrong/pkg/common/faas_common/types" +) + +const ( + defaultAllocRecordSyncChanSize = 1000 +) + +var ( + globalRolloutHandler = &RFHandler{allocRecordSyncCh: make(chan map[string][]string, defaultAllocRecordSyncChanSize)} + rolloutSdkClient api.LibruntimeAPI +) + +// GetGlobalRolloutHandler - +func GetGlobalRolloutHandler() *RFHandler { + return globalRolloutHandler +} + +// SetRolloutSdkClient - +func SetRolloutSdkClient(sdkClient api.LibruntimeAPI) { + rolloutSdkClient = sdkClient +} + +// RolloutRatio - +type RolloutRatio struct { + CurrentVersion string `json:"CurrentVersion"` + RolloutRatio string `json:"rolloutRatio"` +} + +// RFHandler - +type RFHandler struct { + LoadBalance loadbalance.WNGINX + allocRecordSyncCh chan map[string][]string + IsGaryUpdating bool + ForwardInstance string + CurrentVersion string + CurrentRatio int + sync.RWMutex +} + +const ( + forward = "forward" + notForward = "noForward" + maxRatio = 100 +) + +// UpdateForwardInstance - +func (rf *RFHandler) UpdateForwardInstance(instanceID string) { + rf.Lock() + defer rf.Unlock() + log.GetLogger().Infof("update forward instance to %s", instanceID) + rf.ForwardInstance = instanceID +} + +// ShouldForwardRequest - +func (rf *RFHandler) ShouldForwardRequest() bool { + rf.Lock() + defer rf.Unlock() + if len(rf.ForwardInstance) == 0 { + return false + } + node := rf.LoadBalance.Next("", true) + nodeStr, ok := node.(string) + if !ok { + return false + } + if nodeStr == forward { + return true + } + return false +} + +// GetCurrentRatio - +func (rf *RFHandler) GetCurrentRatio() int { + rf.RLock() + ratio := rf.CurrentRatio + rf.RUnlock() + return ratio +} + +// ProcessRatioUpdate - +func (rf *RFHandler) ProcessRatioUpdate(ratioData []byte) error { + rf.Lock() + defer rf.Unlock() + rolloutRatio := &RolloutRatio{} + err := json.Unmarshal(ratioData, rolloutRatio) + if err != nil { + log.GetLogger().Errorf("failed to process ratio update, unmarshal error %s", err.Error()) + return err + } + ratio, err := strconv.Atoi(strings.TrimSuffix(rolloutRatio.RolloutRatio, "%")) + if err != nil { + log.GetLogger().Errorf("failed to process ratio update, ratio parse error %s", err.Error()) + return err + } + if ratio > maxRatio { + log.GetLogger().Errorf("failed to process ratio update, ratio %s is invalid", ratio) + return errors.New("rolloutRatio larger than 100%") + } + rf.CurrentVersion = rolloutRatio.CurrentVersion + grayLoadBalance := loadbalance.WNGINX{} + grayLoadBalance.Add(forward, ratio) + grayLoadBalance.Add(notForward, maxRatio-ratio) + rf.LoadBalance = grayLoadBalance + rf.CurrentRatio = ratio + log.GetLogger().Infof("succeed to update rollout ratio to %d%", ratio) + return nil +} + +// ProcessRatioDelete - +func (rf *RFHandler) ProcessRatioDelete() { + rf.Lock() + defer rf.Unlock() + grayLoadBalance := loadbalance.WNGINX{} + rf.LoadBalance = grayLoadBalance + log.GetLogger().Infof("succeed to delete rollout ratio") +} + +// GetAllocRecordSyncChan - +func (rf *RFHandler) GetAllocRecordSyncChan() chan map[string][]string { + return rf.allocRecordSyncCh +} + +// ProcessAllocRecordSync - +func (rf *RFHandler) ProcessAllocRecordSync(selfInsID, targetInsID string) { + log.GetLogger().Infof("start to process allocation record synchronize") + rsp, err := rf.SendRolloutRequest(selfInsID, targetInsID) + if err != nil { + log.GetLogger().Errorf("failed to sync alloc record from instance %s error %s", targetInsID, err.Error()) + return + } + rf.allocRecordSyncCh <- rsp.AllocRecord + log.GetLogger().Infof("succeed to process allocation record synchronize") +} + +// SendRolloutRequest - +func (rf *RFHandler) SendRolloutRequest(selfInsID, targetInsID string) (*commonTypes.RolloutResponse, error) { + log.GetLogger().Infof("start to send rollout request from %s to %s", selfInsID, targetInsID) + rolloutArg := api.Arg{ + Type: api.Value, + Data: []byte(fmt.Sprintf("rollout#%s", selfInsID)), + } + rspData, err := InvokeByInstanceId([]api.Arg{rolloutArg}, targetInsID, "") + if err != nil { + log.GetLogger().Errorf("failed to send rollout request to %s error %s", targetInsID, err.Error()) + return nil, err + } + rolloutResp := &commonTypes.RolloutResponse{} + err = json.Unmarshal(rspData, rolloutResp) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal rollout response error %s", err.Error()) + return nil, err + } + log.GetLogger().Infof("succeed to send rollout request from %s to %s", selfInsID, targetInsID) + return rolloutResp, nil +} + +// InvokeByInstanceId - +func InvokeByInstanceId(args []api.Arg, instanceID string, traceID string) ([]byte, error) { + wait := make(chan struct{}, 1) + var ( + res []byte + resErr error + ) + invokeStart := time.Now() + funcMeta := api.FunctionMeta{FuncID: constant.FaasSchedulerName, Api: api.PosixApi} + invokeOpts := api.InvokeOptions{TraceID: traceID} + objID, invokeErr := rolloutSdkClient.InvokeByInstanceId(funcMeta, instanceID, args, invokeOpts) + if invokeErr != nil { + log.GetLogger().Errorf("failed to invoke by id %s, traceID: %s, function: %s, error: %s", + instanceID, traceID, constant.FaasSchedulerName, invokeErr.Error()) + return nil, invokeErr + } + rolloutSdkClient.GetAsync(objID, func(result []byte, err error) { + res = result + resErr = err + wait <- struct{}{} + if _, err := rolloutSdkClient.GDecreaseRef([]string{objID}); err != nil { + log.GetLogger().Warnf("GDecreaseRef objID %s failed, %s", objID, err.Error()) + } + }) + <-wait + log.GetLogger().Infof("success invoke instance: %s, resErr %v, totalTime: %f", instanceID, resErr, + time.Since(invokeStart).Seconds()) + return res, resErr +} diff --git a/yuanrong/pkg/functionscaler/rollout/rollouthandler_test.go b/yuanrong/pkg/functionscaler/rollout/rollouthandler_test.go new file mode 100644 index 0000000..0257aaf --- /dev/null +++ b/yuanrong/pkg/functionscaler/rollout/rollouthandler_test.go @@ -0,0 +1,142 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rollout + +import ( + "encoding/json" + "errors" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "yuanrong.org/kernel/runtime/libruntime/api" + + mockUtils "yuanrong/pkg/common/faas_common/utils" +) + +func TestGetGlobalUpdateHandler(t *testing.T) { + handler := GetGlobalRolloutHandler() + assert.NotNil(t, handler) +} + +func TestProcessRatio(t *testing.T) { + handler := GetGlobalRolloutHandler() + handler.UpdateForwardInstance("scheduler2") + assert.Equal(t, handler.ForwardInstance, "scheduler2") + rolloutRatio := &RolloutRatio{ + RolloutRatio: "aa%", + } + ratio, _ := json.Marshal(rolloutRatio) + err := GetGlobalRolloutHandler().ProcessRatioUpdate(ratio) + assert.NotNil(t, err) + + rolloutRatio = &RolloutRatio{ + RolloutRatio: "200%", + } + ratio, _ = json.Marshal(rolloutRatio) + err = GetGlobalRolloutHandler().ProcessRatioUpdate(ratio) + assert.NotNil(t, err) + + rolloutRatio = &RolloutRatio{ + RolloutRatio: "20%", + } + ratio, _ = json.Marshal(rolloutRatio) + _ = GetGlobalRolloutHandler().ProcessRatioUpdate(ratio) + forwardCount := 0 + notForwardCount := 0 + totalCount := 100 + for i := 0; i < totalCount; i++ { + if handler.ShouldForwardRequest() { + forwardCount++ + } else { + notForwardCount++ + } + } + assert.Equal(t, 20, forwardCount) + assert.Equal(t, 80, notForwardCount) + + GetGlobalRolloutHandler().ProcessRatioDelete() + + forwardCount = 0 + notForwardCount = 0 + totalCount = 100 + for i := 0; i < totalCount; i++ { + if handler.ShouldForwardRequest() { + forwardCount++ + } else { + notForwardCount++ + } + } + assert.Equal(t, 0, forwardCount) + assert.Equal(t, 100, notForwardCount) +} + +func TestProcessAllocRecordSync(t *testing.T) { + var ( + invokeRes []byte + invokeErr error + ) + defer gomonkey.ApplyFunc(InvokeByInstanceId, func(args []api.Arg, instanceID string, traceID string) ([]byte, error) { + return invokeRes, invokeErr + }).Reset() + convey.Convey("Test ProcessAllocRecordSync", t, func() { + invokeErr = errors.New("some error") + globalRolloutHandler.ProcessAllocRecordSync("instance1", "instance2") + convey.So(len(globalRolloutHandler.allocRecordSyncCh), convey.ShouldEqual, 0) + invokeRes = []byte("error data") + invokeErr = nil + globalRolloutHandler.ProcessAllocRecordSync("instance1", "instance2") + convey.So(len(globalRolloutHandler.allocRecordSyncCh), convey.ShouldEqual, 0) + invokeRes = []byte(`{}`) + globalRolloutHandler.ProcessAllocRecordSync("instance1", "instance2") + convey.So(len(globalRolloutHandler.allocRecordSyncCh), convey.ShouldEqual, 1) + }) +} + +func TestInvokeByInstanceId(t *testing.T) { + SetRolloutSdkClient(&mockUtils.FakeLibruntimeSdkClient{}) + convey.Convey("test InvokeByInstanceId", t, func() { + convey.Convey("invoke error", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&mockUtils.FakeLibruntimeSdkClient{}), + "InvokeByInstanceId", func(_ *mockUtils.FakeLibruntimeSdkClient, + funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + return "", errors.New("invoke error") + }).Reset() + _, err := InvokeByInstanceId([]api.Arg{}, "testInstance", "123") + convey.So(err.Error(), convey.ShouldContainSubstring, "invoke error") + }) + convey.Convey("success invoke", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&mockUtils.FakeLibruntimeSdkClient{}), + "InvokeByInstanceId", func(_ *mockUtils.FakeLibruntimeSdkClient, + funcMeta api.FunctionMeta, instanceID string, args []api.Arg, + invokeOpt api.InvokeOptions) (string, error) { + return "123", nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&mockUtils.FakeLibruntimeSdkClient{}), + "GetAsync", func(_ *mockUtils.FakeLibruntimeSdkClient, + objectID string, cb api.GetAsyncCallback) { + cb([]byte("hello"), nil) + }).Reset() + res, err := InvokeByInstanceId([]api.Arg{}, "testInstance", "123") + convey.So(err, convey.ShouldBeNil) + convey.So(string(res), convey.ShouldEqual, "hello") + }) + }) +} diff --git a/yuanrong/pkg/functionscaler/scaler/autoscaler.go b/yuanrong/pkg/functionscaler/scaler/autoscaler.go new file mode 100644 index 0000000..c233c39 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scaler/autoscaler.go @@ -0,0 +1,394 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scaler - +package scaler + +import ( + "math" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + // SLA time should not be shorter than cold start time + minSLATime = time.Duration(500) * time.Millisecond + defaultScaleUpInitTime = 100 +) + +var ( + scaleUpInitTime = time.Duration(defaultScaleUpInitTime) * time.Millisecond +) + +// AutoScaler will scales instance automatically based on calculation upon instance metrics +type AutoScaler struct { + metricsCollector metrics.Collector + funcKeyWithRes string + scaleUpWindow time.Duration + scaleDownWindow time.Duration + coldStartTime time.Duration + concurrentNum int + // create new instance may take longer than scale up window, uses pendingInsThdNum to record these instances to + // avoid excessive scaling + pendingInsThdNum int + inUseInsThdNum int + totalInsThdNum int + // remainedInsThdReqNum stores this difference between pendingInsThdReqNum and scaleInsThdNum if scaleInsTheNum is + // less than pendingInsThdReqNum + remainedInsThdReqNum int + // autoScaleUpFlag tells if auto scale up process is running + autoScaleUpFlag bool + // autoScaleDownFlag tells if auto scale up process is running + autoScaleDownFlag bool + enable bool + checkReqNumFunc func() int + scaleUpHandler ScaleUpHandler + scaleDownHandler ScaleDownHandler + // scaleUpTriggerCh triggers auto scale up process, we will trigger auto scale up process if metrics are collected, + // before then we will manually create instance for each instance thread request + scaleUpTriggerCh chan struct{} + // scaleDownTriggerCh triggers auto scale down process, currently it is triggered as soon as autoScaler is created, + // consider to control the scale down timing to make scale down process more efficient + scaleDownTriggerCh chan struct{} + stopCh chan struct{} + sync.RWMutex +} + +// NewAutoScaler will create a AutoScaler +func NewAutoScaler(funcKeyWithRes string, metricsCollector metrics.Collector, checkReqNumFunc CheckReqNumFunc, + scaleUpHandler ScaleUpHandler, scaleDownHandler ScaleDownHandler) InstanceScaler { + scaleUpWindow := time.Duration(config.GlobalConfig.AutoScaleConfig.SLAQuota) * time.Millisecond + if scaleUpWindow < minSLATime { + scaleUpWindow = minSLATime + } + scaleDownWindow := time.Duration(config.GlobalConfig.AutoScaleConfig.ScaleDownTime) * time.Millisecond + if scaleDownWindow < scaleUpWindow { + scaleDownWindow = scaleUpWindow + } + autoScaler := &AutoScaler{ + funcKeyWithRes: funcKeyWithRes, + metricsCollector: metricsCollector, + scaleUpWindow: scaleUpWindow, + scaleDownWindow: scaleDownWindow, + checkReqNumFunc: checkReqNumFunc, + scaleUpHandler: scaleUpHandler, + scaleDownHandler: scaleDownHandler, + enable: false, + scaleUpTriggerCh: make(chan struct{}, 1), + scaleDownTriggerCh: make(chan struct{}, 1), + stopCh: make(chan struct{}), + } + if scaleUpInitTime > autoScaler.scaleUpWindow { + scaleUpInitTime = autoScaler.scaleUpWindow + } + go autoScaler.scaleUpLoop() + // Abandoned before 20250330 + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + go autoScaler.scaleDownLoop() + } + return autoScaler +} + +// SetEnable will configure the enable of scaler +func (as *AutoScaler) SetEnable(enable bool) { +} + +// TriggerScale will trigger scale +func (as *AutoScaler) TriggerScale() { + as.Lock() + if !as.autoScaleUpFlag { + as.autoScaleUpFlag = true + as.scaleUpTriggerCh <- struct{}{} + } + as.Unlock() +} + +// CheckScaling will check if scaler is scaling +func (as *AutoScaler) CheckScaling() bool { + isScaling := false + as.RLock() + isScaling = as.autoScaleUpFlag || as.autoScaleDownFlag + as.RUnlock() + return isScaling +} + +// GetExpectInstanceNumber - number of pending and running instance +func (as *AutoScaler) GetExpectInstanceNumber() int { + as.RLock() + expectNum := (as.totalInsThdNum + as.pendingInsThdNum) / as.concurrentNum + as.RUnlock() + return expectNum +} + +// UpdateCreateMetrics will update create metrics +func (as *AutoScaler) UpdateCreateMetrics(coldStartTime time.Duration) { + as.Lock() + if coldStartTime > as.coldStartTime { + as.coldStartTime = coldStartTime + log.GetLogger().Infof("cold start time for function %s is updated to %d ms", as.funcKeyWithRes, + coldStartTime.Milliseconds()) + } + as.Unlock() +} + +// HandleInsThdUpdate will update instance thread metrics, totalInsThd increase should be coupled with pendingInsThd +// decrease for better consistency +func (as *AutoScaler) HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) { + as.metricsCollector.UpdateInsThdMetrics(inUseInsThdDiff) + as.Lock() + as.inUseInsThdNum += inUseInsThdDiff + as.totalInsThdNum += totalInsThdDiff + // trigger scale down process if curRsvInsThdNum != 0 + if (as.totalInsThdNum - as.inUseInsThdNum) >= as.concurrentNum { + select { + case as.scaleDownTriggerCh <- struct{}{}: + default: + log.GetLogger().Warnf("scale down channel blocks for function %s", as.funcKeyWithRes) + } + } + as.Unlock() +} + +// HandleFuncSpecUpdate - +func (as *AutoScaler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + as.Lock() + as.concurrentNum = utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum) + as.Unlock() + log.GetLogger().Infof("config concurrentNum to %d for auto scaler %s", as.concurrentNum, as.funcKeyWithRes) +} + +// HandleInsConfigUpdate - +func (as *AutoScaler) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { +} + +// HandleCreateError handles instance create error +func (as *AutoScaler) HandleCreateError(createError error) { +} + +// Destroy will destroy scaler +func (as *AutoScaler) Destroy() { + commonUtils.SafeCloseChannel(as.stopCh) +} + +func (as *AutoScaler) scaleUpLoop() { + // unlike scale down, scale up process should start right away, so we define scale up channel this way to manually + // trigger first round's scale up before ticker fires + scaleUpChan := make(chan struct{}, 1) + tickerChan := make(<-chan time.Time, 1) + var ticker *time.Ticker + for { + select { + case _, ok := <-scaleUpChan: + if !ok { + log.GetLogger().Warnf("scale up channel is closed") + return + } + insThdReq := as.checkReqNumFunc() + if insThdReq == 0 { + as.pauseScale(ticker) + continue + } + as.scaleUpInstances() + case _, ok := <-as.scaleUpTriggerCh: + if !ok { + log.GetLogger().Warnf("trigger channel is closed") + return + } + // let requests come in for certain time to calculate a more reasonable scale up number + time.Sleep(scaleUpInitTime) + scaleUpChan <- struct{}{} + ticker = time.NewTicker(as.scaleUpWindow) + tickerChan = ticker.C + log.GetLogger().Infof("scale up loop for function %s is running", as.funcKeyWithRes) + case <-tickerChan: + scaleUpChan <- struct{}{} + case <-as.stopCh: + log.GetLogger().Warnf("stop scale up loop for function %s now", as.funcKeyWithRes) + return + } + } +} + +func (as *AutoScaler) pauseScale(ticker *time.Ticker) { + if ticker != nil { + ticker.Stop() + } + as.Lock() + as.autoScaleUpFlag = false + as.Unlock() + log.GetLogger().Infof("scale up loop for function %s is paused", as.funcKeyWithRes) +} + +func (as *AutoScaler) scaleDownLoop() { + scaleDownChan := make(<-chan time.Time, 1) + var timer *time.Timer + for { + select { + case <-scaleDownChan: + as.scaleDownInstances() + if timer != nil { + timer.Stop() + } + as.Lock() + as.autoScaleDownFlag = false + as.Unlock() + log.GetLogger().Infof("scale down loop for function %s is paused", as.funcKeyWithRes) + case _, ok := <-as.scaleDownTriggerCh: + if !ok { + log.GetLogger().Warnf("trigger channel is closed") + return + } + as.Lock() + as.autoScaleDownFlag = true + as.Unlock() + if timer == nil { + timer = time.NewTimer(as.scaleDownWindow) + } else { + select { + case <-timer.C: + default: + } + timer.Reset(as.scaleDownWindow) + } + scaleDownChan = timer.C + log.GetLogger().Infof("scale down loop for function %s is running", as.funcKeyWithRes) + case <-as.stopCh: + log.GetLogger().Warnf("stop scale down loop for function %s now", as.funcKeyWithRes) + return + } + } +} + +func (as *AutoScaler) handlePendingInsNumIncrease(insDiff int) { + as.Lock() + as.pendingInsThdNum += insDiff * as.concurrentNum + as.Unlock() +} + +func (as *AutoScaler) handlePendingInsNumDecrease(insDiff int) { + as.Lock() + as.pendingInsThdNum -= insDiff * as.concurrentNum + as.Unlock() +} + +func (as *AutoScaler) scaleUpInstances() { + as.Lock() + defer as.Unlock() + pendingInsThdReqNum := float64(as.checkReqNumFunc()) + if !as.metricsCollector.InvokeMetricsCollected() { + // fire at will if no metrics is ever collected for this function + scaleInsThdNum := pendingInsThdReqNum + // when request triggers scale up, there must be no availInsThdNum, no need to calculate with availInsThdNum, + // be aware that scaleInsThdNum is unsigned but pendingInsThdNum is signed + scaleInsThdNum = math.Max(scaleInsThdNum-float64(as.pendingInsThdNum), 0) + scaleInsNum := int(math.Ceil(scaleInsThdNum / float64(as.concurrentNum))) + as.pendingInsThdNum += scaleInsNum * as.concurrentNum + log.GetLogger().Infof("calculated scale up instance number for function %s is %d", as.funcKeyWithRes, + scaleInsNum) + as.scaleUpHandler(scaleInsNum, as.handlePendingInsNumDecrease) + return + } + avgProcTime, insThdProcNumPS, insThdReqNumPS := as.metricsCollector.GetCalculatedInvokeMetrics() + log.GetLogger().Infof("parameters for calculating scale up of function %s avgProcTime %f insThdProcNumPS %f "+ + "insThdReqNumPS %f pendingInsThdReqNum %f pendingInsThdNum %d", as.funcKeyWithRes, + avgProcTime, insThdProcNumPS, insThdReqNumPS, pendingInsThdReqNum, as.pendingInsThdNum) + if insThdProcNumPS == 0 { + log.GetLogger().Errorf("invalid value for insThdProcNumPS") + return + } + // when request triggers scale up, there must be no availInsThdNum, no need to calculate with availInsThdNum + procNumToDo := pendingInsThdReqNum + insThdReqNumPS*as.scaleUpWindow.Seconds() + procCapAvail := insThdProcNumPS * float64(as.inUseInsThdNum) * math.Max(as.scaleUpWindow.Seconds()-avgProcTime, 0) + procNumToDo = math.Max(procNumToDo-procCapAvail, 0) + procWindow := as.scaleUpWindow.Seconds() - as.coldStartTime.Seconds() + // handle the exception case that cold start takes longer than scale up window, set scaleInsThdNum to difference + // between procNumToDo and pendingInsThdNum to avoid scale up 0 + scaleInsThdNum := float64(0) + if procWindow > 0 { + scaleInsThdNum = math.Max(math.Ceil(procNumToDo/insThdProcNumPS/procWindow), 0) + } else { + scaleInsThdNum = math.Max(math.Ceil(procNumToDo), 0) + } + // try to scale less than pendingInsThdReqNum to take most advantage of instance reuse, considering these cases: + // 1. insThdProcNumPS is relatively large and one instance thread can process several requests during this scaling + // window + // 2. in-used instance thread may be released to faas scheduler during autoscaling which can be reused + scaleInsThdNum = math.Min(scaleInsThdNum, pendingInsThdReqNum) + // check remainedInsThdReqNum from last round, if it's less than pendingInsThdReqNum (otherwise some requests in + // remainedInsThdReqNum must be fulfilled and we just need to calculate with pendingInsThdReqNum in previous step), + // calibrate the scaleInsNum to avoid these remained instance thread requests being remained again + if float64(as.remainedInsThdReqNum) < pendingInsThdReqNum { + scaleInsThdNum = math.Max(scaleInsThdNum, float64(as.remainedInsThdReqNum)) + } + // calibrate scaleInsThdNum with pendingInsThdNum to avoid excessive scaling, be aware that scaleInsThdNum is + // unsigned but pendingInsThdNum is signed + scaleInsThdNum = math.Max(scaleInsThdNum-float64(as.pendingInsThdNum), 0) + // scaleInsThdNum may be smaller than pendingInsThdReqNum if insThdProcNumPS is relatively large, in this case + // remainedInsThdReqNum will store this difference between pendingInsThdReqNum and scaleInsThdNum to tell next + // round's scaleUpInstances there may be several insThdReqs unfulfilled which it should be aware + as.remainedInsThdReqNum = int(math.Max(pendingInsThdReqNum-scaleInsThdNum, 0)) + scaleInsNum := int(math.Ceil(scaleInsThdNum / float64(as.concurrentNum))) + log.GetLogger().Infof("calculated scale up instance number for function %s is %d", as.funcKeyWithRes, scaleInsNum) + if scaleInsNum > 0 { + as.pendingInsThdNum += scaleInsNum * as.concurrentNum + as.scaleUpHandler(scaleInsNum, as.handlePendingInsNumDecrease) + } +} + +func (as *AutoScaler) scaleDownInstances() { + if !as.metricsCollector.InvokeMetricsCollected() { + as.Lock() + scaleInsThdNum := math.Max(float64(as.totalInsThdNum-as.inUseInsThdNum), 0) + // be aware that scaleInsThdNum is unsigned but pendingInsThdNum is signed + scaleInsThdNum = -math.Min(-scaleInsThdNum-float64(as.pendingInsThdNum), 0) + scaleInsNum := int(math.Floor(scaleInsThdNum / float64(as.concurrentNum))) + as.pendingInsThdNum -= scaleInsNum * as.concurrentNum + as.Unlock() + as.scaleDownHandler(scaleInsNum, as.handlePendingInsNumIncrease) + return + } + avgProcTime, insThdProcNumPS, insThdReqNumPS := as.metricsCollector.GetCalculatedInvokeMetrics() + log.GetLogger().Infof("parameters for calculating scale down of function %s avgProcTime %f insThdProcNumPS %f "+ + "insThdReqNumPS %f totalInsThdNum %d pendingInsThdNum %d", as.funcKeyWithRes, avgProcTime, insThdProcNumPS, + insThdReqNumPS, as.totalInsThdNum, as.pendingInsThdNum) + if insThdProcNumPS == 0 { + log.GetLogger().Errorf("invalid value for insThdProcNumPS") + return + } + as.Lock() + defer as.Unlock() + procNumToDo := insThdReqNumPS * as.scaleDownWindow.Seconds() + procCapAvail := insThdProcNumPS * (math.Max(float64(as.totalInsThdNum-as.inUseInsThdNum), 0)* + as.scaleDownWindow.Seconds() + float64(as.inUseInsThdNum)*math.Max(as.scaleUpWindow.Seconds()-avgProcTime, 0)) + procCapExcess := math.Max(procCapAvail-procNumToDo, 0) + scaleInsThdNum := math.Ceil(procCapExcess / insThdProcNumPS / as.scaleDownWindow.Seconds()) + // be aware that scaleInsThdNum is unsigned but pendingInsThdNum is signed + scaleInsThdNum = -math.Min(-scaleInsThdNum-float64(as.pendingInsThdNum), 0) + scaleInsNum := int(math.Floor(scaleInsThdNum / float64(as.concurrentNum))) + log.GetLogger().Infof("calculated scale down instance number for function %s is %d", as.funcKeyWithRes, scaleInsNum) + if scaleInsNum > 0 { + as.pendingInsThdNum -= scaleInsNum * as.concurrentNum + as.scaleDownHandler(scaleInsNum, as.handlePendingInsNumIncrease) + } +} diff --git a/yuanrong/pkg/functionscaler/scaler/autoscaler_test.go b/yuanrong/pkg/functionscaler/scaler/autoscaler_test.go new file mode 100644 index 0000000..11ccafb --- /dev/null +++ b/yuanrong/pkg/functionscaler/scaler/autoscaler_test.go @@ -0,0 +1,117 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scaler - +package scaler + +import ( + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" +) + +func TestAutoScaler_TriggerScale(t *testing.T) { + as := &AutoScaler{ + autoScaleUpFlag: false, + scaleUpTriggerCh: make(chan struct{}, 1), + } + as.TriggerScale() + assert.Equal(t, true, as.autoScaleUpFlag) +} + +func TestAutoScaler_UpdateCreateMetrics(t *testing.T) { + as := &AutoScaler{} + as.UpdateCreateMetrics(5 * time.Second) + assert.Equal(t, 5*time.Second, as.coldStartTime) +} + +func TestAutoScaler_scaleUpLoop(t *testing.T) { + as := &AutoScaler{} + convey.Convey("channel close", t, func() { + as.scaleUpTriggerCh = make(chan struct{}, 1) + as.stopCh = make(chan struct{}, 1) + close(as.scaleUpTriggerCh) + as.scaleUpLoop() + as.scaleUpTriggerCh = make(chan struct{}, 1) + close(as.stopCh) + as.scaleUpLoop() + }) + convey.Convey("normal case", t, func() { + reqNum := 1 + as.checkReqNumFunc = func() int { return reqNum } + as.scaleUpTriggerCh = make(chan struct{}, 1) + as.stopCh = make(chan struct{}, 1) + as.scaleUpWindow = 50 * time.Millisecond + callCount := 0 + p := gomonkey.ApplyFunc((*AutoScaler).scaleUpInstances, func() { + callCount++ + reqNum = 0 + }) + go as.scaleUpLoop() + time.Sleep(100 * time.Millisecond) + as.scaleUpTriggerCh <- struct{}{} + time.Sleep(200 * time.Millisecond) + convey.So(as.autoScaleUpFlag, convey.ShouldBeFalse) + convey.So(callCount, convey.ShouldEqual, 1) + p.Reset() + close(as.stopCh) + }) +} + +func TestAutoScaler_scaleDownLoop(t *testing.T) { + as := &AutoScaler{} + convey.Convey("channel close", t, func() { + as.scaleDownTriggerCh = make(chan struct{}, 1) + as.stopCh = make(chan struct{}, 1) + close(as.scaleDownTriggerCh) + as.scaleDownLoop() + as.scaleDownTriggerCh = make(chan struct{}, 1) + close(as.stopCh) + as.scaleDownLoop() + }) + convey.Convey("normal case", t, func() { + as.scaleDownTriggerCh = make(chan struct{}, 1) + as.stopCh = make(chan struct{}, 1) + as.scaleDownWindow = 50 * time.Millisecond + callCount := 0 + p := gomonkey.ApplyFunc((*AutoScaler).scaleDownInstances, func() { + callCount++ + }) + go as.scaleDownLoop() + time.Sleep(100 * time.Millisecond) + as.scaleDownTriggerCh <- struct{}{} + time.Sleep(200 * time.Millisecond) + convey.So(as.autoScaleDownFlag, convey.ShouldBeFalse) + convey.So(callCount, convey.ShouldEqual, 1) + p.Reset() + close(as.stopCh) + }) +} + +func TestAutoScaler_pendingInsNumOperation(t *testing.T) { + as := &AutoScaler{concurrentNum: 100} + convey.Convey("increase", t, func() { + as.handlePendingInsNumIncrease(1) + convey.So(as.pendingInsThdNum, convey.ShouldEqual, 100) + }) + convey.Convey("decrease", t, func() { + as.handlePendingInsNumDecrease(1) + convey.So(as.pendingInsThdNum, convey.ShouldEqual, 0) + }) +} diff --git a/yuanrong/pkg/functionscaler/scaler/instance_scaler.go b/yuanrong/pkg/functionscaler/scaler/instance_scaler.go new file mode 100644 index 0000000..854c184 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scaler/instance_scaler.go @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scaler - +package scaler + +import ( + "time" + + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/functionscaler/types" +) + +// ScaleUpCallback executes some logic after scale up +type ScaleUpCallback func(int) + +// ScaleDownCallback executes some logic after scale down +type ScaleDownCallback func(int) + +// CheckReqNumFunc returns current request number of instance thread +type CheckReqNumFunc func() int + +// ScaleUpHandler handles instance scale up +type ScaleUpHandler func(int, ScaleUpCallback) + +// ScaleDownHandler handles instance scale down +type ScaleDownHandler func(int, ScaleDownCallback) + +// InstanceScaler scales instance to meet certain need +type InstanceScaler interface { + SetEnable(enable bool) + TriggerScale() + CheckScaling() bool + UpdateCreateMetrics(coldStartTime time.Duration) + HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) + HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) + HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) + HandleCreateError(createError error) + GetExpectInstanceNumber() int + Destroy() +} diff --git a/yuanrong/pkg/functionscaler/scaler/instance_scaler_test.go b/yuanrong/pkg/functionscaler/scaler/instance_scaler_test.go new file mode 100644 index 0000000..ba388b8 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scaler/instance_scaler_test.go @@ -0,0 +1,563 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scaler - +package scaler + +import ( + "encoding/json" + "errors" + "os" + "reflect" + "sync" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commontypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" +) + +func TestMiscellaneous(t *testing.T) { + config.GlobalConfig = types.Configuration{} + scaler := NewAutoScaler("test", &metrics.BucketCollector{}, func() int { return 1 }, + func(i int, cb ScaleUpCallback) {}, func(i int, cb ScaleDownCallback) {}) + as := scaler.(*AutoScaler) + as.HandleFuncSpecUpdate(&types.FunctionSpecification{InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}}) + as.pendingInsThdNum = 6 + convey.Convey("CheckScaling", t, func() { + res := as.CheckScaling() + convey.So(res, convey.ShouldBeFalse) + }) + convey.Convey("HandleCreateError", t, func() { + as.HandleCreateError(nil) + convey.So(as.pendingInsThdNum, convey.ShouldEqual, 6) + }) +} + +func TestReplicaScaler(t *testing.T) { + scaler := NewReplicaScaler("test", &metrics.BucketCollector{}, func(i int, cb ScaleUpCallback) {}, + func(i int, cb ScaleDownCallback) {}) + rs := scaler.(*ReplicaScaler) + rs.HandleFuncSpecUpdate(&types.FunctionSpecification{InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 1}}) + convey.Convey("CheckScaling", t, func() { + res := rs.CheckScaling() + convey.So(res, convey.ShouldBeFalse) + }) +} + +func TestScaleUpInstances(t *testing.T) { + me := metrics.NewBucketMetricsCollector("funcKey123", "resource300") + as := &AutoScaler{ + metricsCollector: me, + autoScaleUpFlag: true, + concurrentNum: 2, + pendingInsThdNum: 6, + scaleUpWindow: 1 * time.Second, + coldStartTime: 1 * time.Second, + scaleUpHandler: func(_ int, _ ScaleUpCallback) {}, + } + checkNum := func() int { + return as.pendingInsThdNum + 2 + } + as.checkReqNumFunc = checkNum + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "InvokeMetricsCollected", + func(_ *metrics.BucketCollector) bool { + return true + }).Reset() + + convey.Convey("insThdProcNumPS is 0", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "GetCalculatedInvokeMetrics", + func(_ *metrics.BucketCollector) (float64, float64, float64) { + return 0, 0, 0 + }).Reset() + as.scaleUpInstances() + convey.So(as.pendingInsThdNum, convey.ShouldEqual, 6) + convey.So(as.remainedInsThdReqNum, convey.ShouldEqual, 0) + }) + convey.Convey("procWindow = 0", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "GetCalculatedInvokeMetrics", + func(_ *metrics.BucketCollector) (float64, float64, float64) { + return 0, 1, 0 + }).Reset() + as.scaleUpInstances() + convey.So(as.pendingInsThdNum, convey.ShouldEqual, 8) + convey.So(as.remainedInsThdReqNum, convey.ShouldEqual, 6) + }) + convey.Convey("procWindow > 0", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "GetCalculatedInvokeMetrics", + func(_ *metrics.BucketCollector) (float64, float64, float64) { + return 0, 1, 0 + }).Reset() + as.scaleUpWindow = 2 * time.Second + as.remainedInsThdReqNum = 1 + as.scaleUpInstances() + convey.So(as.pendingInsThdNum, convey.ShouldEqual, 10) + convey.So(as.remainedInsThdReqNum, convey.ShouldEqual, 8) + }) + convey.Convey("InvokeMetricsCollected is false", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "InvokeMetricsCollected", + func(_ *metrics.BucketCollector) bool { + return false + }).Reset() + as.scaleUpInstances() + convey.So(as.pendingInsThdNum, convey.ShouldEqual, 12) + }) +} + +func TestScaleDownInstances(t *testing.T) { + me := metrics.NewBucketMetricsCollector("funcKey123", "resource300") + res := 0 + as := &AutoScaler{ + metricsCollector: me, + autoScaleUpFlag: true, + inUseInsThdNum: 2, + totalInsThdNum: 6, + concurrentNum: 2, + pendingInsThdNum: 0, + scaleDownWindow: 1 * time.Second, + coldStartTime: 1 * time.Second, + scaleUpHandler: func(_ int, cb ScaleUpCallback) {}, + scaleDownHandler: func(input int, cb ScaleDownCallback) { res = input }, + } + checkNum := func() int { + return as.pendingInsThdNum + 2 + } + as.checkReqNumFunc = checkNum + + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "InvokeMetricsCollected", + func(_ *metrics.BucketCollector) bool { + return false + }).Reset() + convey.Convey("pendingThsThdNum negative and little", t, func() { + as.pendingInsThdNum = -2 + as.scaleDownInstances() + convey.So(res, convey.ShouldEqual, 1) + }) + convey.Convey("pendingThsThdNum negative and big", t, func() { + as.pendingInsThdNum = -6 + as.scaleDownInstances() + convey.So(res, convey.ShouldEqual, 0) + }) + convey.Convey("pendingThsThdNum positive", t, func() { + as.pendingInsThdNum = 2 + as.scaleDownInstances() + convey.So(res, convey.ShouldEqual, 3) + }) + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "InvokeMetricsCollected", + func(_ *metrics.BucketCollector) bool { + return true + }).Reset() + convey.Convey("insThdProcNumPS is 0", t, func() { + res = 0 + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "GetCalculatedInvokeMetrics", + func(_ *metrics.BucketCollector) (float64, float64, float64) { + return 0, 0, 0 + }).Reset() + as.scaleDownInstances() + convey.So(res, convey.ShouldEqual, 0) + }) + convey.Convey("insThdProcNumPS is not 0", t, func() { + res = 0 + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "GetCalculatedInvokeMetrics", + func(_ *metrics.BucketCollector) (float64, float64, float64) { + return 0, 10, 10 + }).Reset() + as.scaleDownInstances() + convey.So(res, convey.ShouldEqual, 0) + }) +} + +func TestAutoScalerProcess(t *testing.T) { + me := metrics.NewBucketMetricsCollector("funcKey123", "resource300") + res := 0 + as := &AutoScaler{ + metricsCollector: me, + autoScaleUpFlag: true, + concurrentNum: 2, + pendingInsThdNum: 6, + scaleUpWindow: 1 * time.Second, + coldStartTime: 1 * time.Second, + scaleUpHandler: func(_ int, cb ScaleUpCallback) {}, + scaleDownHandler: func(input int, cb ScaleDownCallback) { res = input }, + } + + convey.Convey("HandleInsThdUpdate", t, func() { + as.inUseInsThdNum = 4 + as.totalInsThdNum = 6 + as.pendingInsThdNum = 0 + as.HandleInsThdUpdate(0, 1) + expectNum := as.GetExpectInstanceNumber() + convey.So(expectNum, convey.ShouldEqual, 3) + convey.So(res, convey.ShouldEqual, 0) + }) +} + +func TestHandleCreateError(t *testing.T) { + res := 0 + rs := &ReplicaScaler{ + metricsCollector: metrics.NewBucketMetricsCollector("", ""), + scaleUpHandler: func(in int, cb ScaleUpCallback) { res = in }, + enable: true, + pendingRsvInsNum: 2, + concurrentNum: 1, + currentRsvInsNum: 1, + targetRsvInsNum: 4, + } + + convey.Convey("HandleCreateError", t, func() { + rs.HandleCreateError(snerror.New(statuscode.StsConfigErrCode, "config sts error")) + convey.So(res, convey.ShouldEqual, 0) + rs.enable = true + rs.HandleCreateError(snerror.New(statuscode.UserFuncEntryNotFoundErrCode, "entry not found")) + convey.So(res, convey.ShouldEqual, 1) + rs.HandleCreateError(errors.New("recoverable error")) + convey.So(res, convey.ShouldEqual, 1) + }) +} + +func TestReplicaScalerProcess(t *testing.T) { + rs := &ReplicaScaler{ + metricsCollector: metrics.NewBucketMetricsCollector("", ""), + targetRsvInsNum: 1, + concurrentNum: 1, + scaleUpHandler: func(i int, cb ScaleUpCallback) { + }, + scaleDownHandler: func(i int, cb ScaleDownCallback) { + }, + } + convey.Convey("HandleInsThdUpdate", t, func() { + rs.HandleInsThdUpdate(0, 1) + convey.So(rs.currentRsvInsNum, convey.ShouldEqual, 1) + }) + convey.Convey("HandleInsConfigUpdate", t, func() { + rs.HandleInsConfigUpdate(&instanceconfig.Configuration{InstanceMetaData: commontypes.InstanceMetaData{MinInstance: 2}}) + convey.So(rs.targetRsvInsNum, convey.ShouldEqual, 2) + }) + convey.Convey("HandleFuncSpecUpdate", t, func() { + rs.HandleFuncSpecUpdate(&types.FunctionSpecification{InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}}) + convey.So(rs.concurrentNum, convey.ShouldEqual, 2) + }) +} + +func TestTriggerScale(t *testing.T) { + convey.Convey("TriggerScale", t, func() { + var scaleDownNum int + replicaScaler := NewReplicaScaler("testfunc-cpu-500-mem-500", metrics.NewBucketMetricsCollector("", ""), + func(i int, callback ScaleUpCallback) {}, func(i int, callback ScaleDownCallback) { + scaleDownNum = i + callback(1) + }) + replicaScaler.SetEnable(false) + replicaScaler.HandleFuncSpecUpdate(&types.FunctionSpecification{InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 1}}) + replicaScaler.HandleInsThdUpdate(0, 2) + replicaScaler.HandleInsConfigUpdate(&instanceconfig.Configuration{InstanceMetaData: commontypes.InstanceMetaData{ + MaxInstance: 1, + MinInstance: 1, + ConcurrentNum: 1, + }}) + replicaScaler.SetEnable(true) + convey.So(scaleDownNum, convey.ShouldEqual, 1) + }) +} + +func TestCalSleepTime(t *testing.T) { + config.GlobalConfig.PredictGroupWindow = 15 * 60 * 1000 + convey.Convey("CalSleepTime", t, func() { + // 2024-05-08 21:33:00 --> 1715175180000 + t1 := 1715175180000 % config.GlobalConfig.PredictGroupWindow + convey.So(t1, convey.ShouldEqual, 3000*60) + + time := CalSleepTime(1715175180000, time.Duration(10)*time.Minute) + convey.So(time.Milliseconds(), convey.ShouldEqual, 2000*60) + }) +} + +func TestHandlePredictUpdate(t *testing.T) { + config.GlobalConfig = types.Configuration{PredictGroupWindow: 15 * 60 * 1000} + pScaler := NewPredictScaler("test", &metrics.BucketCollector{}, func() int { return 1 }, + func(i int, cb ScaleUpCallback) {}, func(i int, cb ScaleDownCallback) {}) + ps := pScaler.(*PredictScaler) + ps.coldStartTime = 15 * 60 * 1000 + result := []float64{30.0, 60.0} // [2,3] + gomonkey.ApplyFunc(CalSleepTime, + func(currentTimeStamp int64, coldStartTime time.Duration) time.Duration { + return time.Duration(1) + }) + // currentInsNum is 6 + ps.totalInsThdNum = 4 + ps.pendingInsThdNum = 1 + ps.concurrentNum = 1 + ps.minRsvInsNum = 1 + + convey.Convey("HandlePredictUpdate", t, func() { + ps.HandlePredictUpdate(&types.PredictQPSGroups{ + QPSGroups: result, + }) + ps.scaleDownChan <- struct{}{} + convey.So(ps.predictDownDiff, convey.ShouldEqual, 4) + // convey.So(ps.predictUpDiff, convey.ShouldEqual, 1) + }) + + convey.Convey("CheckScaling", t, func() { + res := ps.CheckScaling() + convey.So(res, convey.ShouldBeFalse) + }) + + // currentInsNum is 1 + ps.totalInsThdNum = 0 + ps.pendingInsThdNum = 0 + ps.concurrentNum = 1 + ps.minRsvInsNum = 1 + convey.Convey("HandlePredictUpdate", t, func() { + ps.HandlePredictUpdate(&types.PredictQPSGroups{ + QPSGroups: result, + }) + ps.scaleUpChan <- struct{}{} + convey.So(ps.predictDownDiff, convey.ShouldEqual, 0) + // convey.So(ps.predictUpDiff, convey.ShouldEqual, 2) + }) + + convey.Convey("HandleInsConfigUpdate", t, func() { + ps.HandleInsConfigUpdate(&instanceconfig.Configuration{InstanceMetaData: commontypes.InstanceMetaData{MinInstance: 2}}) + convey.So(ps.minRsvInsNum, convey.ShouldEqual, 2) + }) + + convey.Convey("HandleFuncSpecUpdate", t, func() { + ps.HandleFuncSpecUpdate(&types.FunctionSpecification{InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}}) + convey.So(ps.concurrentNum, convey.ShouldEqual, 2) + }) +} + +func TestPredictScalerProcess(t *testing.T) { + me := metrics.NewBucketMetricsCollector("funcKey123", "resource300") + scaleInsNum := 0 + res := 0 + ps := &PredictScaler{ + logger: log.GetLogger().With(zap.Any("funcKeyWithRes", "funcKey123")), + metricsCollector: me, + predictScaleUpFlag: true, + concurrentNum: 2, + pendingInsThdNum: 6, + scaleUpWindow: 1 * time.Second, + scaleDownWindow: 1 * time.Second, + coldStartTime: 1 * time.Second, + scaleUpHandler: func(_ int, cb ScaleUpCallback) {}, + scaleDownHandler: func(input int, cb ScaleDownCallback) { res = input }, + scaleUpChan: make(chan struct{}, 1), + scaleDownChan: make(chan struct{}, 1), + scaleUpTriggerCh: make(chan struct{}, 1), + scaleDownTriggerCh: make(chan struct{}, 1), + stopCh: make(chan struct{}, 1), + } + checkNum := func() int { + return ps.pendingInsThdNum + 2 + } + ps.checkReqNumFunc = checkNum + + convey.Convey("InvokeMetricsCollected is true", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "InvokeMetricsCollected", + func(_ *metrics.BucketCollector) bool { + return true + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "GetCalculatedInvokeMetrics", + func(_ *metrics.BucketCollector) (float64, float64, float64) { + return 0, 0, 0 + }).Reset() + scaleInsNum = ps.getScaleDownInstancesNum() + convey.So(scaleInsNum, convey.ShouldEqual, -1) + convey.So(res, convey.ShouldEqual, 0) + + defer gomonkey.ApplyMethod(reflect.TypeOf(&metrics.BucketCollector{}), "GetCalculatedInvokeMetrics", + func(_ *metrics.BucketCollector) (float64, float64, float64) { + return 3, 1, 0 + }).Reset() + ps.totalInsThdNum = 4 + ps.inUseInsThdNum = 1 + ps.scaleDownWindow = 60000 * time.Millisecond + ps.pendingInsThdNum = -1 + scaleInsNum = ps.getScaleDownInstancesNum() + convey.So(scaleInsNum, convey.ShouldEqual, 1) + }) + + convey.Convey("getScaleUpInstancesNum", t, func() { + scaleInsNum = ps.getScaleUpInstancesNum() + convey.So(scaleInsNum, convey.ShouldEqual, 1) + convey.So(res, convey.ShouldEqual, 0) + }) + + convey.Convey("calculate insThdNum", t, func() { + ps.inUseInsThdNum = 1 + ps.totalInsThdNum = 4 + ps.concurrentNum = 1 + ps.HandleInsThdUpdate(1, 0) + ps.TriggerScale() + + ps.pendingInsThdNum = 2 + ps.predictDownDiff = 1 + expectNum := ps.GetExpectInstanceNumber() + convey.So(expectNum, convey.ShouldEqual, 6) + ps.handlePendingInsNumIncrease(1) + convey.So(ps.pendingInsThdNum, convey.ShouldEqual, 3) + ps.handlePendingInsNumDecrease(1) + convey.So(ps.pendingInsThdNum, convey.ShouldEqual, 2) + }) + + convey.Convey("scaleUp", t, func() { + go ps.scaleUp() + ps.scaleUpTriggerCh <- struct{}{} + convey.So(res, convey.ShouldEqual, 0) + time.Sleep(2 * time.Second) + close(ps.stopCh) + }) + + convey.Convey("scaleDown", t, func() { + ps.stopCh = make(chan struct{}, 1) + go ps.scaleDown() + ps.scaleDownTriggerCh <- struct{}{} + convey.So(res, convey.ShouldEqual, 0) + time.Sleep(2 * time.Second) + close(ps.stopCh) + }) +} + +func TestStartPredictRegistry(t *testing.T) { + convey.Convey("StartPredictRegistry", t, func() { + + registry.GlobalRegistry = ®istry.Registry{FaaSSchedulerRegistry: registry.NewFaasSchedulerRegistry(make(chan struct{}))} + selfregister.GlobalSchedulerProxy.Add(&commontypes.InstanceInfo{ + FunctionName: "faasscheduler", + InstanceName: "abcdefg", + }, "") + selfregister.SelfInstanceID = "abcdefg" + os.Setenv(constant.ClusterNameEnvKey, "localAZ") + os.Setenv("INSTANCE_ID", "abcdefg") + defer func() { + os.Unsetenv(constant.ClusterNameEnvKey) + os.Unsetenv("INSTANCE_ID") + }() + config.GlobalConfig.PredictGroupWindow = 15 * 60 * 1000 + dsw0 := time.Now().Add(-1 * time.Minute) + dsw1 := time.Now().Add(30 * time.Minute) + dataSetTimeWindow := []int64{dsw0.UnixMilli(), dsw1.UnixMilli()} + qpsResult := map[string][]float64{ + "244177614494574475/0@default@primary_secondary/latest": {0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 4.5, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, + 9.0, 9.5, 10.0, 10.5, 11.0, 11.5, 12.0, + 12.5, 13.0, 13.5, 14.0, 14.5, 15.0}, + "244177614494574475/0@default@primary_secondary-slv0/latest": {0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 4.5, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, + 9.0, 9.5, 10.0, 10.5, 11.0, 11.5, 12.0, + 12.5, 13.0, 13.5, 14.0}, + } + p := &types.PredictResult{ + DataSetTimeWindow: dataSetTimeWindow, + QPSResult: qpsResult, + IsValid: false, + } + etcdClient := &etcd3.EtcdClient{} + var ( + patches []*gomonkey.Patches + mockHandlePredictUpdate func(prg *types.PredictQPSGroups) + ) + patches = append(patches, + gomonkey.ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { + return etcdClient + }), + gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", func(ew *etcd3.EtcdWatcher) { + bytes, _ := json.Marshal(p) + ew.ResultChan <- &etcd3.Event{ + Type: 0, + Key: "/arrivalPrediction/localAZ", + Value: bytes, + PrevValue: nil, + Rev: 0, + } + p.DataSetTimeWindow = []int64{time.Now().Add(-40 * time.Minute).UnixMilli(), + time.Now().Add(-10 * time.Minute).UnixMilli()} + bytes, _ = json.Marshal(p) + ew.ResultChan <- &etcd3.Event{ + Type: 0, + Key: "/arrivalPrediction/localAZ", + Value: bytes, + PrevValue: nil, + Rev: 0, + } + }), + gomonkey.ApplyMethod(reflect.TypeOf(&PredictScaler{}), "HandlePredictUpdate", + func(ps *PredictScaler, prg *types.PredictQPSGroups) { + mockHandlePredictUpdate(prg) + }), + ) + defer func() { + for _, patch := range patches { + patch.Reset() + time.Sleep(10 * time.Millisecond) + } + }() + + predictScaler := NewPredictScaler("test-function-res", nil, func() int { + return 0 + }, func(i int, callback ScaleUpCallback) { + + }, func(i int, callback ScaleDownCallback) { + + }).(*PredictScaler) + predictScaler.HandleInsConfigUpdate(&instanceconfig.Configuration{InstanceMetaData: commontypes.InstanceMetaData{MinInstance: 0}}) + var predictGroupNum int + var wg sync.WaitGroup + var ( + result_primary_secondary []float64 + result_primary_secondary_slv0 []float64 + ) + wg.Add(2) + mockHandlePredictUpdate = func(prg *types.PredictQPSGroups) { + predictGroupNum++ + if prg.FuncKey == "244177614494574475/0@default@primary_secondary/latest" { + result_primary_secondary = prg.QPSGroups + + } + if prg.FuncKey == "244177614494574475/0@default@primary_secondary-slv0/latest" { + result_primary_secondary_slv0 = prg.QPSGroups + } + wg.Done() + } + + wg.Wait() + convey.So(predictGroupNum, convey.ShouldEqual, 2) + convey.So(len(result_primary_secondary), convey.ShouldEqual, 2) + convey.So(result_primary_secondary[0], convey.ShouldEqual, 3.6666666666666665) + convey.So(result_primary_secondary[1], convey.ShouldEqual, 10.966666666666667) + + convey.So(len(result_primary_secondary_slv0), convey.ShouldEqual, 2) + convey.So(result_primary_secondary_slv0[0], convey.ShouldEqual, 3.6666666666666665) + convey.So(result_primary_secondary_slv0[1], convey.ShouldEqual, 9) + predictScaler.Destroy() + }) +} diff --git a/yuanrong/pkg/functionscaler/scaler/predictscaler.go b/yuanrong/pkg/functionscaler/scaler/predictscaler.go new file mode 100644 index 0000000..00a519b --- /dev/null +++ b/yuanrong/pkg/functionscaler/scaler/predictscaler.go @@ -0,0 +1,589 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scaler - +package scaler + +import ( + "encoding/json" + "fmt" + "math" + "os" + "strings" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + // Millisecond2Second - + Millisecond2Second = 1000 + // Minute2Second - + Minute2Second = 60 + // DefaultReqDelay - + DefaultReqDelay = 3000 + + predictEtcdPrefix = "/arrivalPrediction/" + dataSetTimeWindowLen = 2 + qpsGroupsNum = 2 + millisecondToMinute = 60 * 1000 + longColdStartTime = 3 * 60 * 1000 * time.Millisecond +) + +// PredictScaler will scale based on a predict instance number +type PredictScaler struct { + logger api.FormatLogger + metricsCollector metrics.Collector + funcKeyWithRes string + scaleUpWindow time.Duration + scaleDownWindow time.Duration + coldStartTime time.Duration + predictGroupWindow int64 // Grouping window for prediction results + concurrentNum int + pendingInsThdNum int + inUseInsThdNum int // The number of leases used by this function version + totalInsThdNum int // The total number of leases that this function version can provide + remainedInsThdReqNum int + predictInsNum int + predictDownDiff int + predictUpDiff int + minRsvInsNum int + predictScaleUpFlag bool + predictScaleDownFlag bool + predictUpChanFlag bool + predictDownChanFlag bool + enable bool + checkReqNumFunc func() int + scaleUpHandler ScaleUpHandler + scaleDownHandler ScaleDownHandler + // active, Triggered by prediction results + scaleUpChan chan struct{} + scaleDownChan chan struct{} + // passive, Triggered based on real-time lease application status + scaleUpTriggerCh chan struct{} + scaleDownTriggerCh chan struct{} + stopCh chan struct{} + sync.RWMutex + sync.Once +} + +// NewPredictScaler will create a PredictScaler +func NewPredictScaler(funcKeyWithRes string, metricsCollector metrics.Collector, checkReqNumFunc CheckReqNumFunc, + scaleUpHandler ScaleUpHandler, scaleDownHandler ScaleDownHandler) InstanceScaler { + scaleUpWindow := time.Duration(config.GlobalConfig.AutoScaleConfig.SLAQuota) * time.Millisecond + if scaleUpWindow < minSLATime { + scaleUpWindow = minSLATime + } + scaleDownWindow := time.Duration(config.GlobalConfig.AutoScaleConfig.ScaleDownTime) * time.Millisecond + if scaleDownWindow < scaleUpWindow { + scaleDownWindow = scaleUpWindow + } + + predictScaler := &PredictScaler{ + logger: log.GetLogger().With(zap.Any("funcKeyWithRes", funcKeyWithRes)), + metricsCollector: metricsCollector, + funcKeyWithRes: funcKeyWithRes, + scaleUpWindow: scaleUpWindow, + scaleDownWindow: scaleDownWindow, + predictGroupWindow: config.GlobalConfig.PredictGroupWindow, + checkReqNumFunc: checkReqNumFunc, + predictInsNum: 0, + enable: false, + scaleUpHandler: scaleUpHandler, + scaleDownHandler: scaleDownHandler, + scaleUpChan: make(chan struct{}, 1), + scaleDownChan: make(chan struct{}, 1), + scaleUpTriggerCh: make(chan struct{}, 1), + scaleDownTriggerCh: make(chan struct{}, 1), + stopCh: make(chan struct{}, 1), + } + predictScaler.logger.Infof("create predict scaler") + + go predictScaler.scaleUp() + go predictScaler.scaleDown() + return predictScaler +} + +func (ps *PredictScaler) startPredictRegistry() { + metaEtcdClient := etcd3.GetMetaEtcdClient() + if metaEtcdClient == nil { + ps.logger.Errorf("failed to get meta etcd") + return + } + watcher := etcd3.NewEtcdWatcher( + predictEtcdPrefix+os.Getenv(constant.ClusterName), + func(event *etcd3.Event) bool { + return false + }, + ps.watcherHandler, + ps.stopCh, + metaEtcdClient) + watcher.StartWatch() +} + +func (ps *PredictScaler) watcherHandler(event *etcd3.Event) { + ps.logger.Infof("handling predict event type %s key %s", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + ps.logger.Infof("predict registry ready to receive etcd kv") + return + } + switch event.Type { + case etcd3.PUT: + predictBytes := event.Value + predictResult := getPredictFromEtcdValue(predictBytes) + if predictResult == nil { + return + } + if !checkPredictResultValid(predictResult, ps.coldStartTime) { + ps.logger.Warnf("predict result is invalid, dataSetTimeWindow is in [%d, %d]", + predictResult.DataSetTimeWindow[0], predictResult.DataSetTimeWindow[1]) + return + } + filterPredictFunction(predictResult) + for funcKey, QPSNums := range predictResult.QPSResult { + predictQPSGroups := &types.PredictQPSGroups{FuncKey: funcKey} + predictQPSGroups.QPSGroups = groupQPSNumByPredictWindow(QPSNums, config.GlobalConfig.PredictGroupWindow) + ps.HandlePredictUpdate(predictQPSGroups) + } + default: + } +} + +// getPredictFromEtcdValue parse the PredictResult from etcd value +func getPredictFromEtcdValue(etcdValue []byte) *types.PredictResult { + if len(etcdValue) == 0 { + return nil + } + predictResult := &types.PredictResult{} + if err := json.Unmarshal(etcdValue, predictResult); err != nil { + log.GetLogger().Errorf("failed to unmarshal etcd value to PredictResult, err: %s", err.Error()) + return nil + } + return predictResult +} + +func checkPredictResultValid(predictResult *types.PredictResult, startTime time.Duration) bool { + if len(predictResult.DataSetTimeWindow) != dataSetTimeWindowLen { + return false + } + + // Limitation: The instance cold start time must be less than the expansion and contraction window + if config.GlobalConfig.PredictGroupWindow < startTime.Milliseconds() { + return false + } + + currentTimeStamp := time.Now().UnixMilli() + // If there is no time to scaleUp for the next predictGroupWindow when faasScheduler is started, + // the prediction results at this time are ignored. + // valid : dataSetTimeWindow[0] < currentTimeStamp < dataSetTimeWindow[0] + PredictGroupWindow - startTime + if currentTimeStamp < predictResult.DataSetTimeWindow[0] || currentTimeStamp > + predictResult.DataSetTimeWindow[0]+config.GlobalConfig.PredictGroupWindow-startTime.Milliseconds() { + return false + } + + return true +} + +// before, QpsResult is kv of functionURN and QpsNum +// after, QpsResult is kv of funcKey and QpsNum +func filterPredictFunction(predictResult *types.PredictResult) { + filterMap := make(map[string][]float64, len(predictResult.QPSResult)) + for functionURN, predictNum := range predictResult.QPSResult { + if selfregister.GlobalSchedulerProxy.CheckFuncOwner(functionURN) { + filterMap[functionURN] = predictNum + } + } + predictResult.QPSResult = filterMap +} + +// len(QpsNums) should be 30, unit is minute +// In the future, the prediction instance strategy can be made configurable, such as average value and extreme value. +func groupQPSNumByPredictWindow(QPSNums []float64, predictWindow int64) []float64 { + if predictWindow == 0 { + return []float64{} + } + step := int(predictWindow / millisecondToMinute) + var groupQPS []float64 + for i := 0; i < len(QPSNums); i += step { + var totalNum float64 + for j := i + 1; j < i+step && j < len(QPSNums); j++ { + totalNum += QPSNums[j] + } + avg := totalNum / float64(step*1.0) + groupQPS = append(groupQPS, avg) + } + return groupQPS +} + +// SetEnable will configure the enable of scaler +func (ps *PredictScaler) SetEnable(enable bool) { +} + +// TriggerScale will trigger scale +func (ps *PredictScaler) TriggerScale() { + ps.Lock() + // 实时租约触发实例扩容 + if !ps.predictScaleUpFlag { + ps.predictScaleUpFlag = true + ps.scaleUpTriggerCh <- struct{}{} + } + ps.Unlock() +} + +// CheckScaling will check if scaler is scaling +func (ps *PredictScaler) CheckScaling() bool { + isScaling := false + ps.RLock() + isScaling = ps.predictScaleUpFlag || ps.predictScaleDownFlag + ps.RUnlock() + return isScaling +} + +// UpdateCreateMetrics will update create metrics +func (ps *PredictScaler) UpdateCreateMetrics(coldStartTime time.Duration) { +} + +// HandleInsThdUpdate will update instance thread metrics, totalInsThd increase should be coupled with pendingInsThd +// decrease for better consistency +func (ps *PredictScaler) HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) { + ps.metricsCollector.UpdateInsThdMetrics(inUseInsThdDiff) + ps.Lock() + ps.inUseInsThdNum += inUseInsThdDiff + ps.totalInsThdNum += totalInsThdDiff + if (ps.totalInsThdNum - ps.inUseInsThdNum) >= ps.concurrentNum { + select { + case ps.scaleDownTriggerCh <- struct{}{}: + default: + ps.logger.Warnf("scale down channel blocks") + } + } + ps.Unlock() +} + +// HandleFuncSpecUpdate - +func (ps *PredictScaler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + ps.Lock() + ps.concurrentNum = utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum) + ps.coldStartTime = time.Duration(funcSpec.ExtendedMetaData.Initializer.Timeout) * time.Second + ps.Unlock() + resSpec := resspeckey.ConvertResourceMetaDataToResSpec(funcSpec.ResourceMetaData) + ps.funcKeyWithRes = fmt.Sprintf("%s-%s", funcSpec.FuncKey, resSpec.String()) + ps.logger = log.GetLogger().With(zap.Any("funcKeyWithRes", ps.funcKeyWithRes)) + ps.logger.Infof("config concurrentNum to %d for predict scaler, coldStartTime is %f s", + ps.concurrentNum, ps.coldStartTime.Seconds()) +} + +// HandleInsConfigUpdate - +func (ps *PredictScaler) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { + replicaNum := int(insConfig.InstanceMetaData.MinInstance) + if replicaNum < 0 { + replicaNum = 0 + } + ps.Lock() + prevMinRsvInsNum := ps.minRsvInsNum + ps.minRsvInsNum = replicaNum + ps.Unlock() + ps.logger.Infof("config reserved num from %d to %d for predict scaler", prevMinRsvInsNum, replicaNum) + ps.Once.Do(func() { + go ps.startPredictRegistry() + }) +} + +// HandleCreateError handles instance create error +func (ps *PredictScaler) HandleCreateError(createError error) { +} + +// GetExpectInstanceNumber get function minInstance in etcd +func (ps *PredictScaler) GetExpectInstanceNumber() int { + ps.RLock() + expectNum := (ps.totalInsThdNum + ps.pendingInsThdNum) / ps.concurrentNum + ps.RUnlock() + return expectNum +} + +// Destroy will destroy scaler +func (ps *PredictScaler) Destroy() { + commonUtils.SafeCloseChannel(ps.stopCh) +} + +// HandlePredictUpdate Use prediction results +func (ps *PredictScaler) HandlePredictUpdate(prg *types.PredictQPSGroups) { + if len(prg.QPSGroups) != qpsGroupsNum || !strings.HasPrefix(ps.funcKeyWithRes, prg.FuncKey) { + ps.logger.Errorf("prg.QPSGroups is invalid") + return + } + var avgProcTime float64 + if !ps.metricsCollector.InvokeMetricsCollected() { + // if no metrics is ever collected + avgProcTime = DefaultReqDelay + ps.logger.Info("if no metrics is ever collected") + } else { + avgProcTime, _, _ = ps.metricsCollector.GetCalculatedInvokeMetrics() + ps.logger.Infof("parameters avgProcTime %f ms", avgProcTime) + } + predictInsNum0 := int(math.Ceil(prg.QPSGroups[0] * avgProcTime / + (float64(ps.concurrentNum) * Minute2Second * Millisecond2Second))) + predictInsNum1 := int(math.Ceil(prg.QPSGroups[1] * avgProcTime / + (float64(ps.concurrentNum) * Minute2Second * Millisecond2Second))) + expectNum := ps.GetExpectInstanceNumber() // managed by predictScaler + currentInsNum := expectNum + ps.minRsvInsNum + ps.logger.Infof("forecast results: predictInsNum0 %d predictInsNum1 %d expectNum %d"+ + " ps.minRsvInsNum %d Current number of instances", predictInsNum0, predictInsNum1, expectNum, ps.minRsvInsNum) + ps.logger.Infof("current number of instances %d", currentInsNum) + + ps.predictDownDiff = 0 + ps.predictUpDiff = 0 + ps.predictInsNum = int(math.Max(float64(predictInsNum0), float64(predictInsNum1))) + predictInsNum0 = int(math.Max(float64(predictInsNum0), float64(ps.minRsvInsNum))) + + if currentInsNum > predictInsNum0 { + ps.predictDownChanFlag = true + ps.predictDownDiff = currentInsNum - predictInsNum0 + ps.scaleDownChan <- struct{}{} + } + + // 为下个窗口期提前扩容,下个窗口期提前一个冷启时间开始扩容 + sleepTime := CalSleepTime(time.Now().UnixMilli(), ps.coldStartTime) + time.Sleep(sleepTime) + currentInsNum = ps.GetExpectInstanceNumber() + ps.minRsvInsNum + if currentInsNum < predictInsNum1 { + ps.predictUpDiff = predictInsNum1 - currentInsNum + ps.scaleUpChan <- struct{}{} + } +} + +// CalSleepTime - +func CalSleepTime(currentTimeStamp int64, coldStartTime time.Duration) time.Duration { + return time.Duration(config.GlobalConfig.PredictGroupWindow)*time.Millisecond - coldStartTime - + time.Duration(currentTimeStamp%config.GlobalConfig.PredictGroupWindow)*time.Millisecond +} + +func (ps *PredictScaler) getScaleUpInstancesNum() int { + pendingInsThdReqNum := float64(ps.checkReqNumFunc()) + // fire at will if no metrics is ever collected for this function + ps.Lock() + scaleInsThdNum := math.Max(pendingInsThdReqNum-math.Max(float64(ps.pendingInsThdNum), 0), 0) + scaleInsNum := int(math.Ceil(scaleInsThdNum / float64(ps.concurrentNum))) + ps.Unlock() + ps.logger.Infof("calculated scale up instance number is %d", scaleInsNum) + if scaleInsNum > 0 { + return scaleInsNum + } + return 0 +} + +func (ps *PredictScaler) getScaleDownInstancesNum() int { + if !ps.metricsCollector.InvokeMetricsCollected() { + // try to scale down instance even if no metrics is ever collected + return 1 + } + avgProcTime, insThdProcNumPS, insThdReqNumPS := ps.metricsCollector.GetCalculatedInvokeMetrics() + ps.logger.Infof("parameters for calculating scale down avgProcTime %f insThdProcNumPS %f "+ + "insThdReqNumPS %f totalInsThdNum %d pendingInsThdNum %d", avgProcTime, insThdProcNumPS, + insThdReqNumPS, ps.totalInsThdNum, ps.pendingInsThdNum) + if insThdProcNumPS == 0 { + ps.logger.Errorf("invalid value for insThdProcNumPS") + return -1 + } + ps.Lock() + defer ps.Unlock() + procNumExcess := insThdProcNumPS*float64(ps.totalInsThdNum-ps.inUseInsThdNum)*ps.scaleDownWindow.Seconds() - + insThdReqNumPS*ps.scaleDownWindow.Seconds() + math.Min(float64(ps.pendingInsThdNum), 0) + scaleInsThdNum := math.Ceil(procNumExcess / insThdProcNumPS / ps.scaleDownWindow.Seconds()) + scaleInsNum := int(math.Floor(scaleInsThdNum / float64(ps.concurrentNum))) + ps.logger.Infof("calculated scale down instance number is %d", scaleInsNum) + if scaleInsNum > 0 { + return scaleInsNum + } + return scaleInsNum +} + +func (ps *PredictScaler) handlePendingInsNumIncrease(insDiff int) { + ps.Lock() + ps.pendingInsThdNum += insDiff * ps.concurrentNum + ps.predictDownChanFlag = false + ps.Unlock() +} + +func (ps *PredictScaler) handlePendingInsNumDecrease(insDiff int) { + ps.Lock() + ps.pendingInsThdNum -= insDiff * ps.concurrentNum + ps.predictUpChanFlag = false + ps.Unlock() +} + +func (ps *PredictScaler) scaleUp() { + scaleUpChan := make(chan struct{}, 1) + tickerChan := make(<-chan time.Time, 1) + var ticker *time.Ticker + for { + select { + case _, ok := <-ps.scaleUpChan: + if !ok { + ps.logger.Warnf("trigger channel is closed") + return + } + ps.logger.Infof("ps.predictUpDiff %d ", ps.predictUpDiff) + ps.Lock() + ps.pendingInsThdNum += ps.predictUpDiff * ps.concurrentNum + ps.Unlock() + ps.scaleUpHandler(ps.predictUpDiff, ps.handlePendingInsNumDecrease) + case _, ok := <-scaleUpChan: + if !ok { + ps.logger.Warnf("scale up channel is closed") + return + } + insThdReq := ps.checkReqNumFunc() + if insThdReq == 0 { + ps.stopScale(ticker) + continue + } + ps.getScaleNum() + case _, ok := <-ps.scaleUpTriggerCh: + ps.logger.Info("receive scaleUpTriggerCh") + if !ok { + ps.logger.Warnf("trigger channel is closed") + return + } + // Functions with long cold start times do not use passive scaleDown + if ps.coldStartTime > longColdStartTime { + continue + } + if ps.predictUpChanFlag || ps.predictDownChanFlag { + ps.logger.Warnf("It is handling predictive scaling, ignoring current signals") + continue + } + // let requests come in for certain time to calculate a more reasonable scale up number + time.Sleep(ps.scaleUpWindow) + scaleUpChan <- struct{}{} + ticker = time.NewTicker(ps.scaleUpWindow) + tickerChan = ticker.C + ps.logger.Infof("scale up loop is running") + case <-tickerChan: + scaleUpChan <- struct{}{} + case <-ps.stopCh: + ps.logger.Warnf("stop scale up loop now") + return + } + } +} + +func (ps *PredictScaler) getScaleNum() { + scaleInsNum := ps.getScaleUpInstancesNum() + ps.logger.Infof("scaleUpTriggerCh scaleInsNum : %d", scaleInsNum) + if scaleInsNum > 0 { + ps.Lock() + ps.pendingInsThdNum += scaleInsNum * ps.concurrentNum + ps.Unlock() + ps.scaleUpHandler(scaleInsNum, ps.handlePendingInsNumDecrease) + } +} + +func (ps *PredictScaler) stopScale(ticker *time.Ticker) { + if ticker != nil { + ticker.Stop() + } + ps.Lock() + ps.predictScaleUpFlag = false + ps.Unlock() + ps.logger.Infof("scale up loop is paused") +} + +func (ps *PredictScaler) scaleDown() { + scaleDownChan := make(<-chan time.Time, 1) + var timer *time.Timer + for { + select { + case _, ok := <-ps.scaleDownChan: + if !ok { + ps.logger.Warnf("trigger channel is closed") + return + } + ps.logger.Infof("ps.predictDownDiff %d ", ps.predictDownDiff) + + ps.Lock() + ps.pendingInsThdNum -= ps.predictDownDiff * ps.concurrentNum + ps.Unlock() + ps.scaleDownHandler(ps.predictDownDiff, ps.handlePendingInsNumIncrease) + case <-scaleDownChan: + scaleInsNum := ps.getScaleDownInstancesNum() + ps.logger.Infof("scaleDownTriggerCh scaleInsNum: %d", scaleInsNum) + if scaleInsNum > 0 { + ps.Lock() + // Scale down based on the predicted number of instances as the bottom line + currentAvailableInsNum := ps.totalInsThdNum/ps.concurrentNum + ps.minRsvInsNum + scaleInsNum = int(math.Min(float64(scaleInsNum), + math.Max(float64(currentAvailableInsNum-ps.predictInsNum), 0))) + ps.logger.Infof("scaleDownTriggerCh new scaleInsNum: %d", scaleInsNum) + ps.pendingInsThdNum -= scaleInsNum * ps.concurrentNum + ps.Unlock() + ps.scaleDownHandler(scaleInsNum, ps.handlePendingInsNumIncrease) + } + if timer != nil { + timer.Stop() + } + ps.Lock() + ps.predictScaleDownFlag = false + ps.Unlock() + ps.logger.Infof("scale down loop is paused") + case _, ok := <-ps.scaleDownTriggerCh: + if !ok { + ps.logger.Warnf("trigger channel is closed") + return + } + // Functions with long cold start times do not use passive scaleUp + if ps.coldStartTime > longColdStartTime { + continue + } + if ps.predictUpChanFlag || ps.predictDownChanFlag { + ps.logger.Warnf("handling predictive scaling, ignoring current signals") + continue + } + ps.Lock() + ps.predictScaleDownFlag = true + ps.Unlock() + if timer == nil { + timer = time.NewTimer(ps.scaleDownWindow) + } else { + select { + case <-timer.C: + default: + } + timer.Reset(ps.scaleDownWindow) + } + scaleDownChan = timer.C + ps.logger.Infof("scale down loop is running", ps.funcKeyWithRes) + case <-ps.stopCh: + ps.logger.Warnf("stop scale down loop now", ps.funcKeyWithRes) + return + } + } +} diff --git a/yuanrong/pkg/functionscaler/scaler/replicascaler.go b/yuanrong/pkg/functionscaler/scaler/replicascaler.go new file mode 100644 index 0000000..c123860 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scaler/replicascaler.go @@ -0,0 +1,245 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scaler - +package scaler + +import ( + "sync" + "time" + + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +// ReplicaScaler will scale based on a given instance number +type ReplicaScaler struct { + metricsCollector metrics.Collector + funcKeyWithRes string + concurrentNum int + pendingRsvInsNum int + targetRsvInsNum int + currentRsvInsNum int + createError error + enable bool + scaleLimit int + scaleUpHandler ScaleUpHandler + scaleDownHandler ScaleDownHandler + sync.RWMutex +} + +// NewReplicaScaler will create a ReplicaScaler +func NewReplicaScaler(funcKeyWithRes string, metricsCollector metrics.Collector, scaleUpHandler ScaleUpHandler, + scaleDownHandler ScaleDownHandler) InstanceScaler { + replicaScaler := &ReplicaScaler{ + metricsCollector: metricsCollector, + funcKeyWithRes: funcKeyWithRes, + targetRsvInsNum: 0, + enable: false, + scaleUpHandler: scaleUpHandler, + scaleDownHandler: scaleDownHandler, + } + log.GetLogger().Infof("create replica scaler for function %s, isManaged is %v", replicaScaler.funcKeyWithRes) + return replicaScaler +} + +// SetEnable will configure the enable of scaler +func (rs *ReplicaScaler) SetEnable(enable bool) { + rs.Lock() + if enable == rs.enable { + rs.Unlock() + return + } + rs.enable = enable + rs.Unlock() + if enable { + rs.handleScale() + } +} + +// TriggerScale will trigger scale +func (rs *ReplicaScaler) TriggerScale() { + rs.RLock() + if !rs.enable { + rs.RUnlock() + return + } + rs.RUnlock() + rs.handleScale() +} + +// CheckScaling will check if scaler is scaling +func (rs *ReplicaScaler) CheckScaling() bool { + rs.RLock() + isScaling := rs.currentRsvInsNum != rs.targetRsvInsNum + rs.RUnlock() + return isScaling +} + +// UpdateCreateMetrics will update create metrics +func (rs *ReplicaScaler) UpdateCreateMetrics(coldStartTime time.Duration) { +} + +// HandleInsThdUpdate will update instance thread metrics, totalInsThd increase should be coupled with pendingInsThd +// decrease for better consistency +func (rs *ReplicaScaler) HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) { + rs.metricsCollector.UpdateInsThdMetrics(inUseInsThdDiff) + // replica scaler won't handle scale when only inUseInsThdDiff is non-zero + if totalInsThdDiff == 0 { + return + } + rs.Lock() + rs.currentRsvInsNum += totalInsThdDiff / rs.concurrentNum + rs.Unlock() + rs.handleScale() +} + +// HandleFuncSpecUpdate - +func (rs *ReplicaScaler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + rs.Lock() + rs.concurrentNum = utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum) + rs.Unlock() + log.GetLogger().Infof("config concurrentNum to %d for replica scaler %s", rs.concurrentNum, rs.funcKeyWithRes) + // some error may be cleared after function update + rs.handleScale() +} + +// HandleInsConfigUpdate - +func (rs *ReplicaScaler) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { + replicaNum := int(insConfig.InstanceMetaData.MinInstance) + if replicaNum < 0 { + replicaNum = 0 + } + rs.Lock() + prevTargetRsvInsNum := rs.targetRsvInsNum + rs.targetRsvInsNum = replicaNum + log.GetLogger().Infof("config reserved num from %d to %d for replica scaler %s", prevTargetRsvInsNum, + replicaNum, rs.funcKeyWithRes) + if !rs.enable { + rs.Unlock() + return + } + rs.Unlock() + rs.handleScale() +} + +// HandleCreateError handles instance create error +func (rs *ReplicaScaler) HandleCreateError(createError error) { + log.GetLogger().Infof("handle create error %s for function %s", createError, rs.funcKeyWithRes) + if utils.IsUnrecoverableError(createError) { + if utils.IsNoNeedToRePullError(createError) { + rs.Lock() + rs.enable = false + rs.createError = createError + rs.Unlock() + return + } + } + // if unrecoverable error is cleared then trigger scale, if createError stays nil then skip scale + rs.Lock() + noNeedScale := rs.createError == nil && createError == nil + rs.createError = createError + rs.Unlock() + if noNeedScale { + return + } + rs.handleScale() +} + +// GetExpectInstanceNumber get function minInstance in etcd +func (rs *ReplicaScaler) GetExpectInstanceNumber() int { + rs.RLock() + enable := rs.enable + currentReplicaNum := rs.targetRsvInsNum + rs.RUnlock() + if !enable { + return 0 + } + return currentReplicaNum +} + +// Destroy will destroy scaler +func (rs *ReplicaScaler) Destroy() { + rs.Lock() + rs.enable = false + rs.targetRsvInsNum = 0 + rs.Unlock() +} + +func (rs *ReplicaScaler) handlePendingInsNumIncrease(insDiff int) { + rs.Lock() + rs.pendingRsvInsNum += insDiff + rs.Unlock() +} + +func (rs *ReplicaScaler) handlePendingInsNumDecrease(insDiff int) { + rs.Lock() + rs.pendingRsvInsNum -= insDiff + rs.Unlock() +} + +func (rs *ReplicaScaler) handleScale() { + rs.RLock() + enable := rs.enable + scaleNum := rs.targetRsvInsNum - rs.currentRsvInsNum + rs.RUnlock() + log.GetLogger().Infof("parameters for handle scale of function %s targetRsvInsNum %d currentRsvInsNum %d "+ + "pendingRsvInsNum %d scaleNum %d", rs.funcKeyWithRes, rs.targetRsvInsNum, rs.currentRsvInsNum, + rs.pendingRsvInsNum, scaleNum) + if !enable || config.GlobalConfig.DisableReplicaScaler { + log.GetLogger().Warnf("replicaScaler of function %s disable, targetNum is %d, currentNum is %d, pendingNum "+ + "is %d, funcKey is %s", rs.funcKeyWithRes, rs.targetRsvInsNum, rs.currentRsvInsNum, rs.pendingRsvInsNum) + return + } + if scaleNum > 0 { + rs.Lock() + if rs.pendingRsvInsNum >= 0 { + scaleNum -= rs.pendingRsvInsNum + if scaleNum <= 0 { + rs.Unlock() + log.GetLogger().Warnf("scaleNum <= 0, no need to scale up, "+ + "targetNum is %d, currentNum is %d, pendingNum is %d, funcKey is %s", + rs.targetRsvInsNum, rs.currentRsvInsNum, rs.pendingRsvInsNum, rs.funcKeyWithRes) + return + } + } + rs.pendingRsvInsNum += scaleNum + rs.Unlock() + log.GetLogger().Infof("calculate scale up instance number for function %s is %d", rs.funcKeyWithRes, scaleNum) + rs.scaleUpHandler(scaleNum, rs.handlePendingInsNumDecrease) + } else if scaleNum < 0 { + rs.Lock() + if rs.pendingRsvInsNum <= 0 { + scaleNum -= rs.pendingRsvInsNum + if scaleNum >= 0 { + rs.Unlock() + log.GetLogger().Warnf("scaleNum >= 0, no need to scale down, "+ + "targetNum is %d, currentNum is %d, pendingNum is %d, funcKey is %s", + rs.targetRsvInsNum, rs.currentRsvInsNum, rs.pendingRsvInsNum, rs.funcKeyWithRes) + return + } + } + rs.pendingRsvInsNum += scaleNum + rs.Unlock() + log.GetLogger().Infof("calculate scale down instance number for function %s is %d", rs.funcKeyWithRes, + -scaleNum) + rs.scaleDownHandler(-scaleNum, rs.handlePendingInsNumIncrease) + } +} diff --git a/yuanrong/pkg/functionscaler/scaler/replicascaler_test.go b/yuanrong/pkg/functionscaler/scaler/replicascaler_test.go new file mode 100644 index 0000000..0596a35 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scaler/replicascaler_test.go @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scaler - +package scaler + +import ( + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/instanceconfig" + commonTypes "yuanrong/pkg/common/faas_common/types" +) + +func TestReplicaScaler_CreateAndDestroy(t *testing.T) { + convey.Convey("create and destroy", t, func() { + s := NewReplicaScaler("testFunc", nil, nil, nil) + rs, ok := s.(*ReplicaScaler) + convey.So(ok, convey.ShouldEqual, true) + rs.Destroy() + convey.So(rs.enable, convey.ShouldEqual, false) + convey.So(rs.targetRsvInsNum, convey.ShouldEqual, 0) + }) +} + +func TestReplicaScaler_TriggerScale(t *testing.T) { + rs := &ReplicaScaler{} + convey.Convey("trigger scale", t, func() { + callCount := 0 + p := gomonkey.ApplyFunc((*ReplicaScaler).handleScale, func() { + callCount++ + }) + rs.TriggerScale() + convey.So(callCount, convey.ShouldEqual, 0) + rs.SetEnable(true) + rs.TriggerScale() + convey.So(callCount, convey.ShouldEqual, 2) + p.Reset() + }) +} + +func TestReplicaScaler_GetExpectInstanceNumber(t *testing.T) { + rs := &ReplicaScaler{concurrentNum: 100} + rs.scaleUpHandler = func(i int, callback ScaleUpCallback) {} + rs.SetEnable(true) + convey.Convey("increase", t, func() { + rs.HandleInsConfigUpdate(&instanceconfig.Configuration{ + InstanceMetaData: commonTypes.InstanceMetaData{ + MinInstance: 1, + }, + }) + convey.So(rs.GetExpectInstanceNumber(), convey.ShouldEqual, 1) + }) +} + +func TestReplicaScaler_pendingInsNumOperation(t *testing.T) { + rs := &ReplicaScaler{concurrentNum: 100} + convey.Convey("increase", t, func() { + rs.handlePendingInsNumIncrease(1) + convey.So(rs.pendingRsvInsNum, convey.ShouldEqual, 1) + }) + convey.Convey("decrease", t, func() { + rs.handlePendingInsNumDecrease(1) + convey.So(rs.pendingRsvInsNum, convey.ShouldEqual, 0) + }) +} diff --git a/yuanrong/pkg/functionscaler/scaler/wisecloudscaler.go b/yuanrong/pkg/functionscaler/scaler/wisecloudscaler.go new file mode 100644 index 0000000..1d191a4 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scaler/wisecloudscaler.go @@ -0,0 +1,227 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scaler - +package scaler + +import ( + "fmt" + "sync" + "time" + + "go.uber.org/zap" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/faas_common/wisecloudtool" + wisecloudTypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + errorExpireTime = 10 * time.Second // 失败后等10s重置unrecovererr +) + +var ( + localColdStarter *wisecloudtool.PodOperator + localColdStarterOnce sync.Once +) + +func newColdStarter(logger api.FormatLogger) *wisecloudtool.PodOperator { + localColdStarterOnce.Do(func() { + localColdStarter = wisecloudtool.NewColdStarter(&config.GlobalConfig.ServiceAccountJwt, logger) + }) + return localColdStarter +} + +// WiseCloudScaler will scales instance automatically based on calculation upon instance metrics +type WiseCloudScaler struct { + funcKeyWithRes string + nuwaRuntimeInfo *wisecloudTypes.NuwaRuntimeInfo + resSpec resspeckey.ResSpecKey + logger api.FormatLogger + + podOperator *wisecloudtool.PodOperator + + CreateCallback func(err error) + totalInsThdNum int + concurrentNum int + isReserve bool + enableScale bool + coldStartTrigger chan struct{} + stopCh chan struct{} + sync.RWMutex +} + +// NewWiseCloudScaler will create a WiseCloudScaler +func NewWiseCloudScaler(funcKeyWithRes string, resSpec resspeckey.ResSpecKey, + isReserve bool, createCallback func(err error)) InstanceScaler { + stopCh := make(chan struct{}) + wiseCloudScaler := &WiseCloudScaler{ + funcKeyWithRes: funcKeyWithRes, + resSpec: resSpec, + podOperator: newColdStarter(log.GetLogger().With(zap.Any("funcKey", funcKeyWithRes))), + logger: log.GetLogger().With(zap.Any("funcKey", funcKeyWithRes)), + CreateCallback: createCallback, + totalInsThdNum: 0, + concurrentNum: 1, + isReserve: isReserve, + enableScale: false, + coldStartTrigger: make(chan struct{}, 1), + stopCh: stopCh, + } + if isReserve { + go wiseCloudScaler.coldStartLoop() + } + return wiseCloudScaler +} + +// SetEnable will configure the enable of scaler +func (as *WiseCloudScaler) SetEnable(enable bool) { + if as.isReserve { + as.logger.Infof("set enable, from %v to %v", as.enableScale, enable) + as.enableScale = enable + } +} + +// DelNuwaPod will send a req to erase runtime pod +func (as *WiseCloudScaler) DelNuwaPod(ins *types.Instance) error { + if ins == nil { + return fmt.Errorf("ins is nil, skip") + } + var err error + for { + select { + case <-as.stopCh: + as.logger.Warnf("stop del nuwa pod %s loop, now", ins.PodID) + return err + default: + if as.nuwaRuntimeInfo == nil { + as.logger.Errorf("failed del nuwa pod %s, nuwaRuntimeInfo empty", ins.PodID) + time.Sleep(time.Second) + continue + } + err = as.podOperator.DelPod(as.nuwaRuntimeInfo, ins.PodDeploymentName, ins.PodID) + if err != nil { + as.logger.Errorf("failed del nuwa pod %s, %s", ins.PodID, err.Error()) + } + return err + } + } + +} + +// TriggerScale will trigger scale +func (as *WiseCloudScaler) TriggerScale() { + as.logger.Infof("trigger scale, enablescale: %v, totalInsThdNum: %v", as.enableScale, as.totalInsThdNum) + if as.enableScale && as.totalInsThdNum == 0 { + select { + case as.coldStartTrigger <- struct{}{}: + default: + as.logger.Debugf("scale has been trigger, skip") + } + } +} + +// CheckScaling will check if scaler is scaling +func (as *WiseCloudScaler) CheckScaling() bool { + return false +} + +// GetExpectInstanceNumber - number of pending and running instance +func (as *WiseCloudScaler) GetExpectInstanceNumber() int { + as.RLock() + defer as.RUnlock() + expectNum := as.totalInsThdNum / as.concurrentNum + return expectNum +} + +// UpdateCreateMetrics will update create metrics +func (as *WiseCloudScaler) UpdateCreateMetrics(coldStartTime time.Duration) { +} + +// HandleInsThdUpdate will update instance thread metrics, totalInsThd increase should be coupled with pendingInsThd +// decrease for better consistency +func (as *WiseCloudScaler) HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) { + as.Lock() + defer as.Unlock() + as.totalInsThdNum += totalInsThdDiff +} + +// HandleFuncSpecUpdate - +func (as *WiseCloudScaler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + as.Lock() + defer as.Unlock() + as.concurrentNum = utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum) +} + +// HandleInsConfigUpdate - +func (as *WiseCloudScaler) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { + if insConfig == nil || insConfig.NuwaRuntimeInfo.WisecloudRuntimeId == "" { + return + } + as.nuwaRuntimeInfo = &insConfig.NuwaRuntimeInfo +} + +// HandleCreateError handles instance create error +func (as *WiseCloudScaler) HandleCreateError(createError error) { +} + +// Destroy will destroy scaler +func (as *WiseCloudScaler) Destroy() { + commonUtils.SafeCloseChannel(as.stopCh) +} + +func (as *WiseCloudScaler) coldStartLoop() { + for { + select { + case _, ok := <-as.coldStartTrigger: + if !ok { + as.logger.Warnf("trigger channel is closed") + return + } + if as.nuwaRuntimeInfo == nil { + as.logger.Warnf("nuwa runtime info is empty, skip") + continue + } + if as.totalInsThdNum != 0 { + continue + } + err := as.podOperator.ColdStart(as.funcKeyWithRes, as.resSpec, as.nuwaRuntimeInfo) + if err != nil { + as.logger.Errorf("cold start failed, err %s", err.Error()) + as.CreateCallback(snerror.New(statuscode.WiseCloudNuwaColdStartErrCode, "cold start failed, err: "+err.Error())) + go func() { + time.Sleep(errorExpireTime) + as.CreateCallback(nil) + }() + continue + } + as.CreateCallback(nil) + as.logger.Infof("cold start succeed") + case <-as.stopCh: + as.logger.Warnf("stop cold create loop for function %s now", as.funcKeyWithRes) + return + } + } +} diff --git a/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/basic_concurrency_scheduler.go b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/basic_concurrency_scheduler.go new file mode 100644 index 0000000..1e5d1f6 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/basic_concurrency_scheduler.go @@ -0,0 +1,1509 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package concurrencyscheduler - +package concurrencyscheduler + +import ( + "context" + "errors" + "fmt" + "os" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/queue" + "yuanrong/pkg/common/faas_common/resspeckey" + commonTypes "yuanrong/pkg/common/faas_common/types" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/lease" + "yuanrong/pkg/functionscaler/metrics" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/signalmanager" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +type popDirection bool + +const ( + forward popDirection = true + backward popDirection = false + + randomThreadIDLength = 8 +) + +var ( + // ErrInsThdNotExist is the error of instance thread not exist + ErrInsThdNotExist = errors.New("instance thread not exist") + // ErrNoInsThdAvail is the error of instance thread not exist + ErrNoInsThdAvail = errors.New("no instance thread available") +) + +type sessionRecord struct { + ctx context.Context + timer *time.Timer + availThdMap map[string]struct{} + allocThdMap map[string]struct{} + expiring atomic.Value + ttl time.Duration + concurrency int + sessionID string + expireCancelCh chan struct{} + expireCh chan struct{} + cancelFunc func() +} + +func (s *sessionRecord) PutThreadToAvailThdMap(threadID string) error { + if _, ok := s.allocThdMap[threadID]; !ok { + return fmt.Errorf("thread %s doesn't belong to session %s for function", threadID, s.sessionID) + } + s.availThdMap[threadID] = struct{}{} + return nil +} + +func (s *sessionRecord) PutThreadToAllocThdMap(threadID string) { + s.allocThdMap[threadID] = struct{}{} +} + +func (s *sessionRecord) GetThreadFromAvailThdMap() string { + var ( + threadID string + ) + for key := range s.availThdMap { + threadID = key + break + } + delete(s.availThdMap, threadID) + return threadID +} + +type instanceElement struct { + instance *types.Instance + threadIndex int + threadIDPrefix string + isNewInstance bool + threadMap map[string]struct{} + sessionMap map[string]*sessionRecord +} + +func (i *instanceElement) PutThreadToThreadMap(threadID string) { + // 如果put回来的租约在map中仍然存在,则不处理(这种情况理论上不存在) + if _, ok := i.threadMap[threadID]; ok { + return + } + if len(i.threadMap) >= i.instance.ConcurrentNum { + return + } + i.threadMap[fmt.Sprintf("%s-thread%s-%d", i.instance.InstanceID, i.threadIDPrefix, i.threadIndex)] = struct{}{} + i.threadIndex++ +} + +func (i *instanceElement) GetThreadFromThreadMap() string { + var ( + threadID string + ) + for key := range i.threadMap { + threadID = key + break + } + delete(i.threadMap, threadID) + return threadID +} + +func (i *instanceElement) initThreadMap() { + i.threadIndex = 1 + i.threadMap = make(map[string]struct{}, i.instance.ConcurrentNum) + i.threadIDPrefix = utils.GenRandomString(randomThreadIDLength) + for ; i.threadIndex <= i.instance.ConcurrentNum; i.threadIndex++ { + i.threadMap[fmt.Sprintf("%s-thread%s-%d", i.instance.InstanceID, i.threadIDPrefix, i.threadIndex)] = struct{}{} + } + return +} + +type instanceObserver struct { + callback func(interface{}) +} + +type instanceQueueWithSubHealthAndEvictingRecord struct { + instanceQueue queue.Queue + subHealthRecord map[string]*instanceElement + evictingRecord map[string]*instanceElement +} + +// Front - +func (i *instanceQueueWithSubHealthAndEvictingRecord) Front() interface{} { + return i.instanceQueue.Front() +} + +// Back - +func (i *instanceQueueWithSubHealthAndEvictingRecord) Back() interface{} { + return i.instanceQueue.Back() +} + +// PopFront - +func (i *instanceQueueWithSubHealthAndEvictingRecord) PopFront() interface{} { + return i.instanceQueue.PopFront() +} + +// PopBack - +func (i *instanceQueueWithSubHealthAndEvictingRecord) PopBack() interface{} { + return i.instanceQueue.PopBack() +} + +// PopSubHealth - +func (i *instanceQueueWithSubHealthAndEvictingRecord) PopSubHealth() interface{} { + insElem := getSubHealthInstanceFromRecord(i.subHealthRecord) + if !commonUtils.IsNil(insElem) { + delete(i.subHealthRecord, insElem.instance.InstanceID) + return insElem + } + return nil +} + +// PushBack - +func (i *instanceQueueWithSubHealthAndEvictingRecord) PushBack(obj interface{}) error { + insElem, ok := obj.(*instanceElement) + if !ok { + return scheduler.ErrTypeConvertFail + } + _, existSubHealth := i.subHealthRecord[insElem.instance.InstanceID] + _, existGShut := i.evictingRecord[insElem.instance.InstanceID] + existHealth := i.instanceQueue.GetByID(insElem.instance.InstanceID) != nil + if existSubHealth || existHealth || existGShut { + return scheduler.ErrInsAlreadyExist + } + switch insElem.instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusRunning): + log.GetLogger().Infof("put into instanceQueue, ins: %+v", insElem) + return i.instanceQueue.PushBack(insElem) + case int32(constant.KernelInstanceStatusSubHealth): + i.subHealthRecord[insElem.instance.InstanceID] = insElem + case int32(constant.KernelInstanceStatusEvicting): + i.evictingRecord[insElem.instance.InstanceID] = insElem + default: + log.GetLogger().Warnf("ignore instance %s with unexpected status code %d", insElem.instance.InstanceID, + insElem.instance.InstanceStatus.Code) + return scheduler.ErrInternal + } + return nil +} + +// GetByID - +func (i *instanceQueueWithSubHealthAndEvictingRecord) GetByID(objID string) interface{} { + if insElem, exist := i.subHealthRecord[objID]; exist { + return insElem + } + if insElem, exist := i.evictingRecord[objID]; exist { + return insElem + } + return i.instanceQueue.GetByID(objID) +} + +// DelByID - +func (i *instanceQueueWithSubHealthAndEvictingRecord) DelByID(objID string) error { + _, existSubHealth := i.subHealthRecord[objID] + existHealth := i.instanceQueue.GetByID(objID) != nil + _, existGShut := i.evictingRecord[objID] + if !existSubHealth && !existHealth && !existGShut { + return scheduler.ErrInsNotExist + } + + delete(i.evictingRecord, objID) + delete(i.subHealthRecord, objID) + i.instanceQueue.DelByID(objID) + return nil +} + +// UpdateObjByID - +func (i *instanceQueueWithSubHealthAndEvictingRecord) UpdateObjByID(objID string, obj interface{}) error { + insElem, ok := obj.(*instanceElement) + if !ok { + return scheduler.ErrTypeConvertFail + } + _, existSubHealth := i.subHealthRecord[objID] + existHealth := i.instanceQueue.GetByID(objID) != nil + _, existGShut := i.evictingRecord[objID] + if !existSubHealth && !existHealth && !existGShut { + return scheduler.ErrInsNotExist + } + + i.instanceQueue.DelByID(objID) + delete(i.subHealthRecord, objID) + delete(i.evictingRecord, objID) + switch insElem.instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusRunning): + err := i.instanceQueue.PushBack(insElem) + if err != nil { + return err + } + case int32(constant.KernelInstanceStatusSubHealth): + i.subHealthRecord[objID] = insElem + case int32(constant.KernelInstanceStatusEvicting): + i.evictingRecord[objID] = insElem + default: + log.GetLogger().Warnf("ignore instance %s with unexpected status code %d", insElem.instance.InstanceID, + insElem.instance.InstanceStatus.Code) + return scheduler.ErrInternal + } + return nil +} + +// Len - +func (i *instanceQueueWithSubHealthAndEvictingRecord) Len() int { + return i.instanceQueue.Len() + len(i.subHealthRecord) // 判断长度不需要考虑gShutRrecord中实例数 +} + +// Range - +func (i *instanceQueueWithSubHealthAndEvictingRecord) Range(f func(obj interface{}) bool) { + i.instanceQueue.Range(f) + for _, insElem := range i.subHealthRecord { + if !f(insElem) { + break + } + } + + for _, insElem := range i.evictingRecord { + if !f(insElem) { + break + } + } +} + +// SortedRange - +func (i *instanceQueueWithSubHealthAndEvictingRecord) SortedRange(f func(obj interface{}) bool) { + i.instanceQueue.SortedRange(f) + for _, insElem := range i.subHealthRecord { + if !f(insElem) { + break + } + } + + for _, insElem := range i.evictingRecord { + if !f(insElem) { + break + } + } +} + +type instanceQueueWithObserver struct { + instanceQueueWithSubHealthAndEvictingRecord + /* + * insAvailThdCount记录队列中实例的可用租约数,这里作用是当外部修改的queue中实例中租约信息,通过该count可以计算出可用实例数的变化, + * 以确保指标上报时可以提供准确的数值。 + * 注意,该count不考虑实例的状态,仅供指标上报,不做他用。 + */ + insAvailThdCount map[string]int + pubAvailTopicFunc func(int) + pubInUseTopicFunc func(int) + pubTotalTopicFunc func(int) +} + +// PopFront - +func (i *instanceQueueWithObserver) PopFront() interface{} { + obj := i.instanceQueueWithSubHealthAndEvictingRecord.PopFront() // 仅pop health实例 + if obj == nil { + return nil + } + insElem, ok := obj.(*instanceElement) + if !ok { + return nil + } + delete(i.insAvailThdCount, insElem.instance.InstanceID) + i.pubAvailTopicFunc(-len(insElem.threadMap)) + i.pubInUseTopicFunc(-(insElem.instance.ConcurrentNum - len(insElem.threadMap))) + i.pubTotalTopicFunc(-insElem.instance.ConcurrentNum) + return insElem +} + +// PopBack - +func (i *instanceQueueWithObserver) PopBack() interface{} { + obj := i.instanceQueueWithSubHealthAndEvictingRecord.PopBack() // 仅pop health实例 + if obj == nil { + return nil + } + insElem, ok := obj.(*instanceElement) + if !ok { + return nil + } + delete(i.insAvailThdCount, insElem.instance.InstanceID) + i.pubAvailTopicFunc(-len(insElem.threadMap)) + i.pubInUseTopicFunc(-(insElem.instance.ConcurrentNum - len(insElem.threadMap))) + i.pubTotalTopicFunc(-insElem.instance.ConcurrentNum) + return insElem +} + +// PopSubHealth - +func (i *instanceQueueWithObserver) PopSubHealth() interface{} { + obj := i.instanceQueueWithSubHealthAndEvictingRecord.PopSubHealth() + if obj == nil { + return nil + } + insElem, ok := obj.(*instanceElement) + if !ok { + return nil + } + // sub-health instance doesn't have availInsThd + delete(i.insAvailThdCount, insElem.instance.InstanceID) + i.pubInUseTopicFunc(-(insElem.instance.ConcurrentNum - len(insElem.threadMap))) + i.pubTotalTopicFunc(-insElem.instance.ConcurrentNum) + return insElem +} + +// PushBack - pushback仅考虑queue中无实例场景 +func (i *instanceQueueWithObserver) PushBack(obj interface{}) error { + insElem, ok := obj.(*instanceElement) + if !ok { + return scheduler.ErrTypeConvertFail + } + if err := i.instanceQueueWithSubHealthAndEvictingRecord.PushBack(obj); err != nil { + return err + } + + i.insAvailThdCount[insElem.instance.InstanceID] = len(insElem.threadMap) + switch insElem.instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusRunning): + i.pubAvailTopicFunc(len(insElem.threadMap)) + i.pubTotalTopicFunc(insElem.instance.ConcurrentNum) + case int32(constant.KernelInstanceStatusSubHealth): + i.pubTotalTopicFunc(insElem.instance.ConcurrentNum) + case int32(constant.KernelInstanceStatusEvicting): + default: + + } + return nil +} + +// DelByID - +func (i *instanceQueueWithObserver) DelByID(objID string) error { + if _, ok := i.evictingRecord[objID]; ok { + delete(i.evictingRecord, objID) + return nil + } + obj := i.instanceQueueWithSubHealthAndEvictingRecord.GetByID(objID) + if obj == nil { + return scheduler.ErrInsNotExist + } + insElem, ok := obj.(*instanceElement) + if !ok { + return scheduler.ErrTypeConvertFail + } + if err := i.instanceQueueWithSubHealthAndEvictingRecord.DelByID(objID); err != nil { + return err + } + delete(i.insAvailThdCount, insElem.instance.InstanceID) + switch insElem.instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusRunning): + i.pubAvailTopicFunc(-len(insElem.threadMap)) + i.pubInUseTopicFunc(-(insElem.instance.ConcurrentNum - len(insElem.threadMap))) + i.pubTotalTopicFunc(-insElem.instance.ConcurrentNum) + case int32(constant.KernelInstanceStatusSubHealth): + i.pubInUseTopicFunc(-(insElem.instance.ConcurrentNum - len(insElem.threadMap))) + i.pubTotalTopicFunc(-insElem.instance.ConcurrentNum) + + // 忽略优雅退出实例的指标上报 + case int32(constant.KernelInstanceStatusEvicting): + default: + } + return nil +} + +// UpdateObjByID - updateObjByID考虑queue中有实例场景 +func (i *instanceQueueWithObserver) UpdateObjByID(objID string, obj interface{}) error { + insElem, ok := obj.(*instanceElement) + if !ok { + return scheduler.ErrTypeConvertFail + } + + oldInstanceStatus := int32(constant.KernelInstanceStatusRunning) + _, ok = i.subHealthRecord[objID] + if ok { + oldInstanceStatus = int32(constant.KernelInstanceStatusSubHealth) + } + _, ok = i.evictingRecord[objID] + if ok { + oldInstanceStatus = int32(constant.KernelInstanceStatusEvicting) + } + if err := i.instanceQueueWithSubHealthAndEvictingRecord.UpdateObjByID(objID, insElem); err != nil { + return err + } + oldInsAvailThdCount, exist := i.insAvailThdCount[objID] + if !exist { + return scheduler.ErrInternal + } + i.insAvailThdCount[objID] = len(insElem.threadMap) + + switch oldInstanceStatus { + case int32(constant.KernelInstanceStatusRunning): + i.handleHealthInstanceUpdateMetrics(oldInsAvailThdCount, insElem) + case int32(constant.KernelInstanceStatusSubHealth): + i.handleSubHealthInstanceUpdateMetrics(oldInsAvailThdCount, insElem) + + // 处于优雅退出状态的实例,不会转换成health或者subhealth实例 + // 处于优雅退出状态的实例的状态变化,不用考虑指标变化和上报 + case int32(constant.KernelInstanceStatusEvicting): + default: + + } + return nil +} + +func (i *instanceQueueWithObserver) handleHealthInstanceUpdateMetrics(oldInsAvailThdCount int, + insElem *instanceElement) { + availInsThdDiff := len(insElem.threadMap) - oldInsAvailThdCount + switch insElem.instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusRunning): // health实例的申请租约 + i.pubInUseTopicFunc(-availInsThdDiff) + i.pubAvailTopicFunc(availInsThdDiff) + case int32(constant.KernelInstanceStatusSubHealth): // health实例转换成了subhealth实例 + // i.pubInUseTopicFunc(-availInsThdDiff) // 这个diff应该是0 + i.pubAvailTopicFunc(-len(insElem.threadMap)) // 这里暗示的是 new和old的可用租约数要一致 + case int32(constant.KernelInstanceStatusEvicting): // health实例转换成了evicting实例 + i.pubTotalTopicFunc(-insElem.instance.ConcurrentNum) + i.pubAvailTopicFunc(-len(insElem.threadMap)) + i.pubInUseTopicFunc(-(insElem.instance.ConcurrentNum - len(insElem.threadMap))) + default: + + } +} + +func (i *instanceQueueWithObserver) handleSubHealthInstanceUpdateMetrics(oldInsAvailThdCount int, + insElem *instanceElement) { + availInsThdDiff := len(insElem.threadMap) - oldInsAvailThdCount + switch insElem.instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusRunning): // subhealth实例转换成health实例了 + i.pubInUseTopicFunc(-availInsThdDiff) + i.pubAvailTopicFunc(len(insElem.threadMap)) + case int32(constant.KernelInstanceStatusSubHealth): // subhealth实例重复收到subhealth事件,不用更新指标 + case int32(constant.KernelInstanceStatusEvicting): // subhealth实例转换成了evicting实例 + i.pubTotalTopicFunc(-insElem.instance.ConcurrentNum) + i.pubInUseTopicFunc(-(insElem.instance.ConcurrentNum - len(insElem.threadMap))) + default: + + } +} + +type basicConcurrencyScheduler struct { + funcSpec *types.FunctionSpecification + insAcqReqQueue *requestqueue.InsAcqReqQueue + leaseManager lease.InstanceLeaseManager + selfInstanceQueue queue.Queue + otherInstanceQueue queue.Queue + selfSubHealthRecord map[string]*instanceElement + otherSubHealthRecord map[string]*instanceElement + sessionRecord map[string]*instanceElement + observers map[scheduler.InstanceTopic][]*instanceObserver + funcKeyWithRes string + concurrentNum int + isFuncOwner bool + stopped bool + leaseInterval time.Duration + *sync.RWMutex + *sync.Cond + grayAllocator GrayInstanceAllocator +} + +func newBasicConcurrencyScheduler(funcSpec *types.FunctionSpecification, resKey resspeckey.ResSpecKey, + selfInstanceQueue queue.Queue, otherInstanceQueue queue.Queue) basicConcurrencyScheduler { + leaseInterval := time.Duration(config.GlobalConfig.LeaseSpan) * time.Millisecond + if leaseInterval < types.MinLeaseInterval { + leaseInterval = types.MinLeaseInterval + } + mutex := &sync.RWMutex{} + funcKeyWitRes := fmt.Sprintf("%s-%s", funcSpec.FuncKey, resKey.String()) + bcs := basicConcurrencyScheduler{ + funcSpec: funcSpec, + funcKeyWithRes: funcKeyWitRes, + leaseManager: lease.NewGenericLeaseManager(funcKeyWitRes), + selfSubHealthRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + otherSubHealthRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + sessionRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + observers: make(map[scheduler.InstanceTopic][]*instanceObserver, utils.DefaultMapSize), + concurrentNum: utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum), + leaseInterval: leaseInterval, + RWMutex: mutex, + Cond: sync.NewCond(mutex), + grayAllocator: NewHashBasedInstanceAllocator(0), + isFuncOwner: selfregister.GlobalSchedulerProxy.CheckFuncOwner(funcSpec.FuncKey), + } + bcs.grayAllocator.UpdateRolloutRatio(rollout.GetGlobalRolloutHandler().GetCurrentRatio()) + bcs.createOtherInstanceQueue(otherInstanceQueue) + bcs.createSelfInstanceQueue(selfInstanceQueue) + return bcs +} + +func (bcs *basicConcurrencyScheduler) createOtherInstanceQueue(instanceQueue queue.Queue) { + bcs.otherInstanceQueue = &instanceQueueWithSubHealthAndEvictingRecord{ + instanceQueue: instanceQueue, + subHealthRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + evictingRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + } +} + +func (bcs *basicConcurrencyScheduler) createSelfInstanceQueue(instanceQueue queue.Queue) { + InsQueWithSubHealth := instanceQueueWithSubHealthAndEvictingRecord{ + instanceQueue: instanceQueue, + subHealthRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + evictingRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + } + bcs.selfInstanceQueue = &instanceQueueWithObserver{ + instanceQueueWithSubHealthAndEvictingRecord: InsQueWithSubHealth, + insAvailThdCount: make(map[string]int, utils.DefaultMapSize), + pubAvailTopicFunc: func(data int) { bcs.publishInsThdEvent(scheduler.AvailInsThdTopic, data) }, + pubInUseTopicFunc: func(data int) { bcs.publishInsThdEvent(scheduler.InUseInsThdTopic, data) }, + pubTotalTopicFunc: func(data int) { bcs.publishInsThdEvent(scheduler.TotalInsThdTopic, data) }, + } +} + +func (bcs *basicConcurrencyScheduler) scheduleRequest(insAcqReq *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, error) { + bcs.Lock() + defer bcs.Unlock() + useSelfInstance := bcs.isFuncOwner || insAcqReq.TrafficLimited + var ( + insAlloc *types.InstanceAllocation + err error + ) + // once a session is bound with an instance, all pending requests of this session will be marked as designate + // instance requests to avoid other instanceScheduler scheduling them (others don't have a bound record) + if len(insAcqReq.InstanceSession.SessionID) != 0 { + insElem, exist := bcs.sessionRecord[insAcqReq.InstanceSession.SessionID] + if exist { + insAcqReq.DesignateInstanceID = insElem.instance.InstanceID + } + } + if useSelfInstance { + insAlloc, err = bcs.acquireInstanceInternal(bcs.selfInstanceQueue, insAcqReq) + } else { + insAlloc, err = bcs.acquireInstanceInternal(bcs.otherInstanceQueue, insAcqReq) + } + return insAlloc, err +} + +// GetInstanceNumber gets instance number inside instance queue +func (bcs *basicConcurrencyScheduler) GetInstanceNumber(onlySelf bool) int { + bcs.RLock() + insNum := bcs.selfInstanceQueue.Len() + if !onlySelf { + insNum += bcs.otherInstanceQueue.Len() + } + bcs.RUnlock() + return insNum +} + +// AcquireInstance acquires an instance +func (bcs *basicConcurrencyScheduler) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, error) { + bcs.Lock() + defer bcs.Unlock() + // use self instance when: 1. this scheduler is the funcOwner 2. this scheduler is not the funcOwner and funcOwner + // encounters traffic limitation so acquire request sent to this scheduler + // use other instance when: this scheduler is not the funcOwner and funcOwner breaks down so acquire request sent + // to this scheduler + useSelfInstance := bcs.isFuncOwner || insAcqReq.TrafficLimited + var ( + insAlloc *types.InstanceAllocation + err error + ) + if useSelfInstance { + insAlloc, err = bcs.acquireInstanceInternal(bcs.selfInstanceQueue, insAcqReq) + } else { + insAlloc, err = bcs.acquireInstanceInternal(bcs.otherInstanceQueue, insAcqReq) + } + return insAlloc, err +} + +func (bcs *basicConcurrencyScheduler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + bcs.handleFuncSpecUpdate(bcs.selfInstanceQueue, funcSpec) + bcs.handleFuncSpecUpdate(bcs.otherInstanceQueue, funcSpec) +} + +func (bcs *basicConcurrencyScheduler) handleFuncSpecUpdate(instanceQueue queue.Queue, + funcSpec *types.FunctionSpecification) { + needUpdate := make(map[string]*instanceElement) + instanceQueue.Range(func(obj interface{}) bool { + insElem, ok := obj.(*instanceElement) + if !ok { + return true + } + if insElem.instance.FuncSig != funcSpec.FuncMetaSignature && insElem.isNewInstance { + insElem.isNewInstance = false + needUpdate[insElem.instance.InstanceID] = insElem + } + if insElem.instance.FuncSig == funcSpec.FuncMetaSignature && !insElem.isNewInstance { + insElem.isNewInstance = true + needUpdate[insElem.instance.InstanceID] = insElem + } + return true + }) + for id, element := range needUpdate { + if err := instanceQueue.UpdateObjByID(id, element); err != nil { + log.GetLogger().Errorf("failed to update instance %s error %s", id, err.Error()) + } + } +} + +func (bcs *basicConcurrencyScheduler) acquireInstanceInternal(instanceQueue queue.Queue, + insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) { + var ( + insAlloc *types.InstanceAllocation + acqErr error + ) + if insAcqReq.DesignateInstanceID != "" { + insAlloc, acqErr = bcs.acquireDesignateInstance(instanceQueue, insAcqReq) + if acqErr == scheduler.ErrInsNotExist && len(insAcqReq.InstanceSession.SessionID) != 0 { + insAcqReq.DesignateInstanceID = "" + insAlloc, acqErr = bcs.acquireSessionInstance(instanceQueue, insAcqReq) + } + } else if len(insAcqReq.InstanceSession.SessionID) != 0 { + insAlloc, acqErr = bcs.acquireSessionInstance(instanceQueue, insAcqReq) + } else { + insAlloc, acqErr = bcs.acquireDefaultInstance(instanceQueue, insAcqReq) + } + if acqErr != nil { + return nil, acqErr + } + newLease, leaseErr := bcs.leaseManager.CreateInstanceLease(insAlloc, bcs.leaseInterval, func() { + if err := bcs.ReleaseInstance(insAlloc); err != nil { + log.GetLogger().Errorf("failed to release lease %s of instance %s for function %s error %s", + insAlloc.AllocationID, insAlloc.Instance.InstanceID, bcs.funcKeyWithRes, err.Error()) + } + }) + if leaseErr != nil { + log.GetLogger().Errorf("failed to create lease of instance %s for function %s error %s", + insAlloc.Instance.InstanceID, bcs.funcKeyWithRes, leaseErr.Error()) + if err := bcs.releaseInstanceInternal(instanceQueue, insAlloc); err != nil { + log.GetLogger().Errorf("failed to release instance %s for function %s error %s", + insAlloc.Instance.InstanceID, bcs.funcKeyWithRes, err.Error()) + } + return nil, leaseErr + } + insAlloc.Lease = newLease + return insAlloc, nil +} + +func (bcs *basicConcurrencyScheduler) acquireDefaultInstance(instanceQueue queue.Queue, + insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) { + log.GetLogger().Infof("acquire default instance for function %s traceID %s", bcs.funcKeyWithRes, + insAcqReq.TraceID) + obj := instanceQueue.Front() + if obj == nil { + return nil, scheduler.ErrNoInsAvailable + } + insElem, ok := obj.(*instanceElement) + if !ok { + return nil, scheduler.ErrTypeConvertFail + } + return acquireInstanceThread(insAcqReq.DesignateThreadID, instanceQueue, insElem) +} + +func (bcs *basicConcurrencyScheduler) acquireSessionInstance(instanceQueue queue.Queue, + insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) { + log.GetLogger().Infof("acquire session instance for function %s session %+v traceID %s", + bcs.funcKeyWithRes, insAcqReq.InstanceSession, insAcqReq.TraceID) + if insAcqReq.InstanceSession.Concurrency > bcs.concurrentNum { + return nil, scheduler.ErrInvalidSession + } + insElem, exist := fetchInsElemForSessionAcquire(instanceQueue, insAcqReq, bcs) + if !exist { + var ( + found bool + ok bool + ) + instanceQueue.SortedRange(func(obj interface{}) bool { + insElem, ok = obj.(*instanceElement) + if !ok { + return true + } + if insElem.instance.InstanceStatus.Code != int32(constant.KernelInstanceStatusRunning) { + return true + } + if len(insElem.threadMap) >= insAcqReq.InstanceSession.Concurrency { + found = true + return false + } + return true + }) + if !found { + return nil, scheduler.ErrNoInsAvailable + } + if err := bcs.bindSessionWithInstance(instanceQueue, insElem, insAcqReq.InstanceSession); err != nil { + return nil, err + } + bcs.sessionRecord[insAcqReq.InstanceSession.SessionID] = insElem + return bcs.acquireInstanceThreadWithSession(insElem, insAcqReq.InstanceSession) + } + insAlloc, acqErr := bcs.acquireInstanceThreadWithSession(insElem, insAcqReq.InstanceSession) + // if acqErr equals ErrNoInsThdAvail, try getting thread without session from the same instance + if acqErr != ErrNoInsThdAvail { + return insAlloc, acqErr + } + return acquireInstanceThread(insAcqReq.DesignateThreadID, instanceQueue, insElem) +} + +func fetchInsElemForSessionAcquire(instanceQueue queue.Queue, insAcqReq *types.InstanceAcquireRequest, + bcs *basicConcurrencyScheduler) (*instanceElement, bool) { + insElem, exist := bcs.sessionRecord[insAcqReq.InstanceSession.SessionID] + if exist { + // 缓存中有session但是instance已经被删除时,更新缓存 + obj := instanceQueue.GetByID(insElem.instance.InstanceID) + if obj == nil { + delete(bcs.sessionRecord, insAcqReq.InstanceSession.SessionID) + exist = false + } else { + insElemNew, ok := obj.(*instanceElement) + if !ok { + delete(bcs.sessionRecord, insAcqReq.InstanceSession.SessionID) + exist = false + } else { + insElem = insElemNew + } + } + } + return insElem, exist +} + +func (bcs *basicConcurrencyScheduler) acquireDesignateInstance(instanceQueue queue.Queue, + insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) { + log.GetLogger().Infof("acquire designate instance %s for function %s session %+v traceID %s", + insAcqReq.DesignateInstanceID, bcs.funcKeyWithRes, insAcqReq.InstanceSession, insAcqReq.TraceID) + var ( + insAlloc *types.InstanceAllocation + acqErr error + ) + obj := instanceQueue.GetByID(insAcqReq.DesignateInstanceID) + if obj == nil { + if len(insAcqReq.InstanceSession.SessionID) != 0 { + delete(bcs.sessionRecord, insAcqReq.InstanceSession.SessionID) + } + return nil, scheduler.ErrInsNotExist + } + insElem, ok := obj.(*instanceElement) + if !ok { + return nil, scheduler.ErrTypeConvertFail + } + if insElem.instance.InstanceStatus.Code == int32(constant.KernelInstanceStatusSubHealth) { + return nil, scheduler.ErrInsSubHealthy + } + if len(insAcqReq.InstanceSession.SessionID) != 0 { + insAlloc, acqErr = bcs.acquireInstanceThreadWithSession(insElem, insAcqReq.InstanceSession) + // if acqErr equals ErrNoInsThdAvail, try getting thread without session from the same instance + if acqErr != ErrNoInsThdAvail { + return insAlloc, acqErr + } + } + + // 对于evicting实例,仅供绑定会话的请求使用,否则认为改实例不可用 + if insElem.instance.InstanceStatus.Code != int32(constant.KernelInstanceStatusRunning) { + return nil, scheduler.ErrDesignateInsNotAvailable + } + return acquireInstanceThread(insAcqReq.DesignateThreadID, instanceQueue, insElem) +} + +func acquireInstanceThread(designateThreadID string, insQue queue.Queue, + insElem *instanceElement) (*types.InstanceAllocation, error) { + // 这里如果指定了租约id,是需要可以超发的 + if len(insElem.threadMap) == 0 && designateThreadID == "" { + return nil, scheduler.ErrNoInsAvailable + } + // 无论有没有指定租约id,都需要从map中取一个租约出来,如果map为空,也不会报错 + threadID := insElem.GetThreadFromThreadMap() + if designateThreadID != "" { + threadID = designateThreadID + } + err := insQue.UpdateObjByID(insElem.instance.InstanceID, insElem) + if err != nil { + log.GetLogger().Errorf("failed to update instance %s in queue error %s", insElem.instance.InstanceID, + err.Error()) + return nil, err + } + insAlloc := &types.InstanceAllocation{ + Instance: insElem.instance, + AllocationID: threadID, + } + metrics.OnAcquireLease(insAlloc) + return insAlloc, nil +} + +func (bcs *basicConcurrencyScheduler) bindSessionWithInstance(insQue queue.Queue, insElem *instanceElement, + session commonTypes.InstanceSessionConfig) error { + if len(insElem.threadMap) < session.Concurrency { + return scheduler.ErrNoInsAvailable + } + ctx, cancelFunc := context.WithCancel(context.TODO()) + record := &sessionRecord{ + ctx: ctx, + sessionID: session.SessionID, + ttl: time.Duration(session.SessionTTL) * time.Second, + availThdMap: make(map[string]struct{}, utils.DefaultMapSize), + allocThdMap: make(map[string]struct{}, utils.DefaultMapSize), + concurrency: session.Concurrency, + expireCancelCh: make(chan struct{}, 1), + cancelFunc: cancelFunc, + } + insElem.sessionMap[session.SessionID] = record + for i := 0; i < session.Concurrency; i++ { + threadID := insElem.GetThreadFromThreadMap() + record.PutThreadToAllocThdMap(threadID) + // there must be no error. + if err := record.PutThreadToAvailThdMap(threadID); err != nil { + log.GetLogger().Errorf("acquire thread failed, skip") + } + } + err := insQue.UpdateObjByID(insElem.instance.InstanceID, insElem) + if err != nil { + log.GetLogger().Errorf("failed to update instance %s during session binding of function %s error %s", + insElem.instance.InstanceID, bcs.funcKeyWithRes, err.Error()) + return err + } + for threadId, _ := range record.allocThdMap { + insAlloc := &types.InstanceAllocation{ + Instance: insElem.instance, + AllocationID: threadId, + } + metrics.OnAcquireLease(insAlloc) + } + log.GetLogger().Infof("bind session %s with instance %s for function %s", record.sessionID, + insElem.instance.InstanceID, bcs.funcKeyWithRes) + return nil +} + +func (bcs *basicConcurrencyScheduler) acquireInstanceThreadWithSession(insElem *instanceElement, + sessionConfig commonTypes.InstanceSessionConfig) ( + *types.InstanceAllocation, error) { + record, exist := insElem.sessionMap[sessionConfig.SessionID] + if !exist { + log.GetLogger().Errorf("session %s is not bound with instance %s for function %s", sessionConfig.SessionID, + insElem.instance.InstanceID, bcs.funcKeyWithRes) + return nil, scheduler.ErrInternal + } + if len(record.availThdMap) == 0 { + return nil, ErrNoInsThdAvail + } + expiring, _ := record.expiring.Load().(bool) + if expiring { + select { + case record.expireCancelCh <- struct{}{}: + default: + } + } + record.ttl = time.Duration(sessionConfig.SessionTTL) * time.Second + // every object here is pointer, no need to call UpdateObjByID + threadID := record.GetThreadFromAvailThdMap() + return &types.InstanceAllocation{ + Instance: insElem.instance.Copy(), + SessionInfo: types.SessionInfo{ + SessionID: sessionConfig.SessionID, + SessionCtx: record.ctx, + }, + AllocationID: threadID, + }, nil +} + +// ReleaseInstance releases an instance +func (bcs *basicConcurrencyScheduler) ReleaseInstance(insAlloc *types.InstanceAllocation) error { + bcs.Lock() + defer bcs.Unlock() + useSelfInstance := bcs.checkSelfInstance(insAlloc.Instance) + var ( + err error + ) + if useSelfInstance { + err = bcs.releaseInstanceInternal(bcs.selfInstanceQueue, insAlloc) + } else { + err = bcs.releaseInstanceInternal(bcs.otherInstanceQueue, insAlloc) + } + return err +} + +func (bcs *basicConcurrencyScheduler) releaseInstanceInternal(instanceQueue queue.Queue, + insAlloc *types.InstanceAllocation) error { + instance := insAlloc.Instance + obj := instanceQueue.GetByID(instance.InstanceID) + if obj == nil { + return scheduler.ErrInsNotExist + } + var ok bool + insElem, ok := obj.(*instanceElement) + if !ok { + return scheduler.ErrTypeConvertFail + } + releaseInSession := false + if len(insAlloc.SessionInfo.SessionID) != 0 { + // not all insAlloc with session comes from session record, handle release of this type of insAlloc below like + // normal insAlloc + err := bcs.releaseInstanceThreadWithSession(instanceQueue, insElem, insAlloc) + if err == nil { + releaseInSession = true + } else if err != ErrInsThdNotExist { + return err + } + } + if !releaseInSession { + insElem.PutThreadToThreadMap(insAlloc.AllocationID) + } + err := instanceQueue.UpdateObjByID(instance.InstanceID, insElem) + if err != nil { + log.GetLogger().Errorf("failed to update instance %s during allocation release for function %s error %s", + insAlloc.Instance.InstanceID, bcs.funcKeyWithRes, err.Error()) + return err + } + if !releaseInSession { + metrics.OnReleaseLease(insAlloc) + } + return nil +} + +func (bcs *basicConcurrencyScheduler) releaseInstanceThreadWithSession(insQue queue.Queue, insElem *instanceElement, + insAlloc *types.InstanceAllocation) error { + log.GetLogger().Infof("start to unbind session %s with thread %s for function %s", insAlloc.SessionInfo.SessionID, + insAlloc.AllocationID, bcs.funcKeyWithRes) + record, exist := insElem.sessionMap[insAlloc.SessionInfo.SessionID] + if !exist { + log.GetLogger().Errorf("session %s is not bound with instance %s for function %s", + insAlloc.SessionInfo.SessionID, insElem.instance.InstanceID, bcs.funcKeyWithRes) + return scheduler.ErrInternal + } + err := record.PutThreadToAvailThdMap(insAlloc.AllocationID) + if err != nil { + log.GetLogger().Warnf("put thread to availthdmap failed, err %s, func %s", err.Error(), bcs.funcKeyWithRes) + return ErrInsThdNotExist + } + expiring, _ := record.expiring.Load().(bool) + if len(record.availThdMap) == len(record.allocThdMap) && !expiring { + record.expiring.Store(true) + record.timer = time.NewTimer(record.ttl) + go bcs.unbindInstanceSession(insQue, insElem, record) + } + return nil +} + +func (bcs *basicConcurrencyScheduler) unbindInstanceSession(insQue queue.Queue, insElem *instanceElement, + record *sessionRecord) { + select { + case <-record.timer.C: + bcs.L.Lock() + if len(record.availThdMap) != len(record.allocThdMap) { + <-record.expireCancelCh + log.GetLogger().Infof("avail thd has been acquired, session %s expire canceled", record.sessionID) + record.timer.Stop() + record.expiring.Store(false) + bcs.L.Unlock() + return + } + record.cancelFunc() + for threadID, _ := range record.allocThdMap { + insElem.PutThreadToThreadMap(threadID) + insAlloc := &types.InstanceAllocation{ + Instance: insElem.instance, + AllocationID: threadID, + } + metrics.OnReleaseLease(insAlloc) + } + delete(insElem.sessionMap, record.sessionID) + delete(bcs.sessionRecord, record.sessionID) + if err := insQue.UpdateObjByID(insElem.instance.InstanceID, insElem); err != nil { + log.GetLogger().Errorf("failed to update instance %s during unbinding with session %s for function %s"+ + " error %s", insElem.instance.InstanceID, record.sessionID, bcs.funcKeyWithRes, err.Error()) + } + bcs.L.Unlock() + log.GetLogger().Infof("unbind session %s with instance %s for function %s", record.sessionID, + insElem.instance.InstanceID, bcs.funcKeyWithRes) + case <-record.expireCancelCh: + // set lock here may cause deadlock because multiple acquire requests of a same session may trigger this + // case many times + log.GetLogger().Infof("session %s expire canceled", record.sessionID) + record.timer.Stop() + record.expiring.Store(false) + } +} + +// popInstanceElement pops an instance for scale down, use condition lock to wait for creating instances which already +// be processing by kernel to be enqueued +func (bcs *basicConcurrencyScheduler) popInstanceElement(popDirection popDirection, + shouldPop func(*instanceElement) bool, wait bool) *instanceElement { + bcs.L.Lock() + defer bcs.L.Unlock() + if wait && bcs.selfInstanceQueue.Len() == 0 { + bcs.Wait() + } + instanceQueue, ok := bcs.selfInstanceQueue.(*instanceQueueWithObserver) + if !ok { + return nil + } + var obj interface{} + if obj = instanceQueue.PopSubHealth(); obj != nil { + insElem, ok := obj.(*instanceElement) + if !ok { + return nil + } + return insElem + } + if popDirection == forward { + obj = bcs.selfInstanceQueue.Front() + } else { + obj = bcs.selfInstanceQueue.Back() + } + if obj == nil { + return nil + } + insElem, ok := obj.(*instanceElement) + if !ok { + return nil + } + if shouldPop != nil && !shouldPop(insElem) { + return nil + } + if popDirection == forward { + bcs.selfInstanceQueue.PopFront() + } else { + bcs.selfInstanceQueue.PopBack() + } + if bcs.grayAllocator.ShouldReassign(Del, insElem.instance.InstanceID) { + log.GetLogger().Infof("pop instance gray invoke reassign. instance: %s, funcKey: %s", + insElem.instance.InstanceID, bcs.funcKeyWithRes) + bcs.reassignInstanceWhenGray() + } + return insElem +} + +// AddInstance adds an instance to instanceScheduler +func (bcs *basicConcurrencyScheduler) AddInstance(instance *types.Instance) error { + bcs.Lock() + defer bcs.Unlock() + isSelfInstance := bcs.checkSelfInstance(instance) + var ( + err error + ) + insElem := &instanceElement{ + instance: instance, + sessionMap: make(map[string]*sessionRecord, utils.DefaultMapSize), + } + insElem.initThreadMap() + if isSelfInstance { + err = bcs.selfInstanceQueue.PushBack(insElem) + } else { + err = bcs.otherInstanceQueue.PushBack(insElem) + } + if err != nil { + return err + } + if bcs.grayAllocator.ShouldReassign(Add, insElem.instance.InstanceID) { + log.GetLogger().Infof("add instance gray invoke reassign. instance: %s, funcKey: %s", + instance.InstanceID, bcs.funcKeyWithRes) + bcs.reassignInstanceWhenGray() + } + return err +} + +// DelInstance deletes an instance from instanceScheduler +func (bcs *basicConcurrencyScheduler) DelInstance(instance *types.Instance) error { + bcs.Lock() + defer bcs.Unlock() + isSelfInstance := bcs.checkSelfInstance(instance) + var ( + err error + ) + if isSelfInstance { + err = bcs.selfInstanceQueue.DelByID(instance.InstanceID) + } else { + err = bcs.otherInstanceQueue.DelByID(instance.InstanceID) + } + bcs.leaseManager.HandleInstanceDelete(instance) + if err != nil { + return err + } + if bcs.grayAllocator.ShouldReassign(Del, instance.InstanceID) { + log.GetLogger().Infof("del instance gray invoke reassign. instance: %s, funcKey: %s", + instance.InstanceID, bcs.funcKeyWithRes) + bcs.reassignInstanceWhenGray() + } + return err +} + +// SignalAllInstances sends signal to all instances +func (bcs *basicConcurrencyScheduler) SignalAllInstances(signalFunc scheduler.SignalInstanceFunc) { + bcs.RLock() + bcs.selfInstanceQueue.Range(func(obj interface{}) bool { + insElem, ok := obj.(*instanceElement) + if !ok { + return true + } + signalFunc(insElem.instance) + return true + }) + bcs.RUnlock() +} + +// HandleInstanceUpdate handles instance update comes from ETCD, it's worth noting that this method will also handle +// instance recover from scheduler state +func (bcs *basicConcurrencyScheduler) HandleInstanceUpdate(instance *types.Instance) { + logger := log.GetLogger().With(zap.Any("funcKey", bcs.funcKeyWithRes), zap.Any("instance", instance.InstanceID), + zap.Any("instanceStatus", instance.InstanceStatus.Code)) + isSelfInstance := bcs.checkSelfInstance(instance) + logger.Infof("handle instance update isSelfInstance %t", isSelfInstance) + bcs.Lock() + defer bcs.Unlock() + var instanceQueue queue.Queue + if isSelfInstance { + instanceQueue = bcs.selfInstanceQueue + } else { + instanceQueue = bcs.otherInstanceQueue + } + isNewInstance := true + if instance.FuncSig != bcs.funcSpec.FuncMetaSignature { + isNewInstance = false + } + obj := instanceQueue.GetByID(instance.InstanceID) + if obj == nil { + signalmanager.GetSignalManager().SignalInstance(instance, constant.KillSignalAliasUpdate) + signalmanager.GetSignalManager().SignalInstance(instance, constant.KillSignalFaaSSchedulerUpdate) + insElem := &instanceElement{ + instance: instance, + isNewInstance: isNewInstance, + sessionMap: make(map[string]*sessionRecord, utils.DefaultMapSize), + } + insElem.initThreadMap() + if err := instanceQueue.PushBack(insElem); err != nil { + logger.Errorf("failed to add new instance with status %+v", instance.InstanceStatus) + return + } + if instance.InstanceStatus.Code != int32(constant.KernelInstanceStatusEvicting) && + bcs.grayAllocator.ShouldReassign(Add, instance.InstanceID) { + logger.Infof("update add instance invoke reassign") + bcs.reassignInstanceWhenGray() + } + } else { + insElem, ok := obj.(*instanceElement) + if !ok { + logger.Errorf("can't convert object to insQueElement type") + return + } + insElem.instance = instance + insElem.isNewInstance = isNewInstance + if err := instanceQueue.UpdateObjByID(instance.InstanceID, insElem); err != nil { + logger.Errorf("failed to update instance %s with status %+v", instance.InstanceID, instance.InstanceStatus) + return + } + } + logger.Infof("handle instance update success") +} + +// IsFuncOwner - +func (bcs *basicConcurrencyScheduler) IsFuncOwner() bool { + bcs.RLock() + isFuncOwner := bcs.isFuncOwner + bcs.RUnlock() + return isFuncOwner +} + +func (bcs *basicConcurrencyScheduler) checkSelfInstance(instance *types.Instance) bool { + isSelf := bcs.checkSelfInstanceInternal(instance) + if !isSelf { + return false + } + + if !config.GlobalConfig.EnableRollout || !selfregister.IsRollingOut { + return isSelf + } else { + return bcs.grayAllocator.CheckSelf(selfregister.IsRolloutObject, instance.InstanceID) + } +} + +func (bcs *basicConcurrencyScheduler) checkSelfInstanceInternal(instance *types.Instance) bool { + if instance.Permanent { + if instance.CreateSchedulerID == selfregister.GetSchedulerProxyName() { + return true + } + funcOwnerExist := selfregister.GlobalSchedulerProxy.Contains(instance.CreateSchedulerID) + return !funcOwnerExist && bcs.isFuncOwner + } + return bcs.isFuncOwner +} + +func (bcs *basicConcurrencyScheduler) selectInstanceQueue(isSelfInstance bool) (queue.Queue, + map[string]*instanceElement) { + if isSelfInstance { + return bcs.selfInstanceQueue, bcs.selfSubHealthRecord + } + return bcs.otherInstanceQueue, bcs.otherSubHealthRecord +} + +// Destroy destroys instanceScheduler +func (bcs *basicConcurrencyScheduler) Destroy() { + bcs.Lock() + bcs.stopped = true + bcs.Unlock() + bcs.leaseManager.CleanAllLeases() +} + +// publishInsThdEvent will notify observers of specific topic of instance +func (bcs *basicConcurrencyScheduler) publishInsThdEvent(topic scheduler.InstanceTopic, data interface{}) { + if bcs.stopped { + return + } + for _, observer := range bcs.observers[topic] { + observer.callback(data) + } +} + +// addObservers will add observer of instance scaledInsQue +func (bcs *basicConcurrencyScheduler) addObservers(topic scheduler.InstanceTopic, callback func(interface{})) { + topicObservers, exist := bcs.observers[topic] + if !exist { + topicObservers = make([]*instanceObserver, 0, utils.DefaultSliceSize) + bcs.observers[topic] = topicObservers + } + bcs.observers[topic] = append(topicObservers, &instanceObserver{ + callback: callback, + }) +} + +// ReassignInstanceWhenGray 监听到进入灰度状态后重新分配self队列 +func (bcs *basicConcurrencyScheduler) ReassignInstanceWhenGray(ratio int) { + bcs.Lock() + defer bcs.Unlock() + bcs.grayAllocator.UpdateRolloutRatio(ratio) + log.GetLogger().Infof("updateRolloutRatio invoke reassign self len: %d, other len: %d, funcKey %s", + bcs.selfInstanceQueue.Len(), bcs.otherInstanceQueue.Len(), bcs.funcKeyWithRes) + bcs.reassignInstanceWhenGray() +} + +func (bcs *basicConcurrencyScheduler) reassignInstanceWhenGray() { + if !config.GlobalConfig.EnableRollout { + return + } + // 灰度结束的最后一步,将currentVersion修改为当前版本,ratio修改为0时,需要修改grayAllocator中的IsRolloutObject,不能直接返回 + if !selfregister.IsRollingOut && bcs.grayAllocator.GetRolloutRatio() > 0 { + log.GetLogger().Infof("no need to reassign Instance when rollout %v, ratio %d , funcKey %s", + selfregister.IsRollingOut, bcs.grayAllocator.GetRolloutRatio(), bcs.funcKeyWithRes) + return + } + + if !bcs.isFuncOwner { + log.GetLogger().Warnf("this scheduler is not funcOwner of function %s skipping reassign", bcs.funcKeyWithRes) + return + } + + var ( + selfInstancesToKeep []*instanceElement + otherInstancesToKeep []*instanceElement + fixedOtherInstances []*instanceElement + ) + + canPartitionInstances, fixedOtherInstances := bcs.collectInstancesForReassign() + + selfInstancesToKeep, otherInstancesToKeep = bcs.grayAllocator.Partition( + canPartitionInstances, + selfregister.IsRolloutObject, + ) + + otherInstancesToKeep = append(otherInstancesToKeep, fixedOtherInstances...) + + log.GetLogger().Infof("Gray reassign isGrayNode: %t - Before self: %d, other %d. "+ + "After self: %d, other, %d, fix: %d, funcKey %s", selfregister.IsRolloutObject, bcs.selfInstanceQueue.Len(), + bcs.otherInstanceQueue.Len(), len(selfInstancesToKeep), len(otherInstancesToKeep), len(fixedOtherInstances), + bcs.funcKeyWithRes) + + clearAndFillQueue(bcs.selfInstanceQueue, selfInstancesToKeep) + clearAndFillQueue(bcs.otherInstanceQueue, otherInstancesToKeep) + + log.GetLogger().Infof("finish reassign. funckey %s", bcs.funcKeyWithRes) +} + +func (bcs *basicConcurrencyScheduler) collectInstancesForReassign() ([]*HashedInstance, []*instanceElement) { + canPartitionInstances := make([]*HashedInstance, 0, bcs.selfInstanceQueue.Len()+bcs.otherInstanceQueue.Len()) + fixedOtherInstances := make([]*instanceElement, 0, bcs.otherInstanceQueue.Len()) + + partitionInstance := func(insElem *instanceElement) { + if !bcs.checkSelfInstanceInternal(insElem.instance) { + fixedOtherInstances = append(fixedOtherInstances, insElem) + return + } + hashValue := bcs.grayAllocator.ComputeHash(insElem.instance.InstanceID) + canPartitionInstances = append(canPartitionInstances, &HashedInstance{ + InsElem: insElem, + hash: hashValue, + }) + } + + processQueue := func(queue queue.Queue) { + queue.Range(func(obj interface{}) bool { + insElem, ok := obj.(*instanceElement) + if !ok { + return true + } + if insElem.instance.InstanceStatus.Code == int32(constant.KernelInstanceStatusEvicting) { + return true + } + partitionInstance(insElem) + return true + }) + } + + processQueue(bcs.selfInstanceQueue) + processQueue(bcs.otherInstanceQueue) + + return canPartitionInstances, fixedOtherInstances +} + +func clearAndFillQueue(instanceQueue queue.Queue, targetInstances []*instanceElement) { + currentMap := make(map[string]*instanceElement, instanceQueue.Len()) + instanceQueue.Range(func(obj interface{}) bool { + insElem, ok := obj.(*instanceElement) + if !ok { + return true + } + if insElem.instance.InstanceStatus.Code == int32(constant.KernelInstanceStatusEvicting) { // evicting实例不考虑在内 + return true + } + currentMap[insElem.instance.InstanceID] = insElem + return true + }) + + targetMap := make(map[string]*instanceElement, len(targetInstances)) + for _, insElem := range targetInstances { + targetMap[insElem.instance.InstanceID] = insElem + } + + for instanceID := range currentMap { + if _, exists := targetMap[instanceID]; !exists { + if err := instanceQueue.DelByID(instanceID); err != nil { + log.GetLogger().Errorf("failed to delete instance %s", instanceID) + } + } + } + + for _, insElem := range targetInstances { + if _, exists := currentMap[insElem.instance.InstanceID]; !exists { + if err := instanceQueue.PushBack(insElem); err != nil { + log.GetLogger().Errorf("failed to push instance %s", insElem.instance.InstanceID) + } + } + } +} + +// HandleFuncOwnerUpdate will reset funcOwner and reassign instances if necessary +func (bcs *basicConcurrencyScheduler) HandleFuncOwnerUpdate(isFuncOwner bool) { + logger := log.GetLogger().With(zap.Any("funcKey", bcs.funcKeyWithRes), zap.Any("isFuncOwner", isFuncOwner)) + logger.Infof("handle funcOwner update") + bcs.Lock() + defer bcs.Unlock() + var ( + becomeOwner bool + resignOwner bool + srcQueue queue.Queue + dstQueue queue.Queue + ) + isOwnerBefore := bcs.isFuncOwner + bcs.isFuncOwner = isFuncOwner + if !isOwnerBefore && isFuncOwner { + becomeOwner = true + srcQueue = bcs.otherInstanceQueue + dstQueue = bcs.selfInstanceQueue + } else if isOwnerBefore && !isFuncOwner { + resignOwner = true + srcQueue = bcs.selfInstanceQueue + dstQueue = bcs.otherInstanceQueue + } else { + logger.Warnf("funcOwner of function in this scheduler %s remains %t, no need to reassign instance", + selfregister.SelfInstanceID, bcs.isFuncOwner) + return + } + + reassignList := bcs.reassignQueues(srcQueue, dstQueue, becomeOwner, resignOwner, logger) + logger.Infof("funcOwner of function in this scheduler %s changes, succeed to reassign instances %+v", + selfregister.SelfInstanceID, reassignList) +} + +func (bcs *basicConcurrencyScheduler) reassignQueues(srcQueue queue.Queue, dstQueue queue.Queue, + becomeOwner bool, resignOwner bool, logger api.FormatLogger) []string { + reassignList := make([]string, 0, utils.DefaultSliceSize) + srcQueue.Range(func(obj interface{}) bool { + insElem, ok := obj.(*instanceElement) + if !ok { + return true + } + // evicting实例self和other都放 + if insElem.instance.InstanceStatus.Code == int32(constant.KernelInstanceStatusEvicting) { + dstQueue.PushBack(insElem) + return true + } + + // isFuncOwner is set before calling this method, checkSelfInstanceInternal will work under new ownership + if (becomeOwner && !bcs.checkSelfInstanceInternal(insElem.instance)) || + (resignOwner && bcs.checkSelfInstanceInternal(insElem.instance)) { + return true + } + reassignList = append(reassignList, insElem.instance.InstanceID) + if err := dstQueue.PushBack(insElem); err != nil { + logger.Errorf("failed to push instance in instance queue error %s", err.Error()) + } + return true + }) + for _, instanceID := range reassignList { + if err := srcQueue.DelByID(instanceID); err != nil { + logger.Errorf("failed to delete instance in instance queue error %s", err.Error()) + } + } + return reassignList +} + +func (bcs *basicConcurrencyScheduler) shouldTriggerColdStart() bool { + // 灰度状态下,新的scheduler不应该触发冷启动,应该快速返回失败 + selfCurVer := os.Getenv(selfregister.CurrentVersionEnvKey) + etcdCurVer := rollout.GetGlobalRolloutHandler().CurrentVersion + if selfCurVer != etcdCurVer && rollout.GetGlobalRolloutHandler().GetCurrentRatio() != 100 { // 100 mean 100% + return false + } + // 灰度状态到100%时,老的scheduler不应该负责冷启动,应该快速返回失败 + if selfCurVer == etcdCurVer && rollout.GetGlobalRolloutHandler().GetCurrentRatio() == 100 { // 100 mean 100% + return false + } + return true +} + +func getInstanceID(obj interface{}) string { + insElem, ok := obj.(*instanceElement) + if ok && insElem.instance != nil { + return insElem.instance.InstanceID + } + return "" +} + +func getSubHealthInstanceFromRecord(subHealthRecord map[string]*instanceElement) *instanceElement { + var ins *instanceElement + for _, v := range subHealthRecord { + if ins == nil { + ins = v + } + if len(ins.threadMap) >= len(v.threadMap) { + ins = v + } + } + return ins +} diff --git a/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/basic_concurrency_scheduler_test.go b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/basic_concurrency_scheduler_test.go new file mode 100644 index 0000000..cff1111 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/basic_concurrency_scheduler_test.go @@ -0,0 +1,1375 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package concurrencyscheduler - +package concurrencyscheduler + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/queue" + "yuanrong/pkg/common/faas_common/resspeckey" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/lease" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" +) + +type fakeInstanceScaler struct { + timer *time.Timer + scaling bool + scaleUpFunc func() + targetRsvInsNum int +} + +func (f *fakeInstanceScaler) SetFuncOwner(isManaged bool) { +} + +func (f *fakeInstanceScaler) SetEnable(enable bool) { +} + +func (f *fakeInstanceScaler) TriggerScale() { + go func() { + time.Sleep(10 * time.Millisecond) + if f.scaleUpFunc != nil { + f.scaleUpFunc() + } + }() +} + +func (f *fakeInstanceScaler) CheckScaling() bool { + if f.timer == nil { + return false + } + select { + case <-f.timer.C: + f.scaling = false + return false + default: + return f.scaling + } +} + +func (f *fakeInstanceScaler) UpdateCreateMetrics(coldStartTime time.Duration) { +} + +func (f *fakeInstanceScaler) HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) { +} + +func (f *fakeInstanceScaler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { +} + +func (f *fakeInstanceScaler) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { +} + +func (f *fakeInstanceScaler) HandleCreateError(createError error) { +} + +func (f *fakeInstanceScaler) GetExpectInstanceNumber() int { + return f.targetRsvInsNum +} + +func (f *fakeInstanceScaler) Destroy() { +} + +func TestMain(m *testing.M) { + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*etcd3.EtcdWatcher).StartList, func(_ *etcd3.EtcdWatcher) {}), + gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + gomonkey.ApplyFunc(etcd3.GetMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + gomonkey.ApplyFunc(etcd3.GetCAEMetaEtcdClient, func() *etcd3.EtcdClient { return &etcd3.EtcdClient{} }), + gomonkey.ApplyFunc((*registry.FaasSchedulerRegistry).WaitForETCDList, func() {}), + gomonkey.ApplyFunc((*etcd3.EtcdClient).AttachAZPrefix, func(_ *etcd3.EtcdClient, key string) string { return key }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + config.GlobalConfig = types.Configuration{} + config.GlobalConfig.AutoScaleConfig = types.AutoScaleConfig{ + SLAQuota: 1000, + ScaleDownTime: 1000, + BurstScaleNum: 100000, + } + config.GlobalConfig.LeaseSpan = 500 + registry.InitRegistry(make(chan struct{})) + m.Run() +} + +func TestNewBasicConcurrencyScheduler(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 1}, + }, resspeckey.ResSpecKey{}, nil, nil) + assert.NotNil(t, bcs) +} + +func TestGetInstanceNumber(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 1}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + err := bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Nil(t, err) + getNum := bcs.GetInstanceNumber(true) + assert.Equal(t, 0, getNum) + bcs.isFuncOwner = true + err = bcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + assert.Nil(t, err) + getNum = bcs.GetInstanceNumber(true) + assert.Equal(t, 1, getNum) +} + +func TestAcquireInstanceBasic(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkInUseInsThd := 0 + checkAvailInsThd := 0 + bcs.addObservers(scheduler.InUseInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkInUseInsThd += delta + }) + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + acqIns1, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance2"}) + assert.Equal(t, scheduler.ErrInsNotExist, err) + assert.Nil(t, acqIns1) + acqIns2, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns2.Instance.InstanceID) + assert.Equal(t, 1, checkInUseInsThd) + assert.Equal(t, 1, checkAvailInsThd) + acqIns3, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns3.Instance.InstanceID) + assert.Equal(t, 2, checkInUseInsThd) + assert.Equal(t, 0, checkAvailInsThd) + acqIns4, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, scheduler.ErrNoInsAvailable, err) + assert.Nil(t, acqIns4) + defer gomonkey.ApplyFunc((*lease.GenericInstanceLeaseManager).CreateInstanceLease, + func(_ *lease.GenericInstanceLeaseManager, + insAlloc *types.InstanceAllocation, interval time.Duration, callback func()) (types.InstanceLease, error) { + return nil, errors.New("some error") + }).Reset() + bcs.ReleaseInstance(acqIns3) + _, err = bcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.NotNil(t, err) +} + +func TestAcquireInstanceOtherQueue(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = false + checkInUseInsThd := 0 + checkAvailInsThd := 0 + bcs.addObservers(scheduler.InUseInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkInUseInsThd += delta + }) + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + acqIns1, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance2"}) + assert.Equal(t, scheduler.ErrInsNotExist, err) + assert.Nil(t, acqIns1) + acqIns2, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns2.Instance.InstanceID) + assert.Equal(t, 0, checkInUseInsThd) + assert.Equal(t, 0, checkAvailInsThd) + acqIns3, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns3.Instance.InstanceID) + assert.Equal(t, 0, checkInUseInsThd) + assert.Equal(t, 0, checkAvailInsThd) + acqIns4, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, scheduler.ErrNoInsAvailable, err) + assert.Nil(t, acqIns4) +} + +func TestAcquireInstanceWithSession(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 4}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkInUseInsThd := 0 + checkAvailInsThd := 0 + bcs.addObservers(scheduler.InUseInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkInUseInsThd += delta + }) + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 4, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 4, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + acqIns1, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{ + InstanceSession: commonTypes.InstanceSessionConfig{ + SessionID: "session1", + Concurrency: 2, + }, + }) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns1.Instance.InstanceID) + assert.Equal(t, 2, checkInUseInsThd) + assert.Equal(t, 6, checkAvailInsThd) + acqIns2, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{ + InstanceSession: commonTypes.InstanceSessionConfig{ + SessionID: "session1", + Concurrency: 2, + }, + }) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns2.Instance.InstanceID) + assert.Equal(t, 2, checkInUseInsThd) + assert.Equal(t, 6, checkAvailInsThd) + acqIns3, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{ + InstanceSession: commonTypes.InstanceSessionConfig{ + SessionID: "session1", + Concurrency: 2, + }, + }) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns3.Instance.InstanceID) + assert.Equal(t, 3, checkInUseInsThd) + assert.Equal(t, 5, checkAvailInsThd) +} + +func TestReleaseInstance(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkInUseInsThd := 0 + checkAvailInsThd := 0 + bcs.addObservers(scheduler.InUseInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkInUseInsThd += delta + }) + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + err := bcs.ReleaseInstance(&types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "instance3", + }, + }) + assert.Equal(t, scheduler.ErrInsNotExist, err) + acqIns1, _ := bcs.AcquireInstance(&types.InstanceAcquireRequest{}) + err = bcs.ReleaseInstance(acqIns1) + assert.Nil(t, err) + assert.Equal(t, 0, checkInUseInsThd) + assert.Equal(t, 2, checkAvailInsThd) + err = bcs.ReleaseInstance(&types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceID: "instance2", + }, + }) + assert.Nil(t, err) + assert.Equal(t, 2, checkAvailInsThd) +} + +func TestReleaseInstanceWithSession(t *testing.T) { + mockTimer := time.NewTimer(100 * time.Millisecond) + defer gomonkey.ApplyFunc(time.NewTimer, func(d time.Duration) *time.Timer { + mockTimer.Reset(100 * time.Millisecond) + return mockTimer + }).Reset() + defer gomonkey.ApplyFunc((*lease.GenericInstanceLeaseManager).CreateInstanceLease, + func(_ *lease.GenericInstanceLeaseManager, + insAlloc *types.InstanceAllocation, interval time.Duration, callback func()) (types.InstanceLease, error) { + return nil, nil + }).Reset() + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkInUseInsThd := 0 + checkAvailInsThd := 0 + bcs.addObservers(scheduler.InUseInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkInUseInsThd += delta + }) + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 4, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + acqIns1, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{ + InstanceSession: commonTypes.InstanceSessionConfig{ + SessionID: "session1", + SessionTTL: 1, + Concurrency: 2, + }, + }) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns1.Instance.InstanceID) + err = bcs.ReleaseInstance(acqIns1) + assert.Nil(t, err) + assert.Equal(t, 2, checkInUseInsThd) + assert.Equal(t, 2, checkAvailInsThd) + time.Sleep(50 * time.Millisecond) + acqIns2, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{ + InstanceSession: commonTypes.InstanceSessionConfig{ + SessionID: "session1", + SessionTTL: 1, + Concurrency: 2, + }, + }) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns2.Instance.InstanceID) + time.Sleep(50 * time.Millisecond) + err = bcs.ReleaseInstance(acqIns2) + assert.Nil(t, err) + assert.Equal(t, 2, checkInUseInsThd) + assert.Equal(t, 2, checkAvailInsThd) + time.Sleep(150 * time.Millisecond) + assert.Equal(t, 0, checkInUseInsThd) + assert.Equal(t, 4, checkAvailInsThd) +} + +func TestAddInstance(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkAvailInsThd := 0 + checkTotalInsThd := 0 + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + err := bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, scheduler.ErrInternal, err) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 0, checkTotalInsThd) + err = bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Nil(t, err) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 2, checkTotalInsThd) + err = bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, scheduler.ErrInsAlreadyExist, err) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 2, checkTotalInsThd) + err = bcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + assert.Nil(t, err) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + err = bcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, scheduler.ErrInsAlreadyExist, err) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + + // evicting实例能添加进去,但是指标不上报 + err = bcs.AddInstance(&types.Instance{ + InstanceID: "instance3", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicting)}, + }) + assert.Nil(t, err) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) +} + +func TestDelInstance(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkAvailInsThd := 0 + checkInUsedInsThd := 0 + checkTotalInsThd := 0 + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.addObservers(scheduler.InUseInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkInUsedInsThd += delta + }) + bcs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance_evicting", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicting)}, + }) + err := bcs.DelInstance(&types.Instance{ + InstanceID: "instance3", + }) + assert.Equal(t, scheduler.ErrInsNotExist, err) + err = bcs.DelInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Nil(t, err) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 0, checkInUsedInsThd) + assert.Equal(t, 2, checkTotalInsThd) + err = bcs.DelInstance(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Nil(t, err) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 0, checkInUsedInsThd) + assert.Equal(t, 0, checkTotalInsThd) + + // evicting实例能正常删除,并且不影响指标 + err = bcs.DelInstance(&types.Instance{ + InstanceID: "instance_evicting", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Nil(t, err) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 0, checkInUsedInsThd) + assert.Equal(t, 0, checkTotalInsThd) + +} + +func TestPopInstanceElement(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkAvailInsThd := 0 + checkInUsedInsThd := 0 + checkTotalInsThd := 0 + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.addObservers(scheduler.InUseInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkInUsedInsThd += delta + }) + bcs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + popIns1 := bcs.popInstanceElement(forward, nil, false) + assert.Nil(t, popIns1) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance_evicting", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicting)}, + }) + popIns2 := bcs.popInstanceElement(forward, nil, false) + assert.Equal(t, "instance2", popIns2.instance.InstanceID) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 0, checkInUsedInsThd) + assert.Equal(t, 2, checkTotalInsThd) + popIns3 := bcs.popInstanceElement(forward, func(element *instanceElement) bool { return false }, false) + assert.Nil(t, popIns3) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 0, checkInUsedInsThd) + assert.Equal(t, 2, checkTotalInsThd) + popIns4 := bcs.popInstanceElement(forward, func(element *instanceElement) bool { return true }, false) + assert.Equal(t, "instance1", popIns4.instance.InstanceID) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 0, checkInUsedInsThd) + assert.Equal(t, 0, checkTotalInsThd) + + // evicting实例仅供绑定会话的申请租约请求使用,不干涉扩缩容逻辑,因此无法pop该实例 + popIns5 := bcs.popInstanceElement(forward, func(element *instanceElement) bool { return true }, false) + assert.Nil(t, popIns5) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 0, checkInUsedInsThd) + assert.Equal(t, 0, checkTotalInsThd) +} + +func TestSignalAllInstances(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.AddInstance(&types.Instance{ + InstanceID: "instance_evicting", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicting)}, + }) + insIDList := make([]string, 0, 3) + bcs.SignalAllInstances(func(instance *types.Instance) { + insIDList = append(insIDList, instance.InstanceID) + }) + assert.Contains(t, insIDList, "instance1") + assert.Contains(t, insIDList, "instance2") + // evicting实例可能还需要给会话请求使用,因此仍然需要被signal + assert.Contains(t, insIDList, "instance_evicting") +} + +func TestHandleInstanceUpdate(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkAvailInsThd := 0 + checkTotalInsThd := 0 + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + _, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance2"}) + assert.Equal(t, scheduler.ErrInsSubHealthy, err) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + selfregister.GlobalSchedulerProxy.Add(&commonTypes.InstanceInfo{InstanceName: "scheduler1"}, "") + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance3", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + CreateSchedulerID: "scheduler1", + Permanent: true, + }) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance3", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + CreateSchedulerID: "scheduler1", + Permanent: true, + }) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance3", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + CreateSchedulerID: "scheduler1", + Permanent: true, + }) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) +} + +func TestHandleInstanceUpdate_withEvictingInstance(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + checkAvailInsThd := 0 + checkTotalInsThd := 0 + checkInUseInsThd := 0 + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + bcs.addObservers(scheduler.InUseInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkInUseInsThd += delta + }) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + assert.Equal(t, 0, checkInUseInsThd) + + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicting)}, + }) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 2, checkTotalInsThd) + assert.Equal(t, 0, checkInUseInsThd) + + obj := bcs.selfInstanceQueue.GetByID("instance2") + ins2, ok := obj.(*instanceElement) + assert.True(t, ok) + ins2.sessionMap["0000"] = &sessionRecord{ + availThdMap: make(map[string]struct{}), + } + ins2.sessionMap["0000"].availThdMap["00"] = struct{}{} + _, err := bcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance2", InstanceSession: commonTypes.InstanceSessionConfig{ + SessionID: "0000", + }}) + assert.Nil(t, err) + + _, err = bcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance2"}) + assert.NotNil(t, err) + + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 2, checkTotalInsThd) + assert.Equal(t, 0, checkInUseInsThd) + + bcs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicting)}, + }) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 0, checkTotalInsThd) + assert.Equal(t, 0, checkInUseInsThd) +} + +func Test_basicConcurrencyScheduler_ReassignInstance(t *testing.T) { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + checkAvailInsThd := 0 + checkTotalInsThd := 0 + bcs.addObservers(scheduler.AvailInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkAvailInsThd += delta + }) + bcs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + instance1 := &types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + } + instance2 := &types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + } + instance_evicting := &types.Instance{ + InstanceID: "instance_evicting", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusEvicting)}, + } + convey.Convey("test HandleFuncOwnerUpdate", t, func() { + convey.Convey("become owner", func() { + checkAvailInsThd = 0 + checkTotalInsThd = 0 + bcs.isFuncOwner = false + bcs.AddInstance(instance1) + bcs.AddInstance(instance2) + bcs.AddInstance(instance_evicting) + defer bcs.DelInstance(instance_evicting) + bcs.HandleFuncOwnerUpdate(true) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + assert.True(t, bcs.selfInstanceQueue.GetByID(instance_evicting.InstanceID) != nil) + assert.True(t, bcs.otherInstanceQueue.GetByID(instance_evicting.InstanceID) != nil) + bcs.DelInstance(instance1) + bcs.DelInstance(instance2) + }) + convey.Convey("resign owner", func() { + checkAvailInsThd = 0 + checkTotalInsThd = 0 + bcs.isFuncOwner = true + bcs.AddInstance(instance1) + bcs.AddInstance(instance2) + bcs.AddInstance(instance_evicting) + defer bcs.DelInstance(instance_evicting) + bcs.HandleFuncOwnerUpdate(false) + assert.Equal(t, 0, checkAvailInsThd) + assert.Equal(t, 0, checkTotalInsThd) + assert.True(t, bcs.selfInstanceQueue.GetByID(instance_evicting.InstanceID) != nil) + assert.True(t, bcs.otherInstanceQueue.GetByID(instance_evicting.InstanceID) != nil) + bcs.DelInstance(instance1) + bcs.DelInstance(instance2) + }) + convey.Convey("no change", func() { + checkAvailInsThd = 0 + checkTotalInsThd = 0 + bcs.isFuncOwner = true + bcs.AddInstance(instance1) + bcs.AddInstance(instance2) + bcs.AddInstance(instance_evicting) + defer bcs.DelInstance(instance_evicting) + bcs.HandleFuncOwnerUpdate(true) + assert.Equal(t, 2, checkAvailInsThd) + assert.Equal(t, 4, checkTotalInsThd) + assert.True(t, bcs.selfInstanceQueue.GetByID(instance_evicting.InstanceID) != nil) + bcs.DelInstance(instance1) + bcs.DelInstance(instance2) + }) + }) +} + +func Test_basicConcurrencyScheduler_scheduleRequest(t *testing.T) { + convey.Convey("test scheduleRequest", t, func() { + convey.Convey("baseline", func() { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = false + p := gomonkey.ApplyFunc((*basicConcurrencyScheduler).acquireInstanceInternal, + func(_ *basicConcurrencyScheduler, + queue queue.Queue, request *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) { + return &types.InstanceAllocation{ + Instance: &types.Instance{ + InstanceType: "bbb", + InstanceID: "ccc", + }, + AllocationID: "aaa", + }, nil + }) + defer p.Reset() + insAlloc, err := bcs.scheduleRequest(&types.InstanceAcquireRequest{}) + convey.So(err, convey.ShouldBeNil) + convey.So(insAlloc.Instance.InstanceID, convey.ShouldEqual, "ccc") + }) + convey.Convey("acquire failed", func() { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = false + p := gomonkey.ApplyFunc((*basicConcurrencyScheduler).acquireInstanceInternal, + func(_ *basicConcurrencyScheduler, + queue queue.Queue, request *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) { + return nil, fmt.Errorf("error") + }) + defer p.Reset() + insAlloc, err := bcs.scheduleRequest(&types.InstanceAcquireRequest{}) + convey.So(err, convey.ShouldNotBeNil) + convey.So(insAlloc, convey.ShouldBeNil) + }) + convey.Convey("session bind", func() { + bcs := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 1}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + bcs.isFuncOwner = true + bcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + }) + insAcqReq := &types.InstanceAcquireRequest{ + InstanceSession: commonTypes.InstanceSessionConfig{ + SessionID: "123", + SessionTTL: 10, + Concurrency: 1, + }, + } + insAlloc1, err := bcs.AcquireInstance(insAcqReq) + convey.So(err, convey.ShouldBeNil) + convey.So(insAlloc1.Instance.InstanceID, convey.ShouldEqual, "instance1") + _, err = bcs.scheduleRequest(insAcqReq) + convey.So(err, convey.ShouldNotBeNil) + err = insAlloc1.Lease.Release() + convey.So(err, convey.ShouldBeNil) + insAlloc2, err := bcs.scheduleRequest(insAcqReq) + convey.So(err, convey.ShouldBeNil) + convey.So(insAlloc2.Instance.InstanceID, convey.ShouldEqual, "instance1") + }) + }) +} + +// 测试初始10个instance,分配、更新ratio后分配,两个scheduler self和other队列对应相等 +func TestReassignInstancesGray(t *testing.T) { + mainScheduler := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + mainScheduler.isFuncOwner = true + + grayScheduler := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + grayScheduler.isFuncOwner = true + + for i := 1; i <= 10; i++ { + instance := &types.Instance{ + InstanceID: fmt.Sprintf("instance%d", i), + ConcurrentNum: i, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + } + assert.NoError(t, mainScheduler.AddInstance(instance)) + assert.NoError(t, grayScheduler.AddInstance(instance)) + } + + config.GlobalConfig.EnableRollout = true + selfregister.IsRollingOut = true + defer func() { + config.GlobalConfig.EnableRollout = false + selfregister.IsRollingOut = false + }() + + selfregister.IsRolloutObject = false + mainScheduler.ReassignInstanceWhenGray(50) + assert.Equal(t, 5, mainScheduler.selfInstanceQueue.Len()) + assert.Equal(t, 5, mainScheduler.otherInstanceQueue.Len()) + + selfregister.IsRolloutObject = true + grayScheduler.ReassignInstanceWhenGray(50) + assert.Equal(t, mainScheduler.selfInstanceQueue.Len(), grayScheduler.otherInstanceQueue.Len()) + mainScheduler.selfInstanceQueue.Range(func(obj interface{}) bool { + insElem, _ := obj.(*instanceElement) + insElemIn2 := grayScheduler.otherInstanceQueue.GetByID(insElem.instance.InstanceID) + assert.NotNil(t, insElemIn2) + return true + }) + + selfregister.IsRolloutObject = false + mainScheduler.ReassignInstanceWhenGray(70) + assert.Equal(t, 3, mainScheduler.selfInstanceQueue.Len()) + assert.Equal(t, 7, mainScheduler.otherInstanceQueue.Len()) + + selfregister.IsRolloutObject = true + grayScheduler.ReassignInstanceWhenGray(70) + assert.Equal(t, 7, grayScheduler.selfInstanceQueue.Len()) + assert.Equal(t, 3, grayScheduler.otherInstanceQueue.Len()) +} + +// 10个节点10%灰度 9个节点到10个节点删除、增加触发重分配 +func TestReassignInstancesGrayWhenAddOrDelReassign(t *testing.T) { + mainScheduler := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + mainScheduler.isFuncOwner = true + + grayScheduler := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + grayScheduler.isFuncOwner = true + // 当前(9,0) + for i := 1; i <= 9; i++ { + instance := &types.Instance{ + InstanceID: fmt.Sprintf("instance%d", i), + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + } + assert.NoError(t, mainScheduler.AddInstance(instance)) + assert.NoError(t, grayScheduler.AddInstance(instance)) + } + config.GlobalConfig.EnableRollout = true + selfregister.IsRollingOut = true + defer func() { + config.GlobalConfig.EnableRollout = false + selfregister.IsRollingOut = false + }() + // main重分配 + selfregister.IsRolloutObject = false + mainScheduler.ReassignInstanceWhenGray(10) + + assert.Equal(t, 9, mainScheduler.selfInstanceQueue.Len()) + assert.Equal(t, 0, mainScheduler.otherInstanceQueue.Len()) + // gray重分配 + selfregister.IsRolloutObject = true + grayScheduler.ReassignInstanceWhenGray(10) + + // 2个sc各再加入一个 判断先加入了self (10,0) 应该自动变成(9,1) + instance := &types.Instance{ + InstanceID: fmt.Sprintf("instance%d", 10), + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + } + + instance9 := &types.Instance{ + InstanceID: fmt.Sprintf("instance%d", 9), + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + } + + // main加入 + selfregister.IsRolloutObject = false + assert.NoError(t, mainScheduler.AddInstance(instance)) + // gray加入 + selfregister.IsRolloutObject = true + assert.NoError(t, grayScheduler.AddInstance(instance)) + + assert.Equal(t, mainScheduler.selfInstanceQueue.Len(), grayScheduler.otherInstanceQueue.Len()) + assert.Equal(t, 9, mainScheduler.selfInstanceQueue.Len()) + assert.Equal(t, 1, mainScheduler.otherInstanceQueue.Len()) + mainScheduler.selfInstanceQueue.Range(func(obj interface{}) bool { + insElem, _ := obj.(*instanceElement) + insElemIn2 := grayScheduler.otherInstanceQueue.GetByID(insElem.instance.InstanceID) + assert.NotNil(t, insElemIn2) + return true + }) + //确定9 hash最大被分到了other + assert.Equal(t, "instance9", mainScheduler.otherInstanceQueue.Front().(*instanceElement).instance.InstanceID) + + // 假设删除9 + selfregister.IsRolloutObject = false + assert.NoError(t, mainScheduler.DelInstance(&types.Instance{ + InstanceID: "instance9", + })) + selfregister.IsRolloutObject = true + assert.NoError(t, grayScheduler.DelInstance(&types.Instance{ + InstanceID: "instance9", + })) + // 这里应该不会触发reassign + assert.Equal(t, 9, mainScheduler.selfInstanceQueue.Len()) + assert.Equal(t, 0, mainScheduler.otherInstanceQueue.Len()) + + // 但是如果加入9,然后删除1,会触发reassign + // 以update的方式加入 + // main重新加入9 - 不会触发reassign + selfregister.IsRolloutObject = false + mainScheduler.HandleInstanceUpdate(instance9) + + // gray重新加入9 + selfregister.IsRolloutObject = true + grayScheduler.HandleInstanceUpdate(instance9) + + // 验证9还是被分配到other + assert.Equal(t, "instance9", mainScheduler.otherInstanceQueue.Front().(*instanceElement).instance.InstanceID) + + // 删除1 -应该触发重分配(8,1)->(9,0) + selfregister.IsRolloutObject = false + assert.NoError(t, mainScheduler.DelInstance(&types.Instance{ + InstanceID: "instance1", + })) + selfregister.IsRolloutObject = true + assert.NoError(t, grayScheduler.DelInstance(&types.Instance{ + InstanceID: "instance1", + })) + assert.Equal(t, 9, mainScheduler.selfInstanceQueue.Len()) + assert.Equal(t, 0, mainScheduler.otherInstanceQueue.Len()) +} + +// 测试空指针防御 +func TestReassignInstancesGrayBothQueuesInitiallyEmpty(t *testing.T) { + config.GlobalConfig.EnableRollout = true + selfregister.IsRollingOut = true + defer func() { + config.GlobalConfig.EnableRollout = false + selfregister.IsRollingOut = false + }() + scheduler1 := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + scheduler1.isFuncOwner = true + // 空状态下reassign无空指针 + scheduler1.ReassignInstanceWhenGray(50) + assert.Equal(t, 0, scheduler1.selfInstanceQueue.Len()) + assert.Equal(t, 0, scheduler1.otherInstanceQueue.Len()) + // 测试空状态下增加删除无空指针 + instance := &types.Instance{ + InstanceID: fmt.Sprintf("instance%d", 10), + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + } + // 当前是旧sc + selfregister.IsRolloutObject = false + scheduler1.HandleInstanceUpdate(instance) + assert.Equal(t, 1, scheduler1.selfInstanceQueue.Len()) + assert.NoError(t, scheduler1.DelInstance(&types.Instance{ + InstanceID: "instance10", + })) + + // 测试边界 + scheduler1.HandleInstanceUpdate(instance) + scheduler1.ReassignInstanceWhenGray(0) + assert.Equal(t, 1, scheduler1.selfInstanceQueue.Len()) + + // 全部灰度后 + scheduler1.ReassignInstanceWhenGray(100) + // 加入other,已经有了报错 + assert.Error(t, scheduler1.AddInstance(instance)) + assert.Equal(t, 1, scheduler1.otherInstanceQueue.Len()) + + // 当前是新sc + selfregister.IsRolloutObject = true + scheduler1.HandleInstanceUpdate(instance) + scheduler1.ReassignInstanceWhenGray(0) + assert.Error(t, scheduler1.AddInstance(instance)) + assert.Equal(t, 1, scheduler1.otherInstanceQueue.Len()) + + // 全部灰度后 + scheduler1.ReassignInstanceWhenGray(100) + // 加入self,已经有了报错 + assert.Error(t, scheduler1.AddInstance(instance)) + assert.Equal(t, 1, scheduler1.selfInstanceQueue.Len()) +} + +func TestReassignInstancesGrayWhenFixedOtherInstance(t *testing.T) { + config.GlobalConfig.EnableRollout = true + selfregister.IsRollingOut = true + defer func() { + config.GlobalConfig.EnableRollout = false + selfregister.IsRollingOut = false + }() + + scheduler1 := newBasicConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commonTypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance), + queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance)) + instance := &types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + Permanent: true, + CreateSchedulerID: "abc", + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + } + selfregister.IsRolloutObject = false + assert.NoError(t, scheduler1.AddInstance(instance)) + assert.Equal(t, 0, scheduler1.selfInstanceQueue.Len()) + assert.Equal(t, 1, scheduler1.otherInstanceQueue.Len()) + + selfregister.IsRolloutObject = false + scheduler1.ReassignInstanceWhenGray(50) + assert.Equal(t, 0, scheduler1.selfInstanceQueue.Len()) + assert.Equal(t, 1, scheduler1.otherInstanceQueue.Len()) +} + +func TestInstanceQueueWithSubHealthAndEvictingRecord(t *testing.T) { + convey.Convey("Test instanceQueueWithSubHealthAndEvictingRecord", t, func() { + // 创建 mock queue 和 mock instanceElement + newInstanceQueue := func() *instanceQueueWithSubHealthAndEvictingRecord { + mockQueue := queue.NewFifoQueue(getInstanceID) + mockSubHealthRecord := make(map[string]*instanceElement) + mockEvictingRecord := make(map[string]*instanceElement) + return &instanceQueueWithSubHealthAndEvictingRecord{ + instanceQueue: mockQueue, + subHealthRecord: mockSubHealthRecord, + evictingRecord: mockEvictingRecord, + } + } + iq := newInstanceQueue() + + insID := "test-instance-id" + insElem := &instanceElement{ + instance: &types.Instance{ + InstanceID: insID, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + }, + } + + convey.Convey("PushBack, instance already exists", func() { + iq = newInstanceQueue() + iq.subHealthRecord[insID] = insElem + err := iq.PushBack(insElem) + convey.So(err, convey.ShouldEqual, scheduler.ErrInsAlreadyExist) + }) + + convey.Convey("PopSubHealth, subHealthRecord is empty", func() { + iq = newInstanceQueue() + result := iq.PopSubHealth() + convey.So(result, convey.ShouldBeNil) + }) + + convey.Convey("GetByID should return instance from subHealthRecord", func() { + iq = newInstanceQueue() + iq.subHealthRecord[insID] = insElem + result := iq.GetByID(insID) + convey.So(result, convey.ShouldEqual, insElem) + }) + + convey.Convey("DelByID", func() { + iq = newInstanceQueue() + iq.PushBack(insElem) + + err := iq.DelByID(insID) + assert.NoError(t, err) + assert.False(t, iq.subHealthRecord[insID] != nil) + assert.False(t, iq.evictingRecord[insID] != nil) + assert.Nil(t, iq.instanceQueue.GetByID(insID)) + }) + + convey.Convey("complex", func() { + iq = newInstanceQueue() + err := iq.PushBack(insElem) + convey.So(err, convey.ShouldBeNil) + convey.So(iq.instanceQueue.GetByID(insElem.instance.InstanceID) == nil, convey.ShouldBeFalse) + convey.So(len(iq.subHealthRecord), convey.ShouldEqual, 0) + convey.So(len(iq.evictingRecord), convey.ShouldEqual, 0) + convey.So(iq.Len(), convey.ShouldEqual, 1) + + insElem.instance.InstanceStatus.Code = int32(constant.KernelInstanceStatusSubHealth) + err = iq.UpdateObjByID(insID, insElem) + convey.So(err, convey.ShouldBeNil) + convey.So(iq.instanceQueue.GetByID(insElem.instance.InstanceID) == nil, convey.ShouldBeTrue) + _, ok := iq.subHealthRecord[insID] + convey.So(ok, convey.ShouldBeTrue) + convey.So(len(iq.evictingRecord), convey.ShouldEqual, 0) + convey.So(iq.Len(), convey.ShouldEqual, 1) + + insElem.instance.InstanceStatus.Code = int32(constant.KernelInstanceStatusEvicting) + err = iq.UpdateObjByID(insID, insElem) + convey.So(err, convey.ShouldBeNil) + convey.So(iq.instanceQueue.GetByID(insElem.instance.InstanceID) == nil, convey.ShouldBeTrue) + _, ok = iq.evictingRecord[insID] + convey.So(ok, convey.ShouldBeTrue) + convey.So(len(iq.subHealthRecord), convey.ShouldEqual, 0) + convey.So(iq.Len(), convey.ShouldEqual, 0) + }) + }) +} diff --git a/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/grayinstanceallocator.go b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/grayinstanceallocator.go new file mode 100644 index 0000000..ab47d51 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/grayinstanceallocator.go @@ -0,0 +1,215 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package concurrencyscheduler + +import ( + "hash/crc32" + "math" + "sort" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/selfregister" +) + +// InstanceOperType - +type InstanceOperType int + +const ( + // Add - + Add InstanceOperType = iota + // Del - + Del + // Find - + Find +) + +const ( + percentageBase = 100 +) + +// HashedInstance - +type HashedInstance struct { + InsElem *instanceElement + hash uint32 +} + +// Hasher 哈希函数接口 +type Hasher interface { + Hash(key string) uint32 +} + +// CRC32Hasher - +type CRC32Hasher struct{} + +// Hash - +func (*CRC32Hasher) Hash(key string) uint32 { + return crc32.ChecksumIEEE([]byte(key)) +} + +// GrayInstanceAllocator 灰度分配器接口 +type GrayInstanceAllocator interface { + ComputeHash(instanceKey string) uint32 + UpdateRolloutRatio(rolloutPercent int) + Partition(instances []*HashedInstance, isGrayNode bool) (self, other []*instanceElement) + CheckSelf(isGrayNode bool, instanceKey string) bool + ShouldReassign(operType InstanceOperType, instanceKey string) bool + GetRolloutRatio() float64 +} + +// HashBasedInstanceAllocator 灰度分配器实现 +type HashBasedInstanceAllocator struct { + hasher Hasher + rolloutRatio float64 + boundaryHash uint32 + maxHashValue uint32 + grayCount int + notGrayCount int +} + +// ShouldReassign - +func (h *HashBasedInstanceAllocator) ShouldReassign(operType InstanceOperType, instanceKey string) bool { + if !selfregister.IsRollingOut { + return false + } + h.modifyCount(operType, instanceKey) + count := h.CountFloorGrayCount(h.grayCount + h.notGrayCount) + // 多了: (9,1)->(8,1) 10% 需要调成(9,0) + // 少了: (19,2)->(19,1)应该重新划分成(18,2) + if h.grayCount != count { + return true + } else { + return false + } +} + +func (h *HashBasedInstanceAllocator) checkSelfByBoundary(isGrayNode bool, hashValue uint32) bool { + // [旧节点self 70% | 新节点self 30%] 从小到大 + if isGrayNode { + return hashValue > h.boundaryHash + } + return hashValue <= h.boundaryHash +} + +// CheckSelf - +func (h *HashBasedInstanceAllocator) CheckSelf(isGrayNode bool, instanceKey string) bool { + hashValue := h.hasher.Hash(instanceKey) + return h.checkSelfByBoundary(isGrayNode, hashValue) +} + +// ModifyCount - +func (h *HashBasedInstanceAllocator) modifyCount(operType InstanceOperType, instanceKey string) { + hashValue := h.hasher.Hash(instanceKey) + switch operType { + case Add: + if hashValue <= h.boundaryHash { + h.notGrayCount++ + } else { + h.grayCount++ + } + case Del: + if hashValue <= h.boundaryHash { + h.notGrayCount-- + } else { + h.grayCount-- + } + default: + } +} + +// NewHashBasedInstanceAllocator 创建新的基于哈希的灰度分配器 +func NewHashBasedInstanceAllocator(rolloutRatio float64) GrayInstanceAllocator { + var hasher Hasher = &CRC32Hasher{} + + return &HashBasedInstanceAllocator{ + hasher: hasher, + rolloutRatio: rolloutRatio, + maxHashValue: math.MaxUint32, + boundaryHash: math.MaxUint32, + } +} + +// CountFloorGrayCount 向下取整 灰度数量 9*10% = 0, 19*10% = 1 +func (h *HashBasedInstanceAllocator) CountFloorGrayCount(totalCount int) int { + targetGrayCount := int(math.Floor(float64(totalCount) * h.rolloutRatio)) + return targetGrayCount +} + +// Partition 调整成2个队列,并且更新内部的graycount计数和boundary +func (h *HashBasedInstanceAllocator) Partition(instances []*HashedInstance, isGrayNode bool) (self, + other []*instanceElement) { + total := len(instances) + if total == 0 { + return + } + + targetGrayCount := h.CountFloorGrayCount(total) + targetNotGrayCount := total - targetGrayCount + + sort.Slice(instances, func(i, j int) bool { + return instances[i].hash > instances[j].hash + }) + + selfCap, otherCap := targetNotGrayCount, targetGrayCount + if isGrayNode { + selfCap, otherCap = targetGrayCount, targetNotGrayCount + } + + self = make([]*instanceElement, 0, selfCap) + other = make([]*instanceElement, 0, otherCap) + + for i, item := range instances { + addToSelf := isGrayNode && (i < targetGrayCount) || !isGrayNode && (i >= targetGrayCount) + if addToSelf { + self = append(self, item.InsElem) + } else { + other = append(other, item.InsElem) + } + } + log.GetLogger().Infof("partition ratio %.2f (notGrayCount,grayCount): (%d,%d)->(%d,%d). isGrayNode: %t, "+ + "(self,other): (%d, %d)", h.rolloutRatio, h.notGrayCount, h.grayCount, targetNotGrayCount, targetGrayCount, + isGrayNode, len(self), len(other)) + // 因为旧节点判断self的时候带==, 灰度100%时候边界应为0,否则旧节点添加、删除时会判断错误 + newBoundary := h.maxHashValue + if targetGrayCount == total { + newBoundary = 0 + } else { + newBoundary = instances[targetGrayCount].hash + } + + h.boundaryHash = newBoundary + h.notGrayCount = targetNotGrayCount + h.grayCount = targetGrayCount + return +} + +// ComputeHash - +func (h *HashBasedInstanceAllocator) ComputeHash(instanceKey string) uint32 { + hashValue := h.hasher.Hash(instanceKey) + return hashValue +} + +// UpdateRolloutRatio - +func (h *HashBasedInstanceAllocator) UpdateRolloutRatio(rolloutPercent int) { + log.GetLogger().Infof("update gray allocator ratio from %.2f to %d", h.rolloutRatio, rolloutPercent) + h.rolloutRatio = float64(rolloutPercent) / percentageBase + return +} + +// GetRolloutRatio - +func (h *HashBasedInstanceAllocator) GetRolloutRatio() float64 { + return h.rolloutRatio +} diff --git a/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/reserved_concurrency_scheduler.go b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/reserved_concurrency_scheduler.go new file mode 100644 index 0000000..6a47fe5 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/reserved_concurrency_scheduler.go @@ -0,0 +1,213 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package concurrencyscheduler - +package concurrencyscheduler + +import ( + "time" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/queue" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + checkScalingInterval = 50 * time.Millisecond +) + +// ReservedConcurrencyScheduler will schedule instance according to concurrency usage for reserved instance +type ReservedConcurrencyScheduler struct { + basicConcurrencyScheduler + instanceScaler scaler.InstanceScaler + checkScalingTimeout time.Duration + createErr error +} + +// NewReservedConcurrencyScheduler creates ReservedConcurrencyScheduler +func NewReservedConcurrencyScheduler(funcSpec *types.FunctionSpecification, resKey resspeckey.ResSpecKey, + requestTimeout time.Duration, insThdReqQueue *requestqueue.InsAcqReqQueue) scheduler.InstanceScheduler { + instanceQueue := queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance) + otherQueue := queue.NewPriorityQueue(getInstanceID, priorityFuncForReservedInstance) + reservedConcurrencyScheduler := &ReservedConcurrencyScheduler{ + basicConcurrencyScheduler: newBasicConcurrencyScheduler(funcSpec, resKey, instanceQueue, otherQueue), + checkScalingTimeout: requestTimeout, + } + reservedConcurrencyScheduler.insAcqReqQueue = insThdReqQueue + insThdReqQueue.RegisterSchFunc("reserveScheduleFunc", reservedConcurrencyScheduler.scheduleRequest) + reservedConcurrencyScheduler.addObservers(scheduler.AvailInsThdTopic, func(data interface{}) { + availInsThdDiff, ok := data.(int) + // schedule request even if availInsThdDiff == 0, because some instance thread changes like session doesn't + // affect availInsThdDiff + if !ok || availInsThdDiff < 0 { + return + } + insThdReqQueue.ScheduleRequest("reserveScheduleFunc") + }) + log.GetLogger().Infof("succeed to create ReservedConcurrencyScheduler for function %s isFuncOwner %t", + reservedConcurrencyScheduler.funcKeyWithRes, reservedConcurrencyScheduler.isFuncOwner) + return reservedConcurrencyScheduler +} + +// AcquireInstance acquires an instance chosen by instanceScheduler +func (rcs *ReservedConcurrencyScheduler) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, error) { + var ( + insAlloc *types.InstanceAllocation + acquireErr error + recordErr error + ) + insAlloc, acquireErr = rcs.basicConcurrencyScheduler.AcquireInstance(insAcqReq) + if acquireErr != nil && acquireErr != scheduler.ErrNoInsAvailable { + return nil, acquireErr + } + if acquireErr == scheduler.ErrNoInsAvailable && rcs.shouldTriggerColdStart() { + // 这里如果是静态函数,则会触发到wisecloudscaler,触发一次nuwa cold start,如果不是静态函数,则会走到replicascaler,没有其他影响 + rcs.publishInsThdEvent(scheduler.TriggerScaleTopic, nil) + if len(insAcqReq.DesignateInstanceID) != 0 { + pendingRequest := &requestqueue.PendingInsAcqReq{ + CreatedTime: time.Now(), + ResultChan: make(chan *requestqueue.PendingInsAcqRsp, 1), + InsAcqReq: insAcqReq, + } + if err := rcs.insAcqReqQueue.AddRequest(pendingRequest); err != nil { + return nil, err + } + insAcqRsp := <-pendingRequest.ResultChan + return insAcqRsp.InsAlloc, insAcqRsp.Error + } + if config.GlobalConfig.DisableReplicaScaler { + return nil, acquireErr + } + // 没租约可用且预留实例个数已经全部拉起来了,此时应该立即返回错误 + if rcs.GetInstanceNumber(true) >= rcs.instanceScaler.GetExpectInstanceNumber() { + return nil, acquireErr + } + // block until instanceScaler finishes scaling reserved instance + var createErr error + ticker := time.NewTicker(checkScalingInterval) + timer := time.NewTimer(rcs.checkScalingTimeout) + defer ticker.Stop() + defer timer.Stop() + loop: + for { + select { + case <-ticker.C: + if insAlloc, acquireErr = rcs.basicConcurrencyScheduler.AcquireInstance(insAcqReq); insAlloc != nil { + return insAlloc, nil + } + // 扩容出来的实例的租约瞬间获取完了,但仍有部分请求未获取到租约,则直接返回去扩弹性实例 + if rcs.GetInstanceNumber(true) == rcs.instanceScaler.GetExpectInstanceNumber() { + return nil, scheduler.ErrNoInsAvailable + } + rcs.RLock() + createErr = rcs.createErr + rcs.RUnlock() + if createErr != nil { + recordErr = createErr + } + // 用户错误直接返回 + if createSnErr, ok := createErr.(snerror.SNError); ok { + if snerror.IsUserError(createSnErr) { + return nil, rcs.createErr + } + } + case <-timer.C: + break loop + } + } + // 超时后若等待过程中出现过错误则返回此错误 + if recordErr != nil { + return nil, recordErr + } + return nil, scheduler.ErrNoInsAvailable + } + return insAlloc, acquireErr +} + +// PopInstance pops an instance, will block and wait for instance which is creating +func (rcs *ReservedConcurrencyScheduler) PopInstance(force bool) *types.Instance { + wait := rcs.instanceScaler.CheckScaling() + if force { + wait = false + } + insElem := rcs.popInstanceElement(forward, nil, wait) + if insElem == nil { + return nil + } + return insElem.instance +} + +// ConnectWithInstanceScaler connects instanceScheduler with an instanceScaler +func (rcs *ReservedConcurrencyScheduler) ConnectWithInstanceScaler(instanceScaler scaler.InstanceScaler) { + rcs.addObservers(scheduler.TriggerScaleTopic, func(data interface{}) { + instanceScaler.TriggerScale() + }) + // check if instanceScaler is a concurrencyInstanceScaler type in future and return error if otherwise + rcs.instanceScaler = instanceScaler + rcs.addObservers(scheduler.InUseInsThdTopic, func(data interface{}) { + inUsedInsThdDiff, ok := data.(int) + if !ok { + return + } + instanceScaler.HandleInsThdUpdate(inUsedInsThdDiff, 0) + }) + rcs.addObservers(scheduler.TotalInsThdTopic, func(data interface{}) { + totalInsThdDiff, ok := data.(int) + if !ok { + return + } + instanceScaler.HandleInsThdUpdate(0, totalInsThdDiff) + rcs.insAcqReqQueue.HandleInsNumUpdate(totalInsThdDiff / rcs.concurrentNum) + }) + return +} + +// HandleFuncSpecUpdate handles funcSpec update comes from ETCD +func (rcs *ReservedConcurrencyScheduler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + rcs.Lock() + rcs.concurrentNum = utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum) + rcs.basicConcurrencyScheduler.HandleFuncSpecUpdate(funcSpec) + rcs.Unlock() +} + +// HandleCreateError handles create error +func (rcs *ReservedConcurrencyScheduler) HandleCreateError(createErr error) { + rcs.Lock() + rcs.createErr = createErr + rcs.Unlock() +} + +// priorityFuncForReservedInstance is the priority function for reserved instance which will average requests among all +// instances since they will not be scaled down +func priorityFuncForReservedInstance(obj interface{}) (int, error) { + insElem, ok := obj.(*instanceElement) + if ok { + weight := len(insElem.threadMap) + if !insElem.isNewInstance { + weight -= insElem.instance.ConcurrentNum + } + return weight, nil + } + return -1, scheduler.ErrTypeConvertFail +} diff --git a/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/reserved_concurrency_scheduler_test.go b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/reserved_concurrency_scheduler_test.go new file mode 100644 index 0000000..5609eed --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/reserved_concurrency_scheduler_test.go @@ -0,0 +1,241 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package concurrencyscheduler - +package concurrencyscheduler + +import ( + "errors" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + commontypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" +) + +func TestNewReservedConcurrencyScheduler(t *testing.T) { + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("", 100*time.Millisecond) + rcs := NewReservedConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 1}, + }, resspeckey.ResSpecKey{}, 0, InsThdReqQueue) + assert.NotNil(t, rcs) +} + +func TestAcquireInstanceReservedNew(t *testing.T) { + config.GlobalConfig.LeaseSpan = 5000 + defer func() { + config.GlobalConfig.LeaseSpan = 0 + }() + defer gomonkey.ApplyGlobalVar(&requestqueue.DefaultRequestTimeout, 100*time.Millisecond).Reset() + defer gomonkey.ApplyFunc((*selfregister.SchedulerProxy).CheckFuncOwner, func(_ *selfregister.SchedulerProxy, + funcKey string) bool { + return true + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&fakeInstanceScaler{}), "GetExpectInstanceNumber", + func(f *fakeInstanceScaler) int { + return 1 + }).Reset() + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("", 100*time.Millisecond) + rcs := NewReservedConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 1}, + }, resspeckey.ResSpecKey{}, 50*time.Millisecond, InsThdReqQueue) + rcs.ConnectWithInstanceScaler(&fakeInstanceScaler{ + scaling: true, + timer: time.NewTimer(100 * time.Millisecond), + }) + + rcs.HandleCreateError(errors.New("some error")) + _, err := rcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, "some error", err.Error()) + rcs.HandleCreateError(snerror.New(4011, "user error")) + _, err = rcs.AcquireInstance(&types.InstanceAcquireRequest{}) + + rcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rcs.HandleCreateError(nil) + insAlloc1, err := rcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Equal(t, nil, err) + assert.Equal(t, "instance1", insAlloc1.Instance.InstanceID) + _, err = rcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Equal(t, false, err == nil) +} + +func TestAcquireInstanceReserved(t *testing.T) { + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("", 10) + rcs := NewReservedConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, 50*time.Millisecond, InsThdReqQueue) + rcs.ConnectWithInstanceScaler(&fakeInstanceScaler{ + scaling: true, + timer: time.NewTimer(100 * time.Millisecond), + }) + _, err := rcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, scheduler.ErrNoInsAvailable, err) + rcs.ConnectWithInstanceScaler(&fakeInstanceScaler{ + scaling: true, + timer: time.NewTimer(10 * time.Millisecond), + }) + _, err = rcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, scheduler.ErrNoInsAvailable, err) + _, err = rcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Equal(t, scheduler.ErrInsNotExist, err) + rcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + acqIns1, err := rcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns1.Instance.InstanceID) + rcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + acqIns2, err := rcs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns2.Instance.InstanceID) + acqIns3, err := rcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns3.Instance.InstanceID) + _, err = rcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance3"}) + assert.Equal(t, "instance does not exist in queue", err.Error()) + + rc := NewReservedConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, 150*time.Millisecond, InsThdReqQueue) + rc.ConnectWithInstanceScaler(&fakeInstanceScaler{ + scaling: true, + timer: time.NewTimer(100 * time.Millisecond), + targetRsvInsNum: 2, + }) + rc.HandleCreateError(errors.New("resource not enough")) + _, err = rc.AcquireInstance(&types.InstanceAcquireRequest{}) + go func() { + time.Sleep(80 * time.Millisecond) + rc.HandleCreateError(nil) + }() + assert.Equal(t, "resource not enough", err.Error()) +} + +func TestPopInstanceReserved(t *testing.T) { + defer gomonkey.ApplyFunc((*selfregister.SchedulerProxy).CheckFuncOwner, func(_ *selfregister.SchedulerProxy, + funcKey string) bool { + return true + }).Reset() + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("", 10) + rcs := NewReservedConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, 50*time.Millisecond, InsThdReqQueue) + rcs.ConnectWithInstanceScaler(&fakeInstanceScaler{}) + popIns1 := rcs.PopInstance(false) + assert.Nil(t, popIns1) + rcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rcs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rcs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + popIns2 := rcs.PopInstance(false) + assert.Equal(t, "instance2", popIns2.InstanceID) +} + +func TestHandleFuncSpecUpdateReserved(t *testing.T) { + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("", 10) + rcs := NewReservedConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, 50*time.Millisecond, InsThdReqQueue) + rcs.ConnectWithInstanceScaler(&fakeInstanceScaler{}) + rcs.HandleFuncSpecUpdate(&types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + ConcurrentNum: 4, + }, + }) +} + +func TestAddInstancePublishReserved(t *testing.T) { + defer gomonkey.ApplyFunc((*selfregister.SchedulerProxy).CheckFuncOwner, func(_ *selfregister.SchedulerProxy, + funcKey string) bool { + return true + }).Reset() + config.GlobalConfig.AutoScaleConfig.BurstScaleNum = 1000 + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 50*time.Millisecond) + insThdReq1 := &requestqueue.PendingInsAcqReq{ + ResultChan: make(chan *requestqueue.PendingInsAcqRsp, 1), + InsAcqReq: &types.InstanceAcquireRequest{}, + } + insThdReq2 := &requestqueue.PendingInsAcqReq{ + ResultChan: make(chan *requestqueue.PendingInsAcqRsp, 1), + InsAcqReq: &types.InstanceAcquireRequest{}, + } + InsThdReqQueue.AddRequest(insThdReq1) + InsThdReqQueue.AddRequest(insThdReq2) + rcs := NewReservedConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, 50*time.Millisecond, InsThdReqQueue) + rcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + }) + time.Sleep(10 * time.Millisecond) + select { + case insThd := <-insThdReq1.ResultChan: + assert.Equal(t, "instance1", insThd.InsAlloc.Instance.InstanceID) + default: + t.Errorf("should get instance from result channel") + } + select { + case insThd := <-insThdReq2.ResultChan: + assert.Equal(t, "instance1", insThd.InsAlloc.Instance.InstanceID) + default: + t.Errorf("should get instance from result channel") + } +} diff --git a/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/scaled_concurrency_scheduler.go b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/scaled_concurrency_scheduler.go new file mode 100644 index 0000000..6356df2 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/scaled_concurrency_scheduler.go @@ -0,0 +1,393 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package concurrencyscheduler - +package concurrencyscheduler + +import ( + "time" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/queue" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +type instanceQueueWithBuffer struct { + queue *queue.PriorityQueue + buffer []*instanceElement + idFunc func(interface{}) string +} + +// Front - +func (p *instanceQueueWithBuffer) Front() interface{} { + if p.queue.Len() != 0 { + return p.queue.Front() + } + if len(p.buffer) == 0 { + return nil + } + return p.buffer[0] +} + +// Back - +func (p *instanceQueueWithBuffer) Back() interface{} { + if p.queue.Len() != 0 { + return p.queue.Back() + } + length := len(p.buffer) + if length == 0 { + return nil + } + item := p.buffer[length-1] + return item +} + +// PopFront - +func (p *instanceQueueWithBuffer) PopFront() interface{} { + if p.queue.Len() != 0 { + return p.queue.PopFront() + } + if len(p.buffer) == 0 { + return nil + } + item := p.buffer[0] + p.buffer = p.buffer[1:] + return item +} + +// PopBack - +func (p *instanceQueueWithBuffer) PopBack() interface{} { + if p.queue.Len() != 0 { + return p.queue.PopBack() + } + length := len(p.buffer) + if length == 0 { + return nil + } + item := p.buffer[length-1] + p.buffer = p.buffer[:length-1] + return item +} + +// PushBack - 在调用pushBack之前,需确保该obj不在instanceQueueWithBuffer中 +func (p *instanceQueueWithBuffer) PushBack(obj interface{}) error { + insElem, ok := obj.(*instanceElement) + if !ok { + return scheduler.ErrTypeConvertFail + } + if len(insElem.threadMap) == 0 { // threadMap == 0 实例需要放入buffer中 + p.buffer = append(p.buffer, insElem) + return nil + } + return p.queue.PushBack(obj) +} + +// GetByID - +func (p *instanceQueueWithBuffer) GetByID(objID string) interface{} { + get := p.queue.GetByID(objID) + if get != nil { + return get + } + for _, item := range p.buffer { + if p.idFunc(item) == objID { + return item + } + } + return nil +} + +// DelByID - +func (p *instanceQueueWithBuffer) DelByID(objID string) error { + err := p.queue.DelByID(objID) + if err == nil { + return nil + } + if err != queue.ErrObjectNotFound { + return err + } + index := -1 + for i, item := range p.buffer { + if p.idFunc(item) == objID { + index = i + break + } + } + if index == -1 { + return queue.ErrObjectNotFound + } + p.buffer = append(p.buffer[0:index], p.buffer[index+1:]...) + return nil +} + +// Range iterates item in queue and process item with given function +func (p *instanceQueueWithBuffer) Range(f func(obj interface{}) bool) { + p.queue.Range(f) + for _, item := range p.buffer { + if !f(item) { + break + } + } +} + +// SortedRange iterates item in queue and process item with given function in order +func (p *instanceQueueWithBuffer) SortedRange(f func(obj interface{}) bool) { + p.queue.SortedRange(f) + for _, item := range p.buffer { + if !f(item) { + break + } + } +} + +// UpdateObjByID has some cases which should be aware: +// 1. for instance who will have no available thread after updating, we should move it from queue to buffer +// 2. for instance had no available thread before and will have available thread after updating, we should mov it from +// buffer to queue +func (p *instanceQueueWithBuffer) UpdateObjByID(objID string, obj interface{}) error { + insElem, ok := obj.(*instanceElement) + if !ok { + return scheduler.ErrTypeConvertFail + } + updateErr := p.queue.UpdateObjByID(objID, obj) + // transfer instance from queue to buffer if there is no available thread in it, transfer instance from buffer to + // queue if there is available thread in it + if updateErr == nil && len(insElem.threadMap) == 0 { + err := p.queue.DelByID(insElem.instance.InstanceID) + if err != nil { + log.GetLogger().Errorf("failed to remove instance %s from scaled instance queue error %s", + insElem.instance.InstanceID, err.Error()) + return err + } + p.buffer = append(p.buffer, insElem) + return nil + } else if updateErr == queue.ErrObjectNotFound { + index := -1 + for i, item := range p.buffer { + if p.idFunc(item) == objID { + index = i + p.buffer[index] = insElem + } + } + if index == -1 { + return queue.ErrObjectNotFound + } + if len(insElem.threadMap) != 0 { + p.buffer = append(p.buffer[0:index], p.buffer[index+1:]...) + err := p.queue.PushBack(insElem) + if err != nil { + return err + } + } + return nil + } + return updateErr +} + +// Len - +func (p *instanceQueueWithBuffer) Len() int { + return p.queue.Len() + len(p.buffer) +} + +// ScaledConcurrencyScheduler will schedule instance according to concurrency usage for scaled instance +type ScaledConcurrencyScheduler struct { + basicConcurrencyScheduler +} + +// NewScaledConcurrencyScheduler creates ScaledConcurrencyScheduler +func NewScaledConcurrencyScheduler(funcSpec *types.FunctionSpecification, resKey resspeckey.ResSpecKey, + insThdReqQueue *requestqueue.InsAcqReqQueue) scheduler.InstanceScheduler { + concurrentNum := utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum) + instanceQueue := &instanceQueueWithBuffer{ + queue: queue.NewPriorityQueue(getInstanceID, priorityFuncForScaledInstance(concurrentNum)), + buffer: make([]*instanceElement, 0, utils.DefaultSliceSize), + idFunc: getInstanceID, + } + otherQueue := &instanceQueueWithBuffer{ + queue: queue.NewPriorityQueue(getInstanceID, priorityFuncForScaledInstance(concurrentNum)), + buffer: make([]*instanceElement, 0, utils.DefaultSliceSize), + idFunc: getInstanceID, + } + scaledConcurrencyScheduler := &ScaledConcurrencyScheduler{ + basicConcurrencyScheduler: newBasicConcurrencyScheduler(funcSpec, resKey, instanceQueue, otherQueue), + } + scaledConcurrencyScheduler.insAcqReqQueue = insThdReqQueue + insThdReqQueue.RegisterSchFunc("scaledScheduleFunc", scaledConcurrencyScheduler.scheduleRequest) + scaledConcurrencyScheduler.addObservers(scheduler.AvailInsThdTopic, func(data interface{}) { + availInsThdDiff, ok := data.(int) + // schedule request even if availInsThdDiff == 0, because some instance thread changes like session doesn't + // affect availInsThdDiff + if !ok || availInsThdDiff < 0 { + return + } + insThdReqQueue.ScheduleRequest("scaledScheduleFunc") + }) + scaledConcurrencyScheduler.addObservers(scheduler.CreateErrorTopic, func(data interface{}) { + if data == nil { + insThdReqQueue.HandleCreateError(nil) + return + } + createError, ok := data.(error) + if !ok { + return + } + insThdReqQueue.HandleCreateError(createError) + }) + log.GetLogger().Infof("succeed to create ScaledConcurrencyScheduler for function %s isFuncOwner %t", + scaledConcurrencyScheduler.funcKeyWithRes, scaledConcurrencyScheduler.isFuncOwner) + return scaledConcurrencyScheduler +} + +// GetReqQueLen will get instance request queue length +func (scs *ScaledConcurrencyScheduler) GetReqQueLen() int { + return scs.insAcqReqQueue.Len() +} + +// AcquireInstance acquires an instance chosen by instanceScheduler +func (scs *ScaledConcurrencyScheduler) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, error) { + insAlloc, acquireErr := scs.basicConcurrencyScheduler.AcquireInstance(insAcqReq) + if acquireErr != nil && acquireErr != scheduler.ErrNoInsAvailable { + return nil, acquireErr + } + if acquireErr == scheduler.ErrNoInsAvailable && !insAcqReq.SkipWaitPending && scs.shouldTriggerColdStart() { + // if this scheduler is not the funcOwner, only scale in the case of traffic limit of the original funcOwner + if !scs.IsFuncOwner() && !insAcqReq.TrafficLimited { + return nil, acquireErr + } + pendingRequest := &requestqueue.PendingInsAcqReq{ + CreatedTime: time.Now(), + ResultChan: make(chan *requestqueue.PendingInsAcqRsp, 1), + InsAcqReq: insAcqReq, + } + if err := scs.insAcqReqQueue.AddRequest(pendingRequest); err != nil { + return nil, err + } + scs.publishInsThdEvent(scheduler.TriggerScaleTopic, nil) + insAcqRsp := <-pendingRequest.ResultChan + return insAcqRsp.InsAlloc, insAcqRsp.Error + } + return insAlloc, acquireErr +} + +// PopInstance pops an instance, set force to false may return nil if there is no instance with full concurrency +// available +func (scs *ScaledConcurrencyScheduler) PopInstance(force bool) *types.Instance { + var insElem *instanceElement + if force { + insElem = scs.popInstanceElement(backward, nil, false) + } else { + insElem = scs.popInstanceElement(backward, shouldPopScaledInstance, false) + } + + if insElem == nil { + return nil + } + return insElem.instance +} + +// ConnectWithInstanceScaler connects instanceScheduler with an instanceScaler +func (scs *ScaledConcurrencyScheduler) ConnectWithInstanceScaler(instanceScaler scaler.InstanceScaler) { + // check if instanceScaler is a concurrencyInstanceScaler type in future and return error if otherwise + scs.addObservers(scheduler.TriggerScaleTopic, func(data interface{}) { + instanceScaler.TriggerScale() + }) + scs.addObservers(scheduler.InUseInsThdTopic, func(data interface{}) { + inUsedInsThdDiff, ok := data.(int) + if !ok { + return + } + instanceScaler.HandleInsThdUpdate(inUsedInsThdDiff, 0) + }) + scs.addObservers(scheduler.TotalInsThdTopic, func(data interface{}) { + totalInsThdDiff, ok := data.(int) + if !ok { + return + } + instanceScaler.HandleInsThdUpdate(0, totalInsThdDiff) + scs.insAcqReqQueue.HandleInsNumUpdate(totalInsThdDiff / scs.concurrentNum) + }) + return +} + +// HandleFuncSpecUpdate handles funcSpec update comes from ETCD +func (scs *ScaledConcurrencyScheduler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { + scs.Lock() + concurrentNum := utils.GetConcurrentNum(funcSpec.InstanceMetaData.ConcurrentNum) + if scs.concurrentNum != concurrentNum { + scs.concurrentNum = concurrentNum + scs.selfInstanceQueue = &instanceQueueWithObserver{ + instanceQueueWithSubHealthAndEvictingRecord: instanceQueueWithSubHealthAndEvictingRecord{ + instanceQueue: &instanceQueueWithBuffer{ + queue: queue.NewPriorityQueue(getInstanceID, priorityFuncForScaledInstance(scs.concurrentNum)), + buffer: make([]*instanceElement, 0, utils.DefaultSliceSize), + idFunc: getInstanceID, + }, + subHealthRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + evictingRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + }, + insAvailThdCount: make(map[string]int, utils.DefaultMapSize), + pubAvailTopicFunc: func(data int) { scs.publishInsThdEvent(scheduler.AvailInsThdTopic, data) }, + pubInUseTopicFunc: func(data int) { scs.publishInsThdEvent(scheduler.InUseInsThdTopic, data) }, + pubTotalTopicFunc: func(data int) { scs.publishInsThdEvent(scheduler.TotalInsThdTopic, data) }, + } + scs.otherInstanceQueue = &instanceQueueWithSubHealthAndEvictingRecord{ + instanceQueue: &instanceQueueWithBuffer{ + queue: queue.NewPriorityQueue(getInstanceID, priorityFuncForScaledInstance(scs.concurrentNum)), + buffer: make([]*instanceElement, 0, utils.DefaultSliceSize), + idFunc: getInstanceID, + }, + subHealthRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + evictingRecord: make(map[string]*instanceElement, utils.DefaultMapSize), + } + } + scs.Unlock() +} + +// HandleCreateError handles create error +func (scs *ScaledConcurrencyScheduler) HandleCreateError(createErr error) { + scs.publishInsThdEvent(scheduler.CreateErrorTopic, createErr) +} + +// Destroy destroys instanceScheduler +func (scs *ScaledConcurrencyScheduler) Destroy() { + scs.basicConcurrencyScheduler.Destroy() +} + +// priorityFuncForScaledInstance will create priority function which put instance with less concurrentNum in front. for +// scaled instance, aggregate requests on busy instances will benefit the scaled down process which improves resource +// utilization +func priorityFuncForScaledInstance(concurrency int) func(obj interface{}) (int, error) { + return func(obj interface{}) (int, error) { + insElem, ok := obj.(*instanceElement) + if ok { + weight := concurrency - len(insElem.threadMap) + if !insElem.isNewInstance { + weight -= concurrency + } + return weight, nil + } + return -1, scheduler.ErrTypeConvertFail + } +} + +func shouldPopScaledInstance(insElem *instanceElement) bool { + return len(insElem.threadMap) == insElem.instance.ConcurrentNum +} diff --git a/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/scaled_concurrency_scheduler_test.go b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/scaled_concurrency_scheduler_test.go new file mode 100644 index 0000000..52e351a --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/concurrencyscheduler/scaled_concurrency_scheduler_test.go @@ -0,0 +1,274 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package concurrencyscheduler - +package concurrencyscheduler + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/queue" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + commontypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/requestqueue" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +func TestInstanceQueueWithBuffer(t *testing.T) { + insQue := &instanceQueueWithBuffer{ + queue: queue.NewPriorityQueue(getInstanceID, priorityFuncForScaledInstance(2)), + buffer: make([]*instanceElement, 0, utils.DefaultSliceSize), + idFunc: getInstanceID, + } + instance1 := &types.Instance{InstanceID: "instance1", ConcurrentNum: 2} + insElem1 := &instanceElement{ + instance: instance1, + } + insElem1.initThreadMap() + insQue.PushBack(insElem1) + instance2 := &types.Instance{InstanceID: "instance2", ConcurrentNum: 2} + insElem2 := &instanceElement{ + instance: instance2, + } + insElem2.initThreadMap() + insQue.PushBack(insElem2) + assert.Equal(t, 2, insQue.Len()) + insQue.DelByID(instance1.InstanceID) + insList := make([]string, 0, 2) + insQue.Range(func(obj interface{}) bool { + insList = append(insList, obj.(*instanceElement).instance.InstanceID) + return true + }) + assert.Contains(t, insList, instance2.InstanceID) + assert.NotContains(t, insList, instance1.InstanceID) + insQue.SortedRange(func(obj interface{}) bool { + insList = append(insList, obj.(*instanceElement).instance.InstanceID) + return true + }) + assert.Contains(t, insList, instance2.InstanceID) + assert.NotContains(t, insList, instance1.InstanceID) + insElem1 = &instanceElement{ + instance: instance1, + } + insElem1.initThreadMap() + insQue.PushBack(insElem1) + insQue.UpdateObjByID("instance1", &instanceElement{ + instance: instance1, + threadMap: make(map[string]struct{}, 0), + }) + popIns1 := insQue.PopFront().(*instanceElement) + assert.Equal(t, instance2.InstanceID, popIns1.instance.InstanceID) + popIns2 := insQue.PopFront().(*instanceElement) + assert.Equal(t, instance1.InstanceID, popIns2.instance.InstanceID) +} + +func TestNewScaledConcurrencyScheduler(t *testing.T) { + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 1000*time.Millisecond) + scs := NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 1}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue) + assert.NotNil(t, scs) +} + +func TestGetReqQueLen(t *testing.T) { + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 1000*time.Millisecond) + scs := NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue).(*ScaledConcurrencyScheduler) + assert.Equal(t, 0, scs.GetReqQueLen()) +} + +func TestAcquireInstanceScaled(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(&selfregister.SchedulerProxy{}), "CheckFuncOwner", func( + *selfregister.SchedulerProxy, string) bool { + return true + }).Reset() + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 1000*time.Millisecond) + scs := NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue) + index := 0 + scs.ConnectWithInstanceScaler(&fakeInstanceScaler{ + scaleUpFunc: func() { + index++ + fmt.Printf("fakeInstanceScaler add instance %d start\n", index) + scs.AddInstance(&types.Instance{ + InstanceID: fmt.Sprintf("instance%d", index), + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + fmt.Printf("fakeInstanceScaler add instance %d finish\n", index) + }, + }) + _, err := scs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Equal(t, scheduler.ErrInsNotExist, err) + acqIns1, err := scs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns1.Instance.InstanceID) + acqIns2, err := scs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns2.Instance.InstanceID) + scs.HandleCreateError(nil) + acqIns3, err := scs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns3.Instance.InstanceID) + scs.ReleaseInstance(&types.InstanceAllocation{Instance: &types.Instance{InstanceID: "instance1"}}) + scs.ReleaseInstance(&types.InstanceAllocation{Instance: &types.Instance{InstanceID: "instance2"}}) + acqIns4, err := scs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns4.Instance.InstanceID) + scs.PopInstance(true) + scs.PopInstance(true) + snErr := snerror.New(4001, "some error") + scs.HandleCreateError(snErr) + _, err = scs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, snErr, err) +} + +func TestPopInstanceScaled(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(&selfregister.SchedulerProxy{}), "CheckFuncOwner", func( + *selfregister.SchedulerProxy, string) bool { + return true + }).Reset() + config.GlobalConfig.AutoScaleConfig.BurstScaleNum = 1000 + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 50*time.Millisecond) + scs := NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue).(*ScaledConcurrencyScheduler) + scs.ConnectWithInstanceScaler(&fakeInstanceScaler{}) + scs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + scs.AcquireInstance(&types.InstanceAcquireRequest{}) + popIns1 := scs.PopInstance(false) + assert.Nil(t, popIns1) + scs.AcquireInstance(&types.InstanceAcquireRequest{}) + popIns2 := scs.PopInstance(true) + assert.Equal(t, "instance1", popIns2.InstanceID) +} + +func TestHandleFuncSpecUpdateScaled(t *testing.T) { + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 50*time.Millisecond) + rcs := NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue) + rcs.ConnectWithInstanceScaler(&fakeInstanceScaler{}) + rcs.HandleFuncSpecUpdate(&types.FunctionSpecification{ + InstanceMetaData: commontypes.InstanceMetaData{ + ConcurrentNum: 4, + }, + }) +} + +func TestAddInstancePublishScaled(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(&selfregister.SchedulerProxy{}), "CheckFuncOwner", func( + *selfregister.SchedulerProxy, string) bool { + return true + }).Reset() + config.GlobalConfig.AutoScaleConfig.BurstScaleNum = 1000 + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 50*time.Millisecond) + insThdReq1 := &requestqueue.PendingInsAcqReq{ + ResultChan: make(chan *requestqueue.PendingInsAcqRsp, 1), + InsAcqReq: &types.InstanceAcquireRequest{}, + } + insThdReq2 := &requestqueue.PendingInsAcqReq{ + ResultChan: make(chan *requestqueue.PendingInsAcqRsp, 1), + InsAcqReq: &types.InstanceAcquireRequest{}, + } + InsThdReqQueue.AddRequest(insThdReq1) + InsThdReqQueue.AddRequest(insThdReq2) + rcs := NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue) + rcs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commontypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + }) + time.Sleep(time.Millisecond) + select { + case insThd := <-insThdReq1.ResultChan: + assert.Equal(t, "instance1", insThd.InsAlloc.Instance.InstanceID) + default: + t.Errorf("should get instance from result channel") + } + select { + case insThd := <-insThdReq2.ResultChan: + assert.Equal(t, "instance1", insThd.InsAlloc.Instance.InstanceID) + default: + t.Errorf("should get instance from result channel") + } +} + +func TestDestroy(t *testing.T) { + InsThdReqQueue := requestqueue.NewInsAcqReqQueue("testFunction", 50*time.Millisecond) + rcs := NewScaledConcurrencyScheduler(&types.FunctionSpecification{ + FuncKey: "testFunction", + InstanceMetaData: commontypes.InstanceMetaData{ConcurrentNum: 2}, + }, resspeckey.ResSpecKey{}, InsThdReqQueue) + rcs.Destroy() +} + +func Test_instanceQueueWithBuffer_DelByID(t *testing.T) { + convey.Convey("test DelByID", t, func() { + insQue := &instanceQueueWithBuffer{ + queue: queue.NewPriorityQueue(getInstanceID, priorityFuncForScaledInstance(2)), + buffer: make([]*instanceElement, 0, utils.DefaultSliceSize), + idFunc: getInstanceID, + } + insQue.PushBack(&instanceElement{ + instance: &types.Instance{ + InstanceID: "2", + }, + threadMap: nil, + }) + convey.So(insQue.DelByID("1"), convey.ShouldNotBeNil) + insQue.PushBack(&instanceElement{ + instance: &types.Instance{ + InstanceID: "1", + }, + threadMap: nil, + }) + convey.So(insQue.DelByID("1"), convey.ShouldBeNil) + }) + +} diff --git a/yuanrong/pkg/functionscaler/scheduler/instance_scheduler.go b/yuanrong/pkg/functionscaler/scheduler/instance_scheduler.go new file mode 100644 index 0000000..a7d2ae9 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/instance_scheduler.go @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package scheduler - +package scheduler + +import ( + "errors" + + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/types" +) + +// InstanceTopic defines topic in instanceScheduler +type InstanceTopic string + +const ( + // AvailInsThdTopic is the topic of available instance thread + AvailInsThdTopic InstanceTopic = "availInsThd" + // InUseInsThdTopic is the topic of in-use instance thread + InUseInsThdTopic InstanceTopic = "inUseInsThd" + // TotalInsThdTopic is the topic of total instance thread + TotalInsThdTopic InstanceTopic = "totalInsThd" + // TriggerScaleTopic is the topic of trigger scale + TriggerScaleTopic InstanceTopic = "triggerScale" + // CreateErrorTopic is the topic of instance thread create error + CreateErrorTopic InstanceTopic = "createErr" +) + +var ( + // ErrUnsupported is the error of operation unsupported + ErrUnsupported = errors.New("operation unsupported") + // ErrInternal is the error of internal error + ErrInternal = errors.New("internal error") + // ErrTypeConvertFail is the error of type convert failed + ErrTypeConvertFail = errors.New("type convert failed") + // ErrInsNotExist is the error of instance does not exist + ErrInsNotExist = errors.New("instance does not exist in queue") + // ErrInsAlreadyExist is the error of instance already exist + ErrInsAlreadyExist = errors.New("instance already exist in queue") + // ErrInsSubHealthy is the error of instance already exist + ErrInsSubHealthy = errors.New("instance is not healthy") + // ErrNoInsAvailable is the error of no instance available + ErrNoInsAvailable = errors.New("no instance available") + // ErrDesignateInsNotAvailable is the error of designateInstance not available + ErrDesignateInsNotAvailable = errors.New("designateInstance not available") + // ErrInsReqTimeout is the error of no instance request timeout + ErrInsReqTimeout = errors.New("instance request timeout") + // ErrInvalidSession is the error of invalid session parameter + ErrInvalidSession = errors.New("invalid session parameter") + // ErrFuncSigMismatch is the error of "function signature mismatch" + ErrFuncSigMismatch = errors.New("function signature mismatch") + // ErrFunctionDeleted is the error of "function is deleted" + ErrFunctionDeleted = errors.New("function is deleted") +) + +// SignalInstanceFunc sends certain signal to instance +type SignalInstanceFunc func(*types.Instance) + +// InstanceScheduler schedules instance +type InstanceScheduler interface { + GetInstanceNumber(onlySelf bool) int + AcquireInstance(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, error) + ReleaseInstance(insAlloc *types.InstanceAllocation) error + AddInstance(instance *types.Instance) error + DelInstance(instance *types.Instance) error + PopInstance(force bool) *types.Instance + ConnectWithInstanceScaler(instanceScaler scaler.InstanceScaler) + HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) + HandleInstanceUpdate(instance *types.Instance) + HandleFuncOwnerUpdate(isFuncOwner bool) + HandleCreateError(createErr error) + SignalAllInstances(signalFunc SignalInstanceFunc) + ReassignInstanceWhenGray(ratio int) + Destroy() +} diff --git a/yuanrong/pkg/functionscaler/scheduler/microservicescheduler/microservice_scheduler.go b/yuanrong/pkg/functionscaler/scheduler/microservicescheduler/microservice_scheduler.go new file mode 100644 index 0000000..531efb0 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/microservicescheduler/microservice_scheduler.go @@ -0,0 +1,302 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package microservicescheduler - +package microservicescheduler + +import ( + "fmt" + "math" + "sync" + + "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + randomThreadIDLength = 4 + // LeastConnections - + LeastConnections = "LC" + // RoundRobin - + RoundRobin = "RR" +) + +type instanceObserver struct { + callback func(interface{}) +} + +type instanceElement struct { + instance *types.Instance + requestCount int +} + +// MicroServiceScheduler will schedule instance according to concurrency usage for reserved instance +type MicroServiceScheduler struct { + instanceScaler scaler.InstanceScaler + instanceQueue []*instanceElement + funcKeyWithRes string + SchedulePolicy string + curIndex int + sync.RWMutex +} + +// NewMicroServiceScheduler creates MicroServiceScheduler +func NewMicroServiceScheduler(funcKeyWithRes string, schedulePolicy string) scheduler.InstanceScheduler { + microServiceScheduler := &MicroServiceScheduler{ + funcKeyWithRes: funcKeyWithRes, + instanceQueue: make([]*instanceElement, 0, utils.DefaultSliceSize), + SchedulePolicy: schedulePolicy, + curIndex: 0, + } + return microServiceScheduler +} + +// GetInstanceNumber gets instance number inside instance queue +func (ms *MicroServiceScheduler) GetInstanceNumber(onlySelf bool) int { + ms.RLock() + insNum := len(ms.instanceQueue) + ms.RUnlock() + return insNum +} + +// CheckInstanceExist checks if instance exist +func (ms *MicroServiceScheduler) CheckInstanceExist(instance *types.Instance) bool { + ms.RLock() + exist := ms.getInstance(instance.InstanceID) != nil + ms.RUnlock() + return exist +} + +// AcquireInstance acquires an instance chosen by instanceScheduler +func (ms *MicroServiceScheduler) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) (*types.InstanceAllocation, + error) { + ms.Lock() + defer ms.Unlock() + var ( + instance *types.Instance + err error + ) + if len(ms.instanceQueue) == 0 { + return nil, scheduler.ErrNoInsAvailable + } + switch ms.SchedulePolicy { + case RoundRobin: + instance, err = ms.selectInstanceWithRR() + case LeastConnections: + instance, err = ms.selectInstanceWithLC() + default: + instance, err = ms.selectInstanceWithRR() + } + if err != nil { + return nil, err + } + if instance == nil { + return nil, scheduler.ErrNoInsAvailable + } + return &types.InstanceAllocation{ + Instance: instance, + AllocationID: fmt.Sprintf("%s-%s", instance.InstanceID, utils.GenRandomString(randomThreadIDLength)), + }, nil +} + +// ReleaseInstance releases an instance to instanceScheduler +func (ms *MicroServiceScheduler) ReleaseInstance(thread *types.InstanceAllocation) error { + ms.Lock() + defer ms.Unlock() + if ms.SchedulePolicy == LeastConnections { + for _, element := range ms.instanceQueue { + if element.instance.InstanceID == thread.Instance.InstanceID { + if element.requestCount > 0 { + element.requestCount-- + } + break + } + } + } + return nil +} + +// AddInstance adds an instance to instanceScheduler +func (ms *MicroServiceScheduler) AddInstance(instance *types.Instance) error { + ms.Lock() + defer ms.Unlock() + if getInstance := ms.getInstance(instance.InstanceID); getInstance != nil { + return scheduler.ErrInsAlreadyExist + } + switch instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusRunning): + ms.addInstance(instance) + default: + log.GetLogger().Warnf("ignore unexpected instance %s with status code %d", instance.InstanceID, + instance.InstanceStatus.Code) + return scheduler.ErrInternal + } + return nil +} + +// PopInstance pops an instance from instanceScheduler +func (ms *MicroServiceScheduler) PopInstance(force bool) *types.Instance { + ms.Lock() + defer ms.Unlock() + var instance *types.Instance + if instance = ms.popInstance(); instance == nil { + return nil + } + return instance +} + +// DelInstance deletes an instance from instanceScheduler +func (ms *MicroServiceScheduler) DelInstance(instance *types.Instance) error { + ms.Lock() + defer ms.Unlock() + var ( + // concurrentNum may be updated and this instance could be an old one, so we get inQueIns from internal + inQueIns *types.Instance + ) + if inQueIns = ms.getInstance(instance.InstanceID); inQueIns == nil { + return scheduler.ErrInsNotExist + } + ms.delInstance(instance.InstanceID) + return nil +} + +// ConnectWithInstanceScaler connects instanceScheduler with an instanceScaler, currently connects with replicaScaler +func (ms *MicroServiceScheduler) ConnectWithInstanceScaler(instanceScaler scaler.InstanceScaler) { + return +} + +// HandleFuncSpecUpdate handles funcSpec update comes from ETCD +func (ms *MicroServiceScheduler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { +} + +// HandleInstanceUpdate handles instance update comes from ETCD +func (ms *MicroServiceScheduler) HandleInstanceUpdate(instance *types.Instance) { + logger := log.GetLogger().With(zap.Any("funcKey", ms.funcKeyWithRes), zap.Any("instance", instance.InstanceID), + zap.Any("instanceStatus", instance.InstanceStatus.Code)) + ms.RLock() + exist := ms.getInstance(instance.InstanceID) != nil + ms.RUnlock() + switch instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusRunning): + if exist { + logger.Warnf("no need to add instance repeatedly") + } else { + logger.Infof("add new instance from instance update") + if err := ms.AddInstance(instance); err != nil { + logger.Errorf("failed to add instance error %s", err.Error()) + } + } + case int32(constant.KernelInstanceStatusEvicting): + if exist { + err := ms.DelInstance(instance) + if err != nil { + logger.Errorf("failed to delete evicting instance %s", err.Error()) + } + } + default: + } +} + +// HandleCreateError handles create error +func (ms *MicroServiceScheduler) HandleCreateError(createErr error) { +} + +// SignalAllInstances sends signal to all instances +func (ms *MicroServiceScheduler) SignalAllInstances(signalFunc scheduler.SignalInstanceFunc) { +} + +// Destroy destroys instanceScheduler +func (ms *MicroServiceScheduler) Destroy() { +} + +func (ms *MicroServiceScheduler) getInstance(targetID string) *types.Instance { + for _, element := range ms.instanceQueue { + if element.instance.InstanceID == targetID { + return element.instance + } + } + return nil +} + +func (ms *MicroServiceScheduler) addInstance(instance *types.Instance) { + ms.instanceQueue = append(ms.instanceQueue, &instanceElement{instance: instance}) +} + +func (ms *MicroServiceScheduler) delInstance(targetID string) { + targetIndex := -1 + for index, element := range ms.instanceQueue { + if element.instance.InstanceID == targetID { + targetIndex = index + break + } + } + if targetIndex == -1 { + return + } + ms.instanceQueue = append(ms.instanceQueue[0:targetIndex], ms.instanceQueue[targetIndex+1:]...) +} + +func (ms *MicroServiceScheduler) popInstance() *types.Instance { + queLen := len(ms.instanceQueue) + if queLen == 0 { + return nil + } + element := ms.instanceQueue[queLen-1] + ms.instanceQueue = ms.instanceQueue[:queLen-1] + return element.instance +} + +func (ms *MicroServiceScheduler) selectInstanceWithLC() (*types.Instance, error) { + var ( + chosenIns *instanceElement + chosenReqNum = math.MaxInt32 + ) + for _, element := range ms.instanceQueue { + if element.requestCount < chosenReqNum { + chosenReqNum = element.requestCount + chosenIns = element + } + } + if chosenIns != nil { + chosenIns.requestCount++ + return chosenIns.instance, nil + } + return nil, scheduler.ErrNoInsAvailable +} + +func (ms *MicroServiceScheduler) selectInstanceWithRR() (*types.Instance, error) { + if ms.curIndex >= len(ms.instanceQueue) { + ms.curIndex = ms.curIndex % len(ms.instanceQueue) + } + element := ms.instanceQueue[ms.curIndex] + ms.curIndex++ + return element.instance, nil +} + +// HandleFuncOwnerUpdate - +func (ms *MicroServiceScheduler) HandleFuncOwnerUpdate(isFuncOwner bool) { +} + +// ReassignInstanceWhenGray - +func (ms *MicroServiceScheduler) ReassignInstanceWhenGray(ratio int) { + return +} diff --git a/yuanrong/pkg/functionscaler/scheduler/microservicescheduler/microservice_scheduler_test.go b/yuanrong/pkg/functionscaler/scheduler/microservicescheduler/microservice_scheduler_test.go new file mode 100644 index 0000000..3910db0 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/microservicescheduler/microservice_scheduler_test.go @@ -0,0 +1,307 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package leastconnectionscheduler - +package microservicescheduler + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/resspeckey" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/types" +) + +type fakeInstanceScaler struct { + timer *time.Timer + expectInsNum int + inUseInsThdNum int + totalInsThdNum int + scaling bool + createErr error + scaleUpFunc func() +} + +func (f *fakeInstanceScaler) SetEnable(enable bool) { +} + +func (f *fakeInstanceScaler) TriggerScale() { + go func() { + time.Sleep(10 * time.Millisecond) + f.scaleUpFunc() + }() +} + +func (f *fakeInstanceScaler) CheckScaling() bool { + if f.timer == nil { + return false + } + select { + case <-f.timer.C: + f.scaling = false + return false + default: + return f.scaling + } +} + +func (f *fakeInstanceScaler) UpdateCreateMetrics(coldStartTime time.Duration) { +} + +func (f *fakeInstanceScaler) HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) { + f.inUseInsThdNum += inUseInsThdDiff + f.totalInsThdNum += totalInsThdDiff +} + +func (f *fakeInstanceScaler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { +} + +func (f *fakeInstanceScaler) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { +} + +func (f *fakeInstanceScaler) HandleCreateError(createError error) { + f.createErr = createError +} + +func (f *fakeInstanceScaler) GetExpectInstanceNumber() int { + return f.expectInsNum +} + +func (f *fakeInstanceScaler) Destroy() { +} + +func (f *fakeInstanceScaler) SetFuncOwner(isManaged bool) { + +} + +func TestNewMicroServiceScheduler(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", RoundRobin) + assert.NotNil(t, ms) +} + +func TestGetInstanceNumber(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", RoundRobin) + ms.AddInstance(&types.Instance{ + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, 1, ms.GetInstanceNumber(true)) +} + +func TestAcquireInstanceWithRR(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", RoundRobin) + _, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, scheduler.ErrNoInsAvailable, err) + ms.AddInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + ms.AddInstance(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + ms.AddInstance(&types.Instance{ + InstanceID: "instance3", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + acqIns1, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns1.Instance.InstanceID) + acqIns2, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns2.Instance.InstanceID) + acqIns3, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance3", acqIns3.Instance.InstanceID) + acqIns4, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns4.Instance.InstanceID) + acqIns5, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns5.Instance.InstanceID) + acqIns6, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance3", acqIns6.Instance.InstanceID) +} + +func TestAcquireInstanceWithLeastConnections(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", LeastConnections) + _, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, scheduler.ErrNoInsAvailable, err) + ms.AddInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + ms.AddInstance(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + ms.AddInstance(&types.Instance{ + InstanceID: "instance3", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + acqIns1, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns1.Instance.InstanceID) + acqIns2, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns2.Instance.InstanceID) + acqIns3, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance3", acqIns3.Instance.InstanceID) + ms.ReleaseInstance(acqIns3) + acqIns4, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance3", acqIns4.Instance.InstanceID) + ms.ReleaseInstance(acqIns2) + acqIns5, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns5.Instance.InstanceID) + ms.ReleaseInstance(acqIns1) + acqIns6, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns6.Instance.InstanceID) +} + +func TestAddInstance(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", RoundRobin).(*MicroServiceScheduler) + err := ms.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, scheduler.ErrInternal, err) + err = ms.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Nil(t, err) + err = ms.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Equal(t, scheduler.ErrInsAlreadyExist, err) + err = ms.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Nil(t, err) +} + +func TestPopInstance(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", RoundRobin).(*MicroServiceScheduler) + popIns1 := ms.PopInstance(false) + assert.Nil(t, popIns1) + ms.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + ms.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + popIns3 := ms.PopInstance(false) + assert.Equal(t, "instance2", popIns3.InstanceID) + popIns4 := ms.PopInstance(false) + assert.Equal(t, "instance1", popIns4.InstanceID) +} + +func TestDelInstance(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", RoundRobin).(*MicroServiceScheduler) + err := ms.DelInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, scheduler.ErrInsNotExist, err) + ms.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + ms.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + err = ms.DelInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Nil(t, err) + err = ms.DelInstance(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Nil(t, err) +} + +func TestHandleInstanceUpdate(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", RoundRobin) + fs := &fakeInstanceScaler{} + ms.ConnectWithInstanceScaler(fs) + ms.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + }) + acqIns1, err := ms.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns1.Instance.InstanceID) +} + +func TestCheckInstanceExist(t *testing.T) { + ms := NewMicroServiceScheduler("testFunction", RoundRobin).(*MicroServiceScheduler) + instance := &types.Instance{ + InstanceID: "instance1", + } + // 测试实例不存在的情况 + assert.False(t, ms.CheckInstanceExist(instance)) + ms.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + + // 测试实例存在的情况 + assert.True(t, ms.CheckInstanceExist(instance)) +} diff --git a/yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler/roundrobin_scheduler.go b/yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler/roundrobin_scheduler.go new file mode 100644 index 0000000..91df007 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler/roundrobin_scheduler.go @@ -0,0 +1,412 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package roundrobinscheduler - +package roundrobinscheduler + +import ( + "fmt" + "os" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/scaler" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/signalmanager" + "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/functionscaler/utils" +) + +const ( + randomAllocationIDLength = 4 + checkScalingInterval = 50 * time.Millisecond +) + +type instanceObserver struct { + callback func(interface{}) +} + +// RoundRobinScheduler will schedule instance according to concurrency usage for reserved instance +type RoundRobinScheduler struct { + instanceScaler scaler.InstanceScaler + instanceQueue []*types.Instance + subHealthInstance map[string]*types.Instance + observers map[scheduler.InstanceTopic][]*instanceObserver + funcKeyWithRes string + curIndex int + isReserve bool + checkScalingTimeout time.Duration + sync.RWMutex +} + +// NewRoundRobinScheduler creates RoundRobinScheduler +func NewRoundRobinScheduler(funcKeyWithRes string, isReserve bool, + requestTimeout time.Duration) scheduler.InstanceScheduler { + return &RoundRobinScheduler{ + funcKeyWithRes: funcKeyWithRes, + instanceQueue: make([]*types.Instance, 0, utils.DefaultSliceSize), + subHealthInstance: make(map[string]*types.Instance, utils.DefaultMapSize), + observers: make(map[scheduler.InstanceTopic][]*instanceObserver, utils.DefaultMapSize), + curIndex: 0, + isReserve: isReserve, + checkScalingTimeout: requestTimeout, + } +} + +// GetInstanceNumber gets instance number inside instance queue +func (rs *RoundRobinScheduler) GetInstanceNumber(onlySelf bool) int { + rs.RLock() + insNum := len(rs.instanceQueue) + len(rs.subHealthInstance) + rs.RUnlock() + return insNum +} + +// CheckInstanceExist checks if instance exist +func (rs *RoundRobinScheduler) CheckInstanceExist(instance *types.Instance) bool { + rs.RLock() + _, existSubHealth := rs.subHealthInstance[instance.InstanceID] + existHealth := rs.getHealthyInstance(instance.InstanceID) != nil + rs.RUnlock() + return existSubHealth || existHealth +} + +// AcquireInstance acquires an instance chosen by instanceScheduler +func (rs *RoundRobinScheduler) AcquireInstance(insAcqReq *types.InstanceAcquireRequest) ( + *types.InstanceAllocation, error) { + var ( + instance *types.Instance + exist bool + ) + rs.Lock() + if len(insAcqReq.DesignateInstanceID) != 0 { + insAlloc, err := rs.acquireInstanceDesignateInstanceID(insAcqReq, instance, exist) + rs.Unlock() + return insAlloc, err + } + rs.Unlock() + if rs.instanceScaler.GetExpectInstanceNumber() <= 0 { + // 灰度状态下,新的scheduler不应该触发冷启动,应该快速返回失败 + selfCurVer := os.Getenv(selfregister.CurrentVersionEnvKey) + etcdCurVer := rollout.GetGlobalRolloutHandler().CurrentVersion + if selfCurVer != etcdCurVer && rollout.GetGlobalRolloutHandler().GetCurrentRatio() != 100 { // 100 mean 100% + return nil, scheduler.ErrNoInsAvailable + } + // 灰度状态到100%时,老的scheduler不应该负责冷启动,应该快速返回失败 + if selfCurVer == etcdCurVer && rollout.GetGlobalRolloutHandler().GetCurrentRatio() == 100 { // 100 mean 100% + return nil, scheduler.ErrNoInsAvailable + } + // 这里如果是静态函数,则会触发到wisecloudscaler,触发一次nuwa cold start,如果不是静态函数,则会走到replicascaler,没有其他影响 + rs.publishInsThdEvent(scheduler.TriggerScaleTopic, nil) + } + if config.GlobalConfig.Scenario == types.ScenarioWiseCloud && !rs.isReserve { + return nil, scheduler.ErrNoInsAvailable + } + for checkTime := time.Duration(0); checkTime <= rs.checkScalingTimeout; checkTime += checkScalingInterval { + rs.Lock() + currentInstanceNum := len(rs.instanceQueue) + len(rs.subHealthInstance) + rs.Unlock() + if currentInstanceNum > 0 { + break + } + time.Sleep(checkScalingInterval) + } + rs.Lock() + defer rs.Unlock() + if len(rs.instanceQueue) == 0 { + return nil, scheduler.ErrNoInsAvailable + } + instance = rs.selectHealthyInstance() + return &types.InstanceAllocation{ + Instance: instance, + AllocationID: fmt.Sprintf("%s-%d-%s", instance.InstanceID, time.Now().UnixMilli(), uuid.New().String()), + }, nil +} + +func (rs *RoundRobinScheduler) acquireInstanceDesignateInstanceID(insAcqReq *types.InstanceAcquireRequest, + instance *types.Instance, exist bool) (*types.InstanceAllocation, error) { + if instance, exist = rs.subHealthInstance[insAcqReq.DesignateInstanceID]; exist { + return nil, scheduler.ErrInsSubHealthy + } + if instance = rs.getHealthyInstance(insAcqReq.DesignateInstanceID); instance != nil { + return &types.InstanceAllocation{ + Instance: instance, + AllocationID: fmt.Sprintf("%s-%d-%s", instance.InstanceID, time.Now().UnixMilli(), + uuid.New().String()), + }, nil + } + return nil, scheduler.ErrInsNotExist +} + +// ReleaseInstance releases an instance to instanceScheduler +func (rs *RoundRobinScheduler) ReleaseInstance(insAlloc *types.InstanceAllocation) error { + return nil +} + +// AddInstance adds an instance to instanceScheduler +func (rs *RoundRobinScheduler) AddInstance(instance *types.Instance) error { + rs.Lock() + defer rs.Unlock() + if _, exist := rs.subHealthInstance[instance.InstanceID]; exist { + return scheduler.ErrInsAlreadyExist + } + if getInstance := rs.getHealthyInstance(instance.InstanceID); getInstance != nil { + return scheduler.ErrInsAlreadyExist + } + switch instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusSubHealth): + rs.subHealthInstance[instance.InstanceID] = instance + case int32(constant.KernelInstanceStatusRunning): + rs.addHealthyInstance(instance) + default: + log.GetLogger().Warnf("ignore unexpected instance %s with status code %d", instance.InstanceID, + instance.InstanceStatus.Code) + return scheduler.ErrInternal + } + + // notify observers that total thread number increases + rs.publishInsThdEvent(scheduler.TotalInsThdTopic, instance.ConcurrentNum) + return nil +} + +// PopInstance pops an instance from instanceScheduler +func (rs *RoundRobinScheduler) PopInstance(force bool) *types.Instance { + rs.Lock() + defer rs.Unlock() + var instance *types.Instance + if instance = rs.selectSubHealthyInstance(); instance != nil { + delete(rs.subHealthInstance, instance.InstanceID) + } else { + if instance = rs.popHealthyInstance(); instance == nil { + return nil + } + } + rs.publishInsThdEvent(scheduler.TotalInsThdTopic, -instance.ConcurrentNum) + return instance +} + +// DelInstance deletes an instance from instanceScheduler +func (rs *RoundRobinScheduler) DelInstance(instance *types.Instance) error { + rs.Lock() + defer rs.Unlock() + var ( + // concurrentNum may be updated and this instance could be an old one, so we get inQueIns from internal + inQueIns *types.Instance + exist bool + ) + if inQueIns, exist = rs.subHealthInstance[instance.InstanceID]; exist { + delete(rs.subHealthInstance, instance.InstanceID) + } else { + if inQueIns = rs.getHealthyInstance(instance.InstanceID); inQueIns == nil { + return scheduler.ErrInsNotExist + } + rs.delHealthyInstance(instance.InstanceID) + } + rs.publishInsThdEvent(scheduler.TotalInsThdTopic, -inQueIns.ConcurrentNum) + return nil +} + +// ConnectWithInstanceScaler connects instanceScheduler with an instanceScaler, currently connects with replicaScaler +func (rs *RoundRobinScheduler) ConnectWithInstanceScaler(instanceScaler scaler.InstanceScaler) { + rs.instanceScaler = instanceScaler + rs.addObservers(scheduler.TotalInsThdTopic, func(data interface{}) { + totalInsThdDiff, ok := data.(int) + if !ok { + return + } + instanceScaler.HandleInsThdUpdate(0, totalInsThdDiff) + }) + rs.addObservers(scheduler.CreateErrorTopic, func(data interface{}) { + createError, ok := data.(error) + if !ok { + return + } + instanceScaler.HandleCreateError(createError) + }) + return +} + +// HandleFuncSpecUpdate handles funcSpec update comes from ETCD +func (rs *RoundRobinScheduler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { +} + +// roundrobinscheduler不处理evicting状态实例, +var instanceStatusCodeMap = map[int32]struct{}{ + int32(constant.KernelInstanceStatusRunning): {}, + int32(constant.KernelInstanceStatusSubHealth): {}, +} + +// HandleInstanceUpdate handles instance update comes from ETCD +func (rs *RoundRobinScheduler) HandleInstanceUpdate(instance *types.Instance) { + logger := log.GetLogger().With(zap.Any("funcKey", rs.funcKeyWithRes), zap.Any("instance", instance.InstanceID), + zap.Any("instanceStatus", instance.InstanceStatus.Code)) + + if _, ok := instanceStatusCodeMap[instance.InstanceStatus.Code]; !ok { + logger.Infof("unexpect instance status, ignore it") + return + } + rs.RLock() + existHealth := rs.getHealthyInstance(instance.InstanceID) != nil + _, existSubHealth := rs.subHealthInstance[instance.InstanceID] + rs.RUnlock() + if !existHealth && !existSubHealth { + // 适配静态函数 + logger.Infof("update alias and faasscheduler event to instance") + signalmanager.GetSignalManager().SignalInstance(instance, constant.KillSignalAliasUpdate) + signalmanager.GetSignalManager().SignalInstance(instance, constant.KillSignalFaaSSchedulerUpdate) + } + + switch instance.InstanceStatus.Code { + case int32(constant.KernelInstanceStatusSubHealth): + if existHealth && !existSubHealth { + logger.Infof("instance transitions from healthy to sub-healthy") + rs.Lock() + rs.delHealthyInstance(instance.InstanceID) + rs.subHealthInstance[instance.InstanceID] = instance + rs.Unlock() + } else if !existHealth && existSubHealth { + logger.Warnf("no need to add sub-healthy instance repeatedly") + } else if !existHealth && !existSubHealth { + // maybe we should update this instance object to handle update inside an instance + logger.Infof("add new sub-healthy instance from instance update") + if err := rs.AddInstance(instance); err != nil { + logger.Errorf("failed to add instance error %s", err.Error()) + } + } + case int32(constant.KernelInstanceStatusRunning): + if existHealth && !existSubHealth { + logger.Warnf("no need to add healthy instance repeatedly") + } else if !existHealth && existSubHealth { + // maybe we should update this instance object to handle update inside an instance + rs.Lock() + delete(rs.subHealthInstance, instance.InstanceID) + rs.addHealthyInstance(instance) + rs.Unlock() + } else if !existHealth && !existSubHealth { + logger.Infof("add new healthy instance from instance update") + if err := rs.AddInstance(instance); err != nil { + logger.Errorf("failed to add instance error %s", err.Error()) + } + } + default: + } +} + +// HandleCreateError handles create error +func (rs *RoundRobinScheduler) HandleCreateError(createErr error) { + rs.publishInsThdEvent(scheduler.CreateErrorTopic, createErr) +} + +// SignalAllInstances sends signal to all instances +func (rs *RoundRobinScheduler) SignalAllInstances(signalFunc scheduler.SignalInstanceFunc) { + rs.RLock() + for _, instance := range rs.instanceQueue { + signalFunc(instance) + } + rs.RUnlock() +} + +// Destroy destroys instanceScheduler +func (rs *RoundRobinScheduler) Destroy() { +} + +func (rs *RoundRobinScheduler) getHealthyInstance(targetID string) *types.Instance { + for _, instance := range rs.instanceQueue { + if instance.InstanceID == targetID { + return instance + } + } + return nil +} + +func (rs *RoundRobinScheduler) addHealthyInstance(instance *types.Instance) { + rs.instanceQueue = append(rs.instanceQueue, instance) +} + +func (rs *RoundRobinScheduler) delHealthyInstance(targetID string) { + targetIndex := -1 + for index, instance := range rs.instanceQueue { + if instance.InstanceID == targetID { + targetIndex = index + break + } + } + if targetIndex == -1 { + return + } + rs.instanceQueue = append(rs.instanceQueue[0:targetIndex], rs.instanceQueue[targetIndex+1:]...) +} + +func (rs *RoundRobinScheduler) popHealthyInstance() *types.Instance { + queLen := len(rs.instanceQueue) + if queLen == 0 { + return nil + } + instance := rs.instanceQueue[queLen-1] + rs.instanceQueue = rs.instanceQueue[:queLen-1] + return instance +} + +func (rs *RoundRobinScheduler) selectHealthyInstance() *types.Instance { + if rs.curIndex >= len(rs.instanceQueue) { + rs.curIndex = rs.curIndex % len(rs.instanceQueue) + } + instance := rs.instanceQueue[rs.curIndex] + rs.curIndex++ + return instance +} + +func (rs *RoundRobinScheduler) selectSubHealthyInstance() *types.Instance { + for _, instance := range rs.subHealthInstance { + return instance + } + return nil +} + +// publishInsThdEvent will notify observers of specific topic of instance +func (rs *RoundRobinScheduler) publishInsThdEvent(topic scheduler.InstanceTopic, data interface{}) { + for _, observer := range rs.observers[topic] { + observer.callback(data) + } +} + +// addObservers will add observer of instance scaledInsQue +func (rs *RoundRobinScheduler) addObservers(topic scheduler.InstanceTopic, callback func(interface{})) { + topicObservers, exist := rs.observers[topic] + if !exist { + topicObservers = make([]*instanceObserver, 0, utils.DefaultSliceSize) + rs.observers[topic] = topicObservers + } + rs.observers[topic] = append(topicObservers, &instanceObserver{ + callback: callback, + }) +} + +// HandleFuncOwnerUpdate - +func (rs *RoundRobinScheduler) HandleFuncOwnerUpdate(isFuncOwner bool) { +} + +// ReassignInstanceWhenGray - +func (rs *RoundRobinScheduler) ReassignInstanceWhenGray(ratio int) { + return +} diff --git a/yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler/roundrobin_scheduler_test.go b/yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler/roundrobin_scheduler_test.go new file mode 100644 index 0000000..3cd3d64 --- /dev/null +++ b/yuanrong/pkg/functionscaler/scheduler/roundrobinscheduler/roundrobin_scheduler_test.go @@ -0,0 +1,440 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package roundrobinscheduler - +package roundrobinscheduler + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/instanceconfig" + "yuanrong/pkg/common/faas_common/resspeckey" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/scheduler" + "yuanrong/pkg/functionscaler/types" +) + +type fakeInstanceScaler struct { + timer *time.Timer + expectInsNum int + inUseInsThdNum int + totalInsThdNum int + scaling bool + createErr error + scaleUpFunc func() +} + +func (f *fakeInstanceScaler) SetFuncOwner(isManaged bool) { +} + +func (f *fakeInstanceScaler) SetEnable(enable bool) { +} + +func (f *fakeInstanceScaler) TriggerScale() { + go func() { + time.Sleep(10 * time.Millisecond) + f.scaleUpFunc() + }() +} + +func (f *fakeInstanceScaler) CheckScaling() bool { + if f.timer == nil { + return false + } + select { + case <-f.timer.C: + f.scaling = false + return false + default: + return f.scaling + } +} + +func (f *fakeInstanceScaler) UpdateCreateMetrics(coldStartTime time.Duration) { +} + +func (f *fakeInstanceScaler) HandleInsThdUpdate(inUseInsThdDiff, totalInsThdDiff int) { + f.inUseInsThdNum += inUseInsThdDiff + f.totalInsThdNum += totalInsThdDiff +} + +func (f *fakeInstanceScaler) HandleFuncSpecUpdate(funcSpec *types.FunctionSpecification) { +} + +func (f *fakeInstanceScaler) HandleInsConfigUpdate(insConfig *instanceconfig.Configuration) { +} + +func (f *fakeInstanceScaler) HandleCreateError(createError error) { + f.createErr = createError +} + +func (f *fakeInstanceScaler) GetExpectInstanceNumber() int { + return f.expectInsNum +} + +func (f *fakeInstanceScaler) Destroy() { +} + +func TestNewRoundRobinScheduler(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond) + assert.NotNil(t, rs) +} + +func TestGetInstanceNumber(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond) + rs.AddInstance(&types.Instance{ + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, 1, rs.GetInstanceNumber(true)) +} + +func TestAcquireInstance(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond) + fs := &fakeInstanceScaler{} + rs.ConnectWithInstanceScaler(fs) + _, err := rs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Equal(t, scheduler.ErrNoInsAvailable, err) + _, err = rs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Equal(t, scheduler.ErrInsNotExist, err) + rs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rs.AddInstance(&types.Instance{ + InstanceID: "instance3", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + fs.expectInsNum = 3 + acqIns1, err := rs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns1.Instance.InstanceID) + acqIns2, err := rs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns2.Instance.InstanceID) + acqIns3, err := rs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance3", acqIns3.Instance.InstanceID) + acqIns4, err := rs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance1"}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns4.Instance.InstanceID) + m := make(map[string]int, 0) + for i := 0; i < 10000; i++ { + acqIns, err := rs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + _, ok := m[acqIns.AllocationID] + if ok { + assert.Errorf(t, fmt.Errorf("task repeat, AllocationID: %s", acqIns.AllocationID), "") + return + } + m[acqIns.AllocationID] = 1 + } +} + +func TestAddInstance(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond).(*RoundRobinScheduler) + checkTotalInsThd := 0 + rs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + err := rs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, scheduler.ErrInternal, err) + err = rs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Nil(t, err) + assert.Equal(t, 1, checkTotalInsThd) + err = rs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Equal(t, scheduler.ErrInsAlreadyExist, err) + assert.Equal(t, 1, checkTotalInsThd) + err = rs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Nil(t, err) + assert.Equal(t, 2, checkTotalInsThd) + err = rs.AddInstance(&types.Instance{ + InstanceID: "instance3", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + assert.Nil(t, err) + assert.Equal(t, 3, checkTotalInsThd) + assert.Equal(t, 1, len(rs.subHealthInstance)) + assert.Equal(t, 2, len(rs.instanceQueue)) +} + +func TestPopInstance(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond).(*RoundRobinScheduler) + checkTotalInsThd := 0 + rs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + popIns1 := rs.PopInstance(false) + assert.Nil(t, popIns1) + rs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rs.AddInstance(&types.Instance{ + InstanceID: "instance3", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + popIns2 := rs.PopInstance(false) + assert.Equal(t, "instance3", popIns2.InstanceID) + assert.Equal(t, 2, checkTotalInsThd) + popIns3 := rs.PopInstance(false) + assert.Equal(t, "instance2", popIns3.InstanceID) + assert.Equal(t, 1, checkTotalInsThd) + popIns4 := rs.PopInstance(false) + assert.Equal(t, "instance1", popIns4.InstanceID) + assert.Equal(t, 0, checkTotalInsThd) +} + +func TestDelInstance(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond).(*RoundRobinScheduler) + checkTotalInsThd := 0 + rs.addObservers(scheduler.TotalInsThdTopic, func(obj interface{}) { + delta := obj.(int) + checkTotalInsThd += delta + }) + err := rs.DelInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Equal(t, scheduler.ErrInsNotExist, err) + rs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ConcurrentNum: 1, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusSubHealth)}, + }) + err = rs.DelInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Nil(t, err) + assert.Equal(t, 1, checkTotalInsThd) + err = rs.DelInstance(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + }) + assert.Nil(t, err) + assert.Equal(t, 0, checkTotalInsThd) +} + +func TestHandleInstanceUpdate(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond) + fs := &fakeInstanceScaler{} + rs.ConnectWithInstanceScaler(fs) + rs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + }) + rs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusSubHealth), + }, + }) + fs.expectInsNum = 2 + acqIns1, err := rs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns1.Instance.InstanceID) + acqIns2, err := rs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance1", acqIns2.Instance.InstanceID) + _, err = rs.AcquireInstance(&types.InstanceAcquireRequest{DesignateInstanceID: "instance2"}) + assert.Equal(t, scheduler.ErrInsSubHealthy, err) + rs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusSubHealth), + }, + }) + rs.HandleInstanceUpdate(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + }) + acqIns3, err := rs.AcquireInstance(&types.InstanceAcquireRequest{}) + assert.Nil(t, err) + assert.Equal(t, "instance2", acqIns3.Instance.InstanceID) +} + +func TestConnectWithInstanceScaler(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond) + instanceScaler := &fakeInstanceScaler{} + rs.ConnectWithInstanceScaler(instanceScaler) + rs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ConcurrentNum: 2, + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + assert.Equal(t, 2, instanceScaler.totalInsThdNum) + someErr := errors.New("some error") + rs.HandleCreateError(someErr) + assert.Equal(t, someErr, instanceScaler.createErr) +} + +func TestSignalAllInstances(t *testing.T) { + rs := NewRoundRobinScheduler("testFunction", true, 10*time.Millisecond) + rs.AddInstance(&types.Instance{ + InstanceID: "instance1", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + rs.AddInstance(&types.Instance{ + InstanceID: "instance2", + ResKey: resspeckey.ResSpecKey{}, + InstanceStatus: commonTypes.InstanceStatus{Code: int32(constant.KernelInstanceStatusRunning)}, + }) + insIDList := make([]string, 0, 2) + rs.SignalAllInstances(func(instance *types.Instance) { + insIDList = append(insIDList, instance.InstanceID) + }) + assert.Contains(t, insIDList, "instance1") + assert.Contains(t, insIDList, "instance2") +} + +func TestRoundRobinScheduler_CheckInstanceExist(t *testing.T) { + tests := []struct { + name string + healthyQueue []*types.Instance + subHealthMap map[string]*types.Instance + checkInstance *types.Instance + expectExist bool + }{ + { + name: "instance_exists_in_healthy_queue", + healthyQueue: []*types.Instance{ + {InstanceID: "instance-1"}, + {InstanceID: "instance-2"}, + }, + subHealthMap: make(map[string]*types.Instance), + checkInstance: &types.Instance{InstanceID: "instance-1"}, + expectExist: true, + }, + { + name: "instance_exists_in_subhealth_map", + healthyQueue: []*types.Instance{ + {InstanceID: "instance-2"}, + }, + subHealthMap: map[string]*types.Instance{ + "instance-1": {InstanceID: "instance-1"}, + }, + checkInstance: &types.Instance{InstanceID: "instance-1"}, + expectExist: true, + }, + { + name: "instance_exists_in_both", + healthyQueue: []*types.Instance{ + {InstanceID: "instance-1"}, + }, + subHealthMap: map[string]*types.Instance{ + "instance-1": {InstanceID: "instance-1"}, + }, + checkInstance: &types.Instance{InstanceID: "instance-1"}, + expectExist: true, + }, + { + name: "instance_not_exist", + healthyQueue: []*types.Instance{ + {InstanceID: "instance-2"}, + }, + subHealthMap: map[string]*types.Instance{ + "instance-3": {InstanceID: "instance-3"}, + }, + checkInstance: &types.Instance{InstanceID: "instance-1"}, + expectExist: false, + }, + { + name: "empty_scheduler", + healthyQueue: []*types.Instance{}, + subHealthMap: make(map[string]*types.Instance), + checkInstance: &types.Instance{InstanceID: "instance-1"}, + expectExist: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := &RoundRobinScheduler{ + instanceQueue: tt.healthyQueue, + subHealthInstance: tt.subHealthMap, + } + + exist := rs.CheckInstanceExist(tt.checkInstance) + + assert.Equal(t, tt.expectExist, exist, + "CheckInstanceExist() result mismatch for case: %s", tt.name) + }) + } +} diff --git a/yuanrong/pkg/functionscaler/selfregister/proxy.go b/yuanrong/pkg/functionscaler/selfregister/proxy.go new file mode 100644 index 0000000..b345fdb --- /dev/null +++ b/yuanrong/pkg/functionscaler/selfregister/proxy.go @@ -0,0 +1,177 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package selfregister - +package selfregister + +import ( + "os" + "strings" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/loadbalance" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" +) + +const ( + // HashRingSize the concurrent hash ring length + HashRingSize = 5000 + // GetHashLenInternal - + GetHashLenInternal = 10 * time.Millisecond + etcdPathElementsLen = 14 +) + +var ( + // SelfInstanceID proxy is the singleton proxy + SelfInstanceID string + // SelfInstanceName is the instanceName used when discovery type is module + SelfInstanceName string + selfInstanceSpec *types.InstanceSpecification +) + +var ( + // GlobalSchedulerProxy - + GlobalSchedulerProxy = NewSchedulerProxy( + loadbalance.NewConcurrentCHGeneric(HashRingSize), + ) +) + +// SchedulerProxy is used to get instances from FaaSScheduler via a grpc stream +type SchedulerProxy struct { + FaaSSchedulers sync.Map + // used to select a FaaSScheduler by the func info Concurrent Consistent Hash + loadBalance loadbalance.LoadBalance +} + +func init() { + log.GetLogger().Infof("set SelfInstanceID to %s", os.Getenv("INSTANCE_ID")) + SelfInstanceID = os.Getenv("INSTANCE_ID") +} + +// SetSelfInstanceName - +func SetSelfInstanceName(instanceName string) { + log.GetLogger().Infof("set SelfInstanceName to %s", instanceName) + SelfInstanceName = instanceName +} + +// SetSelfInstanceSpec - +func SetSelfInstanceSpec(insSpec *types.InstanceSpecification) { + selfInstanceSpec = insSpec +} + +// GetSchedulerProxyName - +func GetSchedulerProxyName() string { + schedulerDiscovery := config.GlobalConfig.SchedulerDiscovery + if schedulerDiscovery != nil && schedulerDiscovery.KeyPrefixType == constant.SchedulerKeyTypeModule { + return SelfInstanceName + } + return SelfInstanceID +} + +// NewSchedulerProxy return an instance pool which get the instance from the remote FaaSScheduler +func NewSchedulerProxy(lb loadbalance.LoadBalance) *SchedulerProxy { + return &SchedulerProxy{ + loadBalance: lb, + } +} + +// Add an FaaSScheduler +func (sp *SchedulerProxy) Add(faaSScheduler *types.InstanceInfo, exclusivity string) { + sp.FaaSSchedulers.Store(faaSScheduler.InstanceName, faaSScheduler) + if exclusivity != "" { + // do not add exclusivity scheduler to load balance + log.GetLogger().Infof("no need to add scheduler %s to load balance for exclusivity %s", + faaSScheduler.InstanceName, exclusivity) + return + } + log.GetLogger().Debugf("add faasscheduler to proxy, id is %s, name is %s", + faaSScheduler.InstanceID, faaSScheduler.InstanceName) + sp.loadBalance.Add(faaSScheduler.InstanceName, 0) +} + +// Remove a FaaSScheduler +func (sp *SchedulerProxy) Remove(faasScheduler *types.InstanceInfo) { + sp.loadBalance.Remove(faasScheduler.InstanceName) + sp.FaaSSchedulers.Delete(faasScheduler.InstanceName) +} + +// Reset - reset hash anchor point +func (sp *SchedulerProxy) Reset() { + sp.loadBalance.Reset() +} + +// Contains - if hash ring contains this scheduelr +func (sp *SchedulerProxy) Contains(id string) bool { + _, ok := sp.FaaSSchedulers.Load(id) + return ok +} + +// CheckFuncOwner determine etcd event should or not to be deal with +func (sp *SchedulerProxy) CheckFuncOwner(funcKey string) bool { + log.GetLogger().Debugf("check which faas scheduler instance should process function %s", funcKey) + // select one FaaSScheduler by the func key + next := sp.loadBalance.Next(funcKey, false) + faasSchedulerName, ok := next.(string) + if !ok { + log.GetLogger().Errorf("failed to parse the result of load balance: %+v", next) + return false + } + if strings.TrimSpace(faasSchedulerName) == "" { + log.GetLogger().Errorf("no available faas scheduler was found") + return false + } + faaSSchedulerData, ok := sp.FaaSSchedulers.Load(faasSchedulerName) + if !ok { + log.GetLogger().Errorf("failed to get the faas scheduler named %s", faasSchedulerName) + return false + } + faaSScheduler, ok := faaSSchedulerData.(*types.InstanceInfo) + if !ok { + log.GetLogger().Errorf("invalid faas scheduler named %s: %#v", faasSchedulerName, faaSSchedulerData) + return false + } + if faaSScheduler.InstanceName != GetSchedulerProxyName() { + log.GetLogger().Warnf("instanceID self is: %s, hash computed: %s", GetSchedulerProxyName(), + faaSScheduler.InstanceName) + return false + } + log.GetLogger().Infof("this scheduler %s should process function %s", SelfInstanceID, funcKey) + return true +} + +// WaitForHash wait for num of concurrent hash node to add +func (sp *SchedulerProxy) WaitForHash(num int) { + if num == 0 { + return + } + for { + hashLen := 0 + sp.FaaSSchedulers.Range(func(k, v interface{}) bool { + hashLen++ + return true + }) + if hashLen < num { + time.Sleep(GetHashLenInternal) + continue + } + log.GetLogger().Infof("succeeded to create num: %d of hash ring node", num) + return + } +} diff --git a/yuanrong/pkg/functionscaler/selfregister/proxy_test.go b/yuanrong/pkg/functionscaler/selfregister/proxy_test.go new file mode 100644 index 0000000..c796b8c --- /dev/null +++ b/yuanrong/pkg/functionscaler/selfregister/proxy_test.go @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package selfregister + +import ( + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/loadbalance" + "yuanrong/pkg/common/faas_common/types" +) + +func Test_schedulerProxy_DealFilter(t *testing.T) { + convey.Convey("test_deal_filter", t, func() { + convey.Convey("base", func() { + + proxy := NewSchedulerProxy(loadbalance.NewCHGeneric()) + + proxy.Add(&types.InstanceInfo{ + InstanceName: "aa2794fb-dc9e-420d-ae54-bedfa3577930", + }, "") + + proxy.Add(&types.InstanceInfo{ + InstanceName: "7d3f736e-b2b0-4b7e-bc8d-3a390ec0ed31", + }, "") + + proxy.Add(&types.InstanceInfo{ + InstanceName: "d06832bc-8c02-4589-9c37-edae4109302d", + }, "") + + SelfInstanceID = "aa2794fb-dc9e-420d-ae54-bedfa3577930" + + flag := proxy.CheckFuncOwner("244177614494719500/0@default@testcustom001/latest") + + convey.So(flag, convey.ShouldBeFalse) + + proxy.Remove(&types.InstanceInfo{ + InstanceName: "d06832bc-8c02-4589-9c37-edae4109302d", + }) + + flag = proxy.CheckFuncOwner("244177614494719500/0@default@testcustom001/latest") + + convey.So(flag, convey.ShouldBeFalse) + + proxy.Remove(&types.InstanceInfo{ + InstanceName: "7d3f736e-b2b0-4b7e-bc8d-3a390ec0ed31", + }) + + flag = proxy.CheckFuncOwner("244177614494719500/0@default@testcustom001/latest") + + convey.So(flag, convey.ShouldBeTrue) + + proxy.Add(&types.InstanceInfo{ + InstanceName: "d06832bc-8c02-4589-9c37-edae4109302d", + }, "") + + flag = proxy.CheckFuncOwner("244177614494719500/0@default@testcustom001/latest") + + convey.So(flag, convey.ShouldBeTrue) + + proxy.Add(&types.InstanceInfo{ + InstanceName: "7d3f736e-b2b0-4b7e-bc8d-3a390ec0ed31", + }, "") + + proxy.Reset() + + flag = proxy.CheckFuncOwner("244177614494719500/0@default@testcustom001/latest") + + convey.So(flag, convey.ShouldBeFalse) + }) + }) +} + +func TestDealFilter(t *testing.T) { + proxy := NewSchedulerProxy(loadbalance.NewConcurrentCHGeneric(10)) + convey.Convey("start failed", t, func() { + res := proxy.CheckFuncOwner("mock-funcKey") + convey.So(res, convey.ShouldBeFalse) + }) + proxy.Add(&types.InstanceInfo{ + InstanceName: "scheduler-001", + }, "") + convey.Convey("start failed", t, func() { + res := proxy.CheckFuncOwner("mock-funcKey") + convey.So(res, convey.ShouldBeFalse) + }) + SelfInstanceID = "scheduler-001" + convey.Convey("start success", t, func() { + res := proxy.CheckFuncOwner("mock-funcKey") + convey.So(res, convey.ShouldBeTrue) + }) + proxy.FaaSSchedulers.Delete("scheduler-001") + convey.Convey("start failed", t, func() { + res := proxy.CheckFuncOwner("mock-funcKey") + convey.So(res, convey.ShouldBeFalse) + }) +} + +func TestContains(t *testing.T) { + proxy := NewSchedulerProxy(loadbalance.NewConcurrentCHGeneric(10)) + convey.Convey("not contains", t, func() { + res := proxy.Contains("instance1") + convey.So(res, convey.ShouldBeFalse) + }) +} + +func Test(t *testing.T) { + proxy := NewSchedulerProxy(loadbalance.NewConcurrentCHGeneric(10)) + callTime := 0 + defer gomonkey.ApplyFunc(time.Sleep, func(d time.Duration) { + callTime++ + }).Reset() + convey.Convey("wait for hash", t, func() { + proxy.WaitForHash(0) + convey.So(callTime, convey.ShouldEqual, 0) + proxy.FaaSSchedulers.Store("instance1", nil) + proxy.WaitForHash(1) + convey.So(callTime, convey.ShouldEqual, 0) + }) +} diff --git a/yuanrong/pkg/functionscaler/selfregister/rolloutregister.go b/yuanrong/pkg/functionscaler/selfregister/rolloutregister.go new file mode 100644 index 0000000..19a8c06 --- /dev/null +++ b/yuanrong/pkg/functionscaler/selfregister/rolloutregister.go @@ -0,0 +1,265 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package selfregister contains service route logic +package selfregister + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/types" +) + +const ( + // CurrentVersionEnvKey envKey for current version + CurrentVersionEnvKey = "CURRENT_VERSION" + defaultClusterID = "defaultCluster" + defaultNodeID = "defaultNode" + validRolloutKeyLen = 7 +) + +var ( + // IsRollingOut - + IsRollingOut bool + // IsRolloutObject - + IsRolloutObject bool + // RolloutSubjectID - + RolloutSubjectID string + // RolloutRegisterKey - + RolloutRegisterKey string + rolloutRegister *etcd3.EtcdRegister + rolloutLocker *etcd3.EtcdLocker +) + +// RegisterRolloutToEtcd - +func RegisterRolloutToEtcd(stopCh <-chan struct{}) error { + log.GetLogger().Infof("start to register rollout key in etcd, rollout enable %t", config.GlobalConfig.EnableRollout) + if !config.GlobalConfig.EnableRollout || IsRolloutObject { + log.GetLogger().Infof("skip register rollout key in etcd, isRolloutObject %t", IsRolloutObject) + return nil + } + key, err := getSelfRolloutKey() + if err != nil { + return err + } + rolloutRegister = &etcd3.EtcdRegister{ + EtcdClient: etcd3.GetRouterEtcdClient(), + InstanceKey: key, + StopCh: stopCh, + } + if err = rolloutRegister.Register(); err != nil { + log.GetLogger().Errorf("failed to register to etcd, register failed error %s", err.Error()) + return err + } + log.GetLogger().Infof("succeed to register rollout key %s in etcd", key) + return nil +} + +// ContendRolloutInEtcd - +func ContendRolloutInEtcd(stopCh <-chan struct{}) error { + log.GetLogger().Infof("start to contend for rollout key in etcd, rollout enable %t", + config.GlobalConfig.EnableRollout) + rolloutLocker = &etcd3.EtcdLocker{ + EtcdClient: etcd3.GetRouterEtcdClient(), + LeaseTTL: defaultLeaseTTL, + StopCh: stopCh, + LockCallback: putInsSpecForRolloutKey, + UnlockCallback: delInsSpecForRolloutKey, + FailCallback: unsetRolloutRegister, + } + var err error + for i := 0; i < maxContendTime; i++ { + err = rolloutLocker.TryLockWithPrefix(constant.SchedulerRolloutPrefix, contendFilterForRollout) + if err != nil { + log.GetLogger().Errorf("failed to contend for rollout key, lock failed error %s", err.Error()) + time.Sleep(contendWaitInterval) + continue + } + break + } + if err != nil { + log.GetLogger().Errorf("contend retry time reaches max value %d, lock failed error %s", maxContendTime, + err.Error()) + return err + } + log.GetLogger().Infof("succeed to contend for rollout key in etcd, rollout key is %s", + rolloutLocker.GetLockedKey()) + return nil +} + +// ReplaceRolloutSubject - +func ReplaceRolloutSubject(stopCh <-chan struct{}) { + if rollout.GetGlobalRolloutHandler().GetCurrentRatio() != 100 { // 100 is finish rollout + log.GetLogger().Infof("current ratio %d is still not 100, skip ReplaceRolloutSubject", + rollout.GetGlobalRolloutHandler().GetCurrentRatio()) + return + } + log.GetLogger().Infof("start to replace rollout subject %s of instance name %s", RolloutSubjectID, + SelfInstanceName) + if rolloutLocker == nil || selfLocker == nil { + log.GetLogger().Errorf("failed to replace rollout subject, rolloutLocker or selfLocker is nil") + return + } + lockedRolloutKey := rolloutLocker.GetLockedKey() + if err := rolloutLocker.Unlock(); err != nil { + log.GetLogger().Errorf("failed to unlock rollout key %s , unlock error %s", lockedRolloutKey, err.Error()) + } + ctx := etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), etcd3.DurationContextTimeout) + if err := rolloutLocker.EtcdClient.Delete(ctx, lockedRolloutKey); err != nil { + log.GetLogger().Errorf("failed to delete rollout key %s, delete error %s", lockedRolloutKey, err.Error()) + } + if err := selfLocker.TryLock(RolloutRegisterKey); err != nil { + log.GetLogger().Warnf("failed to lock rollout register key %s, lock error %s, try lock another key", + RolloutRegisterKey, err.Error()) + err = contendInstanceInEtcd(stopCh) + if err != nil { + log.GetLogger().Errorf("failed to lock register key %s, lock error %s", RolloutRegisterKey, err.Error()) + } + return + } + if err := processLockedInstanceKey(selfLocker.GetLockedKey()); err != nil { + log.GetLogger().Errorf("failed to process lock key %s, process error %s", RolloutRegisterKey, err.Error()) + return + } + IsRollingOut = false + IsRolloutObject = false + if err := RegisterRolloutToEtcd(stopCh); err != nil { + log.GetLogger().Errorf("failed to replace to etcd for rollout, register failed error %s", err.Error()) + } + log.GetLogger().Infof("succeed to replace rollout subject %s of instance name %s", RolloutSubjectID, + SelfInstanceName) +} + +func getSelfRolloutKey() (string, error) { + clusterID := config.GlobalConfig.ClusterID + if len(clusterID) == 0 { + clusterID = defaultClusterID + } + instanceID := SelfInstanceID + if len(instanceID) == 0 { + return "", errors.New("self instanceID is empty") + } + return fmt.Sprintf("%s/%s/%s/%s", constant.SchedulerRolloutPrefix, clusterID, defaultNodeID, instanceID), nil +} + +func contendFilterForRollout(key, value []byte) bool { + items := strings.Split(string(key), "/") + if len(items) != validRolloutKeyLen { + return true + } + return false +} + +func putInsSpecForRolloutKey(locker *etcd3.EtcdLocker) error { + lockedKey := locker.GetLockedKey() + log.GetLogger().Infof("start to put insSpec for rollout key %s", lockedKey) + if len(lockedKey) == 0 { + log.GetLogger().Errorf("failed to get locked key") + return errors.New("locked key is empty") + } + instanceID := SelfInstanceID + if len(instanceID) == 0 { + log.GetLogger().Errorf("failed to get self instance key") + return errors.New("self instance key is empty") + } + if selfInstanceSpec == nil { + log.GetLogger().Errorf("failed to get insSpec of this scheduler %s", instanceID) + return errors.New("insSpec not found") + } + if err := processRolloutRequest(lockedKey); err != nil { + log.GetLogger().Errorf("failed to process rollout request error %s", err.Error()) + return err + } + rolloutInsSpec := types.RolloutInstanceSpecification{ + RegisterKey: RolloutRegisterKey, + InstanceID: selfInstanceSpec.InstanceID, + RuntimeAddress: selfInstanceSpec.RuntimeAddress, + } + rolloutInsSpecData, err := json.Marshal(rolloutInsSpec) + if err != nil { + log.GetLogger().Errorf("failed to marshal insSpec error %s", err.Error()) + return err + } + if err = processEtcdPut(locker.EtcdClient, lockedKey, string(rolloutInsSpecData)); err != nil { + log.GetLogger().Errorf("failed to put insSpec for rollout key into etcd error %s", err.Error()) + return err + } + log.GetLogger().Infof("succeed to put insSpec for rollout key %s", lockedKey) + return nil +} + +func delInsSpecForRolloutKey(locker *etcd3.EtcdLocker) error { + lockedKey := locker.GetLockedKey() + log.GetLogger().Infof("start to clean insSpec for rollout key %s", lockedKey) + if len(lockedKey) == 0 { + log.GetLogger().Errorf("failed to get locked key") + return errors.New("locked key is empty") + } + if exist, err := isKeyExist(locker.EtcdClient, lockedKey); err != nil || !exist { + return fmt.Errorf("key not exist or get error %v, no need clean it", err) + } + if err := processEtcdPut(locker.EtcdClient, lockedKey, ""); err != nil { + log.GetLogger().Errorf("failed to clean insSpec for rollout key in etcd error %s", err.Error()) + return err + } + log.GetLogger().Infof("succeed to clean insSpec for rollout key %s", lockedKey) + return nil +} + +func unsetRolloutRegister() { + log.GetLogger().Warnf("locker of rollout key %s failed, unset register status", RolloutRegisterKey) + IsRolloutObject = false + RolloutRegisterKey = "" +} + +func processRolloutRequest(rolloutKey string) error { + log.GetLogger().Infof("start to process rollout request, rollout key %s", rolloutKey) + items := strings.Split(rolloutKey, "/") + if len(items) != validRolloutKeyLen { + log.GetLogger().Errorf("failed to parse rollout scheduler from key %s", rolloutKey) + return errors.New("invalid rollout key") + } + RolloutSubjectID = items[validRolloutKeyLen-1] + log.GetLogger().Infof("set RolloutSubjectID to %s", RolloutSubjectID) + rsp, err := rollout.GetGlobalRolloutHandler().SendRolloutRequest(SelfInstanceID, RolloutSubjectID) + if err != nil { + log.GetLogger().Errorf("failed to send rollout request to instance %s error %s", RolloutSubjectID, err.Error()) + return err + } + RolloutRegisterKey = rsp.RegisterKey + log.GetLogger().Infof("succeed to set RolloutRegisterKey to %s", RolloutRegisterKey) + registerInfo, err := utils.GetModuleSchedulerInfoFromEtcdKey(RolloutRegisterKey) + if err != nil { + log.GetLogger().Errorf("failed to get register info from key %s error %s", rsp.RegisterKey, err.Error()) + return err + } + SetSelfInstanceName(registerInfo.InstanceName) + IsRollingOut = true + IsRolloutObject = true + log.GetLogger().Infof("succeed to process rollout request, rollout key %s", rolloutKey) + return nil +} diff --git a/yuanrong/pkg/functionscaler/selfregister/rolloutregister_test.go b/yuanrong/pkg/functionscaler/selfregister/rolloutregister_test.go new file mode 100644 index 0000000..798cdee --- /dev/null +++ b/yuanrong/pkg/functionscaler/selfregister/rolloutregister_test.go @@ -0,0 +1,255 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package selfregister contains service route logic +package selfregister + +import ( + "errors" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/rollout" +) + +func TestRegisterRolloutToEtcd(t *testing.T) { + config.GlobalConfig.EnableRollout = true + defer func() { + config.GlobalConfig.EnableRollout = false + }() + stopCh := make(chan struct{}) + var regErr error + defer gomonkey.ApplyFunc((*etcd3.EtcdRegister).Register, func() error { + return regErr + }).Reset() + convey.Convey("Test RegisterRolloutToEtcd", t, func() { + regErr = errors.New("some error") + err := RegisterRolloutToEtcd(stopCh) + convey.So(err, convey.ShouldNotBeNil) + }) +} + +func TestContendRolloutInEtcd(t *testing.T) { + maxContendTime = 1 + config.GlobalConfig.EnableRollout = true + defer func() { + config.GlobalConfig.EnableRollout = false + }() + var ( + getResponse *clientv3.GetResponse + getError error + invokeRes []byte + invokeErr error + lockErr error + ) + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }), + gomonkey.ApplyFunc((*etcd3.EtcdClient).Get, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return getResponse, getError + }), + gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, value string, opts ...clientv3.OpOption) error { + return nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdClient).Delete, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, opts ...clientv3.OpOption) error { + return nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdLocker).TryLock, func(_ *etcd3.EtcdLocker, key string) error { + return lockErr + }), + gomonkey.ApplyFunc(rollout.InvokeByInstanceId, func(args []api.Arg, instanceID string, traceID string) ([]byte, + error) { + return invokeRes, invokeErr + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + convey.Convey("Test RegisterRolloutToEtcd", t, func() { + convey.Convey("register success", func() { + SelfInstanceID = "instance1" + selfInstanceSpec = &types.InstanceSpecification{} + invokeRes = []byte(`{"registerKey": "/sn/faas-scheduler/instances/cluster001/node001/bj-pod1id"}`) + invokeErr = nil + getResponse = &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + {Key: []byte("/sn/faas-scheduler/rollout/cluster1/node1/aaa"), Value: []byte("invalid value")}, + {Key: []byte("/sn/faas-scheduler/rollout/cluster1/node1/bbb")}, + }, + } + stopCh := make(chan struct{}) + err := ContendRolloutInEtcd(stopCh) + convey.So(err, convey.ShouldBeNil) + close(stopCh) + time.Sleep(200 * time.Millisecond) + }) + convey.Convey("putInsSpecForRolloutKey fail", func() { + SelfInstanceID = "" + selfInstanceSpec = nil + getResponse = &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + {Key: []byte("/sn/faas-scheduler/rollout/cluster1/node1/aaa"), Value: []byte("invalid value")}, + {Key: []byte("/sn/faas-scheduler/rollout/cluster1/node1/bbb")}, + }, + } + stopCh := make(chan struct{}) + lockErr = errors.New("some error") + err := ContendRolloutInEtcd(stopCh) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} + +func TestReplaceRolloutSubject(t *testing.T) { + config.GlobalConfig.EnableRollout = true + defer func() { + config.GlobalConfig.EnableRollout = false + }() + rolloutLocker = &etcd3.EtcdLocker{} + selfLocker = &etcd3.EtcdLocker{LockedKey: "/sn/faas-scheduler/rollout/cluster1/node1/aaa"} + var ( + lockErr error + unlockErr error + regErr error + ) + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*etcd3.EtcdLocker).TryLock, func(_ *etcd3.EtcdLocker, key string) error { + return lockErr + }), + gomonkey.ApplyFunc((*etcd3.EtcdLocker).Unlock, func(_ *etcd3.EtcdLocker) error { + return unlockErr + }), + gomonkey.ApplyFunc((*etcd3.EtcdRegister).Register, func() error { + return regErr + }), + gomonkey.ApplyFunc((*etcd3.EtcdClient).Delete, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, opts ...clientv3.OpOption) error { + return unlockErr + }), + gomonkey.ApplyFunc(contendInstanceInEtcd, func(stopCh <-chan struct{}) error { + return lockErr + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + convey.Convey("Test ReplaceRolloutSubject", t, func() { + rollout.GetGlobalRolloutHandler().CurrentRatio = 100 + convey.Convey("replace success", func() { + IsRollingOut = true + IsRolloutObject = true + stopCh := make(chan struct{}) + ReplaceRolloutSubject(stopCh) + convey.So(IsRollingOut, convey.ShouldBeFalse) + convey.So(IsRolloutObject, convey.ShouldBeFalse) + }) + convey.Convey("replace fail", func() { + IsRollingOut = true + IsRolloutObject = true + stopCh := make(chan struct{}) + lockErr = errors.New("some error") + ReplaceRolloutSubject(stopCh) + convey.So(IsRollingOut, convey.ShouldBeTrue) + convey.So(IsRolloutObject, convey.ShouldBeTrue) + lockErr = nil + unlockErr = errors.New("some error") + ReplaceRolloutSubject(stopCh) + convey.So(IsRollingOut, convey.ShouldBeFalse) + convey.So(IsRolloutObject, convey.ShouldBeFalse) + lockErr = nil + unlockErr = nil + regErr = errors.New("some error") + ReplaceRolloutSubject(stopCh) + convey.So(IsRollingOut, convey.ShouldBeFalse) + convey.So(IsRolloutObject, convey.ShouldBeFalse) + }) + }) +} + +func TestPutInsSpecForRolloutKey(t *testing.T) { + var ( + putErr error + lockedKey string + rolloutRes *types.RolloutResponse + rolloutErr error + ) + rolloutRes = &types.RolloutResponse{} + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, value string, opts ...clientv3.OpOption) error { + return putErr + }), + gomonkey.ApplyFunc((*etcd3.EtcdLocker).GetLockedKey, func(_ *etcd3.EtcdLocker) string { + return lockedKey + }), + gomonkey.ApplyFunc((*rollout.RFHandler).SendRolloutRequest, func(_ *rollout.RFHandler, selfInsID, + targetInsID string) (*types.RolloutResponse, error) { + return rolloutRes, rolloutErr + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + convey.Convey("Test putInsSpecForRolloutKey", t, func() { + locker := &etcd3.EtcdLocker{EtcdClient: &etcd3.EtcdClient{}} + putErr = errors.New("some error") + rolloutErr = errors.New("some error") + err := putInsSpecForRolloutKey(locker) + convey.So(err, convey.ShouldNotBeNil) + lockedKey = "testKey" + err = putInsSpecForRolloutKey(locker) + convey.So(err, convey.ShouldNotBeNil) + SelfInstanceID = "testInstanceID" + err = putInsSpecForRolloutKey(locker) + convey.So(err, convey.ShouldNotBeNil) + selfInstanceSpec = &types.InstanceSpecification{} + err = putInsSpecForRolloutKey(locker) + convey.So(err, convey.ShouldNotBeNil) + // test processRolloutRequest start + lockedKey = "/sn/faas-scheduler/rollout/cluster1/node1/aaa" + err = putInsSpecForRolloutKey(locker) + convey.So(err, convey.ShouldNotBeNil) + rolloutErr = nil + err = putInsSpecForRolloutKey(locker) + convey.So(err, convey.ShouldNotBeNil) + rolloutRes.RegisterKey = "/sn/faas-scheduler/instance/cluster1/node1/aaa" + err = putInsSpecForRolloutKey(locker) + convey.So(err, convey.ShouldNotBeNil) + // test processRolloutRequest end + putErr = nil + err = putInsSpecForRolloutKey(locker) + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/functionscaler/selfregister/selfregister.go b/yuanrong/pkg/functionscaler/selfregister/selfregister.go new file mode 100644 index 0000000..305de96 --- /dev/null +++ b/yuanrong/pkg/functionscaler/selfregister/selfregister.go @@ -0,0 +1,258 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package selfregister contains service route logic +package selfregister + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + commonTypes "yuanrong/pkg/common/faas_common/types" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/rollout" + "yuanrong/pkg/functionscaler/types" +) + +const ( + defaultLeaseTTL = 15 +) + +var ( + // Registered - + Registered bool + // RegisterKey - + RegisterKey string + selfRegister *etcd3.EtcdRegister + selfLocker *etcd3.EtcdLocker + maxContendTime = 300 + contendWaitInterval = 1 * time.Second +) + +// RegisterToEtcd - +func RegisterToEtcd(stopCh <-chan struct{}) error { + discoveryConfig := config.GlobalConfig.SchedulerDiscovery + log.GetLogger().Infof("start to register to etcd, discoveryConfig %+v", discoveryConfig) + if discoveryConfig == nil || len(discoveryConfig.RegisterMode) == 0 { + return nil + } + if discoveryConfig != nil && discoveryConfig.RegisterMode == types.RegisterTypeContend { + log.GetLogger().Infof("start to contend for instance name in etcd") + selfCurVer := os.Getenv(CurrentVersionEnvKey) + etcdCurVer := rollout.GetGlobalRolloutHandler().CurrentVersion + selfLocker = &etcd3.EtcdLocker{ + EtcdClient: etcd3.GetRouterEtcdClient(), + LeaseTTL: defaultLeaseTTL, + StopCh: stopCh, + LockCallback: putInsSpecForInstanceKey, + UnlockCallback: delInsSpecForInstanceKey, + FailCallback: unsetInstanceRegister, + } + if config.GlobalConfig.EnableRollout && selfCurVer != etcdCurVer { + log.GetLogger().Infof("rollout is enable, this scheduler's version %s doesn't equal to current "+ + "version %s, contend for rollout instead", + selfCurVer, etcdCurVer) + return ContendRolloutInEtcd(stopCh) + } + if err := contendInstanceInEtcd(stopCh); err != nil { + return err + } + } else { + log.GetLogger().Infof("start to register for instance name in etcd") + key, value, err := getInstanceKeyAndValue() + if err != nil { + log.GetLogger().Errorf("failed to get service key and value error %s", err.Error()) + return err + } + selfRegister = &etcd3.EtcdRegister{ + EtcdClient: etcd3.GetRouterEtcdClient(), + InstanceKey: key, + Value: value, + StopCh: stopCh, + } + if err = selfRegister.Register(); err != nil { + log.GetLogger().Errorf("failed to register to etcd, register failed error %s", err.Error()) + return err + } + } + log.GetLogger().Infof("succeed to register to etcd") + if err := RegisterRolloutToEtcd(stopCh); err != nil { + log.GetLogger().Errorf("failed to register to etcd for rollout, register failed error %s", err.Error()) + return err + } + return nil +} + +func contendInstanceInEtcd(stopCh <-chan struct{}) error { + log.GetLogger().Infof("start to contend for instance key in etcd") + var err error + for i := 0; i < maxContendTime; i++ { + err = selfLocker.TryLockWithPrefix(constant.ModuleSchedulerPrefix, contendFilterForInstance) + if err != nil { + log.GetLogger().Errorf("failed to contend for rollout key, lock failed error %s", err.Error()) + time.Sleep(contendWaitInterval) + continue + } + break + } + if err != nil { + log.GetLogger().Errorf("failed to contend for instance name, lock error %s", err.Error()) + return err + } + // succeed to lock instance key, set SelfInstanceName from this key + log.GetLogger().Infof("succeed to contend for instance name, lock key is %s", selfLocker.GetLockedKey()) + return processLockedInstanceKey(selfLocker.GetLockedKey()) +} + +func processLockedInstanceKey(lockedKey string) error { + info, err := commonUtils.GetModuleSchedulerInfoFromEtcdKey(lockedKey) + if err != nil { + log.GetLogger().Errorf("failed to register to etcd, get instanceInfo failed error %s", err.Error()) + return err + } + Registered = true + RegisterKey = lockedKey + SetSelfInstanceName(info.InstanceName) + log.GetLogger().Infof("succeed to set registerKey to %s selfInstanceName %s", RegisterKey, info.InstanceName) + return nil +} + +func getInstanceKeyAndValue() (string, string, error) { + clusterID := os.Getenv("CLUSTER_ID") + nodeIP := os.Getenv("NODE_IP") + podName := os.Getenv("POD_NAME") + podIP := os.Getenv("POD_IP") + + err := validateEnvs(clusterID, nodeIP, podName, podIP) + if err != nil { + return "", "", err + } + key := fmt.Sprintf("/sn/faas-scheduler/instances/%s/%s/%s", clusterID, nodeIP, podName) + + schedulerInfo := commonTypes.InstanceSpecification{ + InstanceID: podName, + RuntimeID: constant.ModuleScheduler, + DataSystemHost: "", + RuntimeAddress: fmt.Sprintf("%s:%s", podIP, config.GlobalConfig.ModuleConfig.ServicePort), + InstanceStatus: commonTypes.InstanceStatus{ + Code: int32(constant.KernelInstanceStatusRunning), + }, + } + value, err := json.Marshal(schedulerInfo) + if err != nil { + return "", "", err + } + return key, string(value), nil +} + +func validateEnvs(clusterID, nodeIP, podName, podIP string) error { + if clusterID == "" || nodeIP == "" || podName == "" || podIP == "" { + log.GetLogger().Errorf("can not find envs, clusterID %s, nodeIP %s podName %s podIP %s", + clusterID, nodeIP, podName, podIP) + return fmt.Errorf("can not find envs") + } + return nil +} + +func contendFilterForInstance(key, value []byte) bool { + _, err := commonUtils.GetModuleSchedulerInfoFromEtcdKey(string(key)) + if err != nil { + return true + } + return false +} + +func putInsSpecForInstanceKey(locker *etcd3.EtcdLocker) error { + lockedKey := locker.GetLockedKey() + log.GetLogger().Infof("start to put insSpec for instance key %s", lockedKey) + if len(lockedKey) == 0 { + log.GetLogger().Errorf("failed to get locked key") + return errors.New("locked key is empty") + } + if selfInstanceSpec == nil { + log.GetLogger().Errorf("failed to get insSpec of this scheduler %s", SelfInstanceID) + return errors.New("insSpec not found") + } + selfInsSpecData, err := json.Marshal(selfInstanceSpec) + if err != nil { + log.GetLogger().Errorf("failed to marshal insSpec error %s", err.Error()) + return err + } + if err = processEtcdPut(locker.EtcdClient, lockedKey, string(selfInsSpecData)); err != nil { + log.GetLogger().Errorf("failed to put insSpec for instance key into etcd error %s", err.Error()) + return err + } + log.GetLogger().Infof("succeed to put insSpec for instance key %s", lockedKey) + return nil +} + +func delInsSpecForInstanceKey(locker *etcd3.EtcdLocker) error { + lockedKey := locker.GetLockedKey() + log.GetLogger().Infof("start to clean insSpec for instance key %s", lockedKey) + if len(lockedKey) == 0 { + log.GetLogger().Errorf("failed to get locked key") + return errors.New("locked key is empty") + } + if exist, err := isKeyExist(locker.EtcdClient, lockedKey); err != nil || !exist { + return fmt.Errorf("key not exist or get error %s, no need clean it", err.Error()) + } + if err := processEtcdPut(locker.EtcdClient, lockedKey, ""); err != nil { + log.GetLogger().Errorf("failed to clean insSpec for instance key in etcd error %s", err.Error()) + return err + } + Registered = false + RegisterKey = "" + log.GetLogger().Infof("succeed to clean insSpec for instance key %s", lockedKey) + return nil +} + +func unsetInstanceRegister() { + Registered = false + RegisterKey = "" + log.GetLogger().Warnf("locker of key %s failed, unset register status", RegisterKey) +} + +func processEtcdPut(client *etcd3.EtcdClient, key, value string) error { + ctx := etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), etcd3.DurationContextTimeout) + err := client.Put(ctx, key, value) + if err != nil { + log.GetLogger().Errorf("failed to put key %s value %s to etcd %s error %s", key, value, err.Error()) + return err + } + return nil +} + +func isKeyExist(client *etcd3.EtcdClient, key string) (bool, error) { + ctx := etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), etcd3.DurationContextTimeout) + rsp, err := client.Get(ctx, key) + if err != nil { + log.GetLogger().Errorf("failed to check key %s exist error %s", key, err.Error()) + return false, err + } + if len(rsp.Kvs) == 0 { + log.GetLogger().Warnf("locker key has been deleted, skip") + return false, nil + } + return true, nil +} diff --git a/yuanrong/pkg/functionscaler/selfregister/selfregister_test.go b/yuanrong/pkg/functionscaler/selfregister/selfregister_test.go new file mode 100644 index 0000000..28f307f --- /dev/null +++ b/yuanrong/pkg/functionscaler/selfregister/selfregister_test.go @@ -0,0 +1,253 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package selfregister + +import ( + "errors" + "fmt" + "os" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + "yuanrong/pkg/common/faas_common/etcd3" + commontypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +func setEnv() { + os.Setenv("CLUSTER_ID", "cluster001") + os.Setenv("NODE_IP", "127.0.0.1") + os.Setenv("POD_NAME", "faas-scheduler-5b5886db99-66gn8") + os.Setenv("POD_IP", "127.0.0.1") +} + +func cleanEnv() { + os.Setenv("CLUSTER_ID", "") + os.Setenv("NODE_IP", "") + os.Setenv("POD_NAME", "") + os.Setenv("POD_IP", "") +} + +func TestGetServiceKeyAndValue(t *testing.T) { + convey.Convey("test getInstanceKeyAndValue success", t, func() { + convey.Convey("success", func() { + setEnv() + defer cleanEnv() + rawGConfig := config.GlobalConfig + config.GlobalConfig = types.Configuration{ + ModuleConfig: &types.ModuleConfig{ + ServicePort: "8888", + }, + } + defer func() { + config.GlobalConfig = rawGConfig + }() + key, _, err := getInstanceKeyAndValue() + convey.So(err, convey.ShouldBeNil) + convey.So(key, convey.ShouldEqual, + "/sn/faas-scheduler/instances/cluster001/127.0.0.1/faas-scheduler-5b5886db99-66gn8") + }) + convey.Convey("failed", func() { + os.Setenv("NODE_IP", "") + _, _, err := getInstanceKeyAndValue() + convey.So(err, convey.ShouldNotBeNil) + }) + + }) +} + +func TestRegisterToEtcd(t *testing.T) { + config.GlobalConfig.SchedulerDiscovery = &types.SchedulerDiscovery{RegisterMode: types.RegisterTypeSelf} + enableRollout := config.GlobalConfig.EnableRollout + patches := []*gomonkey.Patches{ + gomonkey.ApplyGlobalVar(&maxContendTime, 1), + gomonkey.ApplyGlobalVar(&contendWaitInterval, 100*time.Millisecond), + gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{Client: &clientv3.Client{}} + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + config.GlobalConfig.EnableRollout = enableRollout + config.GlobalConfig.SchedulerDiscovery.RegisterMode = types.RegisterTypeSelf + }() + convey.Convey("test RegisterToEtcd", t, func() { + config.GlobalConfig.SchedulerDiscovery.RegisterMode = types.RegisterTypeSelf + config.GlobalConfig.EnableRollout = false + convey.Convey("baseline", func() { + p := gomonkey.ApplyFunc(getInstanceKeyAndValue, func() (string, string, error) { + return "a", "b", nil + }) + defer p.Reset() + p2 := gomonkey.ApplyFunc((*etcd3.EtcdRegister).Register, func(_ *etcd3.EtcdRegister) error { + return nil + }) + defer p2.Reset() + ch := make(chan struct{}) + err := RegisterToEtcd(ch) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("get key failed", func() { + p := gomonkey.ApplyFunc(getInstanceKeyAndValue, func() (string, string, error) { + return "", "", fmt.Errorf("error") + }) + defer p.Reset() + ch := make(chan struct{}) + err := RegisterToEtcd(ch) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("register failed", func() { + p := gomonkey.ApplyFunc(getInstanceKeyAndValue, func() (string, string, error) { + return "a", "b", nil + }) + defer p.Reset() + p2 := gomonkey.ApplyFunc((*etcd3.EtcdRegister).Register, func(_ *etcd3.EtcdRegister) error { + return fmt.Errorf("error") + }) + defer p2.Reset() + ch := make(chan struct{}) + err := RegisterToEtcd(ch) + convey.So(err, convey.ShouldNotBeNil) + }) + config.GlobalConfig.SchedulerDiscovery.RegisterMode = types.RegisterTypeContend + convey.Convey("register by contend", func() { + patches1 := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*etcd3.EtcdClient).Get, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + {Key: []byte("/sn/faas-scheduler/instances/cluster1/node1/aaa"), Value: []byte("invalid value")}, + {Key: []byte("/sn/faas-scheduler/instances/cluster1/node1/bbb")}, + }, + }, nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, value string, opts ...clientv3.OpOption) error { + return nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdClient).Delete, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, opts ...clientv3.OpOption) error { + return nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdLocker).TryLock, func(_ *etcd3.EtcdLocker, key string) error { + return nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdLocker).GetLockedKey, func(_ *etcd3.EtcdLocker) string { + return "/sn/faas-scheduler/instances/cluster1/node1/aaa" + }), + } + defer func() { + for _, p := range patches1 { + p.Reset() + } + }() + SetSelfInstanceSpec(&commontypes.InstanceSpecification{}) + ch := make(chan struct{}) + err := RegisterToEtcd(ch) + convey.So(err, convey.ShouldBeNil) + close(ch) + time.Sleep(200 * time.Millisecond) + }) + convey.Convey("register by contend failed", func() { + maxContendTime = 1 + selfLocker.LockedKey = "" + selfInstanceSpec = nil + var lockErr error + patches1 := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*etcd3.EtcdClient).Get, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return &clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + {Key: []byte("/sn/faas-scheduler/instances/cluster1/node1/aaa"), Value: []byte("invalid value")}, + {Key: []byte("/sn/faas-scheduler/instances/cluster1/node1/bbb")}, + }, + }, nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, value string, opts ...clientv3.OpOption) error { + return nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdClient).Delete, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, opts ...clientv3.OpOption) error { + return nil + }), + gomonkey.ApplyFunc((*etcd3.EtcdLocker).TryLock, func(_ *etcd3.EtcdLocker, key string) error { + return lockErr + }), + } + defer func() { + for _, p := range patches1 { + p.Reset() + } + }() + ch := make(chan struct{}) + lockErr = errors.New("some error") + err := RegisterToEtcd(ch) + convey.So(err, convey.ShouldNotBeNil) + lockErr = nil + err = RegisterToEtcd(ch) + convey.So(err, convey.ShouldNotBeNil) + os.Setenv(CurrentVersionEnvKey, "blue") + config.GlobalConfig.EnableRollout = true + err = RegisterToEtcd(ch) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestPutInsSpecForInstanceKey(t *testing.T) { + var ( + putErr error + lockedKey string + ) + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, + etcdKey string, value string, opts ...clientv3.OpOption) error { + return putErr + }), + gomonkey.ApplyFunc((*etcd3.EtcdLocker).GetLockedKey, func(_ *etcd3.EtcdLocker) string { + return lockedKey + }), + } + defer func() { + for _, p := range patches { + p.Reset() + } + }() + convey.Convey("Test PutInsSpecForInstanceKey", t, func() { + locker := &etcd3.EtcdLocker{EtcdClient: &etcd3.EtcdClient{}} + putErr = errors.New("some error") + err := putInsSpecForInstanceKey(locker) + convey.So(err, convey.ShouldNotBeNil) + lockedKey = "testKey" + err = putInsSpecForInstanceKey(locker) + convey.So(err, convey.ShouldNotBeNil) + selfInstanceSpec = &commontypes.InstanceSpecification{} + err = putInsSpecForInstanceKey(locker) + convey.So(err, convey.ShouldNotBeNil) + putErr = nil + err = putInsSpecForInstanceKey(locker) + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/functionscaler/signalmanager/signalmanager.go b/yuanrong/pkg/functionscaler/signalmanager/signalmanager.go new file mode 100644 index 0000000..b81030c --- /dev/null +++ b/yuanrong/pkg/functionscaler/signalmanager/signalmanager.go @@ -0,0 +1,247 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package signalmanager - +package signalmanager + +import ( + "encoding/json" + "strings" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/aliasroute" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/urnutils" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/registry" + "yuanrong/pkg/functionscaler/types" +) + +// signalInstance - +type signalInstance struct { + *types.Instance + signalProcessors map[int]*signalProcessor + Logger api.FormatLogger +} + +// PrepareSchedulerArg - +func PrepareSchedulerArg() ([]byte, error) { + schedulerInfo := registry.GlobalRegistry.FaaSSchedulerRegistry.GetSchedulerInfo() + schedulerData, err := json.Marshal(schedulerInfo) + if err != nil { + return nil, err + } + return schedulerData, nil +} + +type getDataFunc func() ([]byte, error) + +type signalProcessor struct { + InstanceId string + HasSignal bool + IsRunning bool + StopChan chan struct{} + SignalNo int + TenantId string + getDataFunc + sync.RWMutex + Logger api.FormatLogger + killFunc func(instanceID string, payload []byte) error +} + +// manager - +type manager struct { + instances map[string]*signalInstance + lock sync.RWMutex + Logger api.FormatLogger + killFunc func(instanceID string, signal int, payload []byte) error +} + +var signalManager *manager +var once sync.Once + +// GetSignalManager - +func GetSignalManager() *manager { + once.Do( + func() { + signalManager = &manager{ + instances: make(map[string]*signalInstance), + Logger: log.GetLogger(), + } + }) + return signalManager +} + +// SetKillFunc - +func (sm *manager) SetKillFunc(killFunc func(instanceID string, signal int, payload []byte) error) { + sm.killFunc = killFunc +} + +// SignalInstance - +func (sm *manager) SignalInstance(instance *types.Instance, signalNo int) { + if config.GlobalConfig.InstanceOperationBackend == constant.BackendTypeFG { + return + } + needProcessSignal := map[int]struct{}{ + constant.KillSignalAliasUpdate: {}, + constant.KillSignalFaaSSchedulerUpdate: {}, + } + + if _, ok := needProcessSignal[signalNo]; !ok { + sm.Logger.Warnf("no need process this signalNo: %d, instanceId: %sm", signalNo, instance.InstanceID) + return + } + sm.lock.Lock() + defer sm.lock.Unlock() + if sm.killFunc == nil { + sm.Logger.Errorf("killFunc not set") + return + } + sInstance, ok := sm.instances[instance.InstanceID] + if !ok { + tenantId := urnutils.GetTenantFromFuncKey(instance.FuncKey) + if tenantId == "" { + sm.Logger.Errorf("instance: tenantId parse failed, funcKey: %s, instanceId is %s", + instance.FuncKey, instance.InstanceID) + return + } + + sInstance = &signalInstance{ + Instance: instance, + Logger: log.GetLogger().With(zap.Any("funcKey", instance.FuncKey), + zap.Any("instanceId", instance.InstanceID), zap.Any("tenantId", tenantId)), + signalProcessors: make(map[int]*signalProcessor, 2), + } + sInstance.signalProcessors[constant.KillSignalAliasUpdate] = &signalProcessor{ + InstanceId: instance.InstanceID, + StopChan: make(chan struct{}), + SignalNo: constant.KillSignalAliasUpdate, + TenantId: tenantId, + getDataFunc: func() ([]byte, error) { + return aliasroute.MarshalTenantAliasList(tenantId) + }, + killFunc: func(instanceID string, payload []byte) error { + return sm.killFunc(instanceID, constant.KillSignalAliasUpdate, payload) + }, + RWMutex: sync.RWMutex{}, + Logger: sInstance.Logger.With(zap.Any("signal", constant.KillSignalAliasUpdate), + zap.Any("update alias", "")), + } + sInstance.signalProcessors[constant.KillSignalFaaSSchedulerUpdate] = &signalProcessor{ + InstanceId: instance.InstanceID, + StopChan: make(chan struct{}), + SignalNo: constant.KillSignalFaaSSchedulerUpdate, + getDataFunc: PrepareSchedulerArg, + killFunc: func(instanceID string, payload []byte) error { + return sm.killFunc(instanceID, constant.KillSignalFaaSSchedulerUpdate, payload) + }, + RWMutex: sync.RWMutex{}, + Logger: sInstance.Logger.With(zap.Any("signal", constant.KillSignalFaaSSchedulerUpdate), + zap.Any("update faasscheduler", "")), + } + sm.instances[instance.InstanceID] = sInstance + } + + processor, ok := sInstance.signalProcessors[signalNo] + if !ok { + sInstance.Logger.Warnf("abnormal!, no signalNo: %d in processors", signalNo) // 通常不会走到这里 + return + } + processor.Lock() + defer processor.Unlock() + processor.HasSignal = true + if !processor.IsRunning { + processor.IsRunning = true + go processor.signalInstance(uuid.New().String()) + } +} + +// RemoveInstance - +func (sm *manager) RemoveInstance(instanceId string) { + sm.lock.Lock() + defer sm.lock.Unlock() + sInstance, ok := sm.instances[instanceId] + if !ok { + return + } + for _, p := range sInstance.signalProcessors { + utils.SafeCloseChannel(p.StopChan) + } + sm.Logger.Infof("remove instance: %s", instanceId) + delete(sm.instances, instanceId) +} + +func (si *signalProcessor) signalInstance(randomId string) { + logger := si.Logger.With(zap.Any("uuid", randomId)) + isRetry := false + retryInterval := 100 * time.Millisecond // 间隔时间初始值 + logger.Infof("begin signal instance") + defer logger.Infof("signal instance over") + for { + si.Lock() + if !si.HasSignal && !isRetry { + si.IsRunning = false + si.Unlock() + return + } + si.HasSignal = false + si.Unlock() + + select { + case <-si.StopChan: + si.Lock() + si.IsRunning = false + si.HasSignal = false + si.Unlock() + logger.Infof("instance removed, exit signalProcessor") + return + default: + data, err := si.getDataFunc() + if err != nil { + logger.Errorf("get data for signal instance failed, err: %s", err.Error()) + isRetry = false + continue + } + if err := si.killFunc(si.InstanceId, data); err != nil { + logger.Errorf("failed to signal instance, error:%s", err.Error()) + // instance not found, the instance may have been killed + if strings.Contains(err.Error(), "instance not found") { + isRetry = false + GetSignalManager().RemoveInstance(si.InstanceId) + continue + } + time.Sleep(retryInterval) + retryInterval *= 2 // 翻倍 + if retryInterval >= 5*time.Minute { // 间隔时间最大值 + retryInterval = 5 * time.Minute // 间隔时间最大值 + } + isRetry = true + } else { + retryInterval = 100 * time.Millisecond // 间隔时间初始值 + isRetry = false + } + } + + } +} diff --git a/yuanrong/pkg/functionscaler/signalmanager/signalmanager_test.go b/yuanrong/pkg/functionscaler/signalmanager/signalmanager_test.go new file mode 100644 index 0000000..60788a7 --- /dev/null +++ b/yuanrong/pkg/functionscaler/signalmanager/signalmanager_test.go @@ -0,0 +1,317 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package signalmanager - +package signalmanager + +import ( + "fmt" + "reflect" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/aliasroute" + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/functionscaler/types" +) + +// 测试SignalManager初始化 +func TestGetSignalManager(t *testing.T) { + convey.Convey("TestGetSignalManager", t, func() { + // 测试SignalManager单例 + sm1 := GetSignalManager() + sm2 := GetSignalManager() + convey.So(sm1 == sm2, convey.ShouldBeTrue) + convey.So(sm1.Logger, convey.ShouldNotBeNil) + convey.So(sm2.instances, convey.ShouldNotBeNil) + }) +} + +func TestSignalManager_SetKillFunc(t *testing.T) { + convey.Convey("TestGetSignalManager", t, func() { + // 测试SignalManager单例 + GetSignalManager().SetKillFunc(func(_ string, value int, _ []byte) error { + if value == 0 { + return nil + } + return fmt.Errorf("%d", value) + }) + sm := GetSignalManager() + + convey.So(sm == nil, convey.ShouldBeFalse) + convey.So(sm.killFunc == nil, convey.ShouldBeFalse) + convey.So(sm.killFunc("", 0, nil), convey.ShouldBeNil) + convey.So(sm.killFunc("", 199, nil), convey.ShouldResemble, fmt.Errorf("199")) + }) +} + +func TestSignalManager_SignalInstance(t *testing.T) { + convey.Convey("SignalInstance", t, func() { + GetSignalManager().instances = make(map[string]*signalInstance) + result := make(map[string]struct{}) + var lock sync.RWMutex + defer gomonkey.ApplyPrivateMethod(reflect.TypeOf(&signalProcessor{}), "signalInstance", func(sp *signalProcessor, _ string) { + lock.Lock() + result[fmt.Sprintf("%s_%d", sp.InstanceId, sp.SignalNo)] = struct{}{} + lock.Unlock() + }).Reset() + convey.Convey("SignalInstance simple", func() { + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal0", FuncKey: "12345678901234561234567890123456/hello/$latest"}, 0) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal1", FuncKey: "12345678901234561234567890123456/hello/$latest"}, 1) + convey.So(len(GetSignalManager().instances) == 0, convey.ShouldBeTrue) + }) + convey.Convey("SignalInstance complex", func() { + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal2"}, constant.KillSignalAliasUpdate) + convey.So(len(GetSignalManager().instances), convey.ShouldEqual, 0) + + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal2", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal2", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalFaaSSchedulerUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal3", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal3", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalFaaSSchedulerUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal4", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal4", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + time.Sleep(500 * time.Millisecond) + convey.So(len(GetSignalManager().instances), convey.ShouldEqual, 3) + processorStrArr := []string{"signal2_64", "signal2_72", "signal3_64", "signal3_72", "signal4_64"} + fmt.Printf("result is %v\n", result) + for _, s := range processorStrArr { + _, ok := result[s] + convey.So(ok, convey.ShouldBeTrue) + } + convey.So(len(result), convey.ShouldEqual, 5) + }) + }) +} + +func TestSignalManager_RemoveInstance(t *testing.T) { + convey.Convey("RemoveInstance", t, func() { + GetSignalManager().instances = make(map[string]*signalInstance) + result := make(map[string]struct{}) + wg := sync.WaitGroup{} + lock := sync.Mutex{} + defer gomonkey.ApplyPrivateMethod(reflect.TypeOf(&signalProcessor{}), "signalInstance", func(sp *signalProcessor, _ string) { + wg.Add(1) + for { + select { + case <-sp.StopChan: + lock.Lock() + result[fmt.Sprintf("%s_%d", sp.InstanceId, sp.SignalNo)] = struct{}{} + lock.Unlock() + wg.Done() + return + } + } + }).Reset() + convey.Convey("RemoveInstance complex", func() { + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal2", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal2", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalFaaSSchedulerUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal3", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal3", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalFaaSSchedulerUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal4", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal4", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + + GetSignalManager().RemoveInstance("signal2") + convey.So(len(GetSignalManager().instances), convey.ShouldEqual, 2) + time.Sleep(500 * time.Millisecond) + var ok bool + _, ok = result["signal2_64"] + convey.So(ok, convey.ShouldBeTrue) + _, ok = result["signal2_72"] + convey.So(ok, convey.ShouldBeTrue) + GetSignalManager().RemoveInstance("signal3") + GetSignalManager().RemoveInstance("signal4") + wg.Wait() + convey.So(len(GetSignalManager().instances), convey.ShouldEqual, 0) + convey.So(len(result), convey.ShouldEqual, 5) + }) + }) +} + +func TestSignalProcessor_signalInstance(t *testing.T) { + convey.Convey("signalInstance", t, func() { + sp := &signalProcessor{ + InstanceId: "mock", + HasSignal: false, + IsRunning: false, + StopChan: make(chan struct{}), + SignalNo: 64, + getDataFunc: func() ([]byte, error) { + return []byte("aaa"), nil + }, + RWMutex: sync.RWMutex{}, + Logger: log.GetLogger(), + } + + result := []time.Duration{} + defer gomonkey.ApplyFunc(time.Sleep, func(duration time.Duration) { + result = append(result, duration) + }).Reset() + + failTimes := 0 + totalFailTimes := 100 + sp.killFunc = func(string, []byte) error { + if failTimes < totalFailTimes { + failTimes++ + return fmt.Errorf("%d times failed", failTimes) + } + return nil + } + sp.HasSignal = true + sp.signalInstance("111") + convey.So(len(result), convey.ShouldEqual, 100) + convey.So(sp.HasSignal, convey.ShouldEqual, false) + convey.So(sp.IsRunning, convey.ShouldEqual, false) + + sleeptTime := 100 * time.Millisecond + for i := 0; i < 99; i++ { + convey.So(result[i], convey.ShouldEqual, sleeptTime) + sleeptTime *= 2 + if sleeptTime > 5*time.Minute { + sleeptTime = 5 * time.Minute + } + } + failTimes = 0 + sp.getDataFunc = func() ([]byte, error) { + return nil, fmt.Errorf("error") + } + + sp.signalInstance("222") + convey.So(failTimes, convey.ShouldEqual, 0) + }) +} + +func TestSignalManager_killFunc_return_instanceNotExist(t *testing.T) { + convey.Convey("killFunc_return_instanceNotExist", t, func() { + flag := false + GetSignalManager().SetKillFunc(func(string, int, []byte) error { + flag = true + return fmt.Errorf("instance not found, the instance may have been killed") + }) + + GetSignalManager().SignalInstance(&types.Instance{ + InstanceID: "1111", + InstanceName: "1111", + FuncKey: "12345678901234561234567890123456/hello/$latest", + }, constant.KillSignalAliasUpdate) + time.Sleep(100 * time.Millisecond) + _, ok := GetSignalManager().instances["1111"] + convey.So(ok, convey.ShouldBeFalse) + convey.So(flag, convey.ShouldBeTrue) + }) +} + +func TestSignalManager_complex(t *testing.T) { + GetSignalManager().instances = make(map[string]*signalInstance) + defer gomonkey.ApplyFunc(PrepareSchedulerArg, func() ([]byte, error) { + return nil, nil + }).Reset() + + defer gomonkey.ApplyFunc(aliasroute.MarshalTenantAliasList, func(string) ([]byte, error) { + return nil, nil + }).Reset() + + var okFlag atomic.Bool + + var killFuncExecuted atomic.Bool + var blockFlag atomic.Bool + killFunc := func(string, int, []byte) error { + killFuncExecuted.Store(true) + if okFlag.Load() { + return nil + } + for blockFlag.Load() { + + } + return fmt.Errorf("") + } + + uuidCount := 0 + defer gomonkey.ApplyMethod(reflect.TypeOf(uuid.New()), "String", func(_ uuid.RandomUUID) string { + uuidCount++ + return "" + }).Reset() + defer gomonkey.ApplyFunc(time.Sleep, func(duration time.Duration) { + + }).Reset() + + convey.Convey("concurrency processor", t, func() { + GetSignalManager().SetKillFunc(killFunc) + killFuncExecuted.Store(false) + + okFlag.Store(false) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal0", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + + convey.So(GetSignalManager().instances["signal0"], convey.ShouldNotBeNil) + sp := GetSignalManager().instances["signal0"].signalProcessors[constant.KillSignalAliasUpdate] + convey.So(sp, convey.ShouldNotBeNil) + + for !killFuncExecuted.Load() { + } + blockFlag.Store(true) + fmt.Printf("executed is %v\n", killFuncExecuted.Load()) + sp.Lock() + convey.So(sp.IsRunning, convey.ShouldEqual, true) + convey.So(sp.HasSignal, convey.ShouldEqual, false) + convey.So(uuidCount, convey.ShouldEqual, 1) + + killFuncExecuted.Store(false) + blockFlag.Store(false) + sp.Unlock() + + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal0", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + blockFlag.Store(true) + <-time.After(200 * time.Millisecond) + sp.Lock() + convey.So(sp.IsRunning, convey.ShouldEqual, true) + convey.So(sp.HasSignal, convey.ShouldEqual, false) + convey.So(uuidCount, convey.ShouldEqual, 1) + sp.Unlock() + blockFlag.Store(false) + okFlag.Store(true) + <-time.After(200 * time.Millisecond) + convey.So(sp.IsRunning, convey.ShouldEqual, false) + convey.So(sp.HasSignal, convey.ShouldEqual, false) + }) + + convey.Convey("stop chan delete", t, func() { + okFlag.Store(false) + blockFlag.Store(false) + GetSignalManager().SignalInstance(&types.Instance{InstanceID: "signal0", FuncKey: "12345678901234561234567890123456/hello/$latest"}, constant.KillSignalAliasUpdate) + + convey.So(GetSignalManager().instances["signal0"], convey.ShouldNotBeNil) + sp := GetSignalManager().instances["signal0"].signalProcessors[constant.KillSignalAliasUpdate] + convey.So(sp, convey.ShouldNotBeNil) + <-time.After(100 * time.Millisecond) + blockFlag.Store(true) + sp.Lock() + convey.So(sp.IsRunning, convey.ShouldEqual, true) + convey.So(sp.HasSignal, convey.ShouldEqual, false) + sp.Unlock() + + blockFlag.Store(false) + GetSignalManager().RemoveInstance("signal0") + <-time.After(100 * time.Millisecond) + convey.So(sp.IsRunning, convey.ShouldEqual, false) + convey.So(sp.HasSignal, convey.ShouldEqual, false) + }) +} diff --git a/yuanrong/pkg/functionscaler/state/state.go b/yuanrong/pkg/functionscaler/state/state.go new file mode 100644 index 0000000..a2d60ce --- /dev/null +++ b/yuanrong/pkg/functionscaler/state/state.go @@ -0,0 +1,236 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package state is used to save and restore the scheduler state. +package state + +import ( + "encoding/json" + "fmt" + "os" + "sync" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/state" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +// SchedulerState add the status to be saved here. +type SchedulerState struct { + InstancePool map[string]*types.InstancePoolState `json:"InstancePool" valid:"optional"` // funcKey - instance +} + +// SchedulerEtcdRevision - +type SchedulerEtcdRevision struct { + // it means etcd key /sn/instance/ revision, all operation on the key will update LastInstanceRevision + LastInstanceRevision int32 `json:"lastInstanceRevision"` +} + +const defaultHandlerQueueSize = 2000 + +var ( + schedulerStateLock sync.RWMutex + // schedulerState - + schedulerState = &SchedulerState{ + InstancePool: make(map[string]*types.InstancePoolState), + } + schedulerEtcdRev *SchedulerEtcdRevision + schedulerHandlerQueue *state.Queue + stateKey = "" + stateRevKey = "" +) + +// RecoverConfig recover config +func RecoverConfig() error { + log.GetLogger().Infof("GlobalConfig recovered") + return nil +} + +func init() { + schedulerInstanceIDSelf := os.Getenv("INSTANCE_ID") + stateKey = "/faas/state/recover/faasscheduler/" + schedulerInstanceIDSelf + stateRevKey = "/faas/state/revision/faasscheduler/" + schedulerInstanceIDSelf +} + +// InitState - +func InitState() { + if config.GlobalConfig.StateDisable { + log.GetLogger().Warnf("state is disable, skip init state") + return + } + for k, v := range schedulerState.InstancePool { + log.GetLogger().Infof("recover state: %s: %v", k, *v) + } + + if schedulerHandlerQueue != nil { + return + } + schedulerHandlerQueue = state.NewStateQueue(defaultHandlerQueueSize) + if schedulerHandlerQueue == nil { + return + } + go schedulerHandlerQueue.Run(updateState) +} + +// SetState - +func SetState(byte []byte) error { + return json.Unmarshal(byte, schedulerState) +} + +// RecoverStateRev - +func RecoverStateRev() { + if schedulerEtcdRev == nil && schedulerHandlerQueue != nil { + stateRevBytes, err := schedulerHandlerQueue.GetState(stateRevKey) + if err != nil { + log.GetLogger().Warnf("failed to get stateRev from etcd, err:%s", err.Error()) + return + } + if err := json.Unmarshal(stateRevBytes, &schedulerEtcdRev); err != nil { + log.GetLogger().Warnf("failed to unmarshal stateRev, err:%s", err.Error()) + return + } + } +} + +// GetState - +func GetState() *SchedulerState { + schedulerStateLock.RLock() + defer schedulerStateLock.RUnlock() + return schedulerState +} + +// GetStateRev - +func GetStateRev() *SchedulerEtcdRevision { + schedulerStateLock.RLock() + defer schedulerStateLock.RUnlock() + return schedulerEtcdRev +} + +// GetStateByte is used to obtain the local state +func GetStateByte() ([]byte, error) { + if schedulerHandlerQueue == nil { + return nil, fmt.Errorf("schedulerHandlerQueue is not initialized") + } + schedulerStateLock.RLock() + defer schedulerStateLock.RUnlock() + stateBytes, err := schedulerHandlerQueue.GetState(stateKey) + if err != nil { + return nil, err + } + log.GetLogger().Debugf("get state from etcd schedulerState: %v", string(stateBytes)) + return stateBytes, nil +} + +func updateState(value interface{}, tags ...string) { + if schedulerHandlerQueue == nil { + log.GetLogger().Errorf("scheduler state schedulerHandlerQueue is nil") + return + } + var ( + stateBytes []byte + updateKey string + err error + ) + schedulerStateLock.Lock() + defer schedulerStateLock.Unlock() + switch v := value.(type) { + case *types.InstancePoolStateInput: + // tags[0] as opt + if len(tags) <= 0 { + log.GetLogger().Errorf("failed to operate the instancePool, tags is empty") + return + } + switch tags[0] { + case types.StateUpdate: + log.GetLogger().Infof("update scheduler state for instance queue") + updateInstancePool(v) + case types.StateDelete: + log.GetLogger().Info("delete scheduler state for instance queue") + deleteInstancePool(v) + default: + log.GetLogger().Errorf("failed to operate the instancePool, opt is error %s", tags[0]) + return + } + if stateBytes, err = json.Marshal(schedulerState); err != nil { + log.GetLogger().Errorf("get scheduler state error %s", err.Error()) + return + } + updateKey = stateKey + case int32: + if schedulerEtcdRev == nil { + schedulerEtcdRev = &SchedulerEtcdRevision{} + } + schedulerEtcdRev.LastInstanceRevision = v + if stateBytes, err = json.Marshal(schedulerEtcdRev); err != nil { + log.GetLogger().Errorf("get scheduler state error %s", err.Error()) + return + } + updateKey = stateRevKey + default: + log.GetLogger().Warnf("unknown data type for scheduler state") + return + } + if len(stateBytes) <= 0 || updateKey == "" { + return + } + if err = schedulerHandlerQueue.SaveState(stateBytes, updateKey); err != nil { + log.GetLogger().Errorf("save scheduler state error: %s", err.Error()) + return + } + log.GetLogger().Infof("update scheduler state successfully") +} + +// Update is used to write scheduler state to the cache queue +func Update(value interface{}, tags ...string) { + if schedulerHandlerQueue == nil { + return + } + if err := schedulerHandlerQueue.Push(value, tags...); err != nil { + log.GetLogger().Errorf("failed to push state to state queue: %s", err.Error()) + } +} + +// updateInstancePool adds an element to the queue specified by the tag +func updateInstancePool(value *types.InstancePoolStateInput) { + if value.InstanceType == types.InstanceTypeState { + if schedulerState.InstancePool[value.FuncKey].StateInstance[value.StateID] != nil { + log.GetLogger().Warnf("state, func %s state %s already has instance %s", + value.FuncKey, value.StateID, + schedulerState.InstancePool[value.FuncKey].StateInstance[value.StateID].InstanceID) + } + schedulerState.InstancePool[value.FuncKey].StateInstance[value.StateID] = &types.Instance{ + InstanceStatus: commonTypes.InstanceStatus{Code: value.InstanceStatusCode}, + InstanceType: types.InstanceTypeState, + InstanceID: value.InstanceID, + FuncKey: value.FuncKey, + FuncSig: value.FuncSig, + } + return + } + log.GetLogger().Errorf("failed to update the instancePool") +} + +// deleteInstancePool deletes an element from the queue specified by the tag +func deleteInstancePool(value *types.InstancePoolStateInput) { + if value.StateID != "" && value.InstanceType == types.InstanceTypeState { + delete(schedulerState.InstancePool[value.FuncKey].StateInstance, value.StateID) + log.GetLogger().Infof("state del state %s", value.StateID) + return + } + delete(schedulerState.InstancePool, value.FuncKey) +} diff --git a/yuanrong/pkg/functionscaler/state/state_test.go b/yuanrong/pkg/functionscaler/state/state_test.go new file mode 100644 index 0000000..61a7b0c --- /dev/null +++ b/yuanrong/pkg/functionscaler/state/state_test.go @@ -0,0 +1,192 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package state - +package state + +import ( + "encoding/json" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + clientv3 "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/state" + "yuanrong/pkg/functionscaler/types" +) + +func TestInitState(t *testing.T) { + convey.Convey("InitState success", t, func() { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + InitState() + }) +} + +func TestOptState(t *testing.T) { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "Put", func(_ *etcd3.EtcdClient, + ctxInfo etcd3.EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return nil + }).Reset() + InitState() + schedulerState = &SchedulerState{ + InstancePool: make(map[string]*types.InstancePoolState), + } + stateByte, _ := json.Marshal(schedulerState) + + convey.Convey("set state", t, func() { + err := SetState(stateByte) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("get state", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(schedulerHandlerQueue), "GetState", + func(q *state.Queue, key string) ([]byte, error) { + return stateByte, nil + }).Reset() + ssByte, err := GetStateByte() + outPut := &SchedulerState{} + json.Unmarshal(ssByte, outPut) + convey.So(err, convey.ShouldBeNil) + }) + time.Sleep(50 * time.Millisecond) +} + +func TestUpdateState(t *testing.T) { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "Put", func(_ *etcd3.EtcdClient, + ctxInfo etcd3.EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return nil + }).Reset() + // client := mockUtils.FakeLibruntimeSdkClient{} + InitState() + schedulerState = &SchedulerState{ + InstancePool: make(map[string]*types.InstancePoolState), + } + schedulerState.InstancePool["testStateFuncKey"] = &types.InstancePoolState{ + StateInstance: map[string]*types.Instance{ + "testState": {}, + }, + } + stateByte, _ := json.Marshal(schedulerState) + SetState(stateByte) + + convey.Convey("delete funcKey success", t, func() { + Update(&types.InstancePoolStateInput{ + FuncKey: "testFuncKey2", + }, types.StateDelete) + time.Sleep(100 * time.Millisecond) + convey.So(GetState().InstancePool, convey.ShouldNotContainKey, "testFuncKey2") + }) + + convey.Convey("type is error", t, func() { + type custom struct{} + temp1 := *GetState() + Update(&custom{}) + time.Sleep(100 * time.Millisecond) + temp2 := *GetState() + convey.So(temp1, convey.ShouldResemble, temp2) + }) + convey.Convey("instancepoll tags is error", t, func() { + Update(&types.InstancePoolStateInput{ + FuncKey: "testFuncKey3", + InstanceType: types.InstanceTypeScaled, + ResKey: resspeckey.ResSpecKey{CPU: 300, Memory: 128}, + InstanceID: "InstanceID-891011", + }) + time.Sleep(100 * time.Millisecond) + convey.So(GetState().InstancePool, convey.ShouldNotContainKey, "testFuncKey3") + + Update(&types.InstancePoolStateInput{ + FuncKey: "testFuncKey3", + InstanceType: types.InstanceTypeScaled, + ResKey: resspeckey.ResSpecKey{CPU: 300, Memory: 128}, + InstanceID: "InstanceID-891011", + }, "error Opt") + time.Sleep(100 * time.Millisecond) + convey.So(GetState().InstancePool, convey.ShouldNotContainKey, "testFuncKey3") + }) + + convey.Convey("update state instance", t, func() { + Update(&types.InstancePoolStateInput{ + FuncKey: "testStateFuncKey", + StateID: "testState", + InstanceType: types.InstanceTypeState, + ResKey: resspeckey.ResSpecKey{CPU: 300, Memory: 128}, + InstanceID: "InstanceID-891011", + }, types.StateUpdate) + time.Sleep(100 * time.Millisecond) + convey.So(GetState().InstancePool, convey.ShouldContainKey, "testStateFuncKey") + }) + + convey.Convey("delete state instance", t, func() { + Update(&types.InstancePoolStateInput{ + FuncKey: "testStateFuncKey", + StateID: "testState", + InstanceType: types.InstanceTypeState, + ResKey: resspeckey.ResSpecKey{CPU: 300, Memory: 128}, + InstanceID: "InstanceID-891011", + }, types.StateDelete) + time.Sleep(100 * time.Millisecond) + convey.So(len(GetState().InstancePool), convey.ShouldEqual, 1) + }) + + convey.Convey("schedulerHandlerQueue is nil", t, func() { + schedulerHandlerQueue = nil + temp1 := *GetState() + Update(&types.Configuration{ + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 1, + }, + }) + time.Sleep(100 * time.Millisecond) + temp2 := *GetState() + convey.So(temp1, convey.ShouldResemble, temp2) + updateState(&types.Configuration{ + AutoScaleConfig: types.AutoScaleConfig{ + SLAQuota: 1, + }, + }) + time.Sleep(100 * time.Millisecond) + temp2 = *GetState() + convey.So(temp1, convey.ShouldResemble, temp2) + }) + +} + +func TestRecoverStateRev(t *testing.T) { + convey.Convey("RecoverStateRev", t, func() { + schedulerHandlerQueue = &state.Queue{} + defer gomonkey.ApplyMethod(reflect.TypeOf(schedulerHandlerQueue), "GetState", + func(q *state.Queue, key string) ([]byte, error) { + return []byte(`{"lastInstanceRevision": 10086}`), nil + }).Reset() + RecoverStateRev() + stateRev := GetStateRev() + convey.So(stateRev, convey.ShouldNotBeNil) + convey.So(stateRev.LastInstanceRevision, convey.ShouldEqual, 10086) + }) +} diff --git a/yuanrong/pkg/functionscaler/stateinstance/lease.go b/yuanrong/pkg/functionscaler/stateinstance/lease.go new file mode 100644 index 0000000..9839039 --- /dev/null +++ b/yuanrong/pkg/functionscaler/stateinstance/lease.go @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package stateinstance - +package stateinstance + +import ( + "fmt" + "sync" + "time" + + "go.uber.org/zap" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/common/uuid" +) + +// DeleteStateInstance - +type DeleteStateInstance func(stateID string, instanceID string) + +// Lease one Leaser has multiple leases +type Lease struct { + ID int + timer *time.Timer // retain lease interval +} + +// Leaser one state instance has one Leaser +type Leaser struct { + Leases map[int]*Lease + mu sync.Mutex + nextID int + maxLeases int + timer *time.Timer // Scale-down Interval Without lease + deleteStateInstance DeleteStateInstance + stateID string + instanceID string + scaleDownWindow time.Duration +} + +// NewLeaser new leaser for the state instance +func NewLeaser(maxLeases int, deleteStateInstance DeleteStateInstance, stateID string, instanceID string, + scaleDownWindow time.Duration) *Leaser { + return &Leaser{ + Leases: make(map[int]*Lease), + maxLeases: maxLeases, + deleteStateInstance: deleteStateInstance, + stateID: stateID, + instanceID: instanceID, + scaleDownWindow: scaleDownWindow, + } +} + +// Recover when the leaser is recovering, start timer for deleting the instance. +// because the timer is set when aquire a lease normally but no aquiring existing. +func (l *Leaser) Recover() { + l.mu.Lock() + defer l.mu.Unlock() + log.GetLogger().Infof("recover leaser timer") + if l.timer != nil { + l.timer.Stop() + } + l.timer = time.AfterFunc(l.scaleDownWindow, func() { + log.GetLogger().Warnf("No lease for %v, scale down, stateKey is %s", l.scaleDownWindow, l.stateID) + l.deleteStateInstance(l.stateID, l.instanceID) + }) +} + +// Terminate when state route is delete, the leaser should be terminate +func (l *Leaser) Terminate() { + l.mu.Lock() + defer l.mu.Unlock() + + log.GetLogger().Infof("terminate leaser %s", l.stateID) + if l.timer != nil { + l.timer.Stop() + l.timer = nil + } +} + +// AcquireLease acquire lease for instance +func (l *Leaser) AcquireLease(leaseInterval time.Duration) (*Lease, error) { + l.mu.Lock() + defer l.mu.Unlock() + + if len(l.Leases) >= l.maxLeases { + return nil, snerror.New(statuscode.StateInstanceNoLease, statuscode.StateInstanceNoLeaseMsg) + } + + lease := &Lease{ID: l.nextID} + lease.timer = time.AfterFunc(leaseInterval, func() { + l.ReleaseLease(lease.ID) + }) + + l.Leases[lease.ID] = lease + l.nextID++ + if l.timer != nil { + l.timer.Stop() + l.timer = nil + } else { + log.GetLogger().Infof("timer is nil") + } + return lease, nil +} + +// ReleaseLease lease expiration or call release lease +func (l *Leaser) ReleaseLease(id int) { + l.mu.Lock() + defer l.mu.Unlock() + logger := log.GetLogger().With(zap.Any("traceID", uuid.New())) + + lease, exists := l.Leases[id] + if exists { + logger.Warnf("Release Lease id %d, stateKey is %s", id, l.stateID) + lease.timer.Stop() + delete(l.Leases, id) + } + + if len(l.Leases) == 0 { + log.GetLogger().Debugf("set timer when lease len = 0") + if l.timer != nil { + l.timer.Stop() + } + l.timer = time.AfterFunc(l.scaleDownWindow, func() { + logger.Warnf("No lease for %v, scale down, stateKey is %s", l.scaleDownWindow, l.stateID) + l.deleteStateInstance(l.stateID, l.instanceID) + }) + } +} + +// RetainLease renew within the lease interval +func (l *Leaser) RetainLease(id int, leaseInterval time.Duration) error { + l.mu.Lock() + defer l.mu.Unlock() + + lease, exists := l.Leases[id] + if !exists { + return fmt.Errorf("lease%d not found", id) + } + + lease.timer.Stop() + lease.timer = time.AfterFunc(leaseInterval, func() { + l.ReleaseLease(id) + }) + + return nil +} diff --git a/yuanrong/pkg/functionscaler/stateinstance/lease_test.go b/yuanrong/pkg/functionscaler/stateinstance/lease_test.go new file mode 100644 index 0000000..27987b7 --- /dev/null +++ b/yuanrong/pkg/functionscaler/stateinstance/lease_test.go @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stateinstance + +import ( + "sync" + "testing" + "time" + + "github.com/smartystreets/goconvey/convey" +) + +func TestNewLeaser(t *testing.T) { + convey.Convey("test NewLeaser", t, func() { + leaser := NewLeaser(10, nil, "testStateID", "testInstanceID", 10*time.Second) + convey.So(leaser, convey.ShouldNotBeNil) + leaser.Recover() + }) +} + +func TestReleaseLease(t *testing.T) { + convey.Convey("test ReleaseLease", t, func() { + timer1 := time.NewTimer(1 * time.Second) + leaser := &Leaser{ + Leases: map[int]*Lease{1: { + timer: timer1, + }, 2: {}}, + mu: sync.Mutex{}, + timer: &time.Timer{}, + stateID: "id", + } + leaser.ReleaseLease(1) + convey.So(len(leaser.Leases), convey.ShouldEqual, 1) + }) +} + +func TestRetainLease(t *testing.T) { + convey.Convey("test RetainLease", t, func() { + timer1 := time.NewTimer(1 * time.Second) + leaser := &Leaser{ + Leases: map[int]*Lease{1: { + timer: timer1, + }, 2: {}}, + mu: sync.Mutex{}, + stateID: "id", + } + err := leaser.RetainLease(1, 1*time.Millisecond) + convey.So(err, convey.ShouldBeNil) + }) +} + +func TestAcquireLease(t *testing.T) { + convey.Convey("test AcquireLease", t, func() { + convey.Convey("max leases", func() { + leaser := &Leaser{ + Leases: map[int]*Lease{1: {}, 2: {}}, + mu: sync.Mutex{}, + timer: &time.Timer{}, + stateID: "id", + maxLeases: 1, + } + _, err := leaser.AcquireLease(1 * time.Millisecond) + convey.So(err, convey.ShouldNotBeNil) + }) + convey.Convey("timer not nil", func() { + timer1 := time.NewTimer(1 * time.Second) + leaser := &Leaser{ + Leases: map[int]*Lease{1: {}, 2: {}}, + mu: sync.Mutex{}, + timer: timer1, + stateID: "id", + maxLeases: 3, + nextID: 2, + } + lease, err := leaser.AcquireLease(1 * time.Millisecond) + convey.So(lease, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }) + convey.Convey("max leases timer is nil", func() { + timer1 := time.NewTimer(1 * time.Second) + leaser := &Leaser{ + Leases: map[int]*Lease{1: {}, 2: {}}, + mu: sync.Mutex{}, + stateID: "id", + maxLeases: 3, + nextID: 2, + timer: timer1, + } + lease, err := leaser.AcquireLease(1 * time.Millisecond) + convey.So(lease, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + }) + }) +} + +func TestTerminate(t *testing.T) { + convey.Convey("test Terminate", t, func() { + timer1 := time.NewTimer(1 * time.Second) + leaser := &Leaser{ + Leases: map[int]*Lease{1: {}, 2: {}}, + mu: sync.Mutex{}, + timer: timer1, + stateID: "id", + } + leaser.Terminate() + convey.So(len(leaser.Leases), convey.ShouldEqual, 2) + }) +} diff --git a/yuanrong/pkg/functionscaler/sts/sensitiveconfig.go b/yuanrong/pkg/functionscaler/sts/sensitiveconfig.go new file mode 100644 index 0000000..a1ed1d3 --- /dev/null +++ b/yuanrong/pkg/functionscaler/sts/sensitiveconfig.go @@ -0,0 +1,114 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package sts provide methods for obtaining sensitive information +package sts + +import ( + "fmt" + + "github.com/json-iterator/go" + "huawei.com/wisesecurity/sts-sdk/pkg/remote" + + "yuanrong/pkg/common/faas_common/logger/log" +) + +// GetEnvMap - environment variables for sensitive configuration items +func GetEnvMap(configs map[string]string) (map[string]string, error) { + envMap := make(map[string]string) + configIDs := getSensitiveValue(configs) + stsConfigs, err := GetSensitiveConfigIDs(configIDs) + if err != nil { + return envMap, err + } + + sensitiveMap, err := ParseStsResponseStrict(stsConfigs) + if err != nil { + return envMap, err + } + + // envMap key indicates the sensitive configuration name, and value indicates the encrypted value, for example, + // common.password.etcd.value: ENC(key=servicekek, value=xxx). + for k, v := range configs { + envMap[k] = sensitiveMap[v] + } + return envMap, nil +} + +// return configIDs slice, such as [Service/WiseFunctionService/common.password.etcd.value/dev +// Service/WiseFunctionService/common.password.iamAuth.value/dev] +func getSensitiveValue(configItems map[string]string) []string { + var configIDs []string + for _, v := range configItems { + configIDs = append(configIDs, v) + } + return configIDs +} + +// ParseStsResponseStrict return key is configID,The value is an encrypted value. +// such as Service/WiseFunctionService/common.password.etcd.value/dev: ENC(key=servicekek, value=xxx) +func ParseStsResponseStrict(respBody *SensitiveConfigResponse) (map[string]string, error) { + if len(respBody.MissingConfigItems) != 0 { + return nil, fmt.Errorf("missing item: %s", respBody.MissingConfigItems) + } + sensitiveMap := make(map[string]string) + for _, configItem := range respBody.ConfigItems { + if configItem.ConfigValue == "" { + return nil, fmt.Errorf("item value is empty, configID: %s", configItem.ConfigID) + } + sensitiveMap[configItem.ConfigID] = configItem.ConfigValue + } + return sensitiveMap, nil +} + +// GetSensitiveConfigIDs - +func GetSensitiveConfigIDs(configIDs []string) (*SensitiveConfigResponse, error) { + log.GetLogger().Info("[sts] start get sensitiveConfig") + size := len(configIDs) + httpRequest := new(remote.StsHttpRequestBuilder).SetMethod("POST").SetPath(SensitiveConfigPath).Build() + httpClient := GetStsHTTPClient() + var stsConfigs = SensitiveConfigResponse{} + + for i := 0; i < size; i += maxConfigIDPerRequest { + var tmpConfigIDs []string + for j := i; j < i+maxConfigIDPerRequest && j < size; j++ { + tmpConfigIDs = append(tmpConfigIDs, configIDs[j]) + } + req := &configIDsReq{ + ConfigIds: tmpConfigIDs, + } + requestJSONByte, err := jsoniter.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal failed: %v", err) + } + httpRequest.Body = requestJSONByte + // If the config ID is incorrect, 200 is returned. + // The config ID is in Missing Config Items. If the config ID is empty, 400 is returned. + buf, err := doStsRequest(httpRequest, httpClient) + if err != nil { + return nil, err + } + + var configs SensitiveConfigResponse + err = jsoniter.Unmarshal(buf, &configs) + if err != nil { + return nil, fmt.Errorf("unmarshal failed, error: %v", err) + } + stsConfigs.MissingConfigItems = append(stsConfigs.MissingConfigItems, configs.MissingConfigItems...) + stsConfigs.ConfigItems = append(stsConfigs.ConfigItems, configs.ConfigItems...) + } + return &stsConfigs, nil +} diff --git a/yuanrong/pkg/functionscaler/sts/sensitiveconfig_test.go b/yuanrong/pkg/functionscaler/sts/sensitiveconfig_test.go new file mode 100644 index 0000000..27ce831 --- /dev/null +++ b/yuanrong/pkg/functionscaler/sts/sensitiveconfig_test.go @@ -0,0 +1,76 @@ +package sts + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "huawei.com/wisesecurity/sts-sdk/pkg/auth" + "huawei.com/wisesecurity/sts-sdk/pkg/remote" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" +) + +type mockHttpBody struct { +} + +func (m *mockHttpBody) Read(p []byte) (n int, err error) { + p = append(p, byte(1)) + return 1, nil +} + +func (m *mockHttpBody) Close() error { + return nil +} + +func TestGetEnvMap(t *testing.T) { + convey.Convey("GetEnvMap", t, func() { + defer gomonkey.ApplyFunc(GetStsHTTPClient, func() *remote.StsKeyHttpClient { + return &remote.StsKeyHttpClient{} + }).Reset() + defer gomonkey.ApplyFunc(stsgoapi.SignRequest, func(httpRequest *remote.StsHttpRequest, providerServiceMeta auth.StsFullServiceMeta) ( + *remote.StsHttpRequest, error) { + return nil, nil + }).Reset() + defer gomonkey.ApplyFunc(io.ReadAll, func(r io.Reader) ([]byte, error) { + return json.Marshal(SensitiveConfigResponse{ConfigItems: []ConfigItem{{ + ConfigID: "config1", + ConfigValue: "configValue1", + }}}) + }).Reset() + convey.Convey("failed", func() { + p1 := gomonkey.ApplyMethod(reflect.TypeOf(&remote.StsKeyHttpClient{}), "SendMessage", + func(_ *remote.StsKeyHttpClient, stsHttpRequest *remote.StsHttpRequest) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusBadRequest, Body: &mockHttpBody{}}, nil + }) + configMap := make(map[string]string) + configMap["configID1"] = "config1" + _, err := GetEnvMap(configMap) + convey.So(err, convey.ShouldNotBeNil) + p1.Reset() + + p2 := gomonkey.ApplyMethod(reflect.TypeOf(&remote.StsKeyHttpClient{}), "SendMessage", + func(_ *remote.StsKeyHttpClient, stsHttpRequest *remote.StsHttpRequest) (*http.Response, error) { + return nil, errors.New("http error") + }) + _, err = GetEnvMap(configMap) + convey.So(err, convey.ShouldNotBeNil) + p2.Reset() + }) + convey.Convey("success", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&remote.StsKeyHttpClient{}), "SendMessage", + func(_ *remote.StsKeyHttpClient, stsHttpRequest *remote.StsHttpRequest) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK, Body: &mockHttpBody{}}, nil + }).Reset() + configMap := make(map[string]string) + configMap["configID1"] = "config1" + envMap, err := GetEnvMap(configMap) + convey.So(err, convey.ShouldBeNil) + convey.So(envMap["configID1"], convey.ShouldEqual, "configValue1") + }) + }) +} diff --git a/yuanrong/pkg/functionscaler/sts/sts.go b/yuanrong/pkg/functionscaler/sts/sts.go new file mode 100644 index 0000000..1ecab00 --- /dev/null +++ b/yuanrong/pkg/functionscaler/sts/sts.go @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package sts provide methods for obtaining sensitive information +package sts + +import ( + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + + "huawei.com/wisesecurity/sts-sdk/pkg/remote" + "huawei.com/wisesecurity/sts-sdk/pkg/stsgoapi" + + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/functionscaler/config" +) + +var ( + httpClient = &remote.StsKeyHttpClient{} + once sync.Once +) + +func doStsRequest(httpRequest *remote.StsHttpRequest, httpClient *remote.StsKeyHttpClient) ([]byte, error) { + newStsHTTPReq, _ := stsgoapi.SignRequest(httpRequest, serviceMeta) + response, err := httpClient.SendMessage(newStsHTTPReq) + if err != nil { + return nil, fmt.Errorf("send message failed: %s", err.Error()) + } + defer response.Body.Close() + buf, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("io.ReadAll failed, error: %s, response body is %v", err.Error(), string(buf)) + } + if response.StatusCode == http.StatusOK { + return buf, nil + } + if response.StatusCode/100 == 4 { // 4xx + errString := strings.ReplaceAll(string(buf), `"`, "") + return nil, snerror.New(statuscode.StsConfigErrCode, + "The requested parameter or permission is abnormal, statusCode is "+strconv.Itoa(response.StatusCode)+ + ", err response is "+errString) + } + + // 5xx... etc + return nil, fmt.Errorf("http error, the code is %d, err response is %s", response.StatusCode, string(buf)) +} + +// GetStsHTTPClient create StsKeyHttpClient +func GetStsHTTPClient() *remote.StsKeyHttpClient { + once.Do(func() { + cfg := config.GlobalConfig + rawHTTPClient := remote.NewInstanceOfHttpClient() + httpClient = &remote.StsKeyHttpClient{ + StsServerHost: "http://" + cfg.RawStsConfig.MgmtServerConfig.Domain, + MyServiceMeta: rawHTTPClient.MyServiceMeta, + Cert: rawHTTPClient.Cert, + Signer: rawHTTPClient.Signer, + HttpClient: rawHTTPClient.HttpClient, + VerifyCert: rawHTTPClient.VerifyCert, + } + }) + return httpClient +} diff --git a/yuanrong/pkg/functionscaler/sts/sts_test.go b/yuanrong/pkg/functionscaler/sts/sts_test.go new file mode 100644 index 0000000..3c1ca09 --- /dev/null +++ b/yuanrong/pkg/functionscaler/sts/sts_test.go @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package sts provide methods for obtaining sensitive information +package sts + +import ( + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + "huawei.com/wisesecurity/sts-sdk/pkg/remote" +) + +func TestGetStsHTTPClient(t *testing.T) { + defer gomonkey.ApplyFunc(remote.NewInstanceOfHttpClient, func() *remote.StsKeyHttpClient { + return &remote.StsKeyHttpClient{} + }).Reset() + client := GetStsHTTPClient() + assert.NotNil(t, client) +} diff --git a/yuanrong/pkg/functionscaler/sts/types.go b/yuanrong/pkg/functionscaler/sts/types.go new file mode 100644 index 0000000..496a79c --- /dev/null +++ b/yuanrong/pkg/functionscaler/sts/types.go @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package sts provide methods for obtaining sensitive information +package sts + +import ( + "huawei.com/wisesecurity/sts-sdk/pkg/auth" +) + +const ( + gcmStandardNonceSize = 12 + // CertPath Get certificate interface url + CertPath = "/mgmt-inner/cms-mgmt/v1/deploy/%s/%s/cert-file" + // DateKeyPath Get dataKey interface url + DateKeyPath = "/mgmt-inner/cms-mgmt/v1/data-key" + // SensitiveConfigPath Get sensitive config interface url + SensitiveConfigPath = "/mgmt-inner/sts-mgmt/v2/sensitiveconfig" + // Channel Deployment channel, fill in the service and microservice name corresponding to the BaaS service + Channel = "HMSClientCloudAccelerateService_HMSCaaSYuanRongWorkerManager" + // DefaultKeyVersion Default Key Version + DefaultKeyVersion = "1" + // MgmtServiceName - Mgmt ServiceName + MgmtServiceName = "SecurityMgmtService" + // MgmtMicroServiceName - Mgmt MicroServiceName + MgmtMicroServiceName = "SecurityMgmtMicroService" + + baseCertFilePath = "/opt/certs" + + maxConfigIDPerRequest = 50 + + privateKeyByteLen = 32 + // ECDHKeyLen - ECDH key len + ECDHKeyLen = 16 +) + +var serviceMeta = &auth.StsMicroServiceMeta{ + Service: MgmtServiceName, + MicroService: MgmtMicroServiceName, +} + +// CertResponse - certificate interface resp +type CertResponse struct { + CertFileData string `json:"certFileData"` + PrivateKeyData string `json:"privateKeyData"` + Format string `json:"format"` + Password Password `json:"password"` +} + +// Password - +type Password struct { + ProtectKey string `json:"protectKey"` + RootKey string `json:"rootkey"` + WorkKey string `json:"workkey"` + CipherPwd string `json:"cipherPwd"` +} + +// RootKey - in sts cert +type RootKey struct { + Apple string `json:"apple"` + Boy string `json:"boy"` + Cat string `json:"cat"` + Dog string `json:"dog"` +} + +// DataKeyReq - DataKey interface req +type DataKeyReq struct { + Algo string `json:"algo"` + AppID string `json:"appId"` + PublicKey string `json:"publicKey"` +} + +// DataKeyResponse - DataKey interface resp +type DataKeyResponse struct { + DataKey string `json:"dataKey,omitempty"` + Version int64 `json:"version,omitempty"` + PublicKey string `json:"publicKey,omitempty"` +} + +// EcdhKeyPair - ECDH key +type EcdhKeyPair struct { + PublicKey []byte `json:"publicKey,omitempty"` + PrivateKey []byte `json:"privateKey,omitempty"` +} + +type configIDsReq struct { + ConfigIds []string `json:"configIds"` +} + +// SensitiveConfigResponse - sensitive config interface resp +type SensitiveConfigResponse struct { + Status string `json:"status"` + Message string `json:"message"` + ConfigItems []ConfigItem `json:"configItems"` + MissingConfigItems []MissingConfigItem `json:"missingConfigItems"` +} + +// ConfigItem - list of sensitive configurations that can successfully obtain ciphertext +type ConfigItem struct { + ConfigID string `json:"configId"` + ConfigValue string `json:"configValue"` +} + +// MissingConfigItem - List of sensitive configuration coordinates that cannot be queried +type MissingConfigItem struct { + ConfigID string `json:"configId"` +} + +// Info - sts cert info +type Info struct { + RootKey *RootKey + Cert *CertResponse + Ini string +} diff --git a/yuanrong/pkg/functionscaler/tenantquota/tenantcache.go b/yuanrong/pkg/functionscaler/tenantquota/tenantcache.go new file mode 100644 index 0000000..e36a8dc --- /dev/null +++ b/yuanrong/pkg/functionscaler/tenantquota/tenantcache.go @@ -0,0 +1,164 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tenantquota - +package tenantquota + +import ( + "sync" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/types" +) + +var ( + cache = &tenantCache{} + defaultMaxOnDemandInstanceNum = int64(types.DefaultMaxOnDemandInstanceNumPerTenant) + defaultMaxReversedInstanceNum = int64(types.DefaultMaxReversedInstanceNumPerTenant) +) + +type tenantCache struct { + // key: tenantID value: *tenantQuota + tenantQuotaList sync.Map + // key: tenantID value: *tenantInstance + tenantInstanceList sync.Map +} + +type tenantQuota struct { + maxOnDemandInstance int64 + maxReversedInstance int64 + mux sync.RWMutex +} + +// GetTenantCache is cache for tenant +func GetTenantCache() *tenantCache { + return cache +} + +func createDefaultTenantQuota() *tenantQuota { + tenantIns := &tenantQuota{ + maxOnDemandInstance: defaultMaxOnDemandInstanceNum, + maxReversedInstance: defaultMaxReversedInstanceNum, + mux: sync.RWMutex{}, + } + return tenantIns +} + +type tenantInstance struct { + onDemandInsNum int64 + reversedInsNum int64 + mux sync.RWMutex +} + +func createTenantInstance() *tenantInstance { + tenantIns := &tenantInstance{ + onDemandInsNum: 0, + reversedInsNum: 0, + mux: sync.RWMutex{}, + } + return tenantIns +} + +func (tc *tenantCache) getOrCreateTenantInstance(tenantID string) *tenantInstance { + tenantInsIf, exist := tc.tenantInstanceList.LoadOrStore(tenantID, createTenantInstance()) + if !exist { + log.GetLogger().Infof("tenantInstance cache has not been created, generate, tenant: %s", tenantID) + } + + return tenantInsIf.(*tenantInstance) +} + +func (tc *tenantCache) getTenantInstanceNum(tenantID string) (int64, int64) { + tenantIns := tc.getOrCreateTenantInstance(tenantID) + tenantIns.mux.RLock() + defer tenantIns.mux.RUnlock() + return tenantIns.onDemandInsNum, tenantIns.reversedInsNum +} + +func (tc *tenantCache) updateTenantInstance(tenantID string, onDemandInsNum int64, reversedInsNum int64) { + if reversedInsNum == 0 && onDemandInsNum == 0 { + _, exist := tc.tenantInstanceList.Load(tenantID) + if !exist { + log.GetLogger().Infof("tenantInstance cache has not been created, tenant: %s", tenantID) + return + } + tc.tenantInstanceList.Delete(tenantID) + log.GetLogger().Infof("succeed to delete instance in tenant cache, tenant: %s", tenantID) + return + } + tenantIns := tc.getOrCreateTenantInstance(tenantID) + tenantIns.mux.Lock() + tenantIns.onDemandInsNum = onDemandInsNum + tenantIns.reversedInsNum = reversedInsNum + tenantIns.mux.Unlock() + log.GetLogger().Infof("succeed to update instance in tenant cache, tenant: %s", tenantID) +} + +func (tc *tenantCache) getTenantQuota(tenantID string) *tenantQuota { + tenantQuotaIf, exist := tc.tenantQuotaList.Load(tenantID) + if !exist { + log.GetLogger().Infof("tenant cache has not been created, tenant: %s", tenantID) + return nil + } + + return tenantQuotaIf.(*tenantQuota) +} + +// GetTenantQuotaNum get max instance quota for tenant +func (tc *tenantCache) GetTenantQuotaNum(tenantID string) (int64, int64) { + quota := tc.getTenantQuota(tenantID) + if quota == nil { + return defaultMaxOnDemandInstanceNum, defaultMaxReversedInstanceNum + } + quota.mux.RLock() + defer quota.mux.RUnlock() + return quota.maxOnDemandInstance, quota.maxReversedInstance +} + +// DeleteTenantQuota trigger by tenant quota key deleted in etcd +func (tc *tenantCache) DeleteTenantQuota(tenantID string) { + _, exist := tc.tenantQuotaList.Load(tenantID) + if !exist { + log.GetLogger().Infof("tenant cache has not been created, tenant: %s", tenantID) + return + } + tc.tenantQuotaList.Delete(tenantID) + log.GetLogger().Infof("succeed to delete quota in tenant cache trigger by etcd for tenant: %s", tenantID) +} + +// UpdateOrAddTenantQuota trigger by tenant quota key add or update in etcd +func (tc *tenantCache) UpdateOrAddTenantQuota(tenantID string, tenantMetaInfo types.TenantMetaInfo) { + quota := tc.getTenantQuota(tenantID) + if quota == nil { + quota = createDefaultTenantQuota() + tc.tenantQuotaList.Store(tenantID, quota) + } + quota.mux.Lock() + quota.maxOnDemandInstance = tenantMetaInfo.TenantInstanceMetaData.MaxOnDemandInstance + quota.maxReversedInstance = tenantMetaInfo.TenantInstanceMetaData.MaxReversedInstance + quota.mux.Unlock() + log.GetLogger().Infof("succeed to update quota in tenant cache trigger by etcd for tenant: %s", tenantID) +} + +// UpdateDefaultQuota trigger by default tenant quota key add or update in etcd +func (tc *tenantCache) UpdateDefaultQuota(tenantMetaInfo types.TenantMetaInfo) { + // update global config + defaultMaxOnDemandInstanceNum = tenantMetaInfo.TenantInstanceMetaData.MaxOnDemandInstance + defaultMaxReversedInstanceNum = tenantMetaInfo.TenantInstanceMetaData.MaxReversedInstance + log.GetLogger().Infof("succeed to reset default quota in tenant cache trigger by etcd, "+ + "defaultMaxOnDemandInstanceNum: %d, defaultMaxReversedInstanceNum: %d", + defaultMaxOnDemandInstanceNum, defaultMaxReversedInstanceNum) +} diff --git a/yuanrong/pkg/functionscaler/tenantquota/tenantcache_test.go b/yuanrong/pkg/functionscaler/tenantquota/tenantcache_test.go new file mode 100644 index 0000000..a7203cf --- /dev/null +++ b/yuanrong/pkg/functionscaler/tenantquota/tenantcache_test.go @@ -0,0 +1,149 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tenantquota - +package tenantquota + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/functionscaler/types" +) + +func Test_tenantCache_UpdateOrAddTenantQuota(t *testing.T) { + type fields struct { + tenantQuotaList sync.Map + } + type args struct { + tenantID string + tenant types.TenantMetaInfo + } + tests := []struct { + name string + fields fields + args args + }{ + { + name: "case_01 normal scenario", + fields: fields{}, + args: args{ + tenantID: "test", + tenant: types.TenantMetaInfo{TenantInstanceMetaData: types.TenantInstanceMetaData{100, 100}}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &tenantCache{ + tenantQuotaList: tt.fields.tenantQuotaList, + } + tc.UpdateOrAddTenantQuota(tt.args.tenantID, tt.args.tenant) + }) + } +} + +func Test_tenantCache_DeleteTenantQuota(t *testing.T) { + tenantIns := &tenantQuota{ + maxOnDemandInstance: defaultMaxOnDemandInstanceNum, + maxReversedInstance: defaultMaxReversedInstanceNum, + mux: sync.RWMutex{}, + } + type fields struct { + tenantQuotaList sync.Map + } + type args struct { + tenantID string + } + tests := []struct { + name string + fields fields + args args + }{ + { + name: "case_01 normal scenario", + fields: fields{}, + args: args{ + tenantID: "test", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &tenantCache{ + tenantQuotaList: tt.fields.tenantQuotaList, + } + tc.DeleteTenantQuota(tt.args.tenantID) + tc.tenantQuotaList.Store(tt.args.tenantID, tenantIns) + tc.DeleteTenantQuota(tt.args.tenantID) + tc.tenantQuotaList.Store(tt.args.tenantID, "") + tc.DeleteTenantQuota(tt.args.tenantID) + }) + } +} + +func Test_tenantCache_UpdateDefaultQuota2(t *testing.T) { + var tenantQuotaList sync.Map + tc := &tenantCache{ + tenantQuotaList: tenantQuotaList, + } + defaultTenantMetaInfo := types.TenantMetaInfo{TenantInstanceMetaData: types.TenantInstanceMetaData{1000, 1000}} + + tc.getTenantQuota("test") + tc.UpdateDefaultQuota(defaultTenantMetaInfo) + maxOnDemandInstance, maxReversedInstance := tc.GetTenantQuotaNum("test") + assert.Equal(t, maxOnDemandInstance, int64(1000)) + assert.Equal(t, maxReversedInstance, int64(1000)) + + tenantMetaInfo := types.TenantMetaInfo{TenantInstanceMetaData: types.TenantInstanceMetaData{100, 100}} + tc.UpdateOrAddTenantQuota("test", tenantMetaInfo) + maxOnDemandInstance, maxReversedInstance = tc.GetTenantQuotaNum("test") + assert.Equal(t, maxOnDemandInstance, int64(100)) + assert.Equal(t, maxReversedInstance, int64(100)) +} + +func Test_tenantCache_UpdateDefaultQuota(t *testing.T) { + type fields struct { + tenantQuotaList sync.Map + } + type args struct { + defaultTenant types.TenantMetaInfo + } + tests := []struct { + name string + fields fields + args args + }{ + { + name: "case_01 normal scenario", + fields: fields{}, + args: args{ + defaultTenant: types.TenantMetaInfo{TenantInstanceMetaData: types.TenantInstanceMetaData{1000, 1000}}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &tenantCache{ + tenantQuotaList: tt.fields.tenantQuotaList, + } + tc.getTenantQuota("test") + tc.UpdateDefaultQuota(tt.args.defaultTenant) + }) + } +} diff --git a/yuanrong/pkg/functionscaler/tenantquota/tenantetcd.go b/yuanrong/pkg/functionscaler/tenantquota/tenantetcd.go new file mode 100644 index 0000000..2f71603 --- /dev/null +++ b/yuanrong/pkg/functionscaler/tenantquota/tenantetcd.go @@ -0,0 +1,181 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tenantquota - +package tenantquota + +import ( + "context" + "encoding/json" + "fmt" + "os" + + "go.etcd.io/etcd/client/v3/concurrency" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/types" +) + +const ( + defaultTTL = 1 + defaultUnlimitedInstanceNumPerTenant = -1 // define the default unlimited value per tenant +) + +func max(a int64, b int64) int64 { + if a > b { + return a + } + return b +} + +func getTenantInsInfoFromETCD(tenantID string) types.TenantInsInfo { + tenantKey := fmt.Sprintf("/sn/functions/tenantinstancenumlimit/cluster/%s/tenant/%s", + os.Getenv("CLUSTER_ID"), tenantID) + tenantInsInfo := types.TenantInsInfo{} + etcdValue, err := etcd3.GetValueFromEtcdWithRetry(tenantKey, etcd3.GetRouterEtcdClient()) + if err != nil { + log.GetLogger().Warnf("failed to get tenant instance info, err: %s", err.Error()) + return tenantInsInfo + } + if err = json.Unmarshal(etcdValue, &tenantInsInfo); err != nil { + log.GetLogger().Warnf("failed to Unmarshal tenant instance info, err: %s", err.Error()) + return tenantInsInfo + } + return tenantInsInfo +} + +// IncreaseTenantInstanceNum 实例数未到上限,扩容先增加租户实例数 +func IncreaseTenantInstanceNum(tenantID string, isReserved bool) (bool, bool) { + var reachMaxOnDemandInsNum bool + maxOnDemandInstance, maxReversedInstance := GetTenantCache().GetTenantQuotaNum(tenantID) + // 若租户实例数quota在etcd中被置为-1,则跳过租户级流控(比如给某些VIP租户不限流) + if maxOnDemandInstance == int64(defaultUnlimitedInstanceNumPerTenant) || + maxReversedInstance == int64(defaultUnlimitedInstanceNumPerTenant) { + return false, false + } + + lockKey := fmt.Sprintf("/lock/cluster/%s/tenant/%s", os.Getenv("CLUSTER_ID"), tenantID) + routerEtcdClient := etcd3.GetRouterEtcdClient() + + onDemandInsNum, reversedInsNum := GetTenantCache().getTenantInstanceNum(tenantID) + session, err := concurrency.NewSession(routerEtcdClient.Client, concurrency.WithTTL(defaultTTL)) // Generate lease + if err != nil { + log.GetLogger().Errorf("failed to new session: %s, determine based on cache", err.Error()) + if isReserved { + reversedInsNum++ + } else { + reachMaxOnDemandInsNum = maxOnDemandInstance < (onDemandInsNum + 1) + if !reachMaxOnDemandInsNum { + onDemandInsNum++ + } + } + GetTenantCache().updateTenantInstance(tenantID, onDemandInsNum, reversedInsNum) + return reachMaxOnDemandInsNum, maxReversedInstance < reversedInsNum + } + defer session.Close() + + // Blocking, other requests will block waiting for the lock to be released + locker := concurrency.NewLocker(session, lockKey) + locker.Lock() + + // 1. 获取tenantID的函数实例数 + tenantInsInfo := getTenantInsInfoFromETCD(tenantID) + tenantInsInfo.ReversedInsNum = max(tenantInsInfo.ReversedInsNum, reversedInsNum) + tenantInsInfo.OnDemandInsNum = max(tenantInsInfo.OnDemandInsNum, onDemandInsNum) + if isReserved { + tenantInsInfo.ReversedInsNum++ + log.GetLogger().Debugf("tenantInsInfo.ReversedInsNum: %d", tenantInsInfo.ReversedInsNum) + } else { + // 2. Determine whether the limit is exceeded after increasing the number of instances + if maxOnDemandInstance < tenantInsInfo.OnDemandInsNum+1 { + onDemandInsNum = tenantInsInfo.OnDemandInsNum + reversedInsNum = tenantInsInfo.ReversedInsNum + GetTenantCache().updateTenantInstance(tenantID, onDemandInsNum, reversedInsNum) + locker.Unlock() + return true, maxReversedInstance < tenantInsInfo.ReversedInsNum + } + tenantInsInfo.OnDemandInsNum++ + log.GetLogger().Debugf("tenantInsInfo.OnDemandInsNum: %d", tenantInsInfo.OnDemandInsNum) + } + // 3. 弹性实例数不超限,更新(增加)实例数; 注意预留实例超限还是会创建实例,所以不管预留实例是否超限需要更新实例数 + updateTenantInstance(tenantID, tenantInsInfo.OnDemandInsNum, tenantInsInfo.ReversedInsNum) + onDemandInsNum = tenantInsInfo.OnDemandInsNum + reversedInsNum = tenantInsInfo.ReversedInsNum + GetTenantCache().updateTenantInstance(tenantID, onDemandInsNum, reversedInsNum) + locker.Unlock() + + return false, maxReversedInstance < tenantInsInfo.ReversedInsNum +} + +// DecreaseTenantInstance Reduce the number of instances +func DecreaseTenantInstance(tenantID string, isReserved bool) { + lockKey := fmt.Sprintf("/lock/cluster/%s/tenant/%s", os.Getenv("CLUSTER_ID"), tenantID) + routerEtcdClient := etcd3.GetRouterEtcdClient() + + onDemandInsNum, reversedInsNum := GetTenantCache().getTenantInstanceNum(tenantID) + session, err := concurrency.NewSession(routerEtcdClient.Client, concurrency.WithTTL(defaultTTL)) // Generate lease + if err != nil { + log.GetLogger().Errorf("failed to new session: %s", err.Error()) + if isReserved { + reversedInsNum-- + } else { + onDemandInsNum-- + } + GetTenantCache().updateTenantInstance(tenantID, onDemandInsNum, reversedInsNum) + return + } + defer session.Close() + + // Blocking, other requests will block waiting for the lock to be released + locker := concurrency.NewLocker(session, lockKey) + locker.Lock() + + // The number of instances needs to be reduced when creation fails, scales down, functions are deleted, etc. + tenantInsInfo := getTenantInsInfoFromETCD(tenantID) + tenantInsInfo.ReversedInsNum = max(tenantInsInfo.ReversedInsNum, reversedInsNum) + tenantInsInfo.OnDemandInsNum = max(tenantInsInfo.OnDemandInsNum, onDemandInsNum) + if isReserved { + tenantInsInfo.ReversedInsNum-- + log.GetLogger().Debugf("tenantInsInfo.ReversedInsNum: %d", tenantInsInfo.ReversedInsNum) + } else { + tenantInsInfo.OnDemandInsNum-- + log.GetLogger().Debugf("tenantInsInfo.OnDemandInsNum: %d", tenantInsInfo.OnDemandInsNum) + } + updateTenantInstance(tenantID, tenantInsInfo.OnDemandInsNum, tenantInsInfo.ReversedInsNum) + onDemandInsNum = tenantInsInfo.OnDemandInsNum + reversedInsNum = tenantInsInfo.ReversedInsNum + GetTenantCache().updateTenantInstance(tenantID, onDemandInsNum, reversedInsNum) + locker.Unlock() +} + +func updateTenantInstance(tenantID string, onDemandInsNum int64, reversedInsNum int64) { + tenantInsInfo := types.TenantInsInfo{OnDemandInsNum: onDemandInsNum, ReversedInsNum: reversedInsNum} + bytes, err := json.Marshal(tenantInsInfo) + if err != nil { + log.GetLogger().Errorf("failed to marshal tenantInsInfo, err: %s", err) + return + } + ctx := etcd3.CreateEtcdCtxInfoWithTimeout(context.Background(), etcd3.DurationContextTimeout) + routerEtcdClient := etcd3.GetRouterEtcdClient() + tenantKey := fmt.Sprintf("/sn/functions/tenantinstancenumlimit/cluster/%s/tenant/%s", + os.Getenv("CLUSTER_ID"), tenantID) + err = routerEtcdClient.Put(ctx, tenantKey, string(bytes)) + if err != nil { + log.GetLogger().Errorf("unable to put key: %s new value to router etcd, err:%s", tenantKey, err.Error()) + } + return +} diff --git a/yuanrong/pkg/functionscaler/tenantquota/tenantetcd_test.go b/yuanrong/pkg/functionscaler/tenantquota/tenantetcd_test.go new file mode 100644 index 0000000..38120be --- /dev/null +++ b/yuanrong/pkg/functionscaler/tenantquota/tenantetcd_test.go @@ -0,0 +1,173 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package tenantquota - +package tenantquota + +import ( + "encoding/json" + "errors" + "sync" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/functionscaler/types" +) + +func TestGetTenantInsInfoFromETCD(t *testing.T) { + info := types.TenantInsInfo{ + OnDemandInsNum: 1, + ReversedInsNum: 2, + } + bytes, _ := json.Marshal(info) + + convey.Convey("Test getTenantInsInfoFromETCD", t, func() { + convey.Convey("value got from etcd is empty", func() { + defer gomonkey.ApplyFunc(etcd3.GetValueFromEtcdWithRetry, + func(key string, etcdClient *etcd3.EtcdClient) ([]byte, error) { + return nil, nil + }).Reset() + tenantInsInfo := getTenantInsInfoFromETCD("test") + convey.So(tenantInsInfo.ReversedInsNum, convey.ShouldEqual, 0) + convey.So(tenantInsInfo.OnDemandInsNum, convey.ShouldEqual, 0) + }) + + convey.Convey("value got from etcd is valid", func() { + defer gomonkey.ApplyFunc(etcd3.GetValueFromEtcdWithRetry, + func(key string, etcdClient *etcd3.EtcdClient) ([]byte, error) { + return bytes, nil + }).Reset() + tenantInsInfo := getTenantInsInfoFromETCD("test") + convey.So(tenantInsInfo.ReversedInsNum, convey.ShouldEqual, 2) + convey.So(tenantInsInfo.OnDemandInsNum, convey.ShouldEqual, 1) + }) + + convey.Convey("value got from etcd err", func() { + defer gomonkey.ApplyFunc(etcd3.GetValueFromEtcdWithRetry, + func(key string, etcdClient *etcd3.EtcdClient) ([]byte, error) { + return nil, errors.New("fail") + }).Reset() + tenantInsInfo := getTenantInsInfoFromETCD("test") + convey.So(tenantInsInfo.ReversedInsNum, convey.ShouldEqual, 0) + convey.So(tenantInsInfo.OnDemandInsNum, convey.ShouldEqual, 0) + }) + }) +} + +func TestUpdateTenantInstance(t *testing.T) { + convey.Convey("Test updateTenantInstance", t, func() { + convey.Convey("normal process", func() { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, + func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, key string, value string, + opts ...clientv3.OpOption) error { + return nil + }).Reset() + updateTenantInstance("test", 1, 2) + }) + convey.Convey("put err", func() { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, + func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, key string, value string, + opts ...clientv3.OpOption) error { + return errors.New("fail") + }).Reset() + updateTenantInstance("test", 1, 2) + }) + }) +} + +func TestAddOrDelTenantInstanceNum(t *testing.T) { + convey.Convey("Test AddOrDelTenantInstanceNum", t, func() { + convey.Convey("new session err", func() { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyFunc(concurrency.NewSession, + func(client *clientv3.Client, opts ...concurrency.SessionOption) (*concurrency.Session, error) { + return nil, errors.New("fail") + }).Reset() + reachMaxOnDemandInsNum, reachMaxReversedInsNum := IncreaseTenantInstanceNum("test", true) + convey.So(reachMaxOnDemandInsNum, convey.ShouldEqual, false) + convey.So(reachMaxReversedInsNum, convey.ShouldEqual, false) + DecreaseTenantInstance("test", true) + + reachMaxOnDemandInsNum, reachMaxReversedInsNum = IncreaseTenantInstanceNum("test", false) + convey.So(reachMaxOnDemandInsNum, convey.ShouldEqual, false) + convey.So(reachMaxReversedInsNum, convey.ShouldEqual, false) + DecreaseTenantInstance("test", false) + }) + convey.Convey("new session success", func() { + info1 := types.TenantInsInfo{ + OnDemandInsNum: 1, + ReversedInsNum: 2, + } + bytes1, _ := json.Marshal(info1) + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyFunc(concurrency.NewSession, + func(client *clientv3.Client, opts ...concurrency.SessionOption) (*concurrency.Session, error) { + return nil, nil + }).Reset() + defer gomonkey.ApplyFunc(concurrency.NewLocker, func(s *concurrency.Session, pfx string) sync.Locker { + return &sync.RWMutex{} + }).Reset() + defer gomonkey.ApplyFunc((*concurrency.Session).Close, func(_ *concurrency.Session) error { + return nil + }).Reset() + defer gomonkey.ApplyFunc((*etcd3.EtcdClient).Put, + func(_ *etcd3.EtcdClient, ctxInfo etcd3.EtcdCtxInfo, key string, value string, + opts ...clientv3.OpOption) error { + return nil + }).Reset() + defer gomonkey.ApplyFunc(etcd3.GetValueFromEtcdWithRetry, + func(key string, etcdClient *etcd3.EtcdClient) ([]byte, error) { + return bytes1, nil + }).Reset() + reachMaxOnDemandInsNum, reachMaxReversedInsNum := IncreaseTenantInstanceNum("test", true) + convey.So(reachMaxOnDemandInsNum, convey.ShouldEqual, false) + convey.So(reachMaxReversedInsNum, convey.ShouldEqual, false) + DecreaseTenantInstance("test", true) + reachMaxOnDemandInsNum, reachMaxReversedInsNum = IncreaseTenantInstanceNum("test", false) + convey.So(reachMaxOnDemandInsNum, convey.ShouldEqual, false) + convey.So(reachMaxReversedInsNum, convey.ShouldEqual, false) + DecreaseTenantInstance("test", false) + + info2 := types.TenantInsInfo{ + OnDemandInsNum: 1000, + ReversedInsNum: 1000, + } + bytes2, _ := json.Marshal(info2) + defer gomonkey.ApplyFunc(etcd3.GetValueFromEtcdWithRetry, + func(key string, etcdClient *etcd3.EtcdClient) ([]byte, error) { + return bytes2, nil + }).Reset() + reachMaxOnDemandInsNum, reachMaxReversedInsNum = IncreaseTenantInstanceNum("test", false) + convey.So(reachMaxOnDemandInsNum, convey.ShouldEqual, true) + convey.So(reachMaxReversedInsNum, convey.ShouldEqual, false) + }) + }) +} diff --git a/yuanrong/pkg/functionscaler/types/constants.go b/yuanrong/pkg/functionscaler/types/constants.go new file mode 100644 index 0000000..f3d8832 --- /dev/null +++ b/yuanrong/pkg/functionscaler/types/constants.go @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +import "time" + +const ( + // SystemTenantID is the tenantID for system functions + SystemTenantID = "0" + // FaasManagerFuncName is the function name of faas manager + FaasManagerFuncName = "faasmanager" + // UserInitEntryKey is the key for user init entry + UserInitEntryKey = "initializer" + // UserCallEntryKey is the key for user call entry + UserCallEntryKey = "handler" + // ConcurrentNumKey is the key for concurrency in CreateOption + ConcurrentNumKey = "ConcurrentNum" + // InitCallTimeoutKey is the key for init call timeout in CreateOption + InitCallTimeoutKey = "init_call_timeout" + // CallTimeoutKey is the key for init call timeout in CreateOption + CallTimeoutKey = "call_timeout" + // NetworkConfigKey is the key for NetworkConfig in CreateOption + NetworkConfigKey = "networkConfig" + // ProberConfigKey is the key for ProberConfig in CreateOption + ProberConfigKey = "proberConfig" + // InstanceNameNote notes instance name + InstanceNameNote = "INSTANCE_NAME_NOTE" + // FunctionKeyNote - is used to describe the function + FunctionKeyNote = "FUNCTION_KEY_NOTE" + // ResourceSpecNote - is used to describe the resource + ResourceSpecNote = "RESOURCE_SPEC_NOTE" + // SchedulerIDNote - is used to decribe the schedulerID + SchedulerIDNote = "SCHEDULER_ID_NOTE" + // InstanceTypeNote - is used to decribe the instance type: "scaled", "reserved", "state" + InstanceTypeNote = "INSTANCE_TYPE_NOTE" + // InstanceLabelNode - + InstanceLabelNode = "INSTANCE_LABEL_NOTE" + // PermanentInstance 不论scheduler怎么扩缩,这个实例都由创建该实例的scheduler纳管 + PermanentInstance = "-permanent" + // TemporaryInstance 普通实例,可能会在scheduler扩缩容时改变纳管关系 + TemporaryInstance = "-temporary" + // FunctionSign 函数签名 + FunctionSign = "FUNCTION_SIGNATURE" + // TenantID - + TenantID = "tenantId" + // HTTPRuntimeType is the runtime type for http function + HTTPRuntimeType = "http" + // CustomContainerRuntimeType is the runtime type for http function + CustomContainerRuntimeType = "custom image" + // HTTPFuncPort is the listening port for http function + HTTPFuncPort = 8000 + // HTTPCallRoute is the call route for http function + HTTPCallRoute = "invoke" + // GracefulShutdownTime is the key for GRACEFUL_SHUTDOWN_TIME in CreateOption + GracefulShutdownTime = "GRACEFUL_SHUTDOWN_TIME" + // MaxShutdownTimeout used to be the default request timeout,now is used to graceful shutdown + MaxShutdownTimeout = 900 + // MinLeaseInterval is the minimum interval for lease + MinLeaseInterval = 500 * time.Millisecond + // HeaderInstanceLabel - + HeaderInstanceLabel = "X-Instance-Label" + // RegisterTypeSelf - + RegisterTypeSelf = "registerBySelf" + // RegisterTypeContend - + RegisterTypeContend = "registerByContend" +) + +const ( + // AscendResourcePrefix is the prefix of ascend resource + AscendResourcePrefix = "huawei.com/ascend" + // AscendRankTableFileEnvKey is the env key of ascend ranktable + AscendRankTableFileEnvKey = "RANK_TABLE_FILE" + // AscendRankTableFileEnvValue is the env value of ascend ranktable + AscendRankTableFileEnvValue = "/opt/config/ascend_config/ranktable_file.json" + // AscendResourceD910B is one type of ascend resource + AscendResourceD910B = "huawei.com/ascend-1980" + // AscendResourceD910BInstanceType is type of D910B + AscendResourceD910BInstanceType = "instanceType" + // SystemNodeInstanceType is type of node instance type ,such as 280T + SystemNodeInstanceType = "X_SYSTEM_NODE_INSTANCE_TYPE" +) + +const ( + // IncrementTimeout - add 5 second + IncrementTimeout = 5 + // DefaultCommonQueueTimeout - + DefaultCommonQueueTimeout = 20 + // DefaultMaxInsQueueTimeout - + DefaultMaxInsQueueTimeout = 10 +) + +const ( + // InstanceSchedulePolicyConcurrency is the schedule policy based on concurrency + InstanceSchedulePolicyConcurrency = "concurrency" + // InstanceSchedulePolicyRoundRobin is the schedule policy based on round-robin + InstanceSchedulePolicyRoundRobin = "round-robin" + // InstanceSchedulePolicyMicroService is the schedule policy based on microservice + InstanceSchedulePolicyMicroService = "microservice" +) + +const ( + // InstanceScalePolicyConcurrency is the auto scaler policy based on concurrency + InstanceScalePolicyConcurrency = "concurrency" + // InstanceScalePolicyPredict is the auto scaler policy based on qps predict + InstanceScalePolicyPredict = "qpsPredict" + // InstanceScalePolicyStaticFunction is the schedule policy for static function + InstanceScalePolicyStaticFunction = "staticFunction" +) + +const ( + // ScenarioWiseCloud is the scenario of CaaS + ScenarioWiseCloud = "WiseCloud" + // ScenarioFunctionGraph is the scenario of FunctionGraph + ScenarioFunctionGraph = "FunctionGraph" +) + +const ( + // DefaultMaxOnDemandInstanceNumPerTenant define the default value for maximum on-demand instance per tenant + DefaultMaxOnDemandInstanceNumPerTenant = 1000 + // DefaultMaxReversedInstanceNumPerTenant define the default value for maximum reversed instance per tenant + DefaultMaxReversedInstanceNumPerTenant = 1000 +) diff --git a/yuanrong/pkg/functionscaler/types/types.go b/yuanrong/pkg/functionscaler/types/types.go new file mode 100644 index 0000000..ea64e71 --- /dev/null +++ b/yuanrong/pkg/functionscaler/types/types.go @@ -0,0 +1,691 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "k8s.io/api/core/v1" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/sts/raw" + "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/common/faas_common/types" + wisecloudTypes "yuanrong/pkg/common/faas_common/wisecloudtool/types" +) + +// Configuration defines configuration faas scheduler needs +type Configuration struct { + Scenario string `json:"scenario" valid:"optional"` + InstanceOperationBackend int `json:"instanceOperationBackend" valid:"optional"` + CPU float64 `json:"cpu" valid:"optional"` + Memory float64 `json:"memory" valid:"optional"` + PredictGroupWindow int64 `json:"predictGroupWindow"` + AutoScaleConfig AutoScaleConfig `json:"autoScaleConfig" valid:"required"` + ScaleRetryConfig ScaleRetryConfig `json:"scaleRetryConfig" valid:"optional"` + LeaseSpan int `json:"leaseSpan" valid:"required"` + FunctionLimitRate int `json:"functionLimitRate"` + RouterETCDConfig etcd3.EtcdConfig `json:"routerEtcd" valid:"required"` + MetaETCDConfig etcd3.EtcdConfig `json:"metaEtcd" valid:"required"` + SchedulerNum int `json:"schedulerNum" valid:"optional"` + DockerRootPath string `json:"dockerRootPath"` + RawStsConfig raw.StsConfig `json:"rawStsConfig,omitempty"` + ClusterID string `json:"clusterID" valid:"optional"` + ClusterName string `json:"clusterName" valid:"optional"` + DiskMonitorEnable bool `json:"diskMonitorEnable"` + RegionName string `json:"regionName" valid:"optional"` + AlarmConfig alarm.Config `json:"alarmConfig" valid:"optional"` + NodeSelector map[string]string `json:"nodeSelector,omitempty"` + EphemeralStorage float64 `json:"ephemeralStorage,omitempty"` + NpuEphemeralStorage float64 `json:"npuEphemeralStorage,omitempty"` + HostAliases []v1.HostAlias `json:"hostaliaseshostname"` + FunctionConfig []FunctionDefaultConfig `json:"functionConfig"` + HTTPSConfig *tls.InternalHTTPSConfig `json:"httpsConfig" valid:"optional"` + LocalAuth localauth.AuthConfig `json:"localAuth"` + XpuNodeLabels []XpuNodeLabel `json:"xpuNodeLabels,omitempty"` + ServiceAccountJwt wisecloudTypes.ServiceAccountJwt `json:"serviceAccountJwt,omitempty"` + Version string `json:"version"` + Image string `json:"image"` + ConcurrentNum int `json:"concurrentNum"` + ModuleConfig *ModuleConfig `json:"moduleConfig" valid:"optional"` + StateDisable bool `json:"stateDisable"` + EnableNPUDriverMount bool `json:"enableNPUDriverMount"` + DisableReplicaScaler bool `json:"disableReplicaScaler"` + TenantInsNumLimitEnable bool `json:"tenantInsNumLimitEnable"` + EnableHealthCheck bool `json:"enableHealthCheck"` + EnableRollout bool `json:"enableRollout"` + SccConfig crypto.SccConfig `json:"sccConfig" valid:"optional"` + NameSpace string `json:"nameSpace"` + Affinity string `json:"affinity"` + MicroServiceSchedulingPolicy string `json:"msSchedulingPolicy" valid:"optional"` + AuthenticationEnable bool `json:"authenticationEnable" valid:"optional"` + NodeAffinity string `json:"nodeAffinity" valid:"optional"` + NodeAffinityPolicy string `json:"nodeAffinityPolicy" valid:"optional"` + DeployMode string `json:"deployMode" valid:"optional"` + MetricsAddr string `json:"metricsAddr" valid:"optional"` + MetricsHTTPSEnable bool `json:"metricsHttpsEnable" valid:"optional"` + PprofAddr string `json:"pprofAddr" valid:"optional"` + SchedulerDiscovery *SchedulerDiscovery `json:"schedulerDiscovery" valid:"optional"` +} + +// AutoScaleConfig - +type AutoScaleConfig struct { + SLAQuota int `json:"slaQuota" valid:"required"` + ScaleDownTime int `json:"scaleDownTime" valid:"required"` + BurstScaleNum int `json:"burstScaleNum" valid:"required"` +} + +// ScaleRetryConfig - +type ScaleRetryConfig struct { + ReservedInstanceAlwaysRetry bool `json:"reservedInstanceAlwaysRetry" valid:"optional"` +} + +// ModuleConfig config info +type ModuleConfig struct { + ServicePort string `json:"servicePort" valid:",optional"` +} + +// SchedulerDiscovery - +type SchedulerDiscovery struct { + KeyPrefixType string `json:"keyPrefixType" valid:"optional"` + RegisterMode string `json:"registerMode" valid:"optional"` +} + +// XpuNodeLabel Distinguish heterogeneous card types +type XpuNodeLabel struct { + XpuType string `json:"xpuType,omitempty"` + InstanceType string `json:"instanceType,omitempty"` + NodeLabelKey string `json:"nodeLabelKey,omitempty"` + NodeLabelValues []string `json:"NodeLabelValues,omitempty"` +} + +// FunctionSpecification contains specification of a function +type FunctionSpecification struct { + FuncCtx context.Context `json:"-"` + CancelFunc context.CancelFunc `json:"-"` + FuncKey string `json:"-"` + FuncMetaSignature string `json:"-"` + FuncSecretName string `json:"-"` + FuncMetaData types.FuncMetaData `json:"funcMetaData" valid:",optional"` + S3MetaData types.S3MetaData `json:"s3MetaData" valid:",optional"` + CodeMetaData types.CodeMetaData `json:"codeMetaData" valid:",optional"` + EnvMetaData types.EnvMetaData `json:"envMetaData" valid:",optional"` + StsMetaData types.StsMetaData `json:"stsMetaData" valid:",optional"` + ResourceMetaData types.ResourceMetaData `json:"resourceMetaData" valid:",optional"` + InstanceMetaData types.InstanceMetaData `json:"instanceMetaData" valid:",optional"` // new add + ExtendedMetaData types.ExtendedMetaData `json:"extendedMetaData" valid:",optional"` +} + +// PATServiceRequest - +type PATServiceRequest struct { + ID string `json:"id,omitempty"` + DomainID string `json:"domain_id,omitempty"` + Namespace string `json:"namespace,omitempty"` + VpcName string `json:"vpc_name,omitempty"` + AppXrole string `json:"app_xrole,omitempty"` + VpcID string `json:"vpc_id,omitempty"` + SubnetName string `json:"subnet_name,omitempty"` // is duplicated with common types, note it + SubnetID string `json:"subnet_id,omitempty"` + TenantCidr string `json:"tenant_cidr,omitempty"` + HostVMCidr string `json:"host_vm_cidr,omitempty"` + Gateway string `json:"gateway,omitempty"` + Xrole string `json:"xrole,omitempty"` +} + +// NATConfigure include nat configure info for worker +type NATConfigure struct { + ContainerCidr string `json:"containerCidr"` + HostVMCidr string `json:"hostVmCidr"` + PatContainerIP string `json:"patContainerIP"` // ip tunnel + PatVMIP string `json:"patVmIP"` + PatPortIP string `json:"patPortIP"` + PatMacAddr string `json:"patMacAddr"` + PatGateway string `json:"patGateway"` // ping + PatPodName string `json:"patPodName"` + TenantCidr string `json:"tenantCidr"` // ip route + NatSubnetList map[string][]string `json:"natSubnetList"` // ip route + IsDeleted bool `json:"isDeleted"` + IsNewCreated bool `json:"isNewCreated"` +} + +// PullTriggerRequestInfo include info of pullTrigger Option Create +type PullTriggerRequestInfo struct { + PodName string `json:"pod_name"` + Image string `json:"image"` + DomainID string `json:"domain_id,omitempty"` + Namespace string `json:"namespace,omitempty"` + VpcName string `json:"vpc_name,omitempty"` + VpcID string `json:"vpc_id,omitempty"` + SubnetName string `json:"subnet_name,omitempty"` + SubnetID string `json:"subnet_id,omitempty"` + TenantCidr string `json:"tenant_cidr,omitempty"` + HostVMCidr string `json:"host_vm_cidr,omitempty"` + ContainerCidr string `json:"container_cidr"` + Gateway string `json:"gateway,omitempty"` + Xrole string `json:"xrole,omitempty"` + AppXrole string `json:"app_xrole,omitempty"` +} + +// PullTriggerDeleteInfo include info of pullTrigger Option delete +type PullTriggerDeleteInfo struct { + PodName string `json:"pod_name,omitempty"` +} + +// NetworkConfig is a config in createOption which describes how to setup network in the environment where instance +// is running +type NetworkConfig struct { + RouteConfig RouteConfig `json:"routeConfig"` + TunnelConfig TunnelConfig `json:"tunnelConfig"` + FirewallConfig FirewallConfig `json:"firewallConfig"` +} + +// RouteConfig is the config describes how to setup route +type RouteConfig struct { + Gateway string `json:"gateway"` + Cidr string `json:"cidr"` +} + +// TunnelConfig is the config describes how to setup tunnel +type TunnelConfig struct { + TunnelName string `json:"tunnelName"` + RemoteIP string `json:"remoteIP"` + Mode string `json:"mode"` +} + +// FirewallConfig is the config describes how to setup firewall +type FirewallConfig struct { + Chain string `json:"chain"` + Table string `json:"table"` + Operation string `json:"operation"` + Target string `json:"target"` + Args string `json:"args"` +} + +// ProberConfig is a config in createOption which describes how to perform certain prober action for instance +type ProberConfig struct { + Protocol string `json:"protocol"` + Address string `json:"address"` + Interval int `json:"interval"` + Timeout int `json:"timeout"` + FailureThreshold int `json:"failureThreshold"` +} + +// DelegateContainerConfig configures custom image in kernel +type DelegateContainerConfig struct { + Image string `json:"image"` + Env []v1.EnvVar `json:"env"` + Command []string `json:"command"` + Args []string `json:"args"` + UID int `json:"uid"` + GID int `json:"gid"` + VolumeMounts []v1.VolumeMount `json:"volumeMounts"` + CustomGracefulShutdown types.CustomGracefulShutdown `json:"runtime_graceful_shutdown"` + Lifecycle v1.Lifecycle `json:"lifecycle"` +} + +// DelegateContainerSideCarConfig configures custom image sidecar in kernel +type DelegateContainerSideCarConfig struct { + Name string `json:"name"` + Image string `json:"image"` + Env []v1.EnvVar `json:"env"` + ResourceRequirements v1.ResourceRequirements `json:"resourceRequirements"` + VolumeMounts []v1.VolumeMount `json:"volumeMounts"` + Lifecycle v1.Lifecycle `json:"lifecycle"` + LivenessProbe v1.Probe `json:"livenessProbe"` + ReadinessProbe v1.Probe `json:"readinessProbe"` +} + +// DelegateInitContainerConfig configures custom image init container in kernel +type DelegateInitContainerConfig struct { + Name string `json:"name"` + Image string `json:"image"` + Env []v1.EnvVar `json:"env"` + ResourceRequirements v1.ResourceRequirements `json:"resourceRequirements"` + VolumeMounts []v1.VolumeMount `json:"volumeMounts"` + Command []string `json:"command"` + Args []string `json:"args"` +} + +// InstanceType defines instance type +type InstanceType string + +const ( + // InstanceTypeOnDemand is the type of onDemand instance + InstanceTypeOnDemand InstanceType = "onDemand" + // InstanceTypeScaled is the type of scaled instance + InstanceTypeScaled InstanceType = "scaled" + // InstanceTypeReserved is the type of reserved instance + InstanceTypeReserved InstanceType = "reserved" + // InstanceTypeState is the type of state instance + InstanceTypeState InstanceType = "state" + // InstanceTypeUnknown is the type of unknown instance + InstanceTypeUnknown InstanceType = "unknown" +) + +const ( + base = 0.00001 +) + +// ValueType - +type ValueType int32 + +// ValueScalar - +type ValueScalar struct { + Value float64 `json:"value"` + Limit float64 `json:"limit"` +} + +// ValueRanges - +type ValueRanges struct { + Range []ValueRange `protobuf:"bytes,1,rep,name=range,proto3" json:"range,omitempty"` +} + +// ValueSet - +type ValueSet struct { + Items string `json:"items"` +} + +// ValueRange - +type ValueRange struct { + Begin uint64 `json:"begin"` + End uint64 `json:"end"` +} + +// DiskInfo - +type DiskInfo struct { + Volume Volume `json:"volume"` + Type string `json:"type"` + DevPath string `json:"devPath"` + MountPath string `json:"mountPath"` +} + +// Volume - +type Volume struct { + Mode int32 `json:"mode"` + SourceType int32 `json:"sourceType"` + HostPaths string `json:"hostPaths"` + ContainerPath string `json:"containerPath"` + ConfigMapPath string `json:"configMapPath"` + EmptyDir string `json:"emptyDir"` + ElaraPath string `json:"elaraPath"` +} + +// Affinity - +type Affinity struct { + NodeAffinity NodeAffinity `json:"nodeAffinity"` + InstanceAffinity InstanceAffinity `json:"instanceAffinity"` + InstanceAntiAffinity InstanceAffinity `json:"instanceAntiAffinity"` +} + +// NodeAffinity - +type NodeAffinity struct { + Affinity map[string]string `json:"affinity"` +} + +// InstanceAffinity - +type InstanceAffinity struct { + Affinity map[string]string `json:"affinity"` +} + +// InstanceSpecification contains specification of a instance in etcd +type InstanceSpecification struct { + InstanceID string + InstanceName string + RequestID string `json:"requestID" valid:",optional"` + RuntimeID string `json:"runtimeID" valid:",optional"` + RuntimeAddress string `json:"runtimeAddress" valid:",optional"` + FunctionAgentID string `json:"functionAgentID" valid:",optional"` + FunctionProxyID string `json:"functionProxyID" valid:",optional"` + Function string `json:"function"` + RestartPolicy string `json:"restartPolicy" valid:",optional"` + Resources types.Resources `json:"resources"` + ActualUse types.Resources `json:"actualUse" valid:",optional"` + ScheduleOption types.ScheduleOption `json:"scheduleOption"` + CreateOptions map[string]string `json:"createOptions"` + Labels []string `json:"labels"` + StartTime string `json:"startTime"` + InstanceStatus types.InstanceStatus `json:"instanceStatus"` + JobID string `json:"jobID"` + SchedulerChain []string `json:"schedulerChain" valid:",optional"` + ParentID string `json:"parentID"` + ConcurrentNum int +} + +// InstanceThreadMetrics contains metrics of a specified instance thread collected by function accessor +type InstanceThreadMetrics struct { + InsThdID string + ProcNumPS float32 + ProcReqNum int `json:"procReqNum"` + AvgProcTime int `json:"avgProcTime"` // millisecond + MaxProcTime int `json:"maxProcTime"` + IsAbnormal bool `json:"isAbnormal"` + ReacquireData []byte `json:"reacquireData"` + FunctionKey string `json:"functionKey"` +} + +// Instance defines a instance +type Instance struct { + InstanceStatus types.InstanceStatus + InstanceType InstanceType + MetricLabelValues []string + ResKey resspeckey.ResSpecKey + InstanceID string + InstanceName string + FuncKey string + FuncSig string + ConcurrentNum int + CreateSchedulerID string + InstanceIP string + InstancePort string + NodeIP string + NodePort string + Permanent bool + ParentID string + PodID string + PodDeploymentName string +} + +// Copy - +func (i *Instance) Copy() *Instance { + if i == nil { + return nil + } + return &Instance{ + InstanceStatus: i.InstanceStatus, + InstanceType: i.InstanceType, + MetricLabelValues: append([]string{}, i.MetricLabelValues...), + ResKey: i.ResKey, + InstanceID: i.InstanceID, + InstanceName: i.InstanceName, + FuncKey: i.FuncKey, + FuncSig: i.FuncSig, + ConcurrentNum: i.ConcurrentNum, + CreateSchedulerID: i.CreateSchedulerID, + InstanceIP: i.InstanceIP, + InstancePort: i.InstancePort, + NodeIP: i.NodeIP, + NodePort: i.NodePort, + Permanent: i.Permanent, + ParentID: i.ParentID, + PodID: i.PodID, + PodDeploymentName: i.PodDeploymentName, + } +} + +// WmInstance defines instance from workerManger +type WmInstance struct { + IP string `json:"ip"` + Port string `json:"port"` + InstanceID string `json:"instanceID"` + DeployedIP string `json:"deployed_ip"` + DeployedNode string `json:"deployed_node"` + TenantID string `json:"tenant_id"` + Version string `json:"version"` + OwnerIP string `json:"owner_ip"` + IsReserved bool `json:"isReserved"` + IsDirectFunc bool `json:"isDirectFunc"` + HasInitializer bool `json:"hasInitializer"` + Resource types.PodResourceInfo `json:"resource,omitempty"` +} + +// InstanceLease defines lease operations +type InstanceLease interface { + Extend() error + Release() error + GetInterval() time.Duration +} + +// SessionInfo - +type SessionInfo struct { + SessionID string + SessionCtx context.Context +} + +// InstanceAllocation defines a instance thread +type InstanceAllocation struct { + Instance *Instance + Lease InstanceLease + SessionInfo SessionInfo + AllocationID string +} + +// InstanceBuilder will create a instance +type InstanceBuilder func(string) *Instance + +// InstanceAcquireRequest contains specifications for acquiring an instance +type InstanceAcquireRequest struct { + FuncSpec *FunctionSpecification + ResSpec *resspeckey.ResourceSpecification + InstanceSession types.InstanceSessionConfig + InstanceName string + DesignateInstanceID string + DesignateThreadID string + PoolLabel string + PoolID string + TraceID string + StateID string + CallerPodName string + TrafficLimited bool + + SkipWaitPending bool +} + +// InstanceCreateRequest contains specifications for creating an instance +type InstanceCreateRequest struct { + FuncSpec *FunctionSpecification + ResSpec *resspeckey.ResourceSpecification + InstanceName string + TraceID string + CreateEvent []byte +} + +// InstanceDeleteRequest contains specifications for deleting an instance +type InstanceDeleteRequest struct { + FuncSpec *FunctionSpecification + ResSpec *resspeckey.ResourceSpecification + InstanceID string + InstanceName string + TraceID string +} + +// FuncMetaArg defines funcMeta args +type FuncMetaArg struct { + CodeID string `json:"codeID"` + Kind string `json:"kind"` + InvokeType int `json:"invokeType"` + ObjectDescriptor ObjectDescriptor `json:"objectDescriptor"` + Config ConfigDescriptor `json:"config"` +} + +// ObjectDescriptor defines object descriptor +type ObjectDescriptor struct { + ModuleName string `json:"moduleName"` + ClassName string `json:"className"` + FunctionName string `json:"functionName"` + TargetLanguage string `json:"targetLanguage"` + SrcLanguage string `json:"srcLanguage"` +} + +// ConfigDescriptor defines config descriptor +type ConfigDescriptor struct { + RecycleTime int `json:"RecycleTime"` + FunctionID map[string]string `json:"functionID"` + JobID string `json:"jobID"` + LogLevel int `json:"logLevel"` +} + +// ExecutorInitResponse - +type ExecutorInitResponse struct { + ErrorCode string `json:"errorCode"` + Message json.RawMessage `json:"message"` +} + +// InstancePoolState instance pool queue state save +type InstancePoolState struct { + // Key: stateID - val: InstanceID + StateInstance map[string]*Instance `json:"StateInstance" valid:"optional"` +} + +// InstancePoolStateInput - +type InstancePoolStateInput struct { + InstanceType InstanceType + ResKey resspeckey.ResSpecKey + ConcurrentNum int + InstanceStatusCode int32 + FuncKey string + FuncSig string + InstanceID string + StateID string + InstanceIP string + InstancePort string + NodeIP string + NodePort string +} + +// RolloutInstanceSpecification - +type RolloutInstanceSpecification struct { + RegisterKey string `json:"registerKey"` + InstanceID string `json:"instanceID"` + RuntimeAddress string `json:"runtimeAddress"` +} + +const ( + // StateUpdate - + StateUpdate = "update" + // StateDelete - + StateDelete = "delete" + // InstanceLifeCycleConsistentWithState - + InstanceLifeCycleConsistentWithState = "ConsistentWithInstance" +) + +// PodRequest define pod standard +type PodRequest struct { + FunSvcID string `json:"funSvcID"` + NameSpace string `json:"nameSpace"` +} + +// CustomUserArgs - +type CustomUserArgs struct { + AlarmConfig alarm.Config `json:"alarmConfig" valid:"optional"` + StsServerConfig raw.ServerConfig `json:"stsServerConfig"` + ClusterName string `json:"clusterName"` + DiskMonitorEnable bool `json:"diskMonitorEnable"` + LocalAuth localauth.AuthConfig `json:"localAuth"` +} + +// StateArgs value of the 'stateInstanceID' Map +type StateArgs struct { + IsStateValid bool + StateID string +} + +// FunctionDefaultConfig - +type FunctionDefaultConfig struct { + ConfigName string `json:"configName"` + Mount v1.VolumeMount `json:"mount"` +} + +// PredictResult - +type PredictResult struct { + DataSetTimeWindow []int64 `json:"dataSetTimeWindow"` + QPSResult map[string][]float64 `json:"QpsResult"` + IsValid bool +} + +// PredictQPSGroups - +type PredictQPSGroups struct { + FuncKey string + QPSGroups []float64 +} + +// TenantInstanceMetaData define tenant instance quota +type TenantInstanceMetaData struct { + MaxOnDemandInstance int64 `json:"maxOnDemandInstance" valid:",optional"` + MaxReversedInstance int64 `json:"maxReversedInstance" valid:",optional"` +} + +// TenantMetaInfo define tenant meta info +type TenantMetaInfo struct { + TenantInstanceMetaData TenantInstanceMetaData `json:"tenantInstanceMetaData" valid:",optional"` +} + +// TenantInsInfo define tenant instance info +type TenantInsInfo struct { + OnDemandInsNum int64 `json:"onDemandInsNum" valid:",optional"` + ReversedInsNum int64 `json:"reversedInsNum" valid:",optional"` +} + +// SchedulingOptions define tenant instance Scheduling Options +type SchedulingOptions struct { + resourcesMap map[string]float64 + Priority int32 + Resources map[string]float64 + Extension map[string]string + Affinity []api.Affinity + ScheduleAffinity []byte +} + +// TLSConfig tls config +type TLSConfig struct { + HttpsInsecureSkipVerify bool `json:"httpsInsecureSkipVerify"` + TlsCipherSuitesStr []string `json:"tlsCipherSuites"` + TlsCipherSuites []uint16 `json:"-"` +} + +// IntOrString is a type that can hold an int32 or a string +type IntOrString struct { + Type Type + IntVal int64 + StrVal string +} + +// Type represents the stored type of IntOrString. +type Type int64 + +const ( + // Int The IntOrString holds an int. + Int Type = iota + // String The IntOrString holds a string. + String +) + +// UnmarshalJSON Overwrite JSON.Unmarshal +func (i *IntOrString) UnmarshalJSON(data []byte) error { + var intValue int64 + if err := json.Unmarshal(data, &intValue); err == nil { + i.Type = Int + i.IntVal = intValue + return nil + } + var stringValue string + if err := json.Unmarshal(data, &stringValue); err == nil { + i.Type = String + i.StrVal = stringValue + return nil + } + return fmt.Errorf("expected int or string, but got %s", data) +} diff --git a/yuanrong/pkg/functionscaler/types/types_test.go b/yuanrong/pkg/functionscaler/types/types_test.go new file mode 100644 index 0000000..6b4083c --- /dev/null +++ b/yuanrong/pkg/functionscaler/types/types_test.go @@ -0,0 +1,99 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/resspeckey" +) + +func TestResourceSpecification_DeepCopy(t *testing.T) { + tests := []struct { + name string + resourceSpec *resspeckey.ResourceSpecification + }{ + { + name: "deep copy success", + resourceSpec: &resspeckey.ResourceSpecification{ + CPU: 1, + Memory: 2, + InvokeLabel: "xxx", + CustomResources: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + specCopy := tt.resourceSpec.DeepCopy() + if &specCopy == &tt.resourceSpec { + t.Errorf("deepCopyFailed is %d, expectStr is %d", &specCopy, &tt.resourceSpec) + } + if &specCopy.InvokeLabel == &tt.resourceSpec.InvokeLabel { + t.Errorf("deepCopy map failed is %d, expectStr is %d", &specCopy, &tt.resourceSpec) + } + }) + } +} + +func TestIntOrString_UnmarshalJSONJSON(t *testing.T) { + tests := []struct { + name string + b []byte + targetCPU int64 + targetLabel string + targetErr error + }{ + { + name: "Unmarshal Json Normal", + b: []byte("{\"CPU\": 128, \"label\": \"aaaaa\"}"), + targetCPU: 128, + targetLabel: "aaaaa", + targetErr: nil, + }, + { + name: "Unmarshal Json error label", + b: []byte("{\"CPU\": 128, \"label\": 123}"), + targetCPU: 128, + targetLabel: "", + targetErr: nil, + }, + { + name: "Unmarshal Json type error", + b: []byte("{\"CPU\": 128, \"label\": []}"), + targetCPU: 0, + targetLabel: "", + targetErr: fmt.Errorf("expected int or string, but got []"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resMap := map[string]IntOrString{} + err := json.Unmarshal(tt.b, &resMap) + if err != nil { + assert.Equal(t, tt.targetErr.Error(), err.Error()) + return + } + assert.Equal(t, tt.targetCPU, resMap["CPU"].IntVal) + assert.Equal(t, tt.targetLabel, resMap["label"].StrVal) + }) + } +} diff --git a/yuanrong/pkg/functionscaler/utils/configmap_util.go b/yuanrong/pkg/functionscaler/utils/configmap_util.go new file mode 100644 index 0000000..f366163 --- /dev/null +++ b/yuanrong/pkg/functionscaler/utils/configmap_util.go @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "fmt" + "net" + "strconv" + "strings" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/k8sclient" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +const ( + defaultLabelLength = 3 +) + +// DeleteConfigMapByFuncInfo - +func DeleteConfigMapByFuncInfo(funcSpec *types.FunctionSpecification) { + if config.GlobalConfig.Scenario != types.ScenarioWiseCloud { + return + } + configmapName := GetConfigmapName(funcSpec.FuncMetaData.FuncName, funcSpec.FuncMetaData.Version) + nameSpace := config.GlobalConfig.NameSpace + if nameSpace == "" { + nameSpace = constant.DefaultNameSpace + } + err := k8sclient.GetkubeClient().DeleteK8sConfigMap(nameSpace, configmapName) + if err != nil { + log.GetLogger().Errorf("delete configmap error, error is ", err.Error()) + } +} + +// GetConfigmapName - +func GetConfigmapName(functionName string, functionVersion string) string { + return strings.ToLower(fmt.Sprintf("%s-%s", strings.ReplaceAll(functionName, "_", "-"), functionVersion)) +} + +// IsNeedRaspSideCar - +func IsNeedRaspSideCar(funcSpec *types.FunctionSpecification) bool { + if funcSpec.ExtendedMetaData.RaspConfig.RaspImage == "" || funcSpec.ExtendedMetaData.RaspConfig.InitImage == "" { + return false + } + if net.ParseIP(funcSpec.ExtendedMetaData.RaspConfig.RaspServerIP) == nil { + log.GetLogger().Warnf("failed to parse rasp ip: %s ", funcSpec.ExtendedMetaData.RaspConfig.RaspServerIP) + return false + } + if !isValidPort(funcSpec.ExtendedMetaData.RaspConfig.RaspServerPort) { + log.GetLogger().Warnf("failed to parse rasp "+ + "port: %s ", funcSpec.ExtendedMetaData.RaspConfig.RaspServerPort) + return false + } + return true +} + +func isValidPort(port string) bool { + p, err := strconv.Atoi(port) + if err != nil { + return false + } + return p > 0 && p <= 65535 // port should between 0 and 65535 +} diff --git a/yuanrong/pkg/functionscaler/utils/configmap_util_test.go b/yuanrong/pkg/functionscaler/utils/configmap_util_test.go new file mode 100644 index 0000000..d3e0c72 --- /dev/null +++ b/yuanrong/pkg/functionscaler/utils/configmap_util_test.go @@ -0,0 +1,73 @@ +package utils + +import ( + "encoding/base64" + "testing" + + "github.com/smartystreets/goconvey/convey" + + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/functionscaler/types" +) + +func TestDeleteConfigMapByFuncInfo(t *testing.T) { + convey.Convey("DeleteConfigMapByFuncInfo", t, func() { + DeleteConfigMapByFuncInfo(&types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{FuncName: "func-name", Version: "latest"}, + ExtendedMetaData: commonTypes.ExtendedMetaData{ + CustomFilebeatConfig: commonTypes.CustomFilebeatConfig{ + ImageAddress: "images", + SidecarConfigInfo: &commonTypes.SidecarConfigInfo{ + ConfigFiles: []commonTypes.CustomLogConfigFile{{Path: "path", + Data: base64.StdEncoding.EncodeToString([]byte("data"))}}, + }, + }, + }}) + }) +} + +func TestIsNeedRaspSideCar(t *testing.T) { + convey.Convey("IsNeedRaspSideCar", t, func() { + isNeed := IsNeedRaspSideCar(&types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + RaspImage: "someImage", + InitImage: "someImage", + RaspServerIP: "1.2.3.4", + RaspServerPort: "1234", + }, + }, + }) + convey.So(isNeed, convey.ShouldEqual, true) + isNeed = IsNeedRaspSideCar(&types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + InitImage: "someImage", + RaspServerIP: "1.2.3.4", + RaspServerPort: "1234", + }, + }, + }) + convey.So(isNeed, convey.ShouldEqual, false) + isNeed = IsNeedRaspSideCar(&types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + RaspImage: "someImage", + InitImage: "someImage", + RaspServerPort: "1234", + }, + }, + }) + convey.So(isNeed, convey.ShouldEqual, false) + isNeed = IsNeedRaspSideCar(&types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + RaspConfig: commonTypes.RaspConfig{ + RaspImage: "someImage", + InitImage: "someImage", + RaspServerIP: "1.2.3.4", + }, + }, + }) + convey.So(isNeed, convey.ShouldEqual, false) + }) +} diff --git a/yuanrong/pkg/functionscaler/utils/utils.go b/yuanrong/pkg/functionscaler/utils/utils.go new file mode 100644 index 0000000..c152fa0 --- /dev/null +++ b/yuanrong/pkg/functionscaler/utils/utils.go @@ -0,0 +1,402 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/resspeckey" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/urnutils" + "yuanrong/pkg/common/faas_common/utils" + commonUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/faas_common/wisecloudtool" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +const ( + ipPortLens = 2 + defaultConcurrentNum = 100 + // DefaultMapSize is default size of map + DefaultMapSize = 16 + // DefaultSliceSize is default size of slice + DefaultSliceSize = 16 + defaultCPUValue = 0 + defaultMemValue = 0 + resourceCPUName = "CPU" + resourceMemoryName = "Memory" + // ValidFuncKeyLen is the valid length of funcKey + ValidFuncKeyLen = 3 + // FuncKeyDelimiter is the delimiter for parsing inner funcKey from funcKey + FuncKeyDelimiter = "/" + // InnerFuncKeyDelimiter is the delimiter for parsing funcName from inner funcKey + InnerFuncKeyDelimiter = "-" + funcNameIndexInFuncKey = 2 + defaultCreateTimeout = time.Duration(30) * time.Second + maxSessionLength = 63 + namespaceLabelKey = "POD_NAMESPACE" + deploymentNameLabelKey = "POD_DEPLOYMENT_NAME" + podNameLabelKey = "POD_NAME" +) + +var ( + labelRegexp = regexp.MustCompile(`-invoke-label-(.*?)-ephemeral-storage-`) +) + +// GetNpuInstanceType - +func GetNpuInstanceType(delegateContainer string) (bool, string) { + config := &types.DelegateContainerConfig{} + err := json.Unmarshal([]byte(delegateContainer), config) + if err != nil { + return false, "" + } + + for _, v := range config.Env { + if v.Name == types.SystemNodeInstanceType { + return true, v.Value + } + } + return false, "" +} + +// GetNpuTypeAndInstanceTypeFromStr 函数用于从字符串中获取 NPU 类型和实例类型 +func GetNpuTypeAndInstanceTypeFromStr(customResource string, customResourcesSpec string) (string, string) { + customResourceMap := make(map[string]int64) + customResourceSpecMap := make(map[string]interface{}) + + err1 := json.Unmarshal([]byte(customResource), &customResourceMap) + err2 := json.Unmarshal([]byte(customResourcesSpec), &customResourceSpecMap) + if err1 != nil { + return "", "" + } + + if err2 != nil { + // pass + } + return GetNpuTypeAndInstanceType(customResourceMap, customResourceSpecMap) +} + +// GetNpuTypeAndInstanceType 函数用于获取 NPU 类型和实例类型 +func GetNpuTypeAndInstanceType(customRes map[string]int64, customResSpec map[string]interface{}) (string, string) { + v, ok := customRes[types.AscendResourceD910B] + if !ok || v <= 0 { + return "", "" + } + d910bType, ok := customResSpec[types.AscendResourceD910BInstanceType] + if !ok { + return types.AscendResourceD910B, "376T" + } + d910bTypeStr, ok := d910bType.(string) + if !ok { + return types.AscendResourceD910B, "376T" + } + return types.AscendResourceD910B, d910bTypeStr +} + +// ConvertInstanceResource will convert instance resource +func ConvertInstanceResource(res commonTypes.Resources) *resspeckey.ResourceSpecification { + resSpec := &resspeckey.ResourceSpecification{ + CustomResources: make(map[string]int64, constant.DefaultMapSize), + } + for k, v := range res.Resources { + if k == constant.ResourceCPUName { + resSpec.CPU = int64(v.Scalar.Value) + continue + } + if k == constant.ResourceMemoryName { + resSpec.Memory = int64(v.Scalar.Value) + continue + } + if k == constant.ResourceEphemeralStorage { + resSpec.EphemeralStorage = int(v.Scalar.Value) + continue + } + resSpec.CustomResources[k] = int64(v.Scalar.Value) + } + return resSpec +} + +// AppendInstanceTypeToInstanceResource - +func AppendInstanceTypeToInstanceResource(resSpec *resspeckey.ResourceSpecification, npuInstanceType string) { + if resSpec == nil { + return + } + if resSpec.CustomResourcesSpec == nil { + resSpec.CustomResourcesSpec = make(map[string]interface{}) + } + resSpec.CustomResourcesSpec["instanceType"] = npuInstanceType +} + +// IsFaaSManager checks if a funcKey is t +func IsFaaSManager(funcKey string) bool { + items := strings.Split(funcKey, InnerFuncKeyDelimiter) + if len(items) != ValidFuncKeyLen { + return false + } + return items[funcNameIndexInFuncKey] == types.FaasManagerFuncName +} + +// GetConcurrentNum gets a valid concurrentNum +func GetConcurrentNum(concurrentNum int) int { + if concurrentNum == 0 { + return 1 + } + return concurrentNum +} + +// GenerateTraceID - +func GenerateTraceID() string { + return uuid.New().String() +} + +// IsUnrecoverableError checks if error should not be retried +func IsUnrecoverableError(err error) bool { + if snErr, ok := err.(snerror.SNError); ok { + if snerror.IsUserError(snErr) { + return true + } + if snErr.Code() == statuscode.UserFuncEntryNotFoundErrCode || snErr.Code() == statuscode.KernelEtcdWriteFailedCode || + snErr.Code() == statuscode.WiseCloudNuwaColdStartErrCode { + return true + } + } + return false +} + +// IsNoNeedToRePullError checks if error pod should not be repull +func IsNoNeedToRePullError(err error) bool { + if snErr, ok := err.(snerror.SNError); ok { + return snErr.Code() == statuscode.StsConfigErrCode + } + return false +} + +// GenerateStsSecretName - +func GenerateStsSecretName(etcdFuncKey string) string { + // keep same as CaaS, prevent secret resource leaks + return strings.ToLower(fmt.Sprintf("%s-sts", urnutils.CrNameByKey(etcdFuncKey))) +} + +// AddNodeSelector - +func AddNodeSelector(nodeSelectorMap map[string]string, schedulingOptions *types.SchedulingOptions, + resSpec *resspeckey.ResourceSpecification) { + if resSpec == nil { + return + } + if resSpec.CustomResources == nil || len(resSpec.CustomResources) == 0 { + if nodeSelectorMap != nil && len(nodeSelectorMap) != 0 { + for k, v := range nodeSelectorMap { + schedulingOptions.Extension[utils.NodeSelectorKey] = fmt.Sprintf(`{"%s": "%s"}`, k, v) + } + } + } +} + +// AddAffinityCPU - +func AddAffinityCPU(crName string, schedulingOptions *types.SchedulingOptions, + resSpec *resspeckey.ResourceSpecification, affinityType api.AffinityType) { + if resSpec == nil { + return + } + if resSpec.CustomResources == nil || len(resSpec.CustomResources) == 0 { + schedulingOptions.Affinity = commonUtils.CreatePodAffinity(crName, "", affinityType) + } +} + +// IsResSpecEmpty - +func IsResSpecEmpty(resSpec *resspeckey.ResourceSpecification) bool { + if resSpec == nil { + return true + } + return resSpec.CPU == 0 && resSpec.Memory == 0 && len(resSpec.CustomResources) == 0 +} + +// GetCreateTimeout - +func GetCreateTimeout(funcSpec *types.FunctionSpecification) time.Duration { + createTimeout := time.Duration(funcSpec.ExtendedMetaData.Initializer.Timeout) * time.Second + if funcSpec.FuncMetaData.Runtime == types.CustomContainerRuntimeType { + createTimeout += constant.CustomImageExtraTimeout * time.Second + } else { + createTimeout += (constant.KernelScheduleTimeout + constant.CommonExtraTimeout) * time.Second + } + if createTimeout < defaultCreateTimeout { + createTimeout = defaultCreateTimeout + } + return createTimeout +} + +// GetRequestTimeout - +func GetRequestTimeout(funcSpec *types.FunctionSpecification) time.Duration { + if funcSpec.InstanceMetaData.ScalePolicy == types.InstanceScalePolicyStaticFunction { + return time.Duration(funcSpec.FuncMetaData.Timeout) * time.Second + } else { + return GetCreateTimeout(funcSpec) + } +} + +// GenRandomString will generate random string with given length +func GenRandomString(n int) string { + randBytes := make([]byte, n/2) + if _, err := rand.Read(randBytes); err != nil { + return "" + } + return fmt.Sprintf("%x", randBytes) +} + +// GenFuncKeyWithRes will generate funcKeyWithRes +func GenFuncKeyWithRes(funcKey, resKey string) string { + return fmt.Sprintf("%s-%s", funcKey, resKey) +} + +// GetLeaseInterval returns the interval of lease +func GetLeaseInterval() time.Duration { + leaseInterval := time.Duration(config.GlobalConfig.LeaseSpan) * time.Millisecond + if leaseInterval < types.MinLeaseInterval { + leaseInterval = types.MinLeaseInterval + } + return leaseInterval +} + +func parseIPAndPort(address string) (string, string) { + var ( + IP, Port string + ) + Address := strings.Split(address, ":") + if len(Address) == ipPortLens { + IP = Address[0] + Port = Address[1] + } + return IP, Port +} + +// BuildInstanceFromInsSpec builds instance from instanceSpecification +func BuildInstanceFromInsSpec(insSpec *commonTypes.InstanceSpecification, + funcSpec *types.FunctionSpecification) *types.Instance { + instanceIP, instancePort := parseIPAndPort(insSpec.RuntimeAddress) + // FunctionProxyID is not "IP:Port" like string, need to find another way to get nodeIP and nodePort + nodeIP, nodePort := parseIPAndPort(insSpec.FunctionProxyID) + instanceName := insSpec.CreateOptions[types.InstanceNameNote] + var ( + instanceType types.InstanceType + isPermanent bool + schedulerID string + concurrentNum int + err error + ) + instanceNote, ok := insSpec.CreateOptions[types.InstanceTypeNote] + if ok { + instanceType = types.InstanceType(instanceNote) + } else { + instanceType = types.InstanceTypeUnknown + } + resSpecKey, err := resspeckey.GetResKeyFromStr(insSpec.CreateOptions[types.ResourceSpecNote]) + if err != nil { + log.GetLogger().Errorf("failed to GetResKeyFromStr from %s for instance %s", + insSpec.CreateOptions[types.ResourceSpecNote], insSpec.InstanceID) + } + schedulerNote := insSpec.CreateOptions[types.SchedulerIDNote] + if strings.HasSuffix(schedulerNote, types.PermanentInstance) { + isPermanent = true + schedulerID = strings.TrimSuffix(schedulerNote, types.PermanentInstance) + } else if strings.HasSuffix(schedulerNote, types.TemporaryInstance) { + schedulerID = strings.TrimSuffix(schedulerNote, types.TemporaryInstance) + } else { + schedulerID = schedulerNote + } + concurrentNote := insSpec.CreateOptions[types.ConcurrentNumKey] + if len(concurrentNote) != 0 { + concurrentNum, err = strconv.Atoi(concurrentNote) + if err != nil { + log.GetLogger().Errorf("failed to parse concurrentNum from %s for instance %s", concurrentNote, + insSpec.InstanceID) + concurrentNum = defaultConcurrentNum + } + } + metricsLabels := make([]string, 0) + if funcSpec != nil { + metricsLabels = wisecloudtool.GetMetricLabels(&funcSpec.FuncMetaData, resSpecKey.InvokeLabel, + insSpec.Extensions.PodNamespace, insSpec.Extensions.PodDeploymentName, insSpec.Extensions.PodName) + } + return &types.Instance{ + InstanceStatus: insSpec.InstanceStatus, + InstanceType: instanceType, + ResKey: resSpecKey, + InstanceID: insSpec.InstanceID, + InstanceName: instanceName, + ParentID: insSpec.ParentID, + FuncSig: insSpec.CreateOptions[types.FunctionSign], + FuncKey: insSpec.Function, + // this instance maybe an old one with previous concurrentNum, need to optimize this logic in future + ConcurrentNum: concurrentNum, + CreateSchedulerID: schedulerID, + Permanent: isPermanent, + NodeIP: nodeIP, + NodePort: nodePort, + InstanceIP: instanceIP, + InstancePort: instancePort, + MetricLabelValues: metricsLabels, + PodID: insSpec.Extensions.PodNamespace + ":" + insSpec.Extensions.PodName, + PodDeploymentName: insSpec.Extensions.PodDeploymentName, + } +} + +// CheckInstanceSessionValid checks if InstanceSessionConfig is valid +func CheckInstanceSessionValid(insSess commonTypes.InstanceSessionConfig) bool { + if len(insSess.SessionID) == 0 || len(insSess.SessionID) > maxSessionLength { + return false + } + if insSess.SessionTTL < 0 { + return false + } + return true +} + +// GetInvokeLabelFromResKey get invoke label from res key +func GetInvokeLabelFromResKey(s string) string { + if !strings.Contains(s, "-invoke-label-") { + return "" + } + matches := labelRegexp.FindStringSubmatch(s) + // 第一个匹配项是整个匹配的字符串,第二个匹配项是我们想要的 + if len(matches) > 1 { + return matches[1] + } + return "" +} + +// IntMax returns the maximum of the params +func IntMax(a, b int) int { + if b > a { + return b + } + return a +} diff --git a/yuanrong/pkg/functionscaler/utils/utils_test.go b/yuanrong/pkg/functionscaler/utils/utils_test.go new file mode 100644 index 0000000..31efab4 --- /dev/null +++ b/yuanrong/pkg/functionscaler/utils/utils_test.go @@ -0,0 +1,274 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "testing" + "time" + + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/resspeckey" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +func TestIsFaaSManager(t *testing.T) { + convey.Convey("error funcKey", t, func() { + is := IsFaaSManager("") + convey.So(is, convey.ShouldBeFalse) + }) + convey.Convey("right funcKey", t, func() { + is := IsFaaSManager("0-system-faasmanager") + convey.So(is, convey.ShouldBeTrue) + }) +} + +func TestAddAffinityCPU(t *testing.T) { + type args struct { + crName string + schedulingOptions *types.SchedulingOptions + resSpec *resspeckey.ResourceSpecification + affinityType api.AffinityType + } + tests := []struct { + name string + args args + }{ + {"case1", args{ + crName: "yyrk1234-0-yrservice-test-image-env-call-latest-670698364", + schedulingOptions: &types.SchedulingOptions{}, + resSpec: &resspeckey.ResourceSpecification{CPU: 300, Memory: 128}, + affinityType: api.PreferredAntiAffinity, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + AddAffinityCPU(tt.args.crName, tt.args.schedulingOptions, tt.args.resSpec, tt.args.affinityType) + }) + } +} + +func TestAppendInstanceTypeToInstanceResource(t *testing.T) { + convey.Convey("test AppendInstanceTypeToInstanceResource", t, func() { + convey.Convey("resSpec is nil", func() { + AppendInstanceTypeToInstanceResource(nil, "") + }) + convey.Convey("success", func() { + resSpec := &resspeckey.ResourceSpecification{} + AppendInstanceTypeToInstanceResource(resSpec, "NPU") + convey.So(resSpec.CustomResourcesSpec["instanceType"], convey.ShouldEqual, "NPU") + }) + }) +} + +func TestGetNpuTypeAndInstanceTypeFromStr(t *testing.T) { + convey.Convey("Test GetNpuTypeAndInstanceTypeFromStr", t, func() { + convey.Convey("Test GetNpuTypeAndInstanceTypeFromStr", func() { + npuType, extraType := GetNpuTypeAndInstanceTypeFromStr(`{"huawei.com/ascend-1980":1}`, + `{"instanceType":"280t"}`) + convey.So(npuType, convey.ShouldEqual, "huawei.com/ascend-1980") + convey.So(extraType, convey.ShouldEqual, "280t") + }) + }) +} + +func TestAddNodeSelector(t *testing.T) { + convey.Convey("Test AddNodeSelector", t, func() { + convey.Convey("Test AddNodeSelector", func() { + scheduleOption := &types.SchedulingOptions{Extension: make(map[string]string)} + AddNodeSelector(map[string]string{"aaa": "123"}, scheduleOption, &resspeckey.ResourceSpecification{}) + convey.So(scheduleOption.Extension[utils.NodeSelectorKey], convey.ShouldEqual, `{"aaa": "123"}`) + }) + }) +} + +func TestGetCreateTimeout(t *testing.T) { + convey.Convey("test GetCreateTimeout", t, func() { + convey.Convey("get create timeout", func() { + timeout := GetCreateTimeout(&types.FunctionSpecification{}) + convey.So(timeout, convey.ShouldEqual, defaultCreateTimeout) + timeout = GetCreateTimeout(&types.FunctionSpecification{ + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Timeout: 50, + }, + }, + }) + convey.So(timeout, convey.ShouldEqual, (50+constant.CommonExtraTimeout+constant.KernelScheduleTimeout)*time.Second) + timeout = GetCreateTimeout(&types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Runtime: types.CustomContainerRuntimeType, + }, + ExtendedMetaData: commonTypes.ExtendedMetaData{ + Initializer: commonTypes.Initializer{ + Timeout: 50, + }, + }, + }) + convey.So(timeout, convey.ShouldEqual, (50+constant.CustomImageExtraTimeout)*time.Second) + }) + }) +} + +func TestGetLeaseInterval(t *testing.T) { + convey.Convey("test GetLeaseInterval", t, func() { + convey.Convey("get lease interval", func() { + config.GlobalConfig.LeaseSpan = 50 + interval := GetLeaseInterval() + convey.So(interval, convey.ShouldEqual, 500*time.Millisecond) + config.GlobalConfig.LeaseSpan = 200 + interval = GetLeaseInterval() + convey.So(interval, convey.ShouldEqual, 500*time.Millisecond) + config.GlobalConfig.LeaseSpan = 600 + interval = GetLeaseInterval() + convey.So(interval, convey.ShouldEqual, 600*time.Millisecond) + }) + }) +} + +func TestBuildInstanceFromInsSpec(t *testing.T) { + convey.Convey("test BuildInstanceFromInsSpec", t, func() { + convey.Convey("get instance IP and port", func() { + instance := BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + RuntimeAddress: "1.2.3.4:1234", + }, nil) + convey.So(instance.InstanceIP, convey.ShouldEqual, "1.2.3.4") + convey.So(instance.InstancePort, convey.ShouldEqual, "1234") + }) + convey.Convey("get instance type", func() { + instance := BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{}, nil) + convey.So(instance.InstanceType, convey.ShouldEqual, types.InstanceTypeUnknown) + instance = BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + CreateOptions: map[string]string{ + types.InstanceTypeNote: string(types.InstanceTypeReserved), + }, + }, nil) + convey.So(instance.InstanceType, convey.ShouldEqual, types.InstanceTypeReserved) + }) + convey.Convey("get scheduler id", func() { + instance := BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + CreateOptions: map[string]string{ + types.SchedulerIDNote: "scheduler1-permanent", + }, + }, nil) + convey.So(instance.Permanent, convey.ShouldEqual, true) + convey.So(instance.CreateSchedulerID, convey.ShouldEqual, "scheduler1") + instance = BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + CreateOptions: map[string]string{ + types.SchedulerIDNote: "scheduler1-temporary", + }, + }, nil) + convey.So(instance.Permanent, convey.ShouldEqual, false) + convey.So(instance.CreateSchedulerID, convey.ShouldEqual, "scheduler1") + }) + convey.Convey("get concurrentNum", func() { + instance := BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + CreateOptions: map[string]string{ + types.ConcurrentNumKey: "", + }, + }, nil) + convey.So(instance.ConcurrentNum, convey.ShouldEqual, 0) + instance = BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + CreateOptions: map[string]string{ + types.ConcurrentNumKey: "wrong data", + }, + }, nil) + convey.So(instance.ConcurrentNum, convey.ShouldEqual, defaultConcurrentNum) + instance = BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + CreateOptions: map[string]string{ + types.ConcurrentNumKey: "50", + }, + }, nil) + convey.So(instance.ConcurrentNum, convey.ShouldEqual, 50) + }) + convey.Convey("get metricLabelValue ", func() { + instance := BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + Extensions: commonTypes.Extensions{}, + }, nil) + convey.So(instance.MetricLabelValues, convey.ShouldBeEmpty) + instance = BuildInstanceFromInsSpec(&commonTypes.InstanceSpecification{ + Extensions: commonTypes.Extensions{ + PodName: "aaa", + PodNamespace: "bbb", + PodDeploymentName: "ccc", + }, + }, &types.FunctionSpecification{ + FuncMetaData: commonTypes.FuncMetaData{ + Name: "111", + TenantID: "222", + BusinessID: "333", + FuncName: "444", + }, + }) + convey.So(len(instance.MetricLabelValues), convey.ShouldEqual, 8) + }) + }) +} + +func TestCheckInstanceSessionValid(t *testing.T) { + convey.Convey("test CheckInstanceSessionValid", t, func() { + convey.Convey("CheckInstanceSessionValid", func() { + res := CheckInstanceSessionValid(commonTypes.InstanceSessionConfig{ + SessionID: "_123&0", + SessionTTL: 10, + }) + convey.So(res, convey.ShouldEqual, true) + res = CheckInstanceSessionValid(commonTypes.InstanceSessionConfig{ + SessionTTL: 10, + }) + convey.So(res, convey.ShouldEqual, false) + res = CheckInstanceSessionValid(commonTypes.InstanceSessionConfig{ + SessionID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + SessionTTL: 10, + }) + convey.So(res, convey.ShouldEqual, false) + res = CheckInstanceSessionValid(commonTypes.InstanceSessionConfig{ + SessionID: "aaa", + SessionTTL: 0, + }) + convey.So(res, convey.ShouldEqual, true) + }) + }) +} + +func TestGetInvokeLabelFromResKey(t *testing.T) { + convey.Convey("test GetInvokeLabelFromResKey", t, func() { + convey.Convey("get GetInvokeLabelFromResKey", func() { + res := resspeckey.ResourceSpecification{ + CPU: 500, + Memory: 1000, + InvokeLabel: "aaaaa", + EphemeralStorage: 0, + } + label := GetInvokeLabelFromResKey(res.String()) + convey.So(label, convey.ShouldEqual, "aaaaa") + }) + }) +} + +func TestIntMax(t *testing.T) { + assert.Equal(t, IntMax(3, 1), 3) + assert.Equal(t, IntMax(1, 3), 3) +} diff --git a/yuanrong/pkg/functionscaler/workermanager/lease.go b/yuanrong/pkg/functionscaler/workermanager/lease.go new file mode 100644 index 0000000..84c6ea5 --- /dev/null +++ b/yuanrong/pkg/functionscaler/workermanager/lease.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package workermanager + +import ( + "errors" + "sync" + + coordinationv1 "k8s.io/api/coordination/v1" + "k8s.io/client-go/informers" + informercdv1 "k8s.io/client-go/informers/coordination/v1" + "k8s.io/client-go/tools/cache" + + "yuanrong/pkg/common/faas_common/k8sclient" + "yuanrong/pkg/common/faas_common/logger/log" +) + +var ( + leaseInformer informercdv1.LeaseInformer + wmAddr string + wmAddrMutex sync.RWMutex +) + +func getWmAddr() string { + wmAddrMutex.RLock() + ret := wmAddr + wmAddrMutex.RUnlock() + return ret +} + +func updateWmAddr(addr string) { + wmAddrMutex.Lock() + wmAddr = addr + wmAddrMutex.Unlock() +} + +// InitLeaseInformer - +func InitLeaseInformer(stopCh <-chan struct{}) error { + + informerFactory := informers.NewSharedInformerFactoryWithOptions(k8sclient.GetkubeClient().Client, + 0, informers.WithNamespace("default")) + leaseInformer = informerFactory.Coordination().V1().Leases() + + leaseInformer.Informer().AddEventHandler(generateInformerHandler()) + + go leaseInformer.Informer().Run(stopCh) + if ok := cache.WaitForCacheSync(stopCh, leaseInformer.Informer().HasSynced); !ok { + return errors.New("failed to sync") + } + return nil +} + +func generateInformerHandler() cache.FilteringResourceEventHandler { + return cache.FilteringResourceEventHandler{ + FilterFunc: func(obj interface{}) bool { + lease, ok := obj.(*coordinationv1.Lease) + if !ok { + return false + } + + if lease.Name == "worker-manager" { + return true + } + + return false + }, + Handler: cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + lease, ok := obj.(*coordinationv1.Lease) + if !ok { + return + } + if lease.Spec.HolderIdentity == nil { + return + } + updateWmAddr(*lease.Spec.HolderIdentity) + log.GetLogger().Infof("ip of worker-manger leader is: %s", *lease.Spec.HolderIdentity) + }, + UpdateFunc: func(oldObj, newObj interface{}) { + lease, ok := newObj.(*coordinationv1.Lease) + if !ok { + return + } + if lease.Spec.HolderIdentity == nil { + return + } + oldAddr := getWmAddr() + if oldAddr != *lease.Spec.HolderIdentity { + updateWmAddr(*lease.Spec.HolderIdentity) + log.GetLogger().Infof("ip of worker-manager leader update to %s", *lease.Spec.HolderIdentity) + } + }, + DeleteFunc: func(obj interface{}) { + log.GetLogger().Errorf("leader of worker-manager lost") + updateWmAddr("") + }, + }, + } +} diff --git a/yuanrong/pkg/functionscaler/workermanager/lease_test.go b/yuanrong/pkg/functionscaler/workermanager/lease_test.go new file mode 100644 index 0000000..36fa96a --- /dev/null +++ b/yuanrong/pkg/functionscaler/workermanager/lease_test.go @@ -0,0 +1,118 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package workermanager + +import ( + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/coordination/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/informers" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/fake" + "yuanrong/pkg/common/faas_common/k8sclient" +) + +func Test_initLeaseInformer(t *testing.T) { + addr1 := "10.0.0.1:58866" + leases := []*v1.Lease{ + { + ObjectMeta: metav1.ObjectMeta{Name: "state-manager"}, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "worker-manager"}, + Spec: v1.LeaseSpec{HolderIdentity: &addr1}, + }, + } + + fakeClient := fake.NewSimpleClientset(leases[0], leases[1]) + fakeInformer := informers.NewSharedInformerFactory(fakeClient, 0) + + patches := []*gomonkey.Patches{ + gomonkey.ApplyFunc(k8sclient.GetkubeClient, func() *k8sclient.KubeClient { + return &k8sclient.KubeClient{ + Client: &kubernetes.Clientset{}, + } + }), + gomonkey.ApplyFunc(informers.NewSharedInformerFactoryWithOptions, func(client kubernetes.Interface, defaultResync time.Duration, options ...informers.SharedInformerOption) informers.SharedInformerFactory { + return fakeInformer + }), + } + defer func() { + for i := range patches { + patches[i].Reset() + } + }() + + stopCh := make(chan struct{}) + err := InitLeaseInformer(stopCh) + assert.Nil(t, err) + close(stopCh) +} + +func Test_generateInformerHandler(t *testing.T) { + addr1 := "10.0.0.1:58866" + addr2 := "10.0.0.2:58866" + addr3 := "10.0.0.3:58866" + emptyLease := &v1.Lease{ + ObjectMeta: metav1.ObjectMeta{Name: "worker-manager"}, + } + lease1 := &v1.Lease{ + ObjectMeta: metav1.ObjectMeta{Name: "worker-manager"}, + Spec: v1.LeaseSpec{HolderIdentity: &addr1}, + } + lease2 := &v1.Lease{ + ObjectMeta: metav1.ObjectMeta{Name: "worker-manager"}, + Spec: v1.LeaseSpec{HolderIdentity: &addr2}, + } + + stateLease := &v1.Lease{ + ObjectMeta: metav1.ObjectMeta{Name: "state-manager"}, + Spec: v1.LeaseSpec{HolderIdentity: &addr3}, + } + + updateWmAddr("") + handlers := generateInformerHandler() + handlers.OnAdd(*lease1, true) + assert.Equal(t, "", wmAddr) + handlers.OnAdd(emptyLease, true) + assert.Equal(t, "", wmAddr) + handlers.OnAdd(lease1, true) + assert.Equal(t, addr1, wmAddr) + handlers.OnAdd(stateLease, true) + assert.Equal(t, addr1, wmAddr) + + filter := handlers.FilterFunc + handlers.FilterFunc = func(obj interface{}) bool { + return true + } + handlers.OnUpdate(lease1, *lease1) + assert.Equal(t, addr1, wmAddr) + handlers.OnUpdate(lease1, emptyLease) + assert.Equal(t, addr1, wmAddr) + handlers.OnUpdate(lease1, lease2) + assert.Equal(t, addr2, wmAddr) + + handlers.FilterFunc = filter + handlers.OnDelete(stateLease) + assert.Equal(t, addr2, wmAddr) + handlers.OnDelete(lease2) + assert.Equal(t, "", wmAddr) +} diff --git a/yuanrong/pkg/functionscaler/workermanager/workermanager_client.go b/yuanrong/pkg/functionscaler/workermanager/workermanager_client.go new file mode 100644 index 0000000..edcf52c --- /dev/null +++ b/yuanrong/pkg/functionscaler/workermanager/workermanager_client.go @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package workermanager worker manager client +package workermanager + +import ( + "errors" + "net" + "net/http" + "sync" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/functionscaler/config" +) + +const ( + workerManagerDeployURL = "/worker-manager/v1/functions/sn:cn:yrk:%s:function:%s:%s/worker" + workerManagerDeleteURL = "/worker-manager/v1/functions/worker/delete" + statusOKCode = 150200 + scaleDownTimeout = 30 * time.Second + defaultRequestWorkerManagerImageTimeout = 25 * time.Minute + appID = "ondemand" + dialTimeout = 3 * time.Second + idleConnTimeout = 90 * time.Second + connKeepAlive = 30 * time.Second + maxIdleConns = 100 + httpScheme = "http://" + httpsScheme = "https://" +) + +var ( + once sync.Once + httpClient *http.Client +) + +// GetWorkerManagerClient - +func GetWorkerManagerClient() *http.Client { + once.Do(func() { + tr := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: dialTimeout, + KeepAlive: connKeepAlive, + }).DialContext, + MaxIdleConns: maxIdleConns, + ForceAttemptHTTP2: true, + IdleConnTimeout: idleConnTimeout, + TLSHandshakeTimeout: dialTimeout, + } + if config.GlobalConfig.HTTPSConfig.HTTPSEnable { + tr.TLSClientConfig = tls.GetClientTLSConfig() + } + + httpClient = &http.Client{ + Timeout: defaultRequestWorkerManagerImageTimeout, + Transport: tr, + } + }) + return httpClient +} + +// GetWorkerManagerBaseURL - +func GetWorkerManagerBaseURL() (string, error) { + scheme := httpScheme + if config.GlobalConfig.HTTPSConfig != nil && config.GlobalConfig.HTTPSConfig.HTTPSEnable { + scheme = httpsScheme + } + addr := getWmAddr() + if addr == "" { + return "", errors.New("worker manager address is empty") + } + return scheme + addr, nil +} + +// FillInWorkerManagerRequestHeaders contains authorization, source and so on +func FillInWorkerManagerRequestHeaders(request *http.Request) { + authorization, timestamp := generateAuthorization() + request.Header.Set(constant.HeaderAuthTimestamp, timestamp) + request.Header.Set(constant.HeaderAuthorization, authorization) + request.Header.Set(constant.HeaderEventSourceID, appID) + request.Header.Set(constant.HeaderCallType, "active") +} + +func generateAuthorization() (string, string) { + var authorization, timestamp string + authorization, timestamp = localauth.SignLocally(config.GlobalConfig.LocalAuth.AKey, + config.GlobalConfig.LocalAuth.SKey, appID, config.GlobalConfig.LocalAuth.Duration) + return authorization, timestamp +} diff --git a/yuanrong/pkg/functionscaler/workermanager/workermanager_client_test.go b/yuanrong/pkg/functionscaler/workermanager/workermanager_client_test.go new file mode 100644 index 0000000..37b1976 --- /dev/null +++ b/yuanrong/pkg/functionscaler/workermanager/workermanager_client_test.go @@ -0,0 +1,114 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package workermanager worker manager client +package workermanager + +import ( + "net/http" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/localauth" + commontls "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +func TestGetWorkerManagerClient(t *testing.T) { + bakHttpsConfig := config.GlobalConfig.HTTPSConfig + config.GlobalConfig.HTTPSConfig = &commontls.InternalHTTPSConfig{ + HTTPSEnable: true, + TLSProtocol: "tls", + TLSCiphers: "ciphers", + } + client := GetWorkerManagerClient() + if client == nil { + t.Errorf("failed to get worker manager client") + } + config.GlobalConfig.HTTPSConfig = bakHttpsConfig +} + +// TestFillInWorkerManagerRequestHeaders - +func TestFillInWorkerManagerRequestHeaders(t *testing.T) { + convey.Convey("test: fill in workermanager request header", t, func() { + patch := gomonkey.ApplyFunc(localauth.SignLocally, func(ak, sk, appID string, duration int) (string, string) { + return "authorization", "timestamp" + }) + config.GlobalConfig = types.Configuration{ + LocalAuth: localauth.AuthConfig{}, + } + defer func() { + patch.Reset() + config.GlobalConfig = types.Configuration{} + }() + req := http.Request{ + Header: map[string][]string{}, + } + FillInWorkerManagerRequestHeaders(&req) + convey.So(req.Header.Get(constant.HeaderAuthTimestamp), convey.ShouldEqual, "timestamp") + convey.So(req.Header.Get(constant.HeaderAuthorization), convey.ShouldEqual, "authorization") + convey.So(req.Header.Get(constant.HeaderEventSourceID), convey.ShouldEqual, appID) + convey.So(req.Header.Get(constant.HeaderCallType), convey.ShouldEqual, "active") + }) +} + +// TestGetWorkerManagerBaseURL - +func TestGetWorkerManagerBaseURL(t *testing.T) { + convey.Convey("test: get worker manager base url", t, func() { + updateWmAddr("worker-manager:58866") + tasks := []struct { + caseName string + httpsEnable bool + patch func() + recover func() + expect string + }{ + { + caseName: "http", + httpsEnable: false, + patch: func() { + config.GlobalConfig = types.Configuration{HTTPSConfig: &commontls.InternalHTTPSConfig{HTTPSEnable: false}} + }, + recover: func() { + config.GlobalConfig = types.Configuration{} + }, + expect: "http://worker-manager:58866", + }, + { + caseName: "https", + httpsEnable: true, + patch: func() { + config.GlobalConfig = types.Configuration{HTTPSConfig: &commontls.InternalHTTPSConfig{HTTPSEnable: true}} + }, + recover: func() { + config.GlobalConfig = types.Configuration{} + }, + expect: "https://worker-manager:58866", + }, + } + // test + for _, task := range tasks { + task.patch() + actual, _ := GetWorkerManagerBaseURL() + convey.So(actual, convey.ShouldEqual, task.expect) + task.recover() + } + }) +} diff --git a/yuanrong/pkg/functionscaler/workermanager/workermanager_request.go b/yuanrong/pkg/functionscaler/workermanager/workermanager_request.go new file mode 100644 index 0000000..31f2bb4 --- /dev/null +++ b/yuanrong/pkg/functionscaler/workermanager/workermanager_request.go @@ -0,0 +1,329 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package workermanager + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "strconv" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/common/faas_common/urnutils" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/selfregister" + "yuanrong/pkg/functionscaler/types" +) + +var nodeID = os.Getenv("NODE_IP") + +// ScaleUpParam scale up param +type ScaleUpParam struct { + TraceID string + FunctionKey string + NodeLabel string + Timeout time.Duration + CPU int + Memory int +} + +// GetWorkerSuccessResponse define +type GetWorkerSuccessResponse struct { + Worker *Worker `json:"worker"` + Instance *types.WmInstance `json:"instance"` + Code int `json:"code"` + Message string `json:"message"` +} + +// Worker define a worker +type Worker struct { + Instances []*types.WmInstance `json:"instances"` + FunctionName string `json:"functionname"` + FunctionVersion string `json:"functionversion"` + Tenant string `json:"tenant"` + Business string `json:"business"` +} + +// DeployParam Deploy function param +type DeployParam struct { + FuncName string // deployed function name + Applier string // applier instance id + DeployNode string // deploy node id + Business string // business + TenantID string // tenant id + Version string // version + OwnerIP string // owner IP + TraceID string // trace id + TriggerFlag string // This is a trigger request flag. + StateID string + CPU int + Memory int + OwnedWorkerView WorkersView // proxy owned worker view +} + +// WorkersView proxy workers view +type WorkersView struct { + // OwnedNum -1 num means scale up 1 worker by force + // other num means proxy owned worker nums + OwnedNum int `json:"ownedNum"` + // CurrentWorkerNum - + CurrentWorkerNum int `json:"currentWorkerNum"` + // ScalingWorkersNum - + ScalingWorkersNum int `json:"scalingWorkersNum"` +} + +// DeleteParam Delete function param +type DeleteParam struct { + InstanceID string `json:"instance_id"` + FuncName string `json:"function_name"` + FuncVersion string `json:"function_version"` + BusinessID string `json:"business_id"` + TenantID string `json:"tenant_id"` + Applier string `json:"applier"` + IsBrokenConnection bool `json:"broken_connection_status"` +} + +// DeleteWorkerResponse is response to delete worker +type DeleteWorkerResponse struct { + Code int `json:"code"` + Reserved bool `json:"reserved"` + Message string `json:"message"` +} + +// ScaleUpInstance send scale up req to worker manager +func ScaleUpInstance(scaleUpParam *ScaleUpParam) (*types.WmInstance, error) { + anonymizedFuncKeyWithRes := urnutils.AnonymizeTenantKey(scaleUpParam.FunctionKey) + ctx := context.TODO() + ctx, cancel := context.WithTimeout(ctx, scaleUpParam.Timeout) + defer cancel() + request := makeScaleUpRequest(ctx, scaleUpParam) + if request == nil { + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | failed to make scale up request", + scaleUpParam.TraceID, anonymizedFuncKeyWithRes) + return nil, errors.New("make scale up request failed") + } + resp, err := GetWorkerManagerClient().Do(request) + if err != nil { + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | failed to send scale up request "+ + "to worker manager: %s", scaleUpParam.TraceID, anonymizedFuncKeyWithRes, err.Error()) + return nil, err + } + instance, err := handleScaleUpResponseFromWorkerManager(resp, scaleUpParam, anonymizedFuncKeyWithRes) + if err != nil { + return nil, err + } + return instance, nil +} + +func handleScaleUpResponseFromWorkerManager(resp *http.Response, + scaleUpParam *ScaleUpParam, anonymizedFuncKeyWithRes string) (*types.WmInstance, error) { + respBody, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + log.GetLogger().Errorf("failed to get response body: %s, functionKey: %s, traceID: %s", err.Error(), + urnutils.AnonymizeTenantKey(scaleUpParam.FunctionKey), scaleUpParam.TraceID) + return nil, snerror.New(statuscode.ScaleUpRequestErrCode, statuscode.ScaleUpRequestErrMsg) + } + if resp.StatusCode != http.StatusOK { + snErr := &snerror.BadResponse{} + err = json.Unmarshal(respBody, snErr) + if err != nil { + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | failed to unmarshal scale up "+ + "response: %s", scaleUpParam.TraceID, anonymizedFuncKeyWithRes, err.Error()) + return nil, err + } + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | errCode: %d | errMessage: %s | failed to "+ + "scale up function instance", scaleUpParam.TraceID, anonymizedFuncKeyWithRes, snErr.Code, snErr.Message) + return nil, snerror.New(snErr.Code, snErr.Message) + } + scaleUpResp := GetWorkerSuccessResponse{} + if err = json.Unmarshal(respBody, &scaleUpResp); err != nil { + log.GetLogger().Errorf("failed to unmarshal scaleUp response: %s, functionKey: %s, traceID: %s, body %s", + err, urnutils.AnonymizeTenantKey(scaleUpParam.FunctionKey), scaleUpParam.TraceID, string(respBody)) + return nil, snerror.New(statuscode.ScaleUpRequestErrCode, statuscode.ScaleUpRequestErrMsg) + } + if scaleUpResp.Code == statuscode.InnerRuntimeInitTimeoutCode && scaleUpResp.Worker != nil { + log.GetLogger().Infof("deploy functionKey %s instance: %v traceID: %s with code %d msg %s", + urnutils.AnonymizeTenantKey(scaleUpParam.FunctionKey), scaleUpResp.Instance, scaleUpParam.TraceID, + scaleUpResp.Code, scaleUpResp.Message) + return scaleUpResp.Instance, nil + } + if scaleUpResp.Code != statusOKCode || scaleUpResp.Worker == nil { + log.GetLogger().Warnf("deploy response code %d, functionKey: %s, traceID: %s", scaleUpResp.Code, + urnutils.AnonymizeTenantKey(scaleUpParam.FunctionKey), scaleUpParam.TraceID) + if scaleUpResp.Code == 0 { + return nil, snerror.New(statuscode.ScaleUpRequestErrCode, scaleUpResp.Message) + } + return nil, snerror.New(scaleUpResp.Code, scaleUpResp.Message) + } + log.GetLogger().Infof("deploy functionKey %s instance: %v traceID: %s successfully", + urnutils.AnonymizeTenantKey(scaleUpParam.FunctionKey), scaleUpResp.Instance, scaleUpParam.TraceID) + return scaleUpResp.Instance, nil +} + +// ScaleDownInstance sends request to workerManager scale down instance. +func ScaleDownInstance(instanceID, functionKey, traceID string) error { + anonymizedFuncKey := urnutils.AnonymizeTenantKey(functionKey) + log.GetLogger().Infof("traceID: %s | FuncKeyWithRes: %s | instanceID: %s | start to scale down instance", + traceID, anonymizedFuncKey, instanceID) + ctx := context.TODO() + ctx, cancel := context.WithTimeout(ctx, scaleDownTimeout) + defer cancel() + request := makeScaleDownRequest(ctx, instanceID, functionKey) + if request == nil { + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | failed to make scale down request", + traceID, anonymizedFuncKey) + return errors.New("failed to make scale down request") + } + resp, err := GetWorkerManagerClient().Do(request) + if err != nil { + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | error: %s | sent http request error and "+ + "failed to send scale down request to worker manager", traceID, anonymizedFuncKey, err.Error()) + return err + } + return handleScaleDownResponseFromWorkerManager(instanceID, functionKey, traceID, resp) +} + +func makeScaleUpRequest(ctx context.Context, scaleUpParam *ScaleUpParam) *http.Request { + tenantID, functionName, functionVersion := utils.ParseFuncKey(scaleUpParam.FunctionKey) + baseURL, err := GetWorkerManagerBaseURL() + if err != nil { + log.GetLogger().Errorf("failed to get worker manager baseURL: %s", err.Error()) + return nil + } + requestURI := baseURL + fmt.Sprintf(workerManagerDeployURL, tenantID, functionName, functionVersion) + param := DeployParam{ + FuncName: functionName, + Applier: selfregister.SelfInstanceID, + DeployNode: nodeID, + Business: "yrk", + TenantID: tenantID, + Version: functionVersion, + OwnerIP: nodeID, + CPU: scaleUpParam.CPU, + Memory: scaleUpParam.Memory, + OwnedWorkerView: WorkersView{ + OwnedNum: -1, + }, + } + data, err := json.Marshal(param) + if err != nil { + log.GetLogger().Errorf("failed to marshal scaleUp request: %s", err.Error()) + return nil + } + request, err := http.NewRequestWithContext(ctx, "POST", requestURI, bytes.NewBuffer(data)) + if err != nil { + log.GetLogger().Errorf("action failed when make scaleUp request, err %s", err.Error()) + return nil + } + FillInWorkerManagerRequestHeaders(request) + log.GetLogger().Debugf("succeeded to sign the authorization of function: %s, traceID: %s", + urnutils.AnonymizeTenantKey(scaleUpParam.FunctionKey), scaleUpParam.TraceID) + request.Header.Set(constant.HeaderTraceID, scaleUpParam.TraceID) + request.Header.Set(constant.HeaderForceDeploy, strconv.FormatBool(false)) + return request +} + +func makeScaleDownRequest(ctx context.Context, instanceID, functionKey string) *http.Request { + tenantID, functionName, functionVersion := utils.ParseFuncKey(functionKey) + baseURL, err := GetWorkerManagerBaseURL() + if err != nil { + log.GetLogger().Errorf("failed to get worker manager baseURL: %s", err.Error()) + return nil + } + requestURI := baseURL + workerManagerDeleteURL + param := DeleteParam{ + InstanceID: instanceID, + FuncName: functionName, + FuncVersion: functionVersion, + BusinessID: "yrk", + TenantID: tenantID, + Applier: selfregister.SelfInstanceID, + } + data, err := json.Marshal(param) + if err != nil { + log.GetLogger().Errorf("marshal error when make scale down request") + return nil + } + request, err := http.NewRequestWithContext(ctx, http.MethodDelete, requestURI, bytes.NewBuffer(data)) + if err != nil { + log.GetLogger().Errorf("error when make scale down request: %s", err.Error()) + return nil + } + FillInWorkerManagerRequestHeaders(request) + return request +} + +func handleScaleDownResponseFromWorkerManager(instanceID, functionKey, traceID string, resp *http.Response) error { + anonymizedFuncKey := urnutils.AnonymizeTenantKey(functionKey) + respBody, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | error: %s | failed to get response body", + traceID, anonymizedFuncKey, err.Error()) + return err + } + if resp.StatusCode != http.StatusOK { + snErr := &snerror.BadResponse{} + err = json.Unmarshal(respBody, snErr) + if err != nil { + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | error: %s | failed to unmarshal scale "+ + "down bad response", traceID, anonymizedFuncKey, err.Error()) + return err + } + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | responseCode: %d | responseMessage: %s | "+ + "worker manager returned bad response", traceID, anonymizedFuncKey, snErr.Code, snErr.Message) + return snerror.New(snErr.Code, snErr.Message) + } + scaleDownResp := DeleteWorkerResponse{} + if err := json.Unmarshal(respBody, &scaleDownResp); err != nil { + log.GetLogger().Errorf("traceID: %s | FuncKeyWithRes: %s | failed to unmarshal scale down response: "+ + "DeleteWorkerResponse", traceID, anonymizedFuncKey) + return err + } + if scaleDownResp.Code != http.StatusOK { + log.GetLogger().Warnf("traceID: %s | FuncKeyWithRes: %s | responseCode: %d | responseMessage: %s | "+ + "worker manager returned error", traceID, anonymizedFuncKey, scaleDownResp.Code, scaleDownResp.Message) + return errors.New("worker manager returned error") + } + log.GetLogger().Infof("traceID: %s | FuncKeyWithRes: %s | instanceID: %s | succeed to scale down instance", + traceID, anonymizedFuncKey, instanceID) + return nil +} + +// NeedTryError no response error and wait for retrying scaling +// or local worker turned idle until request queued timeout +func NeedTryError(err error) bool { + if snErr, ok := err.(snerror.SNError); ok { + if snErr.Code() == statuscode.GettingPodErrorCode || + snErr.Code() == statuscode.CancelGeneralizePod || + snErr.Code() == statuscode.ReachMaxInstancesCode || + snErr.Code() == statuscode.ReachMaxOnDemandInstancesPerTenant { + return true + } + } + return false +} diff --git a/yuanrong/pkg/functionscaler/workermanager/workermanager_request_test.go b/yuanrong/pkg/functionscaler/workermanager/workermanager_request_test.go new file mode 100644 index 0000000..72597a9 --- /dev/null +++ b/yuanrong/pkg/functionscaler/workermanager/workermanager_request_test.go @@ -0,0 +1,509 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package workermanager + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "io/ioutil" + "net/http" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/snerror" + "yuanrong/pkg/common/faas_common/statuscode" + "yuanrong/pkg/common/faas_common/tls" + mockUtils "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/functionscaler/config" + "yuanrong/pkg/functionscaler/types" +) + +const mockFuncKey = "7e1ad6a6-cc5c-44fa-bd54-25873f72a86a/0@default@testyrhttpwebsocket/latest" + +func TestScaleUpInstance(t *testing.T) { + type args struct { + scaleUpParam *ScaleUpParam + } + config.GlobalConfig = types.Configuration{ + LocalAuth: localauth.AuthConfig{ + AKey: "000", + SKey: "111", + Duration: 5, + }, + HTTPSConfig: &tls.InternalHTTPSConfig{ + HTTPSEnable: false, + TLSProtocol: "TLSv1.2", + TLSCiphers: "TLS_ECDHE_RSA", + }, + } + mockArg := args{ + scaleUpParam: &ScaleUpParam{ + TraceID: "testTraceID", + FunctionKey: mockFuncKey, + Timeout: 1 * time.Second, + CPU: 400, + Memory: 256, + }, + } + mockRespBody := `{"code":150200,"instance":{"instanceID":"pool30-xxx"},"worker":{"instances":[], + "functionname":"","functionversion":"","tenant":"","business":""}}` + + tests := []struct { + name string + args args + patchesFunc mockUtils.PatchesFunc + want *types.WmInstance + wantErr error + }{ + { + name: "case 1 scale up instance successfully", + args: mockArg, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader( + []byte(mockRespBody))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + want: &types.WmInstance{ + InstanceID: "pool30-xxx", + }, + wantErr: nil, + }, + { + name: "case 2 scale up instance failed", + args: mockArg, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 500, + Body: ioutil.NopCloser(bytes.NewReader( + []byte(`{"code": 5000,"message":"mockMessage"}`))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + want: nil, + wantErr: snerror.New(5000, "mockMessage"), + }, + { + name: "case 3 scale up instance instance failed when do request error", + args: mockArg, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return nil, errors.New("do request mock error") + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + want: nil, + wantErr: errors.New("do request mock error"), + }, + { + name: "case 4 scale up instance failed when unmarshal scale up bad response error", + args: mockArg, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 500, + Body: ioutil.NopCloser(bytes.NewReader([]byte(`{`))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + want: nil, + wantErr: errors.New("unexpected end of JSON input"), + }, + { + name: "case 5 scale up instance instance successfully when unmarshal scale up response error", + args: mockArg, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader( + []byte("}"))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + want: nil, + wantErr: snerror.New(statuscode.ScaleUpRequestErrCode, statuscode.ScaleUpRequestErrMsg), + }, + { + name: "case 6 response code in GetWorkerSuccessResponse is not 200", + args: mockArg, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader([]byte("{}"))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + want: nil, + wantErr: snerror.New(0, ""), + }, + { + name: "case 7 scale up instance successfully when failed to make scale up request", + args: mockArg, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc(json.Marshal, func(v any) ([]byte, error) { + return nil, errors.New("json marshal error") + }), + }) + return patches + }, + want: nil, + wantErr: errors.New("make scale up request failed"), + }, + { + name: "case 8 scale up instance instance failed when read body error", + args: mockArg, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader( + []byte(mockRespBody))), + }, nil + }), + ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return nil, errors.New("read body mock error") + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + want: nil, + wantErr: snerror.New(statuscode.ScaleUpRequestErrCode, statuscode.ScaleUpRequestErrMsg), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + wmInstance, err := ScaleUpInstance(tt.args.scaleUpParam) + if err != nil { + assert.Equal(t, tt.wantErr.Error(), err.Error()) + } else { + assert.Equal(t, tt.want.InstanceID, wmInstance.InstanceID) + } + patches.ResetAll() + }) + } +} + +func TestScaleDownInstance(t *testing.T) { + config.GlobalConfig = types.Configuration{ + LocalAuth: localauth.AuthConfig{ + AKey: "000", + SKey: "111", + Duration: 5, + }, + HTTPSConfig: &tls.InternalHTTPSConfig{ + HTTPSEnable: false, + TLSProtocol: "TLSv1.2", + TLSCiphers: "TLS_ECDHE_RSA", + }, + } + type args struct { + instanceID string + functionKey string + traceID string + } + tests := []struct { + name string + args args + patchesFunc mockUtils.PatchesFunc + wantErr error + }{ + { + name: "case 1 scale down function instance", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader([]byte(`{"code":200,"message":"mock message"}`))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + wantErr: nil, + }, + { + name: "case 2 scale down function instance when do request err", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return nil, errors.New("do request mock error") + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + wantErr: errors.New("do request mock error"), + }, + { + name: "case 3 scale down function instance when ioutil.ReadAll error", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader([]byte(`{"code":200,"message":"mock message"}`))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + ApplyFunc(ioutil.ReadAll, func(r io.Reader) ([]byte, error) { + return nil, errors.New("read body mock error") + }), + }) + return patches + }, + wantErr: errors.New("read body mock error"), + }, + { + name: "case 4 scale down function instance when returned bad response body", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 500, + Body: ioutil.NopCloser(bytes.NewReader([]byte(`{"code":500,"message":"mock message"}`))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + wantErr: snerror.New(500, "mock message"), + }, + { + name: "case 5 scale down function instance when failed to unmarshal bad response body", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 500, + Body: ioutil.NopCloser(bytes.NewReader( + []byte(`{`), + )), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + wantErr: errors.New("unexpected end of JSON input"), + }, + { + name: "case 6 scale down function instance when failed to unmarshal DeleteWorkerResponse", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader([]byte(`{`))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + wantErr: errors.New("unexpected end of JSON input"), + }, + { + name: "case 7 scale down function instance when status code in DeleteWorkerResponse is not 200", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc((*http.Client).Do, func(_ *http.Client, req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader([]byte(`{"code":500}`))), + }, nil + }), + ApplyFunc(FillInWorkerManagerRequestHeaders, func(request *http.Request) { return }), + }) + return patches + }, + wantErr: errors.New("worker manager returned error"), + }, + { + name: "case 8 scale down function instance failed because json marshal failed", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc(json.Marshal, func(v any) ([]byte, error) { + return nil, errors.New("json marshal error") + }), + }) + return patches + }, + wantErr: errors.New("failed to make scale down request"), + }, + { + name: "case 9 scale down function instance failed because new http request failed", + args: args{ + instanceID: "default-#-pool30", + functionKey: mockFuncKey, + traceID: "mockTraceID", + }, + patchesFunc: func() mockUtils.PatchSlice { + patches := mockUtils.InitPatchSlice() + patches.Append(mockUtils.PatchSlice{ + ApplyFunc(http.NewRequestWithContext, func( + ctx context.Context, method string, url string, body io.Reader) (*http.Request, error) { + return nil, errors.New("http.NewRequestWithContext") + }), + }) + return patches + }, + wantErr: errors.New("failed to make scale down request"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patches := tt.patchesFunc() + err := ScaleDownInstance(tt.args.instanceID, tt.args.functionKey, tt.args.traceID) + if err != nil { + assert.Equal(t, tt.wantErr.Error(), err.Error()) + } else { + assert.Equal(t, tt.wantErr, err) + } + patches.ResetAll() + }) + } +} + +func TestNeedTryError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "SNError with GettingPodErrorCode", + err: snerror.New(statuscode.GettingPodErrorCode, ""), + want: true, + }, + { + name: "SNError with CancelGeneralizePod", + err: snerror.New(statuscode.ReachMaxOnDemandInstancesPerTenant, ""), + want: true, + }, + { + name: "Non-SNError", + err: errors.New("non-SNError"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NeedTryError(tt.err); got != tt.want { + t.Errorf("NeedTryError() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/yuanrong/pkg/system_function_controller/config/config.go b/yuanrong/pkg/system_function_controller/config/config.go new file mode 100644 index 0000000..c4bf659 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/config/config.go @@ -0,0 +1,215 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package config - +package config + +import ( + "encoding/json" + "fmt" + "os" + "sync" + + "github.com/asaskevich/govalidator/v11" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/sts" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/system_function_controller/state" + "yuanrong/pkg/system_function_controller/types" +) + +const ( + // MetaEtcdPwdKey - + MetaEtcdPwdKey = "metaEtcdPwd" +) + +var ( + faaSControllerConfig *types.Config + faaSSchedulerConfig *types.SchedulerConfig + faaSFrontendConfig *types.FrontendConfig + faaSManagerConfig *types.ManagerConfig + + // SchedulerConfigLock scheduler config rw lock + SchedulerConfigLock sync.RWMutex + // FrontendConfigLock frontend config rw lock + FrontendConfigLock sync.RWMutex + // ManagerConfigLock manager config rw lock + ManagerConfigLock sync.RWMutex +) + +// RecoverConfig will recover config +func RecoverConfig() error { + stateConf := state.GetState() + faaSControllerConfig = &types.Config{} + err := utils.DeepCopyObj(stateConf.FaaSControllerConfig, faaSControllerConfig) + if err != nil { + return err + } + if err = setFaaSConfigurations(); err != nil { + return err + } + log.GetLogger().Infof("configuration recovered") + return nil +} + +// InitConfig will initialize global config +func InitConfig(configData []byte) error { + faaSControllerConfig = &types.Config{} + err := json.Unmarshal(configData, faaSControllerConfig) + if err != nil { + log.GetLogger().Errorf("json unmarshal faaS controller config error: %v", err) + return err + } + + if err = setFaaSConfigurations(); err != nil { + return err + } + if _, err = govalidator.ValidateStruct(faaSControllerConfig); err != nil { + return err + } + if faaSControllerConfig.RawStsConfig.StsEnable { + if err := sts.InitStsSDK(faaSControllerConfig.RawStsConfig.ServerConfig); err != nil { + log.GetLogger().Errorf("failed to init sts sdk, err: %s", err.Error()) + return err + } + if err = os.Setenv(sts.EnvSTSEnable, "true"); err != nil { + log.GetLogger().Errorf("failed to set env of %s, err: %s", sts.EnvSTSEnable, err.Error()) + return err + } + } + if faaSControllerConfig.SccConfig.Enable && crypto.InitializeSCC(faaSControllerConfig.SccConfig) != nil { + return fmt.Errorf("failed to initialize scc") + } + return nil +} + +func setFaaSConfigurations() error { + if faaSControllerConfig == nil { + return fmt.Errorf("faaSController config is nil") + } + if faaSControllerConfig.RouterETCD.UseSecret { + etcd3.SetETCDTLSConfig(&faaSControllerConfig.RouterETCD) + } else { + faaSControllerConfig.RouterETCD.CaFile = faaSControllerConfig.TLSConfig.CaContent + faaSControllerConfig.RouterETCD.CertFile = faaSControllerConfig.TLSConfig.CertContent + faaSControllerConfig.RouterETCD.KeyFile = faaSControllerConfig.TLSConfig.KeyContent + } + if faaSControllerConfig.MetaETCD.UseSecret { + etcd3.SetETCDTLSConfig(&faaSControllerConfig.MetaETCD) + } + etcdConfig, err := DecryptEtcdConfig(faaSControllerConfig.MetaETCD) + if err != nil { + return err + } + faaSControllerConfig.MetaETCD = *etcdConfig + + err = setAlarmEnv(faaSControllerConfig) + if err != nil { + return err + } + return nil +} + +func setAlarmEnv(faaSControllerConfig *types.Config) error { + if faaSControllerConfig == nil || !faaSControllerConfig.AlarmConfig.EnableAlarm { + log.GetLogger().Infof("enable alarm is false") + return nil + } + utils.SetClusterNameEnv(faaSControllerConfig.ClusterName) + alarm.SetAlarmEnv(faaSControllerConfig.AlarmConfig.AlarmLogConfig) + alarm.SetXiangYunFourConfigEnv(faaSControllerConfig.AlarmConfig.XiangYunFourConfig) + err := alarm.SetPodIP() + if err != nil { + return err + } + return nil +} + +// GetFaaSControllerConfig will get faas controller config +func GetFaaSControllerConfig() types.Config { + return *faaSControllerConfig +} + +// GetFaaSSchedulerConfig will get faas scheduler config +func GetFaaSSchedulerConfig() *types.SchedulerConfig { + return faaSSchedulerConfig +} + +// GetFaaSFrontendConfig will get faas frontend config +func GetFaaSFrontendConfig() *types.FrontendConfig { + return faaSFrontendConfig +} + +// GetFaaSManagerConfig will get faas manager config +func GetFaaSManagerConfig() *types.ManagerConfig { + return faaSManagerConfig +} + +// DecryptEtcdConfig decrypt etcd secret +func DecryptEtcdConfig(config etcd3.EtcdConfig) (*etcd3.EtcdConfig, error) { + decryptEnvMap, err := localauth.GetDecryptFromEnv() + if err != nil { + log.GetLogger().Errorf("get decrypt from env error: %v", err) + return nil, err + } + if decryptEnvMap[MetaEtcdPwdKey] != "" { + config.Password = decryptEnvMap[MetaEtcdPwdKey] + } + return &config, nil +} + +// InitEtcd - init router etcd and meta etcd +func InitEtcd(stopCh <-chan struct{}) error { + if faaSControllerConfig == nil { + return fmt.Errorf("config is not initialized") + } + if err := etcd3.InitRouterEtcdClient(faaSControllerConfig.RouterETCD, + faaSControllerConfig.AlarmConfig, stopCh); err != nil { + return fmt.Errorf("faaSController failed to init route etcd: %s", err.Error()) + } + + if err := etcd3.InitMetaEtcdClient(faaSControllerConfig.MetaETCD, + faaSControllerConfig.AlarmConfig, stopCh); err != nil { + return fmt.Errorf("faaSController failed to init metadata etcd: %s", err.Error()) + } + return nil +} + +// UpdateSchedulerConfig update scheduler config +func UpdateSchedulerConfig(cfg *types.SchedulerConfig) { + SchedulerConfigLock.Lock() + faaSSchedulerConfig = cfg + SchedulerConfigLock.Unlock() +} + +// UpdateFrontendConfig update frontend config +func UpdateFrontendConfig(cfg *types.FrontendConfig) { + FrontendConfigLock.Lock() + faaSFrontendConfig = cfg + FrontendConfigLock.Unlock() +} + +// UpdateManagerConfig update manager config +func UpdateManagerConfig(cfg *types.ManagerConfig) { + ManagerConfigLock.Lock() + faaSManagerConfig = cfg + ManagerConfigLock.Unlock() +} diff --git a/yuanrong/pkg/system_function_controller/config/config_test.go b/yuanrong/pkg/system_function_controller/config/config_test.go new file mode 100644 index 0000000..081e86f --- /dev/null +++ b/yuanrong/pkg/system_function_controller/config/config_test.go @@ -0,0 +1,308 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/localauth" + "yuanrong/pkg/common/faas_common/sts" + "yuanrong/pkg/common/faas_common/sts/raw" + ftypes "yuanrong/pkg/frontend/types" + stypes "yuanrong/pkg/functionscaler/types" + "yuanrong/pkg/system_function_controller/state" + "yuanrong/pkg/system_function_controller/types" +) + +var ( + configString = `{ + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + }, + "alarmConfig":{"enableAlarm": true} + } + ` + schedulerConfigString = `{ + "cpu":999, + "memory":999, + "autoScaleConfig":{ + "slaQuota": 1000, + "scaleDownTime": 60000, + "burstScaleNum": 1000 + }, + "leaseSpan":600000, + "functionLimitRate":0, + "routerEtcd":{"servers":["1.2.3.4:1234"],"user":"tom","password":"**","sslEnable":false,"CaFile":"","CertFile":"","KeyFile":""}, + "metaEtcd":{"servers":["1.2.3.4:5678"],"user":"tom","password":"**","sslEnable":false,"CaFile":"","CertFile":"","KeyFile":""}, + "schedulerNum":100 + }` + frontendConfigString = `{ + "instanceNum":100, + "cpu":777, + "memory":777, + "slaQuota":1000, + "trafficLimitDisable":true, + "routerEtcd":{ + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd":{ + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + }, + "http":{"maxRequestBodySize": 6} + }` + managerConfigString = `{ + "managerInstanceNum":0, + "routerEtcd":{ + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd":{ + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + }, + "alarmConfig":{"enableAlarm": true} + }` + lostSTSConfigString = `{ + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + }, + "rawStsConfig": { + "stsEnable": true + } + } + ` + + invalidConfigString = `{ + "x3": { + "url": ["1.2.3.4:1234"], + "username": "tom", + "password": "**" + } + } + ` + + routerEtcdConfig = etcd3.EtcdConfig{ + Servers: []string{"1.2.3.4:1234"}, + User: "tom", + Password: "**", + } + + metaEtcdConfig = etcd3.EtcdConfig{ + Servers: []string{"1.2.3.4:5678"}, + User: "tom", + Password: "**", + } + + schedulerBasicConfig = types.SchedulerConfig{ + Configuration: stypes.Configuration{ + CPU: 999, + Memory: 999, + AutoScaleConfig: stypes.AutoScaleConfig{ + SLAQuota: 1000, + ScaleDownTime: 60000, + BurstScaleNum: 1000, + }, + LeaseSpan: 600000, + RouterETCDConfig: routerEtcdConfig, + MetaETCDConfig: metaEtcdConfig, + }, + SchedulerNum: 100, + } + + frontendConfig = types.FrontendConfig{ + Config: ftypes.Config{ + InstanceNum: 100, + CPU: 777, + Memory: 777, + SLAQuota: 1000, + AuthenticationEnable: false, + HTTPConfig: &ftypes.FrontendHTTP{ + MaxRequestBodySize: 6, + }, + RouterEtcd: routerEtcdConfig, + MetaEtcd: metaEtcdConfig, + }, + } + + managerConf = &types.ManagerConfig{ + ManagerInstanceNum: 0, + RouterEtcd: routerEtcdConfig, + MetaEtcd: metaEtcdConfig, + AlarmConfig: alarm.Config{EnableAlarm: true}, + } + + expectedFaaSControllerConfig = types.Config{ + RouterETCD: routerEtcdConfig, + MetaETCD: metaEtcdConfig, + AlarmConfig: alarm.Config{EnableAlarm: true}, + } + + expectedFaaSSchedulerConfig = &types.SchedulerConfig{ + Configuration: schedulerBasicConfig.Configuration, + SchedulerNum: 100, + } + + expectedFaaSFrontendConfig = &frontendConfig +) + +func TestInitConfig(t *testing.T) { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + Convey("Test InitConfig", t, func() { + Convey("Test InitConfig with invalid config", func() { + faaSControllerConfig = nil + err := InitConfig([]byte("123")) + So(err, ShouldNotBeNil) + + err = InitConfig([]byte(invalidConfigString)) + So(err, ShouldNotBeNil) + }) + + Convey("Test InitConfig with valid config", func() { + faaSControllerConfig = nil + defer gomonkey.ApplyFunc(localauth.GetDecryptFromEnv, func() (map[string]string, error) { + return map[string]string{"metaEtcdPwd": "123"}, nil + }).Reset() + err := InitConfig([]byte(configString)) + So(err, ShouldBeNil) + }) + + Convey("Test InitConfig decrypt error", func() { + faaSControllerConfig = nil + defer gomonkey.ApplyFunc(localauth.GetDecryptFromEnv, func() (map[string]string, error) { + return nil, errors.New("decrypt error") + }).Reset() + err := InitConfig([]byte(configString)) + So(err, ShouldNotBeNil) + }) + + Convey("Test InitConfig STS error", func() { + faaSControllerConfig = nil + err := InitConfig([]byte(lostSTSConfigString)) + So(err, ShouldNotBeNil) + }) + + Convey("Test InitConfig STS success", func() { + faaSControllerConfig = nil + defer gomonkey.ApplyFunc(sts.InitStsSDK, func(serverCfg raw.ServerConfig) error { + return nil + }).Reset() + err := InitConfig([]byte(lostSTSConfigString)) + So(err, ShouldBeNil) + }) + }) +} + +func TestGetFaaSControllerConfig(t *testing.T) { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + faaSControllerConfig = nil + Convey("Test GetFaaSControllerConfig", t, func() { + InitConfig([]byte(configString)) + got := GetFaaSControllerConfig() + fmt.Printf("%+v \n", got) + fmt.Printf("%+v", expectedFaaSControllerConfig) + So(reflect.DeepEqual(got, expectedFaaSControllerConfig), ShouldBeTrue) + }) +} + +func TestGetFaaSSchedulerConfig(t *testing.T) { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + Convey("Test GetFaaSSchedulerConfig", t, func() { + var cfg *types.SchedulerConfig + _ = json.Unmarshal([]byte(schedulerConfigString), &cfg) + UpdateSchedulerConfig(cfg) + got := GetFaaSSchedulerConfig() + So(reflect.DeepEqual(got, expectedFaaSSchedulerConfig), ShouldBeTrue) + }) +} + +func TestGetFaaSFrontendConfig(t *testing.T) { + defer gomonkey.ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + Convey("Test GetFaaSFrontendConfig", t, func() { + var cfg *types.FrontendConfig + _ = json.Unmarshal([]byte(frontendConfigString), &cfg) + UpdateFrontendConfig(cfg) + got := GetFaaSFrontendConfig() + So(got, ShouldResemble, expectedFaaSFrontendConfig) + }) +} + +func TestGetFaaSManagerConfig(t *testing.T) { + Convey("TestGetFaaSManagerConfig", t, func() { + var cfg *types.ManagerConfig + _ = json.Unmarshal([]byte(managerConfigString), &cfg) + UpdateManagerConfig(cfg) + got := GetFaaSManagerConfig() + So(reflect.DeepEqual(got, managerConf), ShouldBeTrue) + }) +} + +func TestRecoverConfig(t *testing.T) { + Convey("RecoverConfig", t, func() { + conf := []byte(`{"FaaSControllerConfig":{ + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + } + }}`) + state.SetState(conf) + err := RecoverConfig() + So(err, ShouldBeNil) + }) +} + +func TestInitEtcd(t *testing.T) { + Convey("TestInitEtcd", t, func() { + Convey("failed config is not initialized", func() { + faaSControllerConfig = nil + err := InitEtcd(make(chan struct{})) + So(err, ShouldNotBeNil) + }) + + Convey("success", func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdInitParam{}), "InitClient", + func(_ *etcd3.EtcdInitParam) error { + return nil + }).Reset() + InitConfig([]byte(configString)) + err := InitEtcd(make(chan struct{})) + So(err, ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/system_function_controller/constant/constant.go b/yuanrong/pkg/system_function_controller/constant/constant.go new file mode 100644 index 0000000..d8d2dc9 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/constant/constant.go @@ -0,0 +1,88 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package constant - +package constant + +import ( + "math" + "time" +) + +const ( + // ConcurrencyKey is the key for concurrency in CreateOption + ConcurrencyKey = "ConcurrentNum" +) + +const ( + // NamespaceDefault - + NamespaceDefault = "default" + // SystemFuncName for pod labels + SystemFuncName = "systemFuncName" + // FuncNameFaasfrontend - + FuncNameFaasfrontend = "faasfrontend" + // FuncNameFaasscheduler - + FuncNameFaasscheduler = "faasscheduler" + // FuncNameFaasmanager - + FuncNameFaasmanager = "faasmanager" + // ConcurrentNumKey - + ConcurrentNumKey = "ConcurrentNum" + // DefaultConcurrentNum - + DefaultConcurrentNum = 32 + // SchedulerExclusivity - + SchedulerExclusivity = "exclusivity" + // RetryCounts - + RetryCounts = 3 + // RetryInterval - + RetryInterval = 3 * time.Second + // DefaultInstanceNum - + DefaultInstanceNum = 1 + // RetryIntervalIncrement - + RetryIntervalIncrement = 2 + // MinSleepTime - + MinSleepTime = 1 * time.Second + // MaxSleepTime - + MaxSleepTime = 60 * time.Second + // SystemFunctionKinds - + SystemFunctionKinds = 3 + // DefaultCreateRetryTime - + DefaultCreateRetryTime = math.MaxInt + // DefaultCreateRetryDuration - + DefaultCreateRetryDuration = 2 * time.Second + // DefaultCreateRetryFactor - + DefaultCreateRetryFactor = 1 + // DefaultCreateRetryJitter - + DefaultCreateRetryJitter = 5 + // DefaultChannelSize - + DefaultChannelSize = 100 + // DefaultMultiples - + DefaultMultiples = 2 + // RecreateSleepTime - + RecreateSleepTime = 10 * MinSleepTime + // MaxConcurrency - + MaxConcurrency = 1000 + // InitCallTimeoutKey is the key for init call timeout in CreateOption + InitCallTimeoutKey = "init_call_timeout" +) + +const ( + // ServiceFrontendPort - + ServiceFrontendPort = 8888 + // ServiceFrontendTargetPort - + ServiceFrontendTargetPort = 8888 + // ServiceFrontendNodePort is k8s service NodePort + ServiceFrontendNodePort = 31222 +) diff --git a/yuanrong/pkg/system_function_controller/faascontroller/fasscontroller.go b/yuanrong/pkg/system_function_controller/faascontroller/fasscontroller.go new file mode 100644 index 0000000..d0878a1 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/faascontroller/fasscontroller.go @@ -0,0 +1,277 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faascontroller - +package faascontroller + +import ( + "encoding/base64" + "encoding/json" + "sync" + + "go.uber.org/zap" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/instancemanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager" + "yuanrong/pkg/system_function_controller/registry" + "yuanrong/pkg/system_function_controller/service" + "yuanrong/pkg/system_function_controller/types" +) + +// FaaSController define the controller that manages the faas scheduler instances +type FaaSController struct { + instanceManager *instancemanager.InstanceManager + sdkClient api.LibruntimeAPI + frontendOnce sync.Once + schedulerOnce sync.Once + managerOnce sync.Once + funcCh chan types.SubEvent + stopCh chan struct{} + allocRecord sync.Map + sync.RWMutex +} + +const ( + defaultChanSize = 100 +) + +// NewFaaSControllerLibruntime will create a new scheduler instance manager by new sdk of multi libruntime +func NewFaaSControllerLibruntime(libruntimeAPI api.LibruntimeAPI, stopCh chan struct{}) (*FaaSController, error) { + faaSController := &FaaSController{ + instanceManager: &instancemanager.InstanceManager{}, + sdkClient: libruntimeAPI, + funcCh: make(chan types.SubEvent, defaultChanSize), + allocRecord: sync.Map{}, + RWMutex: sync.RWMutex{}, + stopCh: stopCh, + } + go faaSController.processFunctionSubscription() + return faaSController, nil +} + +// NewFaaSController will create a new scheduler instance manager +func NewFaaSController(sdkClient api.LibruntimeAPI, stopCh chan struct{}) (*FaaSController, error) { + faaSController := &FaaSController{ + instanceManager: &instancemanager.InstanceManager{}, + sdkClient: sdkClient, + funcCh: make(chan types.SubEvent, defaultChanSize), + allocRecord: sync.Map{}, + RWMutex: sync.RWMutex{}, + stopCh: stopCh, + } + if err := service.CreateFrontendService(); err != nil { + log.GetLogger().Errorf("failed to create service, reason: %s", err.Error()) + } + go faaSController.processFunctionSubscription() + return faaSController, nil +} + +func (fc *FaaSController) processFunctionSubscription() { + for { + select { + case event, ok := <-fc.funcCh: + if !ok { + log.GetLogger().Warnf("function channel is closed") + return + } + functionSpec, ok := event.EventMsg.(*types.InstanceSpecification) + if !ok { + log.GetLogger().Warnf("event message doesn't contain instance specification") + continue + } + if event.EventType == types.SubEventTypeUpdate { + log.GetLogger().Debugf("receive update event. instanceID=%s", functionSpec.InstanceID) + go fc.instanceManager.HandleEventUpdate(functionSpec, event.EventKind) + } + if event.EventType == types.SubEventTypeDelete { + log.GetLogger().Debugf("receive delete event. instanceID=%s", functionSpec.InstanceID) + go fc.instanceManager.HandleEventDelete(functionSpec, event.EventKind) + } + if event.EventType == types.SubEventTypeRecover { + log.GetLogger().Debugf("receive recover event. instanceID=%s", functionSpec.InstanceID) + go fc.instanceManager.HandleEventRecover(functionSpec, event.EventKind) + } + } + } +} + +// FrontendSignalHandler frontend signal handler +func (fc *FaaSController) FrontendSignalHandler(data []byte) error { + traceID := uuid.New().String() + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + if len(data) == 0 { + logger.Warnf("config is empty, exit frontend hot update") + return nil + } + cfgStr, err := base64.StdEncoding.DecodeString(string(data)) + if err != nil { + logger.Errorf("failed to decode frontend config bytes: %v, config: %s", err, data) + return err + } + log.GetLogger().Infof("frontend config string is: %s", cfgStr) + frontendConfig := &types.FrontendConfig{} + err = json.Unmarshal(cfgStr, frontendConfig) + if err != nil { + logger.Errorf("failed to parse frontend config: %v, config: %s", err, cfgStr) + return err + } + etcdConfig, err := config.DecryptEtcdConfig(frontendConfig.MetaEtcd) + if err != nil { + return err + } + frontendConfig.MetaEtcd = *etcdConfig + fc.frontendOnce.Do(func() { + config.UpdateFrontendConfig(frontendConfig) + frontendManager := faasfrontendmanager.NewFaaSFrontendManager(fc.sdkClient, + etcd3.GetRouterEtcdClient(), fc.stopCh, frontendConfig.InstanceNum, frontendConfig.DynamicPoolEnable) + fc.instanceManager.FrontendManager = frontendManager + frontendRegistry := registry.NewFrontendRegistry(fc.stopCh) + registry.GlobalRegistry.AddFunctionRegistry(frontendRegistry, types.EventKindFrontend) + frontendRegistry.AddSubscriberChan(fc.funcCh) + frontendRegistry.InitWatcher() + frontendRegistry.RunWatcher() + }) + cfgEvent := &types.ConfigChangeEvent{ + FrontendCfg: frontendConfig, + TraceID: traceID, + } + cfgEvent.Add(1) + fc.instanceManager.FrontendManager.ConfigChangeCh <- cfgEvent + cfgEvent.Wait() + log.GetLogger().Infof("frontend signal handler deal complete, err: %v", cfgEvent.Error) + return cfgEvent.Error +} + +// SchedulerSignalHandler scheduler signal handler +func (fc *FaaSController) SchedulerSignalHandler(data []byte) error { + traceID := uuid.New().String() + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + if len(data) == 0 { + logger.Warnf("config is empty, exit scheduler hot update") + return nil + } + cfgStr, err := base64.StdEncoding.DecodeString(string(data)) + if err != nil { + logger.Errorf("failed to decode scheduler config bytes: %v, config: %s", err, data) + return err + } + log.GetLogger().Infof("scheduler config string is: %s", cfgStr) + schedulerConfig := &types.SchedulerConfig{} + err = json.Unmarshal(cfgStr, schedulerConfig) + if err != nil { + logger.Errorf("failed to parse scheduler config: %v, config: %s", err, cfgStr) + return err + } + etcdConfig, err := config.DecryptEtcdConfig(schedulerConfig.MetaETCDConfig) + if err != nil { + return err + } + schedulerConfig.MetaETCDConfig = *etcdConfig + fc.schedulerOnce.Do(func() { + config.UpdateSchedulerConfig(schedulerConfig) + // init default scheduler manager and other scheduler managers for exclusivity tenant + fc.initSchedulerManagers(schedulerConfig) + schedulerRegistry := registry.NewSchedulerRegistry(fc.stopCh) + registry.GlobalRegistry.AddFunctionRegistry(schedulerRegistry, types.EventKindScheduler) + schedulerRegistry.AddSubscriberChan(fc.funcCh) + schedulerRegistry.InitWatcher() + schedulerRegistry.RunWatcher() + }) + cfgEvent := &types.ConfigChangeEvent{ + SchedulerCfg: schedulerConfig, + TraceID: traceID, + } + cfgEvent.Add(1) + fc.instanceManager.CommonSchedulerManager.ConfigChangeCh <- cfgEvent + for tenantID := range fc.instanceManager.ExclusivitySchedulerManagers { + cfgEvent.Add(1) + if fc.instanceManager.ExclusivitySchedulerManagers[tenantID] != nil { + fc.instanceManager.ExclusivitySchedulerManagers[tenantID].ConfigChangeCh <- cfgEvent + } + } + cfgEvent.Wait() + log.GetLogger().Infof("scheduler signal handler deal complete, err: %v", cfgEvent.Error) + return cfgEvent.Error +} + +// ManagerSignalHandler manager signal handler +func (fc *FaaSController) ManagerSignalHandler(data []byte) error { + traceID := uuid.New().String() + logger := log.GetLogger().With(zap.Any("traceID", traceID)) + if len(data) == 0 { + logger.Warnf("config is empty, exit funcManager hot update") + return nil + } + cfgStr, err := base64.StdEncoding.DecodeString(string(data)) + if err != nil { + logger.Errorf("failed to decode funcManager config bytes: %v, config: %s", err, data) + return err + } + log.GetLogger().Infof("funcManager config string is: %s", cfgStr) + managerConfig := &types.ManagerConfig{} + err = json.Unmarshal(cfgStr, managerConfig) + if err != nil { + logger.Errorf("failed to parse funcManager config: %v, config: %s", err, cfgStr) + return err + } + etcdConfig, err := config.DecryptEtcdConfig(managerConfig.MetaEtcd) + if err != nil { + return err + } + managerConfig.MetaEtcd = *etcdConfig + fc.managerOnce.Do(func() { + config.UpdateManagerConfig(managerConfig) + funcManager := faasfunctionmanager.NewFaaSFunctionManager(fc.sdkClient, + etcd3.GetRouterEtcdClient(), + fc.stopCh, managerConfig.ManagerInstanceNum) + fc.instanceManager.FunctionManager = funcManager + managerRegistry := registry.NewManagerRegistry(fc.stopCh) + registry.GlobalRegistry.AddFunctionRegistry(managerRegistry, types.EventKindManager) + managerRegistry.AddSubscriberChan(fc.funcCh) + managerRegistry.InitWatcher() + managerRegistry.RunWatcher() + }) + cfgEvent := &types.ConfigChangeEvent{ + ManagerCfg: managerConfig, + TraceID: traceID, + } + cfgEvent.Add(1) + fc.instanceManager.FunctionManager.ConfigChangeCh <- cfgEvent + cfgEvent.Wait() + log.GetLogger().Infof("funcManager signal handler deal complete, err: %v", cfgEvent.Error) + return cfgEvent.Error +} + +func (fc *FaaSController) initSchedulerManagers(schedulerConfig *types.SchedulerConfig) { + schedulerManager := faasschedulermanager.NewFaaSSchedulerManager(fc.sdkClient, etcd3.GetRouterEtcdClient(), + fc.stopCh, schedulerConfig.SchedulerNum, "") + fc.instanceManager.CommonSchedulerManager = schedulerManager + fc.instanceManager.ExclusivitySchedulerManagers = make(map[string]*faasschedulermanager.SchedulerManager, + len(config.GetFaaSControllerConfig().SchedulerExclusivity)) + for _, tenantID := range config.GetFaaSControllerConfig().SchedulerExclusivity { + log.GetLogger().Infof("begin to new scheduler manager for tenant %s", tenantID) + fc.instanceManager.ExclusivitySchedulerManagers[tenantID] = faasschedulermanager.NewFaaSSchedulerManager( + fc.sdkClient, etcd3.GetRouterEtcdClient(), fc.stopCh, 1, tenantID) + } +} diff --git a/yuanrong/pkg/system_function_controller/faascontroller/fasscontroller_test.go b/yuanrong/pkg/system_function_controller/faascontroller/fasscontroller_test.go new file mode 100644 index 0000000..c631df9 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/faascontroller/fasscontroller_test.go @@ -0,0 +1,356 @@ +package faascontroller + +import ( + "encoding/base64" + "encoding/json" + "reflect" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/client/v3" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/etcd3" + mockUtils "yuanrong/pkg/common/faas_common/utils" + ftypes "yuanrong/pkg/frontend/types" + stypes "yuanrong/pkg/functionscaler/types" + fcConfig "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager" + "yuanrong/pkg/system_function_controller/registry" + "yuanrong/pkg/system_function_controller/types" +) + +var ( + configString = `{ + "frontendInstanceNum": 5, + "schedulerInstanceNum": 5, + "managerInstanceNum": 5, + "faasschedulerConfig": { + "slaQuota": 1000, + "scaleDownTime": 60000, + "burstScaleNum": 1000, + "leaseSpan": 600000 + }, + "etcd": { + "url": ["1.2.3.4:1234"], + "username": "tom", + "password": "**" + }, + "enableRetry": true + } + ` + + invalidConfigString = `{ + "frontendInstanceNum": 0, + "schedulerInstanceNum": 0, + "managerInstanceNum": 0, + "faasschedulerConfig": { + "slaQuota": 1000, + "scaleDownTime": 60000, + "burstScaleNum": 1000, + "leaseSpan": 600000 + }, + "etcd": { + "url": ["1.2.3.4:1234"], + "username": "tom", + "password": "**" + } + } + ` +) + +func newFaaSControllerHelper() (*FaaSController, error) { + stopCh := make(chan struct{}) + fcConfig.InitConfig([]byte(configString)) + + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&etcd3.EtcdInitParam{}), "InitClient", func(_ *etcd3.EtcdInitParam) error { + return nil + }), + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }), + ApplyMethod(reflect.TypeOf(&faasschedulermanager.SchedulerManager{}), "GetInstanceCountFromEtcd", + func(_ *faasschedulermanager.SchedulerManager) map[string]struct{} { + return map[string]struct{}{} + }), + ApplyMethod(reflect.TypeOf(&faasfrontendmanager.FrontendManager{}), "GetInstanceCountFromEtcd", + func(_ *faasfrontendmanager.FrontendManager) map[string]struct{} { + return map[string]struct{}{} + }), + ApplyMethod(reflect.TypeOf(&faasfunctionmanager.FunctionManager{}), "GetInstanceCountFromEtcd", + func(_ *faasfunctionmanager.FunctionManager) map[string]struct{} { + return map[string]struct{}{} + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + controller, _ := NewFaaSController(&mockUtils.FakeLibruntimeSdkClient{}, stopCh) + controller.instanceManager.FrontendManager = &faasfrontendmanager.FrontendManager{ConfigChangeCh: make(chan *types.ConfigChangeEvent, + 1)} + controller.instanceManager.CommonSchedulerManager = &faasschedulermanager.SchedulerManager{ConfigChangeCh: make(chan *types.ConfigChangeEvent, + 1)} + return controller, nil +} + +// TestNewFaaSController - +func TestNewFaaSController(t *testing.T) { + defer ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "Put", func(_ *etcd3.EtcdClient, + ctxInfo etcd3.EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return nil + }).Reset() + // state.InitState() + stopCh := make(chan struct{}) + Convey("Test NewFaaSController for failure", t, func() { + + Convey("Test NewFaaSController when passed invalid instance number", func() { + fcConfig.InitConfig([]byte(invalidConfigString)) + controller, err := NewFaaSController(&mockUtils.FakeLibruntimeSdkClient{}, stopCh) + So(controller, ShouldNotBeNil) + So(err, ShouldBeNil) + }) + }) + time.Sleep(50 * time.Millisecond) +} +func Test_processFunctionSubscription(t *testing.T) { + Convey("processFunctionSubscription", t, func() { + defer ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer ApplyMethod(reflect.TypeOf(&etcd3.EtcdClient{}), "Put", func(_ *etcd3.EtcdClient, + ctxInfo etcd3.EtcdCtxInfo, key string, value string, opts ...clientv3.OpOption) error { + return nil + }).Reset() + controller, err := newFaaSControllerHelper() + So(controller, ShouldNotBeNil) + So(err, ShouldBeNil) + + Convey("assert failed and close ch", func() { + go func() { + controller.funcCh <- types.SubEvent{EventMsg: "invalid"} + time.Sleep(500 * time.Millisecond) + close(controller.funcCh) + }() + controller.processFunctionSubscription() + }) + + Convey("SubEventTypeUpdate", func() { + controller.funcCh = make(chan types.SubEvent, 1) + go func() { + controller.funcCh <- types.SubEvent{ + EventType: types.SubEventTypeUpdate, + EventMsg: &types.InstanceSpecification{}, + } + time.Sleep(500 * time.Millisecond) + close(controller.funcCh) + }() + controller.processFunctionSubscription() + }) + + Convey("SubEventTypeDelete", func() { + controller.funcCh = make(chan types.SubEvent, 1) + go func() { + controller.funcCh <- types.SubEvent{ + EventType: types.SubEventTypeDelete, + EventMsg: &types.InstanceSpecification{}, + } + time.Sleep(500 * time.Millisecond) + close(controller.funcCh) + }() + controller.processFunctionSubscription() + }) + + Convey("SubEventTypeRecover", func() { + controller.funcCh = make(chan types.SubEvent, 1) + go func() { + controller.funcCh <- types.SubEvent{ + EventType: types.SubEventTypeRecover, + EventMsg: &types.InstanceSpecification{}, + } + time.Sleep(500 * time.Millisecond) + close(controller.funcCh) + }() + controller.processFunctionSubscription() + }) + time.Sleep(50 * time.Millisecond) + }) +} + +func TestFaaSController_FrontendSignalHandler(t *testing.T) { + controller, err := newFaaSControllerHelper() + assert.NotNil(t, controller) + assert.Nil(t, err) + Convey("FrontendSignalHandler", t, func() { + Convey("config is emtpy", func() { + err = controller.FrontendSignalHandler([]byte{}) + So(err, ShouldBeNil) + }) + Convey("failed to decode frontend config", func() { + err = controller.FrontendSignalHandler([]byte("123")) + So(err, ShouldBeError) + }) + Convey("success", func() { + defer ApplyFunc(faasfrontendmanager.NewFaaSFrontendManager, func(libruntimeApi api.LibruntimeAPI, + etcdClient *etcd3.EtcdClient, stopCh chan struct{}, size int, isDynamic bool) *faasfrontendmanager.FrontendManager { + return &faasfrontendmanager.FrontendManager{ConfigChangeCh: make(chan *types.ConfigChangeEvent)} + }).Reset() + defer ApplyMethod(reflect.TypeOf(®istry.FaaSFrontendRegistry{}), "InitWatcher", + func(fr *registry.FaaSFrontendRegistry) {}).Reset() + defer ApplyMethod(reflect.TypeOf(®istry.FaaSFrontendRegistry{}), "RunWatcher", + func(fr *registry.FaaSFrontendRegistry) {}).Reset() + frontendConfig := types.FrontendConfig{ + Config: ftypes.Config{ + InstanceNum: 100, + CPU: 777, + Memory: 777, + SLAQuota: 1000, + AuthenticationEnable: false, + HTTPConfig: &ftypes.FrontendHTTP{ + MaxRequestBodySize: 6, + }, + }, + } + bytes, _ := json.Marshal(frontendConfig) + encodeToString := base64.StdEncoding.EncodeToString(bytes) + registry.InitRegistry() + go func() { + time.Sleep(50 * time.Millisecond) + event := <-controller.instanceManager.FrontendManager.ConfigChangeCh + event.Error = nil + event.Done() + }() + err = controller.FrontendSignalHandler([]byte(encodeToString)) + So(err, ShouldBeNil) + }) + }) +} + +func TestFaaSController_SchedulerSignalHandler(t *testing.T) { + mockTenantID001 := "mock-tenant-001" + mockTenantID002 := "mock-tenant-001" + mockTenantSchedulerManager001 := &faasschedulermanager.SchedulerManager{ConfigChangeCh: make(chan *types.ConfigChangeEvent, + 1)} + mockTenantSchedulerManager002 := &faasschedulermanager.SchedulerManager{ConfigChangeCh: make(chan *types.ConfigChangeEvent, + 1)} + controller, err := newFaaSControllerHelper() + assert.NotNil(t, controller) + assert.Nil(t, err) + Convey("SchedulerSignalHandler", t, func() { + Convey("config is emtpy", func() { + err = controller.SchedulerSignalHandler([]byte{}) + So(err, ShouldBeNil) + }) + Convey("failed to decode scheduler config", func() { + err = controller.SchedulerSignalHandler([]byte("123")) + So(err, ShouldBeError) + }) + Convey("success", func() { + defer ApplyFunc(fcConfig.GetFaaSControllerConfig, func() types.Config { + return types.Config{SchedulerExclusivity: []string{mockTenantID001, mockTenantID002}} + }).Reset() + defer ApplyFunc(faasschedulermanager.NewFaaSSchedulerManager, func(libruntimeAPI api.LibruntimeAPI, + etcdClient *etcd3.EtcdClient, stopCh chan struct{}, size int, + tenantID string) *faasschedulermanager.SchedulerManager { + if tenantID == mockTenantID001 { + return mockTenantSchedulerManager001 + } + if tenantID == mockTenantID002 { + return mockTenantSchedulerManager002 + } + return &faasschedulermanager.SchedulerManager{ConfigChangeCh: make(chan *types.ConfigChangeEvent)} + }).Reset() + defer ApplyMethod(reflect.TypeOf(®istry.FaaSSchedulerRegistry{}), "InitWatcher", + func(fr *registry.FaaSSchedulerRegistry) {}).Reset() + defer ApplyMethod(reflect.TypeOf(®istry.FaaSSchedulerRegistry{}), "RunWatcher", + func(fr *registry.FaaSSchedulerRegistry) {}).Reset() + schedulerBasicConfig := types.SchedulerConfig{ + Configuration: stypes.Configuration{ + CPU: 999, + Memory: 999, + AutoScaleConfig: stypes.AutoScaleConfig{ + SLAQuota: 1000, + ScaleDownTime: 60000, + BurstScaleNum: 1000, + }, + LeaseSpan: 600000, + }, + SchedulerNum: 100, + } + bytes, _ := json.Marshal(schedulerBasicConfig) + encodeToString := base64.StdEncoding.EncodeToString(bytes) + registry.InitRegistry() + go func() { + time.Sleep(50 * time.Millisecond) + event := <-controller.instanceManager.CommonSchedulerManager.ConfigChangeCh + event.Error = nil + event.Done() + }() + go func() { + time.Sleep(50 * time.Millisecond) + event := <-mockTenantSchedulerManager001.ConfigChangeCh + event.Error = nil + event.Done() + }() + go func() { + time.Sleep(50 * time.Millisecond) + event := <-mockTenantSchedulerManager002.ConfigChangeCh + event.Error = nil + event.Done() + }() + err = controller.SchedulerSignalHandler([]byte(encodeToString)) + So(err, ShouldBeNil) + }) + }) +} + +func TestFaaSController_ManagerSignalHandler(t *testing.T) { + controller, err := newFaaSControllerHelper() + assert.NotNil(t, controller) + assert.Nil(t, err) + Convey("SchedulerSignalHandler", t, func() { + Convey("config is emtpy", func() { + err = controller.ManagerSignalHandler([]byte{}) + So(err, ShouldBeNil) + }) + Convey("failed to decode scheduler config", func() { + err = controller.ManagerSignalHandler([]byte("123")) + So(err, ShouldBeError) + }) + Convey("success", func() { + defer ApplyFunc(faasfunctionmanager.NewFaaSFunctionManager, func(libruntimeAPI api.LibruntimeAPI, + etcdClient *etcd3.EtcdClient, stopCh chan struct{}, size int) *faasfunctionmanager.FunctionManager { + return &faasfunctionmanager.FunctionManager{ConfigChangeCh: make(chan *types.ConfigChangeEvent)} + }).Reset() + defer ApplyMethod(reflect.TypeOf(®istry.FaaSManagerRegistry{}), "InitWatcher", + func(fr *registry.FaaSManagerRegistry) {}).Reset() + defer ApplyMethod(reflect.TypeOf(®istry.FaaSManagerRegistry{}), "RunWatcher", + func(fr *registry.FaaSManagerRegistry) {}).Reset() + managerBasicConfig := types.ManagerConfig{ + CPU: 999, + Memory: 999, + } + bytes, _ := json.Marshal(managerBasicConfig) + encodeToString := base64.StdEncoding.EncodeToString(bytes) + registry.InitRegistry() + go func() { + time.Sleep(50 * time.Millisecond) + event := <-controller.instanceManager.FunctionManager.ConfigChangeCh + event.Error = nil + event.Done() + }() + err = controller.ManagerSignalHandler([]byte(encodeToString)) + So(err, ShouldBeNil) + }) + }) +} diff --git a/yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager/faasfrontendmanager.go b/yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager/faasfrontendmanager.go new file mode 100644 index 0000000..8638b92 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager/faasfrontendmanager.go @@ -0,0 +1,855 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faasfrontendmanager manages faasfrontend status and instance ID +package faasfrontendmanager + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + "go.etcd.io/etcd/client/v3" + "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/wait" + + "yuanrong.org/kernel/runtime/libruntime/api" + + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/sts" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/constant" + "yuanrong/pkg/system_function_controller/state" + "yuanrong/pkg/system_function_controller/types" + controllerutils "yuanrong/pkg/system_function_controller/utils" +) + +const ( + dsWorkerUnreadyKey = "is-ds-worker-unready" +) + +var ( + once sync.Once + // faasFrontendManager is the singaleton of Manager + frontendManager *FrontendManager + + createInstanceBackoff = wait.Backoff{ + Steps: constant.DefaultCreateRetryTime, // retry times (include first time) + Duration: constant.DefaultCreateRetryDuration, + Factor: constant.DefaultCreateRetryFactor, + Jitter: constant.DefaultCreateRetryJitter, + } +) + +// FrontendManager manages faasfrontend status and instance ID +type FrontendManager struct { + instanceCache map[string]*types.InstanceSpecification + terminalCache map[string]*types.InstanceSpecification + etcdClient *etcd3.EtcdClient + sdkClient api.LibruntimeAPI + ConfigChangeCh chan *types.ConfigChangeEvent + recreateInstanceIDCh chan string + recreateInstanceIDMap sync.Map + stopCh chan struct{} + sync.RWMutex + count int + isDynamic bool +} + +// GetFrontendManager can only be called after NewFrontendManager +func GetFrontendManager() *FrontendManager { + return frontendManager +} + +// NewFaaSFrontendManager supply a singleton frontend manager +func NewFaaSFrontendManager(libruntimeAPI api.LibruntimeAPI, + etcdClient *etcd3.EtcdClient, stopCh chan struct{}, size int, isDynamic bool) *FrontendManager { + once.Do(func() { + frontendManager = &FrontendManager{ + instanceCache: make(map[string]*types.InstanceSpecification), + terminalCache: map[string]*types.InstanceSpecification{}, + etcdClient: etcdClient, + sdkClient: libruntimeAPI, + count: size, + stopCh: stopCh, + recreateInstanceIDCh: make(chan string, constant.DefaultChannelSize), + ConfigChangeCh: make(chan *types.ConfigChangeEvent, constant.DefaultChannelSize), + isDynamic: isDynamic, + } + go frontendManager.recreateInstance() + frontendManager.initInstanceCache(etcdClient) + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + err := frontendManager.CreateExpectedInstanceCount(ctx) + if err != nil { + log.GetLogger().Errorf("Failed to create expected frontend instance count, error: %v", err) + } + }() + go frontendManager.configChangeProcessor(ctx, cancelFunc) + }) + return frontendManager +} + +func (ffm *FrontendManager) initInstanceCache(etcdClient *etcd3.EtcdClient) { + response, err := etcdClient.Client.Get(context.Background(), types.FasSFrontendPrefixKey, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("Failed to get frontend instance: %v", err) + return + } + config.FrontendConfigLock.RLock() + targetSign := controllerutils.GetFrontendConfigSignature(config.GetFaaSFrontendConfig()) + config.FrontendConfigLock.RUnlock() + for _, kv := range response.Kvs { + meta := &types.InstanceSpecificationMeta{} + err := json.Unmarshal(kv.Value, meta) + if err != nil { + log.GetLogger().Errorf("Failed to unmarshal instance specification: %v", err) + continue + } + if isExceptInstance(meta, targetSign) { + funcCtx, cancelFunc := context.WithCancel(context.TODO()) + ffm.instanceCache[meta.InstanceID] = &types.InstanceSpecification{ + FuncCtx: funcCtx, + CancelFunc: cancelFunc, + InstanceID: meta.InstanceID, + InstanceSpecificationMeta: *meta, + } + log.GetLogger().Infof("find expected frontend instance %s add to cache", meta.InstanceID) + } + } +} + +func isExceptInstance(meta *types.InstanceSpecificationMeta, targetSign string) bool { + if len(meta.Args) == 0 { + log.GetLogger().Errorf("args is empty, %v", meta) + return false + } + value := meta.Args[0]["value"] + s, err := base64.StdEncoding.DecodeString(value) + if err != nil { + log.GetLogger().Errorf("Failed to decode args: %v", err) + return false + } + cfg := &types.FrontendConfig{} + err = json.Unmarshal(s, cfg) + if err != nil && len(s) > commonconstant.LibruntimeHeaderSize { + // args in libruntime create request with 16 bytes header. + // except libruntime, if other modules use this field, should try yo delete the header + err = json.Unmarshal(s[commonconstant.LibruntimeHeaderSize:], cfg) + if err != nil { + log.GetLogger().Errorf("Failed to unmarshal frontend config: %v, value: %s", err, s) + return false + } + } + oldSign := controllerutils.GetFrontendConfigSignature(cfg) + if oldSign == "" { + log.GetLogger().Errorf("old sign is empty, insID:%s", meta.InstanceID) + return false + } + log.GetLogger().Infof("frontend(%s) sign: %s, expect sign: %s", meta.InstanceID, oldSign, targetSign) + return strings.Compare(oldSign, targetSign) == 0 +} + +// GetInstanceCountFromEtcd get current instance count from etcd +func (ffm *FrontendManager) GetInstanceCountFromEtcd() map[string]struct{} { + resp, err := ffm.etcdClient.Client.Get(context.TODO(), types.FasSFrontendPrefixKey, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("failed to search etcd key, prefixKey=%s, err=%s", types.FasSFrontendPrefixKey, err.Error()) + return nil + } + instanceIDs := make(map[string]struct{}, resp.Count) + for _, kv := range resp.Kvs { + instanceID := controllerutils.ExtractInfoFromEtcdKey(string(kv.Key), commonconstant.InstanceIDIndexForInstance) + if instanceID != "" { + instanceIDs[instanceID] = struct{}{} + } + } + log.GetLogger().Infof("get etcd frontend instance count=%d, %+v", resp.Count, instanceIDs) + return instanceIDs +} + +// CreateExpectedInstanceCount create expected frontend instance count +func (ffm *FrontendManager) CreateExpectedInstanceCount(ctx context.Context) error { + // 此方法在启动的时候调用一次或者在配置更新的时候调用, 目的是将frontend实例数量补充至设置的实例数 + // 不需要删除多余实例, 多余实例会在 HandleInstanceUpdate 中删除 + // frontend的 instanceID不需要保持一致 + ffm.RLock() + currentCount := len(ffm.instanceCache) + expectedCount := ffm.count - currentCount + ffm.RUnlock() + return ffm.CreateMultiInstances(ctx, expectedCount) +} + +// CreateMultiInstances create multi instances +func (ffm *FrontendManager) CreateMultiInstances(ctx context.Context, count int) error { + if count <= 0 { + log.GetLogger().Infof("no need to create frontend instance, kill %d instances instead.", -count) + return ffm.KillExceptInstance(-count) + } + log.GetLogger().Infof("need to create %d faas frontend instances", count) + + args, params, err := genFunctionConfig() + if err != nil { + return err + } + + var createErr error + g := &sync.WaitGroup{} + for i := 0; i < count; i++ { + g.Add(1) + go func() { + defer g.Done() + if err = ffm.createOrRetry(ctx, args, *params, + config.GetFaaSControllerConfig().EnableRetry); err != nil { + createErr = err + } + }() + } + g.Wait() + if createErr != nil { + return createErr + } + log.GetLogger().Infof("succeed to create %d faaS frontend instances", count) + return nil +} + +func createExtraParams(conf *types.FrontendConfig) (*commonTypes.ExtraParams, error) { + extraParams := &commonTypes.ExtraParams{} + extraParams.Resources = utils.GenerateResourcesMap(conf.CPU, conf.Memory) + extraParams.CustomExtensions = utils.CreateCustomExtensions(extraParams.CustomExtensions, + utils.MonopolyPolicyValue) + extraParams.ScheduleAffinities = utils.CreatePodAffinity(constant.SystemFuncName, constant.FuncNameFaasfrontend, + api.PreferredAntiAffinity) + utils.AddNodeSelector(conf.NodeSelector, extraParams) + createOpt, err := prepareCreateOptions(conf) + extraParams.Label = []string{constant.FuncNameFaasfrontend} + extraParams.CreateOpt = createOpt + return extraParams, err +} + +func genFunctionConfig() ([]api.Arg, *commonTypes.ExtraParams, error) { + config.FrontendConfigLock.RLock() + frontendConfig := config.GetFaaSFrontendConfig() + extraParams, err := createExtraParams(frontendConfig) + if err != nil { + config.FrontendConfigLock.RUnlock() + log.GetLogger().Errorf("failed to prepare faaSFrontend createExtraParams, err:%s", err.Error()) + return nil, nil, err + } + frontendConf, err := json.Marshal(frontendConfig) + config.FrontendConfigLock.RUnlock() + if err != nil { + log.GetLogger().Errorf("faaSFrontend config json marshal failed, err:%s", err.Error()) + return nil, nil, err + } + args := []api.Arg{ + { + Type: api.Value, + Data: frontendConf, + }, + } + return args, extraParams, nil +} + +func (ffm *FrontendManager) createOrRetry(ctx context.Context, args []api.Arg, extraParams commonTypes.ExtraParams, + enableRetry bool) error { + // 只有首次拉起/扩容时insID为空,需要在拉起失败时防止进入失败回调中的重试逻辑 + if extraParams.DesignatedInstanceID == "" { + instanceID := uuid.New().String() + ffm.recreateInstanceIDMap.Store(instanceID, nil) + extraParams.DesignatedInstanceID = instanceID + } + defer ffm.recreateInstanceIDMap.Delete(extraParams.DesignatedInstanceID) + if enableRetry { + err := ffm.CreateWithRetry(ctx, args, &extraParams) + if err != nil { + return err + } + } else { + instanceID := ffm.CreateInstance(ctx, types.FasSFrontendFunctionKey, args, &extraParams) + if instanceID == "" { + log.GetLogger().Errorf("failed to create frontend instance") + return errors.New("failed to create frontend instance") + } + } + return nil +} + +// CreateWithRetry - +func (ffm *FrontendManager) CreateWithRetry( + ctx context.Context, args []api.Arg, extraParams *commonTypes.ExtraParams, +) error { + err := wait.ExponentialBackoffWithContext( + ctx, createInstanceBackoff, func(context.Context) (done bool, err error) { + instanceID := ffm.CreateInstance(ctx, types.FasSFrontendFunctionKey, args, extraParams) + if instanceID == "" { + log.GetLogger().Errorf("failed to create frontend instance") + return false, nil + } + if instanceID == "cancelled" { + return true, fmt.Errorf("create has been cancelled") + } + return true, nil + }) + return err +} + +// CreateInstance create an instance of system function, faaS frontend +func (ffm *FrontendManager) CreateInstance(ctx context.Context, function string, args []api.Arg, + extraParams *commonTypes.ExtraParams) string { + instanceID := extraParams.DesignatedInstanceID + funcMeta := api.FunctionMeta{FuncID: function, Api: api.PosixApi, Name: &instanceID} + invokeOpts := api.InvokeOptions{ + Cpu: int(extraParams.Resources[controllerutils.ResourcesCPU]), + Memory: int(extraParams.Resources[controllerutils.ResourcesMemory]), + ScheduleAffinities: extraParams.ScheduleAffinities, + CustomExtensions: extraParams.CustomExtensions, + CreateOpt: extraParams.CreateOpt, + Labels: extraParams.Label, + Timeout: 150, + } + createCh := make(chan api.ErrorInfo, 1) + go func() { + _, createErr := ffm.sdkClient.CreateInstance(funcMeta, args, invokeOpts) + if createErr != nil { + if errorInfo, ok := createErr.(api.ErrorInfo); ok { + createCh <- errorInfo + } else { + createCh <- api.ErrorInfo{Code: commonconstant.KernelInnerSystemErrCode, Err: createErr} + } + return + } + createCh <- api.ErrorInfo{Code: api.Ok} + }() + timer := time.NewTimer(types.CreatedTimeout) + defer timer.Stop() + select { + case err, ok := <-createCh: + if !ok { + log.GetLogger().Errorf("result channel of frontend instance request is closed") + return "" + } + if !err.IsOk() { + log.GetLogger().Errorf("failed to bring up frontend instance(id=%s),code:%d,err:%s", + instanceID, err.Code, err.Error()) + ffm.clearInstanceAfterError(instanceID) + return "" + } + case <-timer.C: + log.GetLogger().Errorf("time out waiting for instance creation") + ffm.clearInstanceAfterError(instanceID) + return "" + case <-ctx.Done(): + log.GetLogger().Errorf("create instance has been cancelled") + ffm.clearInstanceAfterError(instanceID) + return "cancelled" + } + log.GetLogger().Infof("succeed to create frontend instance(id=%s)", instanceID) + ffm.addInstance(instanceID) + return instanceID +} + +func (ffm *FrontendManager) clearInstanceAfterError(instanceID string) { + var err error + err = ffm.KillInstance(instanceID) + if err != nil { + log.GetLogger().Errorf("failed to kill frontend instance: %s", instanceID) + } +} + +func (ffm *FrontendManager) addInstance(instanceID string) { + ffm.Lock() + defer ffm.Unlock() + _, exist := ffm.instanceCache[instanceID] + if exist { + log.GetLogger().Warnf("the frontend instance(id=%s) already exist", instanceID) + return + } + log.GetLogger().Infof("add instance(id=%s) to local cache", instanceID) + ffm.instanceCache[instanceID] = &types.InstanceSpecification{InstanceID: instanceID} + state.Update(instanceID, types.StateUpdate, types.FaasFrontendInstanceState) +} + +// GetInstanceCache supply local instance cache +func (ffm *FrontendManager) GetInstanceCache() map[string]*types.InstanceSpecification { + return ffm.instanceCache +} + +// SyncKillAllInstance kill all instances of system function, faaS frontend +func (ffm *FrontendManager) SyncKillAllInstance() { + var wg sync.WaitGroup + ffm.Lock() + defer ffm.Unlock() + var deletedInstance []string + for instanceID := range ffm.instanceCache { + wg.Add(1) + go func(instanceID string) { + defer wg.Done() + if err := ffm.sdkClient.Kill(instanceID, types.SyncKillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill frontend instance(id=%s), err:%s", instanceID, err.Error()) + return + } + deletedInstance = append(deletedInstance, instanceID) + log.GetLogger().Infof("success to kill frontend instance(id=%s)", instanceID) + }(instanceID) + } + wg.Wait() + for _, instanceID := range deletedInstance { + log.GetLogger().Infof("delete frontend instance(id=%s) from local cache", instanceID) + delete(ffm.instanceCache, instanceID) + state.Update(instanceID, types.StateDelete, types.FaasFrontendInstanceState) + ffm.count-- + } +} + +// KillInstance kill an instance of system function, faaS frontend +func (ffm *FrontendManager) KillInstance(instanceID string) error { + log.GetLogger().Infof("start to kill instance %s", instanceID) + return wait.ExponentialBackoffWithContext( + context.Background(), createInstanceBackoff, func(context.Context) (bool, error) { + var err error + err = ffm.sdkClient.Kill(instanceID, types.KillSignalVal, []byte{}) + if err != nil && !strings.Contains(err.Error(), "instance not found") { + log.GetLogger().Warnf("failed to kill instanceID: %s, err: %s", instanceID, err.Error()) + return false, nil + } + return true, nil + }) +} + +// KillExceptInstance - +func (ffm *FrontendManager) KillExceptInstance(count int) error { + if len(ffm.instanceCache) < count { + return nil + } + if ffm.isDynamic { + // will be scaled down by function-master, no need to process it + return nil + } + for instanceID := range ffm.instanceCache { + if count <= 0 { + return nil + } + if err := ffm.KillInstance(instanceID); err != nil { + log.GetLogger().Errorf("kill frontend instance:%s, error:%s", instanceID, err.Error()) + return err + } + count-- + } + return nil +} + +// RecoverInstance recover a faaS frontend instance when faults occur +func (ffm *FrontendManager) RecoverInstance(info *types.InstanceSpecification) { + err := ffm.KillInstance(info.InstanceID) + if err != nil { + log.GetLogger().Warnf("failed to kill instanceID: %s, err: %s", info.InstanceID, err.Error()) + } +} + +func (ffm *FrontendManager) recreateInstance() { + for { + select { + case <-ffm.stopCh: + return + case instanceID, ok := <-ffm.recreateInstanceIDCh: + if !ok { + log.GetLogger().Warnf("recreateInstanceIDCh is closed") + return + } + ffm.RLock() + if _, exist := ffm.instanceCache[instanceID]; exist || len(ffm.instanceCache) >= ffm.count { + log.GetLogger().Infof("current frontend num is %d, no need to recreate instance:%s", + len(ffm.instanceCache), instanceID) + ffm.RUnlock() + break + } + ffm.RUnlock() + ctx, cancel := context.WithCancel(context.Background()) + _, loaded := ffm.recreateInstanceIDMap.LoadOrStore(instanceID, cancel) + if loaded { + log.GetLogger().Warnf("instance[%s] is recreating", instanceID) + break + } + args, extraParams, err := genFunctionConfig() + if err != nil { + log.GetLogger().Errorf("failed to prepare createExtraParams, err:%s", err.Error()) + break + } + extraParams.DesignatedInstanceID = instanceID + go func() { + time.Sleep(constant.RecreateSleepTime) + log.GetLogger().Infof("start to recover faaSFrontend instance: %s", instanceID) + if err = ffm.createOrRetry(ctx, args, *extraParams, + config.GetFaaSControllerConfig().EnableRetry); err != nil { + log.GetLogger().Errorf("failed to recreate instance: %s", instanceID) + } + }() + } + } +} + +// SyncCreateInstance - +func (ffm *FrontendManager) SyncCreateInstance(ctx context.Context) error { + log.GetLogger().Infof("start to sync create frontend instance") + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + args, extraParams, err := genFunctionConfig() + if err != nil { + return err + } + err = ffm.createOrRetry(ctx, args, *extraParams, config.GetFaaSControllerConfig().EnableRetry) + if err != nil { + return err + } + return nil +} + +// HandleInstanceUpdate handle the etcd PUT event +func (ffm *FrontendManager) HandleInstanceUpdate(instanceSpec *types.InstanceSpecification) { + log.GetLogger().Infof("handling frontend instance %s update", instanceSpec.InstanceID) + if instanceSpec.InstanceSpecificationMeta.InstanceStatus.Code == int(commonconstant.KernelInstanceStatusExiting) { + log.GetLogger().Infof("frontend instance %s is exiting,no need to update", instanceSpec.InstanceID) + return + } + config.FrontendConfigLock.RLock() + signature := controllerutils.GetFrontendConfigSignature(config.GetFaaSFrontendConfig()) + config.FrontendConfigLock.RUnlock() + + if isExceptInstance(&instanceSpec.InstanceSpecificationMeta, signature) { + ffm.Lock() + currentNum := len(ffm.instanceCache) + _, exist := ffm.instanceCache[instanceSpec.InstanceID] + if currentNum > ffm.count || (currentNum == ffm.count && !exist) { + log.GetLogger().Infof("current frontend num is %s, kill the new instance %s", + currentNum, instanceSpec.InstanceID) + delete(ffm.instanceCache, instanceSpec.InstanceID) + ffm.Unlock() + if err := ffm.sdkClient.Kill(instanceSpec.InstanceID, types.KillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill instance %s error:%s", instanceSpec.InstanceID, + err.Error()) + } + return + } + // add instance to cache if not exist, otherwise update the instance + if !exist { + ffm.instanceCache[instanceSpec.InstanceID] = instanceSpec + log.GetLogger().Infof("add frontend instance %s to cache", instanceSpec.InstanceID) + ffm.Unlock() + state.Update(instanceSpec.InstanceID, types.StateUpdate, types.FaasFrontendInstanceState) + return + } + ffm.instanceCache[instanceSpec.InstanceID].InstanceSpecificationMeta = instanceSpec.InstanceSpecificationMeta + log.GetLogger().Infof("frontend instance %s is updated, refresh instance cache", instanceSpec.InstanceID) + ffm.Unlock() + return + } + ffm.RLock() + _, exist := ffm.terminalCache[instanceSpec.InstanceID] + ffm.RUnlock() + if !exist { + log.GetLogger().Infof("frontend instance %s is not expected, start to delete", instanceSpec.InstanceID) + if err := ffm.KillInstance(instanceSpec.InstanceID); err != nil { + log.GetLogger().Errorf("failed to kill instance %s error:%s", instanceSpec.InstanceID, err.Error()) + } + } +} + +// HandleInstanceDelete handle the etcd DELETE event +func (ffm *FrontendManager) HandleInstanceDelete(instanceSpec *types.InstanceSpecification) { + log.GetLogger().Infof("handling frontend instance %s delete", instanceSpec.InstanceID) + config.FrontendConfigLock.RLock() + signature := controllerutils.GetFrontendConfigSignature(config.GetFaaSFrontendConfig()) + config.FrontendConfigLock.RUnlock() + ffm.Lock() + delete(ffm.instanceCache, instanceSpec.InstanceID) + ffm.Unlock() + state.Update(instanceSpec.InstanceID, types.StateDelete, types.FaasFrontendInstanceState) + if isExceptInstance(&instanceSpec.InstanceSpecificationMeta, signature) { + ffm.RLock() + if len(ffm.instanceCache) < ffm.count { + log.GetLogger().Infof("current faaSFrontend instance num is %d, need to recreate instance: %s", + len(ffm.instanceCache), instanceSpec.InstanceID) + ffm.RUnlock() + ffm.recreateInstanceIDCh <- instanceSpec.InstanceID + return + } + ffm.RUnlock() + } + cancel, exist := ffm.recreateInstanceIDMap.Load(instanceSpec.InstanceID) + if exist { + if cancelFunc, ok := cancel.(context.CancelFunc); ok { + cancelFunc() + log.GetLogger().Infof("instance %s bring up has been canceled", instanceSpec.InstanceID) + return + } + log.GetLogger().Errorf("get cancel func failed from instanceIDMap, instanceID:%s", + instanceSpec.InstanceID) + } +} + +// prepareCreateOptions for create faasfrontend createOpt +func prepareCreateOptions(conf *types.FrontendConfig) (map[string]string, error) { + podLabels := map[string]string{ + constant.SystemFuncName: constant.FuncNameFaasfrontend, + } + labels, err := json.Marshal(podLabels) + if err != nil { + return nil, fmt.Errorf("pod labels json marshal failed, err:%s", err.Error()) + } + + delegateRuntime, err := json.Marshal(map[string]interface{}{ + "image": conf.Image, + }) + if err != nil { + return nil, err + } + + createOptions := map[string]string{ + commonconstant.DelegatePodLabels: string(labels), + commonconstant.DelegateRuntimeManagerTag: string(delegateRuntime), + commonconstant.InstanceLifeCycle: commonconstant.InstanceLifeCycleDetached, + commonconstant.DelegateNodeAffinity: conf.NodeAffinity, + commonconstant.DelegateNodeAffinityPolicy: conf.NodeAffinityPolicy, + } + + if config.GetFaaSControllerConfig().RawStsConfig.StsEnable { + secretVolumeMounts, err := sts.GenerateSecretVolumeMounts(sts.FaasfrontendName, utils.NewVolumeBuilder()) + if err != nil { + return nil, fmt.Errorf("secretVolumeMounts json marshal failed, err:%s", err.Error()) + } + createOptions[commonconstant.DelegateVolumeMountKey] = string(secretVolumeMounts) + } + + if conf.Affinity != "" { + createOptions[commonconstant.DelegateAffinity] = conf.Affinity + } + + if conf.NodeSelector != nil && len(conf.NodeSelector) != 0 { + var podTolerations []v1.Toleration + for k, v := range config.GetFaaSFrontendConfig().NodeSelector { + podTolerations = append(podTolerations, v1.Toleration{ + Key: k, + Operator: v1.TolerationOpEqual, + Value: v, + Effect: v1.TaintEffectNoSchedule, + }) + } + podTolerations = append(podTolerations, v1.Toleration{ + Key: dsWorkerUnreadyKey, + Operator: v1.TolerationOpEqual, + Value: "true", + Effect: v1.TaintEffectPreferNoSchedule, + }) + tolerations, err := json.Marshal(podTolerations) + if err != nil { + return nil, fmt.Errorf("pod tolerations json marshal failed, err:%s", err.Error()) + } + createOptions[commonconstant.DelegateTolerations] = string(tolerations) + } + + err = prepareVolumesAndMounts(conf, createOptions) + if err != nil { + return nil, err + } + + return createOptions, nil +} + +func prepareVolumesAndMounts(conf *types.FrontendConfig, createOptions map[string]string) error { + if createOptions == nil { + return fmt.Errorf("createOptions is nil") + } + var delegateVolumes string + var delegateVolumesMounts string + var delegateInitVolumesMounts string + var err error + builder := utils.NewVolumeBuilder() + if conf.RawStsConfig.StsEnable { + delegateVolumesMountsData, err := sts.GenerateSecretVolumeMounts(sts.FaasfrontendName, builder) + if err != nil { + return fmt.Errorf("secretVolumeMounts json marshal failed, err:%s", err.Error()) + } + delegateVolumesMounts = string(delegateVolumesMountsData) + } + if conf.HTTPSConfig != nil && conf.HTTPSConfig.HTTPSEnable { + delegateVolumes, delegateVolumesMounts, err = sts.GenerateHTTPSAndLocalSecretVolumeMounts(*conf.HTTPSConfig, builder) + if err != nil { + log.GetLogger().Errorf("failed to generate https volumes and mounts") + return fmt.Errorf("httpsVolumeMounts json marshal failed, err:%s", err.Error()) + } + } + + delegateInitVolumesMounts = delegateVolumesMounts + + if delegateVolumes != "" { + createOptions[commonconstant.DelegateVolumesKey] = delegateVolumes + } + if delegateVolumesMounts != "" { + createOptions[commonconstant.DelegateVolumeMountKey] = delegateVolumesMounts + } + if delegateInitVolumesMounts != "" { + createOptions[commonconstant.DelegateInitVolumeMountKey] = delegateInitVolumesMounts + } + log.GetLogger().Debugf("delegateVolumes: %s, delegateVolumesMounts: %s, delegateInitVolumesMounts: %s", + delegateVolumes, delegateVolumesMounts, delegateInitVolumesMounts) + return nil +} + +// RollingUpdate rolling update frontend +func (ffm *FrontendManager) RollingUpdate(ctx context.Context, event *types.ConfigChangeEvent) { + // 1. 更新 预期实例个数 + // 2. 把不符合预期的instanceCache -> terminalCache + // 3. 从terminalCache随机删除一个实例 同步 + // 4. 创建新实例 同步 + // 5. terminalCache为空时将实例数调谐为预期实例数(同步), instanceCache到达预期数量时清空terminalCache(异步) + newSign := controllerutils.GetFrontendConfigSignature(event.FrontendCfg) + ffm.Lock() + ffm.count = event.FrontendCfg.InstanceNum + for _, ins := range ffm.instanceCache { + if !isExceptInstance(&ins.InstanceSpecificationMeta, newSign) { + ffm.terminalCache[ins.InstanceID] = ins + delete(ffm.instanceCache, ins.InstanceID) + } + } + ffm.Unlock() + for { + select { + case <-ctx.Done(): + event.Error = fmt.Errorf("rolling update has stopped") + event.Done() + return + default: + } + ffm.RLock() + if len(ffm.instanceCache) == ffm.count { + ffm.RUnlock() + log.GetLogger().Infof("frontend instance count arrive at expectation:%d,"+ + " delete all terminating instance", ffm.count) + go ffm.killAllTerminalInstance() + event.Done() + return + } + if len(ffm.terminalCache) == 0 { + ffm.RUnlock() + log.GetLogger().Infof("no frontend instance need to terminate, pull up missing instance count:%d", + ffm.count-len(ffm.instanceCache)) + err := ffm.CreateExpectedInstanceCount(ctx) + if err != nil { + event.Error = err + } + event.Done() + return + } + insID := "" + for _, ins := range ffm.terminalCache { + insID = ins.InstanceID + break + } + ffm.RUnlock() + log.GetLogger().Infof("start to terminate instance:%s", insID) + var err error + if err = ffm.sdkClient.Kill(insID, types.SyncKillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill frontend instance(id=%s), err:%v", insID, err) + } + ffm.Lock() + delete(ffm.terminalCache, insID) + ffm.Unlock() + time.Sleep(constant.RecreateSleepTime) + err = ffm.SyncCreateInstance(ctx) + if err != nil { + event.Error = err + event.Done() + return + } + } +} + +func (ffm *FrontendManager) killAllTerminalInstance() { + ffm.Lock() + insMap := ffm.terminalCache + ffm.terminalCache = map[string]*types.InstanceSpecification{} + ffm.Unlock() + for _, ins := range insMap { + insID := ins.InstanceID + go func() { + err := ffm.KillInstance(insID) + if err != nil { + log.GetLogger().Errorf("Failed to kill instance %v, err: %v", insID, err) + } + }() + } +} + +func (ffm *FrontendManager) configChangeProcessor(ctx context.Context, cancel context.CancelFunc) { + if ctx == nil || cancel == nil || ffm.ConfigChangeCh == nil { + return + } + for { + select { + case cfgEvent, ok := <-ffm.ConfigChangeCh: + if !ok { + cancel() + return + } + frontendConfig, frontendInsNum := ffm.ConfigDiff(cfgEvent) + if frontendConfig != nil || frontendInsNum != -1 { + log.GetLogger().Infof("frontend config or instance num is changed," + + " need to update frontend instance") + cancel() + ctx, cancel = context.WithCancel(context.Background()) + config.UpdateFrontendConfig(cfgEvent.FrontendCfg) + cfgEvent.Add(1) + go ffm.RollingUpdate(ctx, cfgEvent) + } else { + log.GetLogger().Infof("frontend config is same as current, no need to update") + } + cfgEvent.Done() + } + } +} + +// ConfigDiff config diff +func (ffm *FrontendManager) ConfigDiff(event *types.ConfigChangeEvent) (*types.FrontendConfig, int) { + newSign := controllerutils.GetFrontendConfigSignature(event.FrontendCfg) + config.FrontendConfigLock.RLock() + frontendOldCfg := config.GetFaaSFrontendConfig() + config.FrontendConfigLock.RUnlock() + if strings.Compare(newSign, + controllerutils.GetFrontendConfigSignature(frontendOldCfg)) == 0 { + if event.FrontendCfg.InstanceNum != frontendOldCfg.InstanceNum { + config.FrontendConfigLock.Lock() + frontendOldCfg.InstanceNum = event.FrontendCfg.InstanceNum + config.FrontendConfigLock.Unlock() + return nil, event.FrontendCfg.InstanceNum + } + return nil, -1 + } + return event.FrontendCfg, event.FrontendCfg.InstanceNum +} diff --git a/yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager/frontendmanager_test.go b/yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager/frontendmanager_test.go new file mode 100644 index 0000000..2b888da --- /dev/null +++ b/yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager/frontendmanager_test.go @@ -0,0 +1,917 @@ +package faasfrontendmanager + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/wait" + + commonTypes "yuanrong/pkg/common/faas_common/types" + + . "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/alarm" + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/redisclient" + "yuanrong/pkg/common/faas_common/sts" + "yuanrong/pkg/common/faas_common/sts/raw" + "yuanrong/pkg/common/faas_common/tls" + "yuanrong/pkg/common/faas_common/utils" + mockUtils "yuanrong/pkg/common/faas_common/utils" + faasfrontendconf "yuanrong/pkg/frontend/types" + ftypes "yuanrong/pkg/frontend/types" + fcConfig "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/types" + controllerutils "yuanrong/pkg/system_function_controller/utils" +) + +type KvMock struct { +} + +func (k *KvMock) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + response := &clientv3.GetResponse{} + response.Count = 10 + return response, nil +} + +func (k *KvMock) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Txn(ctx context.Context) clientv3.Txn { + // TODO implement me + panic("implement me") +} + +func initConfig(configString string) { + fcConfig.InitConfig([]byte(configString)) + routerEtcdConfig := etcd3.EtcdConfig{ + Servers: []string{"1.2.3.4:1234"}, + User: "tom", + Password: "**", + } + + metaEtcdConfig := etcd3.EtcdConfig{ + Servers: []string{"1.2.3.4:5678"}, + User: "tom", + Password: "**", + } + frontendConfig := types.FrontendConfig{ + Config: ftypes.Config{ + InstanceNum: 10, + CPU: 777, + Memory: 777, + SLAQuota: 1000, + AuthenticationEnable: false, + HTTPConfig: &ftypes.FrontendHTTP{ + MaxRequestBodySize: 6, + }, + RouterEtcd: routerEtcdConfig, + MetaEtcd: metaEtcdConfig, + NodeSelector: map[string]string{"testkey": "testvalue"}, + }, + } + fcConfig.UpdateFrontendConfig(&frontendConfig) +} + +func newFaaSFrontendManager(size int) (*FrontendManager, error) { + stopCh := make(chan struct{}) + initConfig(`{ + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + }, + "alarmConfig":{"enableAlarm": true} + } + `) + + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + manager := NewFaaSFrontendManager(&mockUtils.FakeLibruntimeSdkClient{}, etcdClient, stopCh, size, false) + time.Sleep(50 * time.Millisecond) + manager.count = size + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.terminalCache = map[string]*types.InstanceSpecification{} + return manager, nil +} + +func newFaaSFrontendManagerWithRetry(size int) (*FrontendManager, error) { + stopCh := make(chan struct{}) + initConfig(`{ + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + }, + "alarmConfig":{"enableAlarm": true}, + "enableRetry": true + } + `) + + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + manager := NewFaaSFrontendManager(&mockUtils.FakeLibruntimeSdkClient{}, etcdClient, stopCh, size, false) + time.Sleep(50 * time.Millisecond) + manager.count = size + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.terminalCache = map[string]*types.InstanceSpecification{} + return manager, nil +} + +func Test_initInstanceCache(t *testing.T) { + Convey("test initInstanceCache", t, func() { + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + frontendManager = &FrontendManager{ + instanceCache: make(map[string]*types.InstanceSpecification), + terminalCache: map[string]*types.InstanceSpecification{}, + etcdClient: etcdClient, + sdkClient: &mockUtils.FakeLibruntimeSdkClient{}, + count: 1, + stopCh: make(chan struct{}), + recreateInstanceIDCh: make(chan string, 1), + ConfigChangeCh: make(chan *types.ConfigChangeEvent, 100), + } + defer ApplyMethod(reflect.TypeOf(kv), "Get", + func(k *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + bytes, _ := json.Marshal(fcConfig.GetFaaSFrontendConfig()) + meta := types.InstanceSpecificationMeta{Function: "test-function", InstanceID: "123", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}} + marshal, _ := json.Marshal(meta) + kvs := []*mvccpb.KeyValue{{Value: marshal}} + response := &clientv3.GetResponse{Kvs: kvs} + response.Count = 1 + return response, nil + }).Reset() + frontendManager.initInstanceCache(etcdClient) + cache := frontendManager.GetInstanceCache() + So(cache["123"], ShouldNotBeNil) + close(frontendManager.stopCh) + }) +} + +func TestNewInstanceManager(t *testing.T) { + Convey("Test NewInstanceManager", t, func() { + Convey("Test NewInstanceManager with correct size", func() { + got, err := newFaaSFrontendManager(16) + So(err, ShouldBeNil) + So(got, ShouldNotBeNil) + }) + }) +} + +func TestInstanceManager_CreateMultiInstances(t *testing.T) { + Convey("Test CreateMultiInstances", t, func() { + Convey("Test CreateMultiInstances with retry", func() { + p := ApplyFunc((*mockUtils.FakeLibruntimeSdkClient).CreateInstance, func(_ *mockUtils.FakeLibruntimeSdkClient, _ api.FunctionMeta, _ []api.Arg, _ api.InvokeOptions) (string, error) { + time.Sleep(100 * time.Millisecond) + return "", api.ErrorInfo{ + Code: 10, + Err: fmt.Errorf("xxxxx"), + StackTracesInfo: api.StackTracesInfo{}, + } + }) + p2 := ApplyFunc(wait.ExponentialBackoffWithContext, func(ctx context.Context, backoff wait.Backoff, condition wait.ConditionWithContextFunc) error { + _, err := condition(ctx) + return err + }) + manager, err := newFaaSFrontendManagerWithRetry(1) + So(err, ShouldBeNil) + ctx, cancel := context.WithCancel(context.TODO()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + err = manager.CreateMultiInstances(ctx, 1) + So(err, ShouldBeError) + + err = manager.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldBeNil) + p.Reset() + p2.Reset() + }) + + instanceMgr, err := newFaaSFrontendManager(3) + So(err, ShouldBeNil) + defer ApplyMethod(reflect.TypeOf(instanceMgr), "CreateWithRetry", + func(ffm *FrontendManager, ctx context.Context, args []api.Arg, extraParams *commonTypes.ExtraParams) error { + if value := ctx.Value("err"); value != nil { + if s := value.(string); s == "canceled" { + return fmt.Errorf("create has been cancelled") + } + } + return nil + }).Reset() + + Convey("Test CreateMultiInstances when passed different count", func() { + err = instanceMgr.CreateMultiInstances(context.TODO(), -1) + So(err, ShouldBeNil) + + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldBeNil) + }) + + Convey("Test CreateMultiInstances when failed to get scheduler config", func() { + patches := ApplyFunc(json.Marshal, func(_ interface{}) ([]byte, error) { + return nil, errors.New("json Marshal failed") + }) + defer patches.Reset() + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldNotBeNil) + }) + + Convey("Test CreateMultiInstances when failed to create instance", func() { + patches := ApplyMethod(reflect.TypeOf(instanceMgr), "CreateInstance", + func(_ *FrontendManager, ctx context.Context, _ string, _ []api.Arg, _ *commonTypes.ExtraParams) string { + return "" + }) + defer patches.Reset() + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldNotBeNil) + }) + + Convey("Test CreateMultiInstances when https enable", func() { + fcConfig.GetFaaSFrontendConfig().HTTPSConfig = &tls.InternalHTTPSConfig{ + HTTPSEnable: true, + TLSProtocol: "TLS", + TLSCiphers: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + } + defer func() { fcConfig.GetFaaSFrontendConfig().HTTPSConfig = nil }() + + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldBeNil) + }) + }) +} + +func TestInstanceManager_GetInstanceCountFromEtcd(t *testing.T) { + Convey("GetInstanceCountFromEtcd", t, func() { + Convey("failed", func() { + instanceMgr, err := newFaaSFrontendManager(3) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&KvMock{}), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return nil, errors.New("get etcd error") + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + count := instanceMgr.GetInstanceCountFromEtcd() + So(len(count), ShouldEqual, 0) + }) + }) + +} + +func TestInstanceManager_CreateExpectedInstanceCount(t *testing.T) { + Convey("CreateExpectedInstanceCount", t, func() { + Convey("success", func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&KvMock{}), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + response := &clientv3.GetResponse{} + response.Count = 2 + response.Kvs = []*mvccpb.KeyValue{ + { + Key: []byte("/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasfrontend/version/$latest/defaultaz/task-66ccf050-50f6-4835-aa24-c1b15dddb00e/12996c08-0000-4000-8000-db6c3db0fcbf"), + }, + { + Key: []byte("/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasfrontend/version/$latest/defaultaz/task-66ccf050-50f6-4835-aa24-c1b15dddb00e/12996c08-0000-4000-8000-db6c3db0fcb2"), + }, + } + + return response, nil + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + count := instanceMgr.GetInstanceCountFromEtcd() + So(len(count), ShouldEqual, 2) + + err = instanceMgr.CreateExpectedInstanceCount(context.TODO()) + So(err, ShouldBeNil) + }) + }) +} + +func TestInstanceManager_RecoverInstance(t *testing.T) { + Convey("RecoverInstance", t, func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + Convey("create failed", func() { + patches := []*Patches{ + ApplyFunc((*FrontendManager).KillInstance, func(_ *FrontendManager, _ string) error { + return nil + }), + ApplyFunc((*FrontendManager).CreateMultiInstances, func(_ *FrontendManager, ctx context.Context, _ int) error { + return errors.New("failed to create instances") + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + + instanceMgr.RecoverInstance(&types.InstanceSpecification{}) + }) + }) +} + +func TestKillExceptInstance(t *testing.T) { + Convey("KillExceptInstance", t, func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + Convey("success", func() { + instanceMgr.instanceCache["instanceID1"] = &types.InstanceSpecification{} + instanceMgr.instanceCache["instanceID2"] = &types.InstanceSpecification{} + err = instanceMgr.KillExceptInstance(0) + So(err, ShouldBeNil) + err = instanceMgr.KillExceptInstance(1) + So(err, ShouldBeNil) + defer ApplyMethod(reflect.TypeOf(instanceMgr), "KillInstance", func(ffm *FrontendManager, instanceID string) error { + return fmt.Errorf("err") + }).Reset() + err = instanceMgr.KillExceptInstance(1) + So(err, ShouldBeError) + }) + }) +} + +func TestInstanceManager_HandleInstanceUpdate(t *testing.T) { + Convey("HandleInstanceUpdate", t, func() { + Convey("no need to update", func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "456", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", InstanceStatus: types.InstanceStatus{ + Code: int(commonconstant.KernelInstanceStatusExiting), + }}}) + So(instanceMgr.instanceCache["456"], ShouldEqual, nil) + }) + + Convey("not exist", func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "456", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function"}}) + + So(instanceMgr.instanceCache["456"], ShouldEqual, nil) + }) + + Convey("exist", func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = make(map[string]*types.InstanceSpecification) + frontendConfig := fcConfig.GetFaaSFrontendConfig() + bytes, _ := json.Marshal(frontendConfig) + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}}, + ) + + So(instanceMgr.instanceCache["123"].InstanceSpecificationMeta.Function, ShouldEqual, "test-function") + + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", RequestID: "test-runtimeID", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}}, + ) + So(instanceMgr.instanceCache["123"].InstanceSpecificationMeta.RequestID, ShouldEqual, "test-runtimeID") + }) + + Convey("delete extra instance", func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + instanceMgr.instanceCache = make(map[string]*types.InstanceSpecification) + + defer ApplyMethod(reflect.TypeOf(instanceMgr.etcdClient.Client.KV), "Get", + func(k *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + bytes, _ := json.Marshal(fcConfig.GetFaaSFrontendConfig()) + meta := types.InstanceSpecificationMeta{Function: "test-function", InstanceID: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}} + marshal, _ := json.Marshal(meta) + kvs := []*mvccpb.KeyValue{{Value: marshal}} + response := &clientv3.GetResponse{Kvs: kvs} + response.Count = 1 + return response, nil + }).Reset() + instanceMgr.initInstanceCache(instanceMgr.etcdClient) + So(instanceMgr.instanceCache["test-function"], ShouldNotBeNil) + frontendConfig := fcConfig.GetFaaSFrontendConfig() + bytes, _ := json.Marshal(frontendConfig) + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}}, + ) + So(instanceMgr.instanceCache["123"], ShouldBeNil) + }) + }) +} + +func TestInstanceManager_HandleInstanceDelete(t *testing.T) { + Convey("HandleInstanceDelete", t, func() { + Convey("not exist", func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + instanceMgr.recreateInstanceIDMap.Store("456", "error type") + instanceMgr.HandleInstanceDelete(&types.InstanceSpecification{ + InstanceID: "456", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function"}}) + _, exist := instanceMgr.instanceCache["123"] + So(exist, ShouldEqual, true) + }) + Convey("not exist2", func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + _, cancel := context.WithCancel(context.Background()) + instanceMgr.recreateInstanceIDMap.Store("456", cancel) + instanceMgr.HandleInstanceDelete(&types.InstanceSpecification{ + InstanceID: "456", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function"}}) + _, exist := instanceMgr.instanceCache["123"] + So(exist, ShouldEqual, true) + }) + + Convey("exist", func() { + size := 1 + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + NewFaaSFrontendManager(&mockUtils.FakeLibruntimeSdkClient{}, &etcd3.EtcdClient{Client: client}, + make(chan struct{}), size, false) + manager := GetFrontendManager() + frontendConfig := fcConfig.GetFaaSFrontendConfig() + bytes, _ := json.Marshal(frontendConfig) + manager.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + manager.HandleInstanceDelete(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}}) + _, exist := manager.instanceCache["123"] + So(exist, ShouldEqual, false) + }) + }) +} + +func Test_recreateInstance(t *testing.T) { + Convey("", t, func() { + instanceMgr, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + defer ApplyFunc(json.Marshal, func(v any) ([]byte, error) { + return []byte{}, nil + }).Reset() + defer ApplyFunc(fcConfig.GetFaaSControllerConfig, func() *types.Config { + return &types.Config{ + RawStsConfig: raw.StsConfig{StsEnable: true}, + } + }).Reset() + defer ApplyFunc(sts.GenerateSecretVolumeMounts, func(systemFunctionName string) ([]byte, error) { + return []byte{}, nil + }).Reset() + instanceMgr.recreateInstanceIDCh <- "instanceID1" + go instanceMgr.recreateInstance() + time.Sleep(11 * time.Second) + So(len(instanceMgr.instanceCache) > 0, ShouldBeTrue) + + instanceMgr.recreateInstanceIDCh <- "instanceID1" + So(len(instanceMgr.instanceCache), ShouldEqual, 1) + close(instanceMgr.recreateInstanceIDCh) + }) +} + +func TestSyncCreateInstance(t *testing.T) { + Convey("test SyncCreateInstance", t, func() { + manager, err := newFaaSFrontendManager(1) + So(err, ShouldBeNil) + err = manager.SyncCreateInstance(context.TODO()) + So(err, ShouldBeNil) + + Convey("context canceled", func() { + ctx, cancelFunc := context.WithCancel(context.TODO()) + cancelFunc() + err = manager.SyncCreateInstance(ctx) + So(err, ShouldNotBeNil) + }) + }) +} + +func TestRollingUpdate(t *testing.T) { + Convey("test RollingUpdate", t, func() { + manager, err := newFaaSFrontendManager(2) + So(err, ShouldBeNil) + Convey("same config", func() { + frontendConfig := fcConfig.GetFaaSFrontendConfig() + bytes, _ := json.Marshal(frontendConfig) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + cfgEvent := &types.ConfigChangeEvent{ + FrontendCfg: fcConfig.GetFaaSFrontendConfig(), + TraceID: "traceID-123456789", + } + cfgEvent.FrontendCfg.InstanceNum = 2 + cfgEvent.Add(1) + manager.ConfigChangeCh <- cfgEvent + cfgEvent.Wait() + So(len(manager.instanceCache), ShouldEqual, cfgEvent.FrontendCfg.InstanceNum) + So(manager.instanceCache["test-1"], ShouldNotBeNil) + So(manager.instanceCache["test-2"], ShouldNotBeNil) + }) + + Convey("different config", func() { + frontendConfig := fcConfig.GetFaaSFrontendConfig() + bytes, _ := json.Marshal(frontendConfig) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + cfg := &types.FrontendConfig{} + utils.DeepCopyObj(fcConfig.GetFaaSFrontendConfig(), cfg) + cfgEvent := &types.ConfigChangeEvent{ + FrontendCfg: cfg, + TraceID: "traceID-123456789", + } + cfgEvent.FrontendCfg.CPU = 888 + cfgEvent.FrontendCfg.InstanceNum = 10 + cfgEvent.Add(1) + manager.ConfigChangeCh <- cfgEvent + cfgEvent.Wait() + close(manager.ConfigChangeCh) + So(len(manager.terminalCache), ShouldEqual, 0) + So(len(manager.instanceCache), ShouldEqual, cfgEvent.FrontendCfg.InstanceNum) + }) + + Convey("killAllTerminalInstance", func() { + frontendConfig := fcConfig.GetFaaSFrontendConfig() + bytes, _ := json.Marshal(frontendConfig) + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + cfgEvent := &types.ConfigChangeEvent{ + FrontendCfg: fcConfig.GetFaaSFrontendConfig(), + TraceID: "traceID-123456789", + } + cfgEvent.FrontendCfg.InstanceNum = 2 + cfgEvent.Add(1) + manager.RollingUpdate(context.TODO(), cfgEvent) + cfgEvent.Wait() + So(len(manager.instanceCache), ShouldEqual, 2) + }) + }) +} + +func Test_killAllTerminalInstance(t *testing.T) { + Convey("killAllTerminalInstance", t, func() { + manager, err := newFaaSFrontendManager(2) + So(err, ShouldBeNil) + frontendConfig := fcConfig.GetFaaSFrontendConfig() + bytes, _ := json.Marshal(frontendConfig) + manager.terminalCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + + defer ApplyMethod(reflect.TypeOf(manager), "KillInstance", func(ffm *FrontendManager, instanceID string) error { + return nil + }).Reset() + manager.killAllTerminalInstance() + So(len(manager.terminalCache), ShouldEqual, 0) + }) +} + +func TestConfigDiff(t *testing.T) { + Convey("ConfigDiff", t, func() { + manager, err := newFaaSFrontendManager(2) + So(err, ShouldBeNil) + Convey("same config ,different num", func() { + cfg := &types.FrontendConfig{} + utils.DeepCopyObj(fcConfig.GetFaaSFrontendConfig(), cfg) + cfgEvent := &types.ConfigChangeEvent{ + FrontendCfg: cfg, + TraceID: "traceID-123456789", + } + cfgEvent.FrontendCfg.InstanceNum = 100 + _, num := manager.ConfigDiff(cfgEvent) + So(num, ShouldEqual, 100) + }) + }) +} + +func TestFrontendManager_KillInstance(t *testing.T) { + Convey("kill instance test", t, func() { + Convey("baseline", func() { + manager, err := newFaaSFrontendManager(2) + So(err, ShouldBeNil) + i := 0 + p := ApplyFunc((*mockUtils.FakeLibruntimeSdkClient).Kill, func(_ *mockUtils.FakeLibruntimeSdkClient, instanceID string, _ int, _ []byte) error { + if i == 0 { + i = 1 + return fmt.Errorf("error") + } + return nil + }) + defer p.Reset() + err = manager.KillInstance("aaa") + So(err, ShouldBeNil) + }) + }) +} + +func TestFrontendManager_SyncKillAllInstance(t *testing.T) { + Convey("kill instance test", t, func() { + Convey("baseline", func() { + manager, err := newFaaSFrontendManager(2) + So(err, ShouldBeNil) + p := ApplyFunc((*mockUtils.FakeLibruntimeSdkClient).Kill, func(_ *mockUtils.FakeLibruntimeSdkClient, instanceID string, _ int, _ []byte) error { + return nil + }) + defer p.Reset() + manager.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + manager.SyncKillAllInstance() + So(len(manager.instanceCache), ShouldBeZeroValue) + + }) + }) +} + +func Test_prepareCreateOptions_for_InstanceLifeCycle(t *testing.T) { + Convey("testInstanceLifeCycle", t, func() { + conf := &types.FrontendConfig{ + Config: ftypes.Config{ + Image: "", + NodeSelector: map[string]string{}, + }, + } + newFaaSFrontendManager(2) + fcConfig.GetFaaSFrontendConfig() + creatOpt, err := prepareCreateOptions(conf) + So(err, ShouldBeNil) + So(creatOpt[commonconstant.InstanceLifeCycle], ShouldEqual, commonconstant.InstanceLifeCycleDetached) + }) +} + +func Test_prepareCreateOptions_for_NodeAffinity(t *testing.T) { + Convey("testInstanceNodeAffinity", t, func() { + tt := []struct { + name string + nodeAffinity string + nodeAffinityPolicy string + }{ + { + name: "case1", + nodeAffinity: "", + nodeAffinityPolicy: "", + }, + { + name: "case2", + nodeAffinity: "{\"requiredDuringSchedulingIgnoredDuringExecution\":{\"nodeSelectorTerms\":[{\"matchExpressions\":[{\"key\":\"node-role\",\"operator\":\"In\",\"values\":[\"edge\"]}]}]}}", + nodeAffinityPolicy: commonconstant.DelegateNodeAffinityPolicyCoverage, + }, + { + name: "case3", + nodeAffinity: "{\"requiredDuringSchedulingIgnoredDuringExecution\":{\"nodeSelectorTerms\":[{\"matchExpressions\":[{\"key\":\"node-role\",\"operator\":\"In\",\"values\":[\"edge-tagw\"]}]}]}}", + nodeAffinityPolicy: commonconstant.DelegateNodeAffinityPolicyCoverage, + }, + } + conf := &types.FrontendConfig{ + Config: ftypes.Config{ + Image: "", + NodeSelector: map[string]string{}, + }, + } + for _, ttt := range tt { + newFaaSFrontendManager(2) + conf.NodeAffinityPolicy = ttt.nodeAffinityPolicy + conf.NodeAffinity = ttt.nodeAffinity + fcConfig.GetFaaSFrontendConfig() + creatOpt, err := prepareCreateOptions(conf) + So(err, ShouldBeNil) + So(creatOpt[commonconstant.InstanceLifeCycle], ShouldEqual, commonconstant.InstanceLifeCycleDetached) + So(creatOpt[commonconstant.DelegateNodeAffinityPolicy], ShouldEqual, ttt.nodeAffinityPolicy) + So(creatOpt[commonconstant.DelegateNodeAffinity], ShouldEqual, ttt.nodeAffinity) + } + }) +} + +func Test_createExtraParams(t *testing.T) { + Convey("test createExtraParams", t, func() { + Convey("baseline", func() { + aff := "{\"nodeAffinity\":{\"preferredDuringSchedulingIgnoredDuringExecution\":[{\"preference\":{\"matchExpressions\":[{\"key\":\"node-type\",\"operator\":\"In\",\"values\":[\"system\"]}]},\"weight\":1}]}}" + cfg := &types.FrontendConfig{ + Config: ftypes.Config{ + CPU: 100, + Memory: 1024, + Affinity: aff, + }, + } + p, err := createExtraParams(cfg) + So(err, ShouldBeNil) + So(p.CreateOpt[commonconstant.DelegateAffinity], ShouldEqual, aff) + }) + }) +} + +func TestIsExceptInstance(t *testing.T) { + Convey("Test isExceptInstance", t, func() { + conf := &faasfrontendconf.Config{ + InstanceNum: 1, + CPU: 500, + Memory: 500, + SLAQuota: 1000, + Runtime: faasfrontendconf.RuntimeConfig{}, + LocalAuth: nil, + MetaEtcd: etcd3.EtcdConfig{ + Servers: []string{"127.0.0.1:32379"}, + User: "", + Password: "", + SslEnable: false, + AuthType: "Noauth", + UseSecret: false, + SecretName: "etcd-client-secret", + LimitRate: 0, + LimitBurst: 0, + LimitTimeout: 0, + CaFile: "", + CertFile: "", + KeyFile: "", + PassphraseFile: "", + }, + RouterEtcd: etcd3.EtcdConfig{}, + RedisConfig: faasfrontendconf.RedisConfig{ + ClusterID: "", + ServerAddr: "", + ServerMode: "", + Password: "", + EnableTLS: false, + TimeoutConf: redisclient.TimeoutConf{}, + }, + HTTPConfig: nil, + HTTPSConfig: nil, + DataSystemConfig: nil, + BusinessType: 0, + SccConfig: crypto.SccConfig{}, + Image: "", + MemoryControlConfig: nil, + MemoryEvaluatorConfig: nil, + DefaultTenantLimitQuota: 0, + AuthenticationEnable: false, + RawStsConfig: raw.StsConfig{}, + TrafficLimitParams: nil, + NodeSelector: nil, + AzID: "", + ClusterID: "", + ClusterName: "", + AlarmConfig: alarm.Config{}, + Version: "", + AuthConfig: faasfrontendconf.AuthConfig{}, + FunctionNameSeparator: "", + AlarmServerAddress: "", + InvokeMaxRetryTimes: 0, + EtcdLeaseConfig: nil, + HeartbeatConfig: nil, + E2EMaxDelayTime: 0, + RetryConfig: nil, + ShareKeys: faasfrontendconf.ShareKeys{}, + Affinity: "", + } + confStr, err := json.Marshal(conf) + if err != nil { + return + } + + meta := &types.InstanceSpecificationMeta{ + Args: []map[string]string{}, + } + targetSign := "testSign" + if isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return false when args is empty") + } + meta.Args = []map[string]string{ + { + "value": "invalidBase64", + }, + } + if isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return false when base64 decoding fails") + } + + // 测试参数反序列化为前端配置失败的情况 + meta.Args[0]["value"] = base64.StdEncoding.EncodeToString([]byte("invalidJson")) + if isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return false when unmarshalling frontend config fails") + } + + meta.Args[0]["value"] = base64.StdEncoding.EncodeToString([]byte(confStr)) + patches := ApplyFunc(controllerutils.GetFrontendConfigSignature, func(frontendCfg *types.FrontendConfig) string { + return targetSign + }) + defer patches.Reset() + if !isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return true") + } + newStr := make([]byte, len(confStr)+16) + newStr = append([]byte("0000000000000000"), confStr[:]...) + meta.Args[0]["value"] = base64.StdEncoding.EncodeToString(newStr) + if !isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return true") + } + + newStr = append([]byte("0000000000000000"), newStr[:]...) + meta.Args[0]["value"] = base64.StdEncoding.EncodeToString(newStr) + if isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return false") + } + }) + +} diff --git a/yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager/faasfunctionmanager.go b/yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager/faasfunctionmanager.go new file mode 100644 index 0000000..43dc188 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager/faasfunctionmanager.go @@ -0,0 +1,763 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faasfunctionmanager manages faasfunction status and instance ID +package faasfunctionmanager + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + "go.etcd.io/etcd/client/v3" + "k8s.io/apimachinery/pkg/util/wait" + + "yuanrong.org/kernel/runtime/libruntime/api" + + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/constant" + "yuanrong/pkg/system_function_controller/types" + controllerutils "yuanrong/pkg/system_function_controller/utils" +) + +var ( + once sync.Once + // faasFunctionManager is the singaleton of Manager + functionManager *FunctionManager + + createInstanceBackoff = wait.Backoff{ + Steps: constant.DefaultCreateRetryTime, // retry times (include first time) + Duration: constant.DefaultCreateRetryDuration, + Factor: constant.DefaultCreateRetryFactor, + Jitter: constant.DefaultCreateRetryJitter, + } +) + +// FunctionManager manages faasfunction status and instance ID +type FunctionManager struct { + ConfigChangeCh chan *types.ConfigChangeEvent + instanceCache map[string]*types.InstanceSpecification + terminalCache map[string]*types.InstanceSpecification + etcdClient *etcd3.EtcdClient + sdkClient api.LibruntimeAPI // add sdkClientLibruntime to adaptor multi runtime + stopCh chan struct{} + recreateInstanceIDCh chan string + recreateInstanceIDMap sync.Map + sync.RWMutex + count int +} + +// NewFaaSFunctionManager supply a singleton function manager +func NewFaaSFunctionManager(libruntimeAPI api.LibruntimeAPI, + etcdClient *etcd3.EtcdClient, stopCh chan struct{}, size int) *FunctionManager { + once.Do(func() { + functionManager = &FunctionManager{ + instanceCache: make(map[string]*types.InstanceSpecification), + terminalCache: map[string]*types.InstanceSpecification{}, + etcdClient: etcdClient, + sdkClient: libruntimeAPI, + stopCh: stopCh, + count: size, + ConfigChangeCh: make(chan *types.ConfigChangeEvent, constant.DefaultChannelSize), + recreateInstanceIDCh: make(chan string, constant.DefaultChannelSize), + } + go functionManager.recreateInstance() + functionManager.initInstanceCache(etcdClient) + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + err := functionManager.CreateExpectedInstanceCount(ctx) + if err != nil { + log.GetLogger().Errorf("Failed to create expected faasManager instance count, error: %v", err) + } + }() + go functionManager.configChangeProcessor(ctx, cancelFunc) + }) + return functionManager +} + +func (fm *FunctionManager) initInstanceCache(etcdClient *etcd3.EtcdClient) { + response, err := etcdClient.Client.Get(context.Background(), types.FasSManagerPrefixKey, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("Failed to get faasManager instance: %v", err) + return + } + config.ManagerConfigLock.RLock() + targetSign := controllerutils.GetManagerConfigSignature(config.GetFaaSManagerConfig()) + config.ManagerConfigLock.RUnlock() + for _, kv := range response.Kvs { + meta := &types.InstanceSpecificationMeta{} + err := json.Unmarshal(kv.Value, meta) + if err != nil { + log.GetLogger().Errorf("Failed to unmarshal instance specification: %v", err) + continue + } + if isExceptInstance(meta, targetSign) { + funcCtx, cancelFunc := context.WithCancel(context.TODO()) + fm.instanceCache[meta.InstanceID] = &types.InstanceSpecification{ + FuncCtx: funcCtx, + CancelFunc: cancelFunc, + InstanceID: meta.InstanceID, + InstanceSpecificationMeta: *meta, + } + log.GetLogger().Infof("find expected manager instance %s add to cache", meta.InstanceID) + } + } +} + +func isExceptInstance(meta *types.InstanceSpecificationMeta, targetSign string) bool { + if len(meta.Args) == 0 { + log.GetLogger().Errorf("manager ins args is empty, %v", meta) + return false + } + v := meta.Args[0]["value"] + s, err := base64.StdEncoding.DecodeString(v) + if err != nil { + log.GetLogger().Errorf("manager ins failed to decode args: %v", err) + return false + } + cfg := &types.ManagerConfig{} + err = json.Unmarshal(s, cfg) + if err != nil && len(s) > commonconstant.LibruntimeHeaderSize { + // args in libruntime create request with 16 bytes header. + // except libruntime, if other modules use this field, should try yo delete the header + err = json.Unmarshal(s[commonconstant.LibruntimeHeaderSize:], cfg) + if err != nil { + log.GetLogger().Errorf("Failed to unmarshal faasManager config: %v, value: %s", err, s) + return false + } + } + oldSign := controllerutils.GetManagerConfigSignature(cfg) + if oldSign == "" { + log.GetLogger().Errorf("old sign is empty, insID:%s", meta.InstanceID) + return false + } + log.GetLogger().Infof("manager(%s) sign: %s, expect sign: %s", meta.InstanceID, oldSign, targetSign) + return strings.Compare(oldSign, targetSign) == 0 +} + +// GetInstanceCountFromEtcd get current instance count from etcd +func (fm *FunctionManager) GetInstanceCountFromEtcd() map[string]struct{} { + resp, err := fm.etcdClient.Client.Get(context.TODO(), types.FasSManagerPrefixKey, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("failed to search etcd key, prefixKey=%s, err=%s", types.FasSManagerPrefixKey, + err.Error()) + return nil + } + instanceIDs := make(map[string]struct{}, resp.Count) + for _, kv := range resp.Kvs { + instanceID := controllerutils.ExtractInfoFromEtcdKey(string(kv.Key), commonconstant.InstanceIDIndexForInstance) + if instanceID != "" { + instanceIDs[instanceID] = struct{}{} + } + } + log.GetLogger().Infof("get etcd function instance count=%d, %+v", resp.Count, instanceIDs) + return instanceIDs +} + +// CreateExpectedInstanceCount create expected faasManager instance count +func (fm *FunctionManager) CreateExpectedInstanceCount(ctx context.Context) error { + // 此方法在启动的时候调用一次或者在配置更新的时候调用, 目的是将manager实例数量补充至设置的实例数 + // 不需要删除多余实例, 多余实例会在 HandleInstanceUpdate 中删除 + // manager的 instanceID不需要保持一致 + fm.RLock() + currentCount := len(fm.instanceCache) + expectedCount := fm.count - currentCount + fm.RUnlock() + return fm.CreateMultiInstances(ctx, expectedCount) +} + +// CreateMultiInstances create multi instances +func (fm *FunctionManager) CreateMultiInstances(ctx context.Context, count int) error { + if count <= 0 { + log.GetLogger().Infof("no need to create manager instance, kill %d instances instead.", -count) + return fm.KillExceptInstance(-count) + } + log.GetLogger().Infof("need to create %d faas manager instances", count) + + functionArgs, params, err := genFunctionConfig() + if err != nil { + return err + } + group := &sync.WaitGroup{} + var createErr error + for i := 0; i < count; i++ { + group.Add(1) + go func() { + defer group.Done() + if err = fm.createOrRetry(ctx, functionArgs, *params, + config.GetFaaSControllerConfig().EnableRetry); err != nil { + createErr = err + } + }() + } + group.Wait() + if createErr != nil { + return createErr + } + log.GetLogger().Infof("succeed to create %d faas manager instances", count) + return nil +} + +func genFunctionConfig() ([]api.Arg, *commonTypes.ExtraParams, error) { + config.ManagerConfigLock.RLock() + managerConfig := config.GetFaaSManagerConfig() + extraParams, err := createExtraParams(managerConfig) + if err != nil { + config.ManagerConfigLock.RUnlock() + log.GetLogger().Errorf("failed to prepare faasManager createExtraParams, err:%s", err.Error()) + return nil, nil, err + } + managerConf, err := json.Marshal(managerConfig) + config.ManagerConfigLock.RUnlock() + if err != nil { + log.GetLogger().Errorf("faasManager config json marshal failed, err:%s", err.Error()) + return nil, nil, err + } + args := []api.Arg{ + { + Type: api.Value, + Data: managerConf, + }, + } + return args, extraParams, nil +} + +func createExtraParams(conf *types.ManagerConfig) (*commonTypes.ExtraParams, error) { + extraParams := &commonTypes.ExtraParams{} + extraParams.Resources = utils.GenerateResourcesMap(conf.CPU, conf.Memory) + extraParams.CustomExtensions = utils.CreateCustomExtensions(extraParams.CustomExtensions, + utils.MonopolyPolicyValue) + extraParams.ScheduleAffinities = utils.CreatePodAffinity(constant.SystemFuncName, constant.FuncNameFaasmanager, + api.PreferredAntiAffinity) + utils.AddNodeSelector(conf.NodeSelector, extraParams) + createOpts, err := prepareCreateOptions(conf) + extraParams.CreateOpt = createOpts + extraParams.Label = []string{constant.FuncNameFaasmanager} + return extraParams, err +} + +func (fm *FunctionManager) createOrRetry(ctx context.Context, args []api.Arg, + extraParams commonTypes.ExtraParams, enableRetry bool) error { + // 只有首次拉起/扩容时insID为空,需要在拉起失败时防止进入失败回调中的重试逻辑 + if extraParams.DesignatedInstanceID == "" { + instanceID := uuid.New().String() + fm.recreateInstanceIDMap.Store(instanceID, nil) + extraParams.DesignatedInstanceID = instanceID + } + defer fm.recreateInstanceIDMap.Delete(extraParams.DesignatedInstanceID) + if enableRetry { + err := fm.CreateWithRetry(ctx, args, extraParams) + if err != nil { + return err + } + } else { + instanceID := fm.CreateInstance(ctx, types.FasSManagerFunctionKey, args, &extraParams) + if instanceID == "" { + log.GetLogger().Errorf("failed to create function manager instance") + return errors.New("failed to create function manager instance") + } + } + return nil +} + +// CreateWithRetry - +func (fm *FunctionManager) CreateWithRetry(ctx context.Context, args []api.Arg, + extraParams commonTypes.ExtraParams) error { + err := wait.ExponentialBackoffWithContext(ctx, createInstanceBackoff, func(context.Context) (bool, error) { + instanceID := fm.CreateInstance(ctx, types.FasSManagerFunctionKey, args, &extraParams) + if instanceID == "" { + log.GetLogger().Errorf("failed to create funcManager instance") + return false, nil + } + if instanceID == "cancelled" { + return true, fmt.Errorf("create has been cancelled") + } + return true, nil + }) + return err +} + +// CreateInstance create an instance of system function, faaS function +func (fm *FunctionManager) CreateInstance(ctx context.Context, function string, args []api.Arg, + extraParams *commonTypes.ExtraParams) string { + instanceID := extraParams.DesignatedInstanceID + log.GetLogger().Infof("start to create funcManager instance(id=%s)", + extraParams.DesignatedInstanceID) + funcMeta := api.FunctionMeta{Name: &instanceID, FuncID: function, Api: api.PosixApi} + invokeOpts := api.InvokeOptions{ + Cpu: int(extraParams.Resources[controllerutils.ResourcesCPU]), + Memory: int(extraParams.Resources[controllerutils.ResourcesMemory]), + ScheduleAffinities: extraParams.ScheduleAffinities, + CustomExtensions: extraParams.CustomExtensions, + CreateOpt: extraParams.CreateOpt, + Labels: extraParams.Label, + Timeout: 150, + } + createCh := make(chan api.ErrorInfo, 1) + go func() { + _, createErr := fm.sdkClient.CreateInstance(funcMeta, args, invokeOpts) + if createErr != nil { + if errorInfo, ok := createErr.(api.ErrorInfo); ok { + createCh <- errorInfo + } else { + createCh <- api.ErrorInfo{Code: commonconstant.KernelInnerSystemErrCode, Err: createErr} + } + return + } + createCh <- api.ErrorInfo{Code: api.Ok} + }() + if result := fm.waitForInstanceCreation(ctx, createCh, instanceID); result != instanceID { + return result + } + log.GetLogger().Infof("succeed to create manager instance(id=%s)", instanceID) + fm.addInstance(instanceID) + return instanceID +} +func (fm *FunctionManager) waitForInstanceCreation(ctx context.Context, createCh <-chan api.ErrorInfo, + instanceID string) string { + timer := time.NewTimer(types.CreatedTimeout) + select { + case err, ok := <-createCh: + defer timer.Stop() + if !ok { + log.GetLogger().Errorf("result channel of manager instance request is closed") + return "" + } + if !err.IsOk() { + log.GetLogger().Errorf("failed to bring up manager instance(id=%s), err: %v", + instanceID, err.Error()) + fm.clearInstanceAfterError(instanceID) + return "" + } + return instanceID + case <-timer.C: + log.GetLogger().Errorf("time out waiting for instance creation") + fm.clearInstanceAfterError(instanceID) + return "" + case <-ctx.Done(): + log.GetLogger().Errorf("create instance has been cancelled") + fm.clearInstanceAfterError(instanceID) + return "cancelled" + } +} +func (fm *FunctionManager) recreateInstance() { + for { + select { + case <-fm.stopCh: + return + case instanceID, ok := <-fm.recreateInstanceIDCh: + if !ok { + log.GetLogger().Warnf("recreateInstanceIDCh is closed") + return + } + fm.RLock() + if _, exist := fm.instanceCache[instanceID]; exist || len(fm.instanceCache) >= fm.count { + log.GetLogger().Infof("current manager num is %s, no need to recreate instance:%s", + len(fm.instanceCache), instanceID) + fm.RUnlock() + break + } + fm.RUnlock() + ctx, cancel := context.WithCancel(context.Background()) + _, loaded := fm.recreateInstanceIDMap.LoadOrStore(instanceID, cancel) + if loaded { + log.GetLogger().Warnf("instance[%s] is recreating", instanceID) + break + } + args, extraParams, err := genFunctionConfig() + if err != nil { + break + } + extraParams.DesignatedInstanceID = instanceID + if err != nil { + log.GetLogger().Errorf("failed to prepare createExtraParams, err:%s", err.Error()) + break + } + go func() { + time.Sleep(constant.RecreateSleepTime) + log.GetLogger().Infof("start to recover faaSManager instance: %s", instanceID) + if err = fm.createOrRetry(ctx, args, *extraParams, + config.GetFaaSControllerConfig().EnableRetry); err != nil { + log.GetLogger().Errorf("failed to recreate manager instance: %s", instanceID) + } + }() + } + } +} + +func (fm *FunctionManager) clearInstanceAfterError(instanceID string) { + if err := fm.sdkClient.Kill(instanceID, types.KillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill manager instance: %s", instanceID) + } +} + +func (fm *FunctionManager) addInstance(instanceID string) { + fm.Lock() + defer fm.Unlock() + _, exist := fm.instanceCache[instanceID] + if exist { + log.GetLogger().Warnf("the manager instance(id=%s) already exist", instanceID) + return + } + log.GetLogger().Infof("add instance(id=%s) to local cache", instanceID) + fm.instanceCache[instanceID] = &types.InstanceSpecification{InstanceID: instanceID} +} + +// SyncCreateInstance - +func (fm *FunctionManager) SyncCreateInstance(ctx context.Context) error { + log.GetLogger().Infof("start to sync create funcManager instance") + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + args, extraParams, err := genFunctionConfig() + if err != nil { + return err + } + err = fm.createOrRetry(ctx, args, *extraParams, config.GetFaaSControllerConfig().EnableRetry) + if err != nil { + return err + } + return nil +} + +// GetInstanceCache supply local instance cache +func (fm *FunctionManager) GetInstanceCache() map[string]*types.InstanceSpecification { + return fm.instanceCache +} + +// KillInstance kill an instance of system function, faaS function +func (fm *FunctionManager) KillInstance(instanceID string) error { + log.GetLogger().Infof("start to kill funcManager instance %s", instanceID) + return wait.ExponentialBackoffWithContext( + context.Background(), createInstanceBackoff, func(context.Context) (bool, error) { + var err error + err = fm.sdkClient.Kill(instanceID, types.KillSignalVal, []byte{}) + if err != nil && !strings.Contains(err.Error(), "instance not found") { + log.GetLogger().Warnf("failed to kill funcManager instanceID: %s, err: %s", + instanceID, err.Error()) + return false, nil + } + return true, nil + }) +} + +// SyncKillAllInstance kill all instances of system function, faaS manager +func (fm *FunctionManager) SyncKillAllInstance() { + var wg sync.WaitGroup + var deletedInstance []string + fm.Lock() + defer fm.Unlock() + for instanceID := range fm.instanceCache { + wg.Add(1) + go func(instanceID string) { + defer wg.Done() + if err := fm.sdkClient.Kill(instanceID, types.SyncKillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill manager instance(id=%s), err:%s", instanceID, err.Error()) + return + } + deletedInstance = append(deletedInstance, instanceID) + log.GetLogger().Infof("success to kill manager instance(id=%s)", instanceID) + }(instanceID) + } + wg.Wait() + for _, instanceID := range deletedInstance { + log.GetLogger().Infof("delete manager instance(id=%s) from local cache", instanceID) + delete(fm.instanceCache, instanceID) + fm.count-- + } +} + +// KillExceptInstance - +func (fm *FunctionManager) KillExceptInstance(count int) error { + if len(fm.instanceCache) < count { + return nil + } + for instanceID := range fm.instanceCache { + if count <= 0 { + return nil + } + if err := fm.KillInstance(instanceID); err != nil { + log.GetLogger().Errorf("kill manager instance:%s, error:%s", instanceID, err.Error()) + return err + } + count-- + } + return nil +} + +// RecoverInstance recover a faaS manager instance when faults occur +func (fm *FunctionManager) RecoverInstance(info *types.InstanceSpecification) { + err := fm.KillInstance(info.InstanceID) + if err != nil { + log.GetLogger().Warnf("failed to kill instanceID: %s, err: %s", info.InstanceID, err.Error()) + } +} + +// HandleInstanceUpdate handle the etcd PUT event +func (fm *FunctionManager) HandleInstanceUpdate(instanceSpec *types.InstanceSpecification) { + log.GetLogger().Infof("handling funcManager instance %s update", instanceSpec.InstanceID) + if instanceSpec.InstanceSpecificationMeta.InstanceStatus.Code == int(commonconstant.KernelInstanceStatusExiting) { + log.GetLogger().Infof("funcManager instance %s is exiting,no need to update", instanceSpec.InstanceID) + return + } + config.ManagerConfigLock.RLock() + signature := controllerutils.GetManagerConfigSignature(config.GetFaaSManagerConfig()) + config.ManagerConfigLock.RUnlock() + + if isExceptInstance(&instanceSpec.InstanceSpecificationMeta, signature) { + fm.Lock() + currentNum := len(fm.instanceCache) + _, exist := fm.instanceCache[instanceSpec.InstanceID] + if currentNum > fm.count || (currentNum == fm.count && !exist) { + log.GetLogger().Infof("current funcManager num is %s, kill the new instance %s", + currentNum, instanceSpec.InstanceID) + delete(fm.instanceCache, instanceSpec.InstanceID) + fm.Unlock() + if err := fm.sdkClient.Kill(instanceSpec.InstanceID, types.KillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill instance %s error:%s", instanceSpec.InstanceID, + err.Error()) + } + return + } + // add instance to cache if not exist, otherwise update the instance + if !exist { + fm.instanceCache[instanceSpec.InstanceID] = instanceSpec + log.GetLogger().Infof("add funcManager instance %s to cache", instanceSpec.InstanceID) + fm.Unlock() + return + } + fm.instanceCache[instanceSpec.InstanceID].InstanceSpecificationMeta = instanceSpec.InstanceSpecificationMeta + log.GetLogger().Infof("funcManager instance %s is updated, refresh instance cache", + instanceSpec.InstanceID) + fm.Unlock() + return + } + fm.RLock() + _, exist := fm.terminalCache[instanceSpec.InstanceID] + fm.RUnlock() + if !exist { + log.GetLogger().Infof("funcManager instance %s is not expected, start to delete", + instanceSpec.InstanceID) + if err := fm.KillInstance(instanceSpec.InstanceID); err != nil { + log.GetLogger().Errorf("failed to kill funcManager instance %s error:%s", + instanceSpec.InstanceID, err.Error()) + } + } +} + +// HandleInstanceDelete handle the etcd DELETE event +func (fm *FunctionManager) HandleInstanceDelete(instanceSpec *types.InstanceSpecification) { + log.GetLogger().Infof("handling funcManager instance %s delete", instanceSpec.InstanceID) + config.ManagerConfigLock.RLock() + signature := controllerutils.GetManagerConfigSignature(config.GetFaaSManagerConfig()) + config.ManagerConfigLock.RUnlock() + fm.Lock() + delete(fm.instanceCache, instanceSpec.InstanceID) + fm.Unlock() + if isExceptInstance(&instanceSpec.InstanceSpecificationMeta, signature) { + fm.RLock() + if len(fm.instanceCache) < fm.count { + log.GetLogger().Infof("current funcManager instance num is %d, need to recreate instance: %s", + len(fm.instanceCache), instanceSpec.InstanceID) + fm.RUnlock() + fm.recreateInstanceIDCh <- instanceSpec.InstanceID + return + } + fm.RUnlock() + } + cancel, exist := fm.recreateInstanceIDMap.Load(instanceSpec.InstanceID) + if exist { + if cancelFunc, ok := cancel.(context.CancelFunc); ok { + cancelFunc() + log.GetLogger().Infof("funcManager instance %s bring up has been canceled", instanceSpec.InstanceID) + return + } + log.GetLogger().Errorf("get cancel func failed from instanceIDMap, instanceID:%s", + instanceSpec.InstanceID) + } +} + +func (fm *FunctionManager) killAllTerminalInstance() { + fm.Lock() + insMap := fm.terminalCache + fm.terminalCache = map[string]*types.InstanceSpecification{} + fm.Unlock() + for _, ins := range insMap { + insID := ins.InstanceID + go func() { + err := fm.KillInstance(insID) + if err != nil { + log.GetLogger().Errorf("Failed to kill funcManager instance %v, err: %v", insID, err) + } + }() + } +} + +// RollingUpdate - +func (fm *FunctionManager) RollingUpdate(ctx context.Context, event *types.ConfigChangeEvent) { + // 1. 更新 预期实例个数 + // 2. 把不符合预期的instanceCache -> terminalCache + // 3. 从terminalCache随机删除一个实例 同步 + // 4. 创建新实例 同步 + // 5. terminalCache为空时将实例数调谐为预期实例数(同步), instanceCache到达预期数量时清空terminalCache(异步) + newSign := controllerutils.GetManagerConfigSignature(event.ManagerCfg) + fm.Lock() + fm.count = event.ManagerCfg.ManagerInstanceNum + for _, ins := range fm.instanceCache { + if !isExceptInstance(&ins.InstanceSpecificationMeta, newSign) { + fm.terminalCache[ins.InstanceID] = ins + delete(fm.instanceCache, ins.InstanceID) + } + } + fm.Unlock() + for { + select { + case <-ctx.Done(): + event.Error = fmt.Errorf("rolling update has stopped") + event.Done() + return + default: + } + fm.RLock() + if len(fm.instanceCache) == fm.count { + fm.RUnlock() + log.GetLogger().Infof("faasManager instance count arrive at expectation:%d,"+ + " delete all terminating instance", fm.count) + go fm.killAllTerminalInstance() + event.Done() + return + } + if len(fm.terminalCache) == 0 { + fm.RUnlock() + log.GetLogger().Infof("no faasManager instance need to terminate, pull up missing instance count:%d", + fm.count-len(fm.instanceCache)) + err := fm.CreateExpectedInstanceCount(ctx) + if err != nil { + event.Error = err + } + event.Done() + return + } + insID := "" + for _, ins := range fm.terminalCache { + insID = ins.InstanceID + break + } + fm.RUnlock() + log.GetLogger().Infof("start to terminate instance:%d", insID) + var err error + if err = fm.sdkClient.Kill(insID, types.SyncKillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill faasManager instance(id=%s), err:%v", insID, err) + } + fm.Lock() + delete(fm.terminalCache, insID) + fm.Unlock() + time.Sleep(constant.RecreateSleepTime) + err = fm.SyncCreateInstance(ctx) + if err != nil { + event.Error = err + event.Done() + return + } + } +} + +func prepareCreateOptions(conf *types.ManagerConfig) (map[string]string, error) { + podLabels := map[string]string{ + constant.SystemFuncName: constant.FuncNameFaasmanager, + } + labels, err := json.Marshal(podLabels) + if err != nil { + return nil, fmt.Errorf("pod labels json marshal failed, err:%s", err.Error()) + } + delegateRuntime, err := json.Marshal(map[string]interface{}{ + "image": conf.Image, + }) + if err != nil { + return nil, err + } + createOpts := map[string]string{ + commonconstant.DelegatePodLabels: string(labels), + commonconstant.DelegateRuntimeManagerTag: string(delegateRuntime), + commonconstant.InstanceLifeCycle: commonconstant.InstanceLifeCycleDetached, + commonconstant.DelegateNodeAffinity: conf.NodeAffinity, + commonconstant.DelegateNodeAffinityPolicy: conf.NodeAffinityPolicy, + } + if config.GetFaaSManagerConfig().Affinity != "" { + createOpts[commonconstant.DelegateAffinity] = config.GetFaaSManagerConfig().Affinity + } + return createOpts, nil +} + +func (fm *FunctionManager) configChangeProcessor(ctx context.Context, cancel context.CancelFunc) { + if ctx == nil || cancel == nil || fm.ConfigChangeCh == nil { + return + } + for { + select { + case cfgEvent, ok := <-fm.ConfigChangeCh: + if !ok { + cancel() + return + } + managerConfig, insNum := fm.ConfigDiff(cfgEvent) + if managerConfig != nil || insNum != -1 { + log.GetLogger().Infof("manager config or instance num is changed," + + " need to update manager instance") + cancel() + ctx, cancel = context.WithCancel(context.Background()) + config.UpdateManagerConfig(cfgEvent.ManagerCfg) + cfgEvent.Add(1) + go fm.RollingUpdate(ctx, cfgEvent) + } else { + log.GetLogger().Infof("manager config is same as current, no need to update") + } + cfgEvent.Done() + } + } +} + +// ConfigDiff config diff +func (fm *FunctionManager) ConfigDiff(event *types.ConfigChangeEvent) (*types.ManagerConfig, int) { + newSign := controllerutils.GetManagerConfigSignature(event.ManagerCfg) + config.ManagerConfigLock.RLock() + managerOldCfg := config.GetFaaSManagerConfig() + config.ManagerConfigLock.RUnlock() + if strings.Compare(newSign, + controllerutils.GetManagerConfigSignature(managerOldCfg)) == 0 { + if event.ManagerCfg.ManagerInstanceNum != managerOldCfg.ManagerInstanceNum { + config.ManagerConfigLock.Lock() + managerOldCfg.ManagerInstanceNum = event.ManagerCfg.ManagerInstanceNum + config.ManagerConfigLock.Unlock() + return nil, event.ManagerCfg.ManagerInstanceNum + } + return nil, -1 + } + return event.ManagerCfg, event.ManagerCfg.ManagerInstanceNum +} diff --git a/yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager/faasfunctionmanager_test.go b/yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager/faasfunctionmanager_test.go new file mode 100644 index 0000000..772d062 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager/faasfunctionmanager_test.go @@ -0,0 +1,905 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package faasfunctionmanager + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "k8s.io/apimachinery/pkg/util/wait" + "reflect" + "sync" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/alarm" + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/sts" + "yuanrong/pkg/common/faas_common/sts/raw" + commontype "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/utils" + mockUtils "yuanrong/pkg/common/faas_common/utils" + faasmanagerconf "yuanrong/pkg/functionmanager/types" + fcConfig "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/types" + controllerutils "yuanrong/pkg/system_function_controller/utils" +) + +type KvMock struct { +} + +func (k *KvMock) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + response := &clientv3.GetResponse{} + response.Count = 2 + return response, nil +} + +func (k *KvMock) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, + error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Txn(ctx context.Context) clientv3.Txn { + // TODO implement me + panic("implement me") +} + +func initConfig(configString string) { + fcConfig.InitConfig([]byte(configString)) + routerEtcdConfig := etcd3.EtcdConfig{ + Servers: []string{"1.2.3.4:1234"}, + User: "tom", + Password: "**", + } + metaEtcdConfig := etcd3.EtcdConfig{ + Servers: []string{"1.2.3.4:5678"}, + User: "tom", + Password: "**", + } + managerConfig := types.ManagerConfig{ + ManagerInstanceNum: 10, + CPU: 777, + Memory: 777, + RouterEtcd: routerEtcdConfig, + MetaEtcd: metaEtcdConfig, + NodeSelector: map[string]string{"testkey": "testvalue"}, + } + fcConfig.UpdateManagerConfig(&managerConfig) +} + +func newFaaSFunctionManager(size int) (*FunctionManager, error) { + stopCh := make(chan struct{}) + initConfig(`{ + "faasfunctionConfig": { + "slaQuota": 1000, + "functionCapability": 1, + "authenticationEnable": false, + "trafficLimitDisable": true, + "http": { + "resptimeout": 5, + "workerInstanceReadTimeOut": 5, + "maxRequestBodySize": 6 + } + } + } + `) + + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + once = sync.Once{} + manager := NewFaaSFunctionManager(&mockUtils.FakeLibruntimeSdkClient{}, etcdClient, stopCh, size) + time.Sleep(50 * time.Millisecond) + manager.count = size + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.terminalCache = map[string]*types.InstanceSpecification{} + functionManager = manager + return functionManager, nil +} + +func newFaaSFunctionManagerWithRetry(size int) (*FunctionManager, error) { + stopCh := make(chan struct{}) + initConfig(`{ + "faasfunctionConfig": { + "slaQuota": 1000, + "functionCapability": 1, + "authenticationEnable": false, + "trafficLimitDisable": true, + "http": { + "resptimeout": 5, + "workerInstanceReadTimeOut": 5, + "maxRequestBodySize": 6 + } + }, + "enableRetry": true + } + `) + + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + once = sync.Once{} + manager := NewFaaSFunctionManager(&mockUtils.FakeLibruntimeSdkClient{}, etcdClient, stopCh, size) + time.Sleep(50 * time.Millisecond) + manager.count = size + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.terminalCache = map[string]*types.InstanceSpecification{} + functionManager = manager + return functionManager, nil +} + +func TestNewInstanceManager(t *testing.T) { + Convey("Test NewInstanceManager", t, func() { + Convey("Test NewInstanceManager with correct size", func() { + got, err := newFaaSFunctionManager(16) + So(err, ShouldBeNil) + So(got, ShouldNotBeNil) + }) + }) +} + +func Test_initInstanceCache(t *testing.T) { + Convey("test initInstanceCache", t, func() { + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + functionManager = &FunctionManager{ + instanceCache: make(map[string]*types.InstanceSpecification), + terminalCache: map[string]*types.InstanceSpecification{}, + etcdClient: etcdClient, + sdkClient: &mockUtils.FakeLibruntimeSdkClient{}, + count: 1, + stopCh: make(chan struct{}), + recreateInstanceIDCh: make(chan string, 1), + ConfigChangeCh: make(chan *types.ConfigChangeEvent, 100), + } + defer ApplyMethod(reflect.TypeOf(kv), "Get", + func(k *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + bytes, _ := json.Marshal(fcConfig.GetFaaSManagerConfig()) + meta := types.InstanceSpecificationMeta{Function: "test-function", InstanceID: "123", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}} + marshal, _ := json.Marshal(meta) + kvs := []*mvccpb.KeyValue{{Value: marshal}} + response := &clientv3.GetResponse{Kvs: kvs} + response.Count = 1 + return response, nil + }).Reset() + functionManager.initInstanceCache(etcdClient) + cache := functionManager.GetInstanceCache() + So(cache["123"], ShouldNotBeNil) + close(functionManager.stopCh) + }) +} + +func TestInstanceManager_CreateMultiInstances(t *testing.T) { + Convey("Test CreateMultiInstances", t, func() { + Convey("Test CreateMultiInstances with retry", func() { + p := ApplyFunc((*mockUtils.FakeLibruntimeSdkClient).CreateInstance, func(_ *mockUtils.FakeLibruntimeSdkClient, _ api.FunctionMeta, _ []api.Arg, _ api.InvokeOptions) (string, error) { + time.Sleep(100 * time.Millisecond) + return "", api.ErrorInfo{ + Code: 10, + Err: fmt.Errorf("xxxxx"), + StackTracesInfo: api.StackTracesInfo{}, + } + }) + p2 := ApplyFunc(wait.ExponentialBackoffWithContext, func(ctx context.Context, backoff wait.Backoff, condition wait.ConditionWithContextFunc) error { + _, err := condition(ctx) + return err + }) + manager, err := newFaaSFunctionManagerWithRetry(1) + So(err, ShouldBeNil) + ctx, cancel := context.WithCancel(context.TODO()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + err = manager.CreateMultiInstances(ctx, 1) + So(err, ShouldBeError) + + err = manager.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldBeNil) + p.Reset() + p2.Reset() + }) + + instanceMgr, err := newFaaSFunctionManager(3) + So(err, ShouldBeNil) + defer ApplyMethod(reflect.TypeOf(instanceMgr), "CreateWithRetry", + func(fm *FunctionManager, ctx context.Context, args []api.Arg, extraParams commontype.ExtraParams) error { + if value := ctx.Value("err"); value != nil { + if s := value.(string); s == "canceled" { + return fmt.Errorf("create has been cancelled") + } + } + return nil + }).Reset() + + Convey("Test CreateMultiInstances when passed different count", func() { + err = instanceMgr.CreateMultiInstances(context.TODO(), -1) + So(err, ShouldBeNil) + + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldBeNil) + }) + + Convey("Test CreateMultiInstances when failed to get scheduler config", func() { + patches := ApplyFunc(json.Marshal, func(_ interface{}) ([]byte, error) { + return nil, errors.New("json Marshal failed") + }) + defer patches.Reset() + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldNotBeNil) + }) + + Convey("Test CreateMultiInstances when failed to create instance", func() { + patches := ApplyMethod(reflect.TypeOf(instanceMgr), "CreateInstance", + func(_ *FunctionManager, ctx context.Context, _ string, _ []api.Arg, _ *commontype.ExtraParams) string { + return "" + }) + defer patches.Reset() + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldNotBeNil) + }) + }) +} + +func TestInstanceManager_GetInstanceCountFromEtcd(t *testing.T) { + Convey("GetInstanceCountFromEtcd", t, func() { + Convey("failed", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&KvMock{}), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, + error) { + return nil, errors.New("get etcd error") + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + count := instanceMgr.GetInstanceCountFromEtcd() + So(len(count), ShouldEqual, 0) + }) + }) + +} + +func TestInstanceManager_CreateExpectedInstanceCount(t *testing.T) { + Convey("CreateExpectedInstanceCount", t, func() { + Convey("success", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&KvMock{}), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, + error) { + response := &clientv3.GetResponse{} + response.Count = 2 + response.Kvs = []*mvccpb.KeyValue{ + { + Key: []byte("/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasmanager/version/$latest/defaultaz/task-66ccf050-50f6-4835-aa24-c1b15dddb00e/12996c08-0000-4000-8000-db6c3db0fcbf"), + }, + { + Key: []byte("/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasmanager/version/$latest/defaultaz/task-66ccf050-50f6-4835-aa24-c1b15dddb00e/12996c08-0000-4000-8000-db6c3db0fcb2"), + }, + } + return response, nil + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + count := instanceMgr.GetInstanceCountFromEtcd() + So(len(count), ShouldEqual, 2) + + err = instanceMgr.CreateExpectedInstanceCount(context.TODO()) + So(err, ShouldBeNil) + }) + }) +} + +func TestInstanceManager_SyncKillAllInstance(t *testing.T) { + Convey("SyncKillAllInstance", t, func() { + Convey("success", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&mockUtils.FakeLibruntimeSdkClient{}), "Kill", + func(_ *mockUtils.FakeLibruntimeSdkClient, instanceID string, signal int, payload []byte) error { + return nil + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + instanceMgr.SyncKillAllInstance() + So(len(instanceMgr.instanceCache), ShouldEqual, 0) + _, exist := instanceMgr.instanceCache["123"] + So(exist, ShouldEqual, false) + }) + }) +} + +func TestInstanceManager_RecoverInstance(t *testing.T) { + Convey("RecoverInstance", t, func() { + Convey("create failed", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyFunc((*FunctionManager).KillInstance, func(_ *FunctionManager, _ string) error { + return nil + }), + ApplyFunc((*FunctionManager).CreateMultiInstances, + func(_ *FunctionManager, ctx context.Context, _ int) error { + return errors.New("failed to create instances") + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + + instanceMgr.RecoverInstance(&types.InstanceSpecification{}) + }) + }) +} + +func TestKillExceptInstance(t *testing.T) { + Convey("KillExceptInstance", t, func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + Convey("success", func() { + instanceMgr.instanceCache["instanceID1"] = &types.InstanceSpecification{} + instanceMgr.instanceCache["instanceID2"] = &types.InstanceSpecification{} + err = instanceMgr.KillExceptInstance(2) + So(err, ShouldBeNil) + }) + }) +} + +func TestInstanceManager_HandleInstanceUpdate(t *testing.T) { + Convey("HandleInstanceUpdate", t, func() { + Convey("not exist", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "456", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function"}}) + + So(instanceMgr.instanceCache["456"], ShouldEqual, nil) + }) + + Convey("exist", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = make(map[string]*types.InstanceSpecification) + managerConfig := fcConfig.GetFaaSManagerConfig() + bytes, _ := json.Marshal(managerConfig) + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}}, + ) + + So(instanceMgr.instanceCache["123"].InstanceSpecificationMeta.Function, ShouldEqual, "test-function") + + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + RequestID: "test-runtimeID", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}}, + ) + So(instanceMgr.instanceCache["123"].InstanceSpecificationMeta.RequestID, ShouldEqual, "test-runtimeID") + }) + + Convey("delete extra instance", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + instanceMgr.instanceCache = make(map[string]*types.InstanceSpecification) + + defer ApplyMethod(reflect.TypeOf(instanceMgr.etcdClient.Client.KV), "Get", + func(k *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, + error) { + bytes, _ := json.Marshal(fcConfig.GetFaaSManagerConfig()) + meta := types.InstanceSpecificationMeta{Function: "test-function", InstanceID: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}} + marshal, _ := json.Marshal(meta) + kvs := []*mvccpb.KeyValue{{Value: marshal}} + response := &clientv3.GetResponse{Kvs: kvs} + response.Count = 1 + return response, nil + }).Reset() + instanceMgr.initInstanceCache(instanceMgr.etcdClient) + So(instanceMgr.instanceCache["test-function"], ShouldNotBeNil) + managerConfig := fcConfig.GetFaaSManagerConfig() + bytes, _ := json.Marshal(managerConfig) + instanceMgr.HandleInstanceUpdate(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}}, + ) + So(instanceMgr.instanceCache["123"], ShouldBeNil) + }) + }) +} + +func TestInstanceManager_HandleInstanceDelete(t *testing.T) { + Convey("HandleInstanceDelete", t, func() { + Convey("not exist", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + instanceMgr.recreateInstanceIDMap.Store("456", "error type") + instanceMgr.HandleInstanceDelete(&types.InstanceSpecification{ + InstanceID: "456", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function"}}) + _, exist := instanceMgr.instanceCache["123"] + So(exist, ShouldEqual, true) + }) + + Convey("not exist2", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + _, cancel := context.WithCancel(context.Background()) + instanceMgr.recreateInstanceIDMap.Store("456", cancel) + instanceMgr.HandleInstanceDelete(&types.InstanceSpecification{ + InstanceID: "456", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function"}}) + _, exist := instanceMgr.instanceCache["123"] + So(exist, ShouldEqual, true) + }) + + Convey("exist", func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + instanceMgr.HandleInstanceDelete(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function"}}) + _, exist := instanceMgr.instanceCache["123"] + So(exist, ShouldEqual, false) + }) + + Convey("is not except instance", func() { + defer ApplyFunc(isExceptInstance, func(meta *types.InstanceSpecificationMeta, targetSign string) bool { + return true + }).Reset() + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + instanceMgr.HandleInstanceDelete(&types.InstanceSpecification{ + InstanceID: "123", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function"}}) + _, exist := instanceMgr.instanceCache["123"] + So(exist, ShouldEqual, false) + }) + }) +} + +func TestFunctionManager_recreateInstance(t *testing.T) { + Convey("test recreate instance", t, func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + defer ApplyFunc(json.Marshal, func(v any) ([]byte, error) { + return []byte{}, nil + }).Reset() + defer ApplyFunc(fcConfig.GetFaaSControllerConfig, func() *types.Config { + return &types.Config{ + RawStsConfig: raw.StsConfig{StsEnable: true}, + } + }).Reset() + defer ApplyFunc(sts.GenerateSecretVolumeMounts, func(systemFunctionName string) ([]byte, error) { + return []byte{}, nil + }).Reset() + instanceMgr.recreateInstanceIDCh <- "instanceID1" + go instanceMgr.recreateInstance() + time.Sleep(11 * time.Second) + _, ok := instanceMgr.recreateInstanceIDMap.Load("instanceID1") + So(ok, ShouldBeFalse) + + instanceMgr.recreateInstanceIDCh <- "instanceID1" + So(len(instanceMgr.instanceCache), ShouldEqual, 1) + close(instanceMgr.recreateInstanceIDCh) + }) +} + +func TestSyncCreateInstance(t *testing.T) { + Convey("test SyncCreateInstance", t, func() { + manager, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + err = manager.SyncCreateInstance(context.TODO()) + So(err, ShouldBeNil) + + Convey("context canceled", func() { + ctx, cancelFunc := context.WithCancel(context.TODO()) + cancelFunc() + err = manager.SyncCreateInstance(ctx) + So(err, ShouldNotBeNil) + }) + }) +} + +func TestRollingUpdate(t *testing.T) { + Convey("test RollingUpdate", t, func() { + manager, err := newFaaSFunctionManager(2) + So(err, ShouldBeNil) + Convey("same config", func() { + managerConfig := fcConfig.GetFaaSManagerConfig() + bytes, _ := json.Marshal(managerConfig) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + cfgEvent := &types.ConfigChangeEvent{ + ManagerCfg: fcConfig.GetFaaSManagerConfig(), + TraceID: "traceID-123456789", + } + cfgEvent.ManagerCfg.ManagerInstanceNum = 2 + cfgEvent.Add(1) + manager.ConfigChangeCh <- cfgEvent + cfgEvent.Wait() + So(len(manager.instanceCache), ShouldEqual, cfgEvent.ManagerCfg.ManagerInstanceNum) + So(manager.instanceCache["test-1"], ShouldNotBeNil) + So(manager.instanceCache["test-2"], ShouldNotBeNil) + }) + + Convey("different config", func() { + managerConfig := fcConfig.GetFaaSManagerConfig() + bytes, _ := json.Marshal(managerConfig) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + cfg := &types.ManagerConfig{} + utils.DeepCopyObj(fcConfig.GetFaaSManagerConfig(), cfg) + cfgEvent := &types.ConfigChangeEvent{ + ManagerCfg: cfg, + TraceID: "traceID-123456789", + } + cfgEvent.ManagerCfg.CPU = 888 + cfgEvent.ManagerCfg.ManagerInstanceNum = 10 + cfgEvent.Add(1) + manager.ConfigChangeCh <- cfgEvent + cfgEvent.Wait() + close(manager.ConfigChangeCh) + So(len(manager.terminalCache), ShouldEqual, 0) + So(len(manager.instanceCache), ShouldEqual, cfgEvent.ManagerCfg.ManagerInstanceNum) + }) + + Convey("killAllTerminalInstance", func() { + managerConfig := fcConfig.GetFaaSManagerConfig() + bytes, _ := json.Marshal(managerConfig) + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + cfgEvent := &types.ConfigChangeEvent{ + ManagerCfg: fcConfig.GetFaaSManagerConfig(), + TraceID: "traceID-123456789", + } + cfgEvent.ManagerCfg.ManagerInstanceNum = 2 + cfgEvent.Add(1) + manager.RollingUpdate(context.TODO(), cfgEvent) + cfgEvent.Wait() + So(len(manager.instanceCache), ShouldEqual, 2) + }) + }) +} + +func Test_killAllTerminalInstance(t *testing.T) { + Convey("killAllTerminalInstance", t, func() { + manager, err := newFaaSFunctionManager(2) + So(err, ShouldBeNil) + managerConfig := fcConfig.GetFaaSManagerConfig() + bytes, _ := json.Marshal(managerConfig) + manager.terminalCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + + defer ApplyMethod(reflect.TypeOf(manager), "KillInstance", func(fm *FunctionManager, instanceID string) error { + return nil + }).Reset() + manager.killAllTerminalInstance() + So(len(manager.terminalCache), ShouldEqual, 0) + }) +} + +func TestConfigDiff(t *testing.T) { + Convey("ConfigDiff", t, func() { + manager, err := newFaaSFunctionManager(2) + So(err, ShouldBeNil) + Convey("same config ,different num", func() { + cfg := &types.ManagerConfig{} + utils.DeepCopyObj(fcConfig.GetFaaSManagerConfig(), cfg) + cfgEvent := &types.ConfigChangeEvent{ + ManagerCfg: cfg, + TraceID: "traceID-123456789", + } + cfgEvent.ManagerCfg.ManagerInstanceNum = 100 + _, num := manager.ConfigDiff(cfgEvent) + So(num, ShouldEqual, 100) + }) + }) +} + +func TestFunctionManager_KillExceptInstance(t *testing.T) { + Convey("KillExceptInstance", t, func() { + instanceMgr, err := newFaaSFunctionManager(1) + So(err, ShouldBeNil) + err = instanceMgr.KillExceptInstance(1) + So(err, ShouldBeNil) + }) +} + +func TestFunctionManager_KillInstance(t *testing.T) { + Convey("kill instance test", t, func() { + Convey("baseline", func() { + manager, err := newFaaSFunctionManager(2) + So(err, ShouldBeNil) + i := 0 + p := ApplyFunc((*mockUtils.FakeLibruntimeSdkClient).Kill, + func(_ *mockUtils.FakeLibruntimeSdkClient, instanceID string, _ int, _ []byte) error { + if i == 0 { + i = 1 + return fmt.Errorf("error") + } + return nil + }) + defer p.Reset() + err = manager.KillInstance("aaa") + So(err, ShouldBeNil) + }) + }) +} + +func TestFunctionManager_CreateWithRetry(t *testing.T) { + Convey("create with retry test", t, func() { + Convey("baseline", func() { + instanceMgr, _ := newFaaSFunctionManager(1) + p := ApplyFunc((*FunctionManager).CreateInstance, + func(_ *FunctionManager, ctx context.Context, function string, + args []api.Arg, extraParams *commontype.ExtraParams) string { + return "aaa" + }) + defer p.Reset() + err := instanceMgr.CreateWithRetry(context.TODO(), nil, commontype.ExtraParams{}) + So(err, ShouldBeNil) + }) + Convey("retry success", func() { + instanceMgr, _ := newFaaSFunctionManager(1) + ins := "" + p := ApplyFunc((*FunctionManager).CreateInstance, + func(_ *FunctionManager, ctx context.Context, function string, + args []api.Arg, extraParams *commontype.ExtraParams) string { + if ins == "" { + ins = "aaa" + return "" + } + return ins + }) + defer p.Reset() + err := instanceMgr.CreateWithRetry(context.TODO(), nil, commontype.ExtraParams{}) + So(err, ShouldBeNil) + }) + }) +} + +func Test_prepareCreateOptions(t *testing.T) { + Convey("prepareCreateOptions test", t, func() { + Convey("baseline", func() { + p := ApplyFunc(fcConfig.GetFaaSManagerConfig, func() *types.ManagerConfig { + return &types.ManagerConfig{ + Affinity: "aaaa", + } + }) + defer p.Reset() + options, err := prepareCreateOptions(&types.ManagerConfig{ + Affinity: "aaaa", + }) + So(err, ShouldBeNil) + So(options[commonconstant.DelegateAffinity], ShouldEqual, "aaaa") + }) + + Convey("baseline1", func() { + nodeAffinity := "{\"nodeAffinity\":{\"preferredDuringSchedulingIgnoredDuringExecution\":[{\"preference\":{\"matchExpressions\":[{\"key\":\"node-type\",\"operator\":\"In\",\"values\":[\"system\"]}]},\"weight\":1}]}}" + nodeAffinityPolicy := "coverage" + p := ApplyFunc(fcConfig.GetFaaSManagerConfig, func() *types.ManagerConfig { + return &types.ManagerConfig{ + NodeAffinity: nodeAffinity, + NodeAffinityPolicy: nodeAffinityPolicy, + } + }) + defer p.Reset() + options, err := prepareCreateOptions(&types.ManagerConfig{ + NodeAffinity: nodeAffinity, + NodeAffinityPolicy: nodeAffinityPolicy, + }) + So(err, ShouldBeNil) + So(options[commonconstant.DelegateNodeAffinity], ShouldEqual, nodeAffinity) + So(options[commonconstant.DelegateNodeAffinityPolicy], ShouldEqual, nodeAffinityPolicy) + }) + Convey("baseline2", func() { + nodeAffinity := "{\"nodeAffinity\":{\"preferredDuringSchedulingIgnoredDuringExecution\":[{\"preference\":{\"matchExpressions\":[{\"key\":\"node-type\",\"operator\":\"In\",\"values\":[\"system\"]}]},\"weight\":1}]}}" + nodeAffinityPolicy := "aggregation" + p := ApplyFunc(fcConfig.GetFaaSManagerConfig, func() *types.ManagerConfig { + return &types.ManagerConfig{ + NodeAffinity: nodeAffinity, + NodeAffinityPolicy: nodeAffinityPolicy, + } + }) + defer p.Reset() + options, err := prepareCreateOptions(&types.ManagerConfig{ + NodeAffinity: nodeAffinity, + NodeAffinityPolicy: nodeAffinityPolicy, + }) + So(err, ShouldBeNil) + So(options[commonconstant.DelegateNodeAffinity], ShouldEqual, nodeAffinity) + So(options[commonconstant.DelegateNodeAffinityPolicy], ShouldEqual, nodeAffinityPolicy) + }) + }) +} + +func TestIsExceptInstance(t *testing.T) { + Convey("Test isExceptInstance", t, func() { + conf := &faasmanagerconf.ManagerConfig{ + MetaEtcd: etcd3.EtcdConfig{ + Servers: []string{"127.0.0.1:32379"}, + User: "", + Password: "", + SslEnable: false, + AuthType: "Noauth", + UseSecret: false, + SecretName: "etcd-client-secret", + LimitRate: 0, + LimitBurst: 0, + LimitTimeout: 0, + CaFile: "", + CertFile: "", + KeyFile: "", + PassphraseFile: "", + }, + RouterEtcd: etcd3.EtcdConfig{}, + FunctionCapability: 0, + SccConfig: crypto.SccConfig{}, + AuthenticationEnable: false, + AlarmConfig: alarm.Config{}, + } + confStr, err := json.Marshal(conf) + if err != nil { + return + } + + meta := &types.InstanceSpecificationMeta{ + Args: []map[string]string{}, + } + targetSign := "testSign" + if isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return false when args is empty") + } + meta.Args = []map[string]string{ + { + "value": "invalidBase64", + }, + } + if isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return false when base64 decoding fails") + } + + // 测试参数反序列化为前端配置失败的情况 + meta.Args[0]["value"] = base64.StdEncoding.EncodeToString([]byte("invalidJson")) + if isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return false when unmarshalling frontend config fails") + } + + meta.Args[0]["value"] = base64.StdEncoding.EncodeToString([]byte(confStr)) + patches := ApplyFunc(controllerutils.GetManagerConfigSignature, func(managerCfg *types.ManagerConfig) string { + return targetSign + }) + defer patches.Reset() + if !isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return true") + } + newStr := make([]byte, len(confStr)+16) + newStr = append([]byte("0000000000000000"), confStr[:]...) + meta.Args[0]["value"] = base64.StdEncoding.EncodeToString(newStr) + if !isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return true") + } + + newStr = append([]byte("0000000000000000"), newStr[:]...) + meta.Args[0]["value"] = base64.StdEncoding.EncodeToString(newStr) + if isExceptInstance(meta, targetSign) { + t.Errorf("Expected isExceptInstance to return false") + } + }) + +} diff --git a/yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager/faasschedulermanager.go b/yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager/faasschedulermanager.go new file mode 100644 index 0000000..be4bc55 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager/faasschedulermanager.go @@ -0,0 +1,882 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package faasschedulermanager - +package faasschedulermanager + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "go.etcd.io/etcd/client/v3" + "k8s.io/apimachinery/pkg/util/wait" + + "yuanrong.org/kernel/runtime/libruntime/api" + + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/sts" + commonTypes "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/common/uuid" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/constant" + "yuanrong/pkg/system_function_controller/state" + "yuanrong/pkg/system_function_controller/types" + controllerutils "yuanrong/pkg/system_function_controller/utils" +) + +var ( + createInstanceBackoff = wait.Backoff{ + Steps: constant.DefaultCreateRetryTime, // retry times (include first time) + Duration: constant.DefaultCreateRetryDuration, + Factor: constant.DefaultCreateRetryFactor, + Jitter: constant.DefaultCreateRetryJitter, + } +) + +// SchedulerManager manages faaS scheduler specification +type SchedulerManager struct { + instanceCache map[string]*types.InstanceSpecification + terminalCache map[string]*types.InstanceSpecification + etcdClient *etcd3.EtcdClient + sdkClient api.LibruntimeAPI + ConfigChangeCh chan *types.ConfigChangeEvent + stopCh chan struct{} + recreateInstanceIDMap sync.Map + sync.RWMutex + count int + tenantID string +} + +// NewFaaSSchedulerManager create a scheduler instance manager +func NewFaaSSchedulerManager(libruntimeAPI api.LibruntimeAPI, + etcdClient *etcd3.EtcdClient, stopCh chan struct{}, size int, tenantID string) *SchedulerManager { + schedulerManager := &SchedulerManager{ + instanceCache: map[string]*types.InstanceSpecification{}, + terminalCache: map[string]*types.InstanceSpecification{}, + etcdClient: etcdClient, + sdkClient: libruntimeAPI, + count: size, + tenantID: tenantID, + stopCh: stopCh, + ConfigChangeCh: make(chan *types.ConfigChangeEvent, constant.DefaultChannelSize), + } + schedulerManager.initInstanceCache(etcdClient) + ctx, cancelFunc := context.WithCancel(context.Background()) + go schedulerManager.configChangeProcessor(ctx, cancelFunc) + go func() { + err := schedulerManager.CreateExpectedInstanceCount(ctx) + if err != nil { + log.GetLogger().Errorf("Failed to create expected scheduler instance count, error: %v", err) + return + } + log.GetLogger().Infof("succeed to create expected scheduler instance count %d for tenantID %s", + size, tenantID) + }() + return schedulerManager +} + +func (s *SchedulerManager) initInstanceCache(etcdClient *etcd3.EtcdClient) { + response, err := etcdClient.Client.Get(context.Background(), types.FaaSSchedulerPrefixKey, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("Failed to get scheduler instance: %v", err) + return + } + config.SchedulerConfigLock.RLock() + targetSign := controllerutils.GetSchedulerConfigSignature(config.GetFaaSSchedulerConfig()) + config.SchedulerConfigLock.RUnlock() + for _, kv := range response.Kvs { + meta := &types.InstanceSpecificationMeta{} + err := json.Unmarshal(kv.Value, meta) + if err != nil { + log.GetLogger().Errorf("Failed to unmarshal instance specification: %v", err) + continue + } + if isExpectInstance(meta, targetSign) { + tenantID := "" + if meta.CreateOptions != nil { + tenantID = meta.CreateOptions[constant.SchedulerExclusivity] + } + if s.tenantID == tenantID { + funcCtx, cancelFunc := context.WithCancel(context.TODO()) + s.instanceCache[meta.InstanceID] = &types.InstanceSpecification{ + FuncCtx: funcCtx, + CancelFunc: cancelFunc, + InstanceID: meta.InstanceID, + InstanceSpecificationMeta: *meta, + } + log.GetLogger().Infof("find expected scheduler instance %s add to cache for tenantID %s", + meta.InstanceID, tenantID) + } + } + state.Update(meta.InstanceID, types.StateUpdate, types.FaasSchedulerInstanceState+s.tenantID) + } +} + +func isExpectInstance(meta *types.InstanceSpecificationMeta, targetSign string) bool { + if len(meta.Args) == 0 { + log.GetLogger().Errorf("args is empty, %v", meta) + return false + } + v := meta.Args[0]["value"] + s, err := base64.StdEncoding.DecodeString(v) + if err != nil { + log.GetLogger().Errorf("Failed to decode args: %v", err) + return false + } + cfg := &types.SchedulerConfig{} + err = json.Unmarshal(s, cfg) + if err != nil && len(s) > commonconstant.LibruntimeHeaderSize { + // args in libruntime create request with 16 bytes header. + // except libruntime, if other modules use this field, should try yo delete the header + err = json.Unmarshal(s[commonconstant.LibruntimeHeaderSize:], cfg) + if err != nil { + log.GetLogger().Errorf("Failed to unmarshal scheduler config: %v, value: %s", err, s) + return false + } + } + oldSign := controllerutils.GetSchedulerConfigSignature(cfg) + if oldSign == "" { + log.GetLogger().Errorf("old sign is empty, insID:%s", meta.InstanceID) + return false + } + log.GetLogger().Infof("scheduler(%s) sign: %s, expect sign: %s", meta.InstanceID, oldSign, targetSign) + return strings.Compare(oldSign, targetSign) == 0 +} + +// CreateExpectedInstanceCount create expected scheduler instance count +func (s *SchedulerManager) CreateExpectedInstanceCount(ctx context.Context) error { + // 此方法在启动的时候调用一次, 目的是将scheduler实例数量补充至设置的实例数 + // 不需要删除多余实例, 多余实例会在 HandleInstanceUpdate 中删除 + // 因为需要保证拉起的scheduler的instanceID不变, 该函数需要同步FaasInstance缓存 + s.Lock() + currentCount := len(s.instanceCache) + expectedCount := s.count - currentCount + recoverIns := state.GetState().FaasInstance[types.FaasSchedulerInstanceState+s.tenantID] + + // 这里将已有实例全部加到FaasInstance中, 保证FaasInstance的缓存包含已有实例(防止有些极端情况下FaasInstance缓存丢失) + for _, meta := range s.instanceCache { + _, exist := recoverIns[meta.InstanceID] + if !exist { + state.Update(meta.InstanceID, types.StateUpdate, types.FaasSchedulerInstanceState+s.tenantID) + } + } + + var insList []string + // 同步FaasInstance缓存 + // 如果已有实例数没有达到设置的值, 拉起新实例的instanceID需要从缓存中取出, 如果缓存中没有则新建 + // 如果已有实例数已经达到了设置的值, 将缓存中多余的instanceID删除 + for instanceID := range recoverIns { + if _, exist := s.instanceCache[instanceID]; !exist { + if expectedCount <= 0 { + state.Update(instanceID, types.StateDelete, types.FaasSchedulerInstanceState+s.tenantID) + continue + } + insList = append(insList, instanceID) + expectedCount-- + } + } + s.Unlock() + wg := sync.WaitGroup{} + for _, insID := range insList { + go func(instanceID string) { + log.GetLogger().Infof("instance %s not exist, need recover for %s", instanceID, s.tenantID) + wg.Add(1) + time.Sleep(constant.RecreateSleepTime) + err := s.SyncCreateInstanceByID(ctx, instanceID) + wg.Done() + if err != nil { + return + } + }(insID) + } + err := s.CreateMultiInstances(ctx, expectedCount) + wg.Wait() + return err +} + +// CreateMultiInstances create multi instances +func (s *SchedulerManager) CreateMultiInstances(ctx context.Context, count int) error { + if count <= 0 { + log.GetLogger().Infof("no need to create scheduler instance, kill %d instances instead.", -count) + return s.KillExceptInstance(-count) + } + log.GetLogger().Infof("need to create %d faas scheduler instances", count) + + args, extraParams, err := genFunctionConfig(s.tenantID) + if err != nil { + return err + } + group := &sync.WaitGroup{} + var createErr error + for i := 0; i < count; i++ { + group.Add(1) + go func() { + defer group.Done() + if err = s.createOrRetry(ctx, args, *extraParams, + config.GetFaaSControllerConfig().EnableRetry); err != nil { + createErr = err + } + }() + } + group.Wait() + if createErr != nil { + return createErr + } + log.GetLogger().Infof("succeed to create %d faaS scheduler instances", count) + return nil +} + +func createExtraParams(conf *types.SchedulerConfig, tenantID string) (*commonTypes.ExtraParams, error) { + config.SchedulerConfigLock.RLock() + defer config.SchedulerConfigLock.RUnlock() + extraParams := &commonTypes.ExtraParams{} + extraParams.Resources = utils.GenerateResourcesMap(conf.CPU, conf.Memory) + extraParams.CustomExtensions = utils.CreateCustomExtensions(extraParams.CustomExtensions, + utils.MonopolyPolicyValue) + extraParams.ScheduleAffinities = utils.CreatePodAffinity(constant.SystemFuncName, constant.FuncNameFaasscheduler, + api.PreferredAntiAffinity) + utils.AddNodeSelector(conf.NodeSelector, extraParams) + encryptMap := map[string]string{config.MetaEtcdPwdKey: conf.MetaETCDConfig.Password} + encryptData, err := json.Marshal(encryptMap) + if err != nil { + log.GetLogger().Errorf("encryptData json marshal failed, err:%s", err.Error()) + return nil, err + } + createOptions := make(map[string]string) + createOptions = utils.CreateCreateOptions(createOptions, commonconstant.EnvDelegateEncrypt, string(encryptData)) + + delegateRuntime, err := json.Marshal(map[string]interface{}{ + "image": conf.Image, + }) + if err != nil { + return nil, err + } + createOptions[commonconstant.DelegateRuntimeManagerTag] = string(delegateRuntime) + createOptions[constant.InitCallTimeoutKey] = strconv.Itoa(int(types.CreatedTimeout.Seconds())) + createOptions[constant.ConcurrencyKey] = strconv.Itoa(constant.MaxConcurrency) + createOptions[commonconstant.InstanceLifeCycle] = commonconstant.InstanceLifeCycleDetached + createOptions[constant.SchedulerExclusivity] = tenantID + createOptions[commonconstant.DelegateNodeAffinity] = conf.NodeAffinity + createOptions[commonconstant.DelegateNodeAffinityPolicy] = conf.NodeAffinityPolicy + if conf.Affinity != "" { + createOptions[commonconstant.DelegateAffinity] = conf.Affinity + } + makePodLabel(createOptions, tenantID) + if config.GetFaaSControllerConfig().RawStsConfig.StsEnable { + secretVolumeMounts, err := sts.GenerateSecretVolumeMounts(sts.FaaSSchedulerName, utils.NewVolumeBuilder()) + if err != nil { + return nil, err + } + createOptions = utils.CreateCreateOptions(createOptions, commonconstant.DelegateVolumeMountKey, + string(secretVolumeMounts)) + } + if config.GetFaaSSchedulerConfig().ConcurrentNum == 0 { + createOptions[constant.ConcurrentNumKey] = strconv.Itoa(constant.DefaultConcurrentNum) + } else { + createOptions[constant.ConcurrentNumKey] = strconv.Itoa(config.GetFaaSSchedulerConfig().ConcurrentNum) + } + err = prepareVolumesAndMounts(conf, createOptions) + if err != nil { + return nil, err + } + + extraParams.CreateOpt = createOptions + extraParams.Label = []string{constant.FuncNameFaasscheduler} + return extraParams, nil +} + +func prepareVolumesAndMounts(conf *types.SchedulerConfig, createOptions map[string]string) error { + if conf == nil || createOptions == nil { + return fmt.Errorf("parameter of prepareVolumesAndMounts is nil") + } + var delegateVolumes string + var delegateVolumesMounts string + var delegateInitVolumesMounts string + builder := utils.NewVolumeBuilder() + + if conf.RawStsConfig.StsEnable { + delegateVolumesMountsData, err := sts.GenerateSecretVolumeMounts(sts.FaaSSchedulerName, builder) + if err != nil { + return err + } + delegateVolumesMounts = string(delegateVolumesMountsData) + } + + delegateInitVolumesMounts = delegateVolumesMounts + log.GetLogger().Debugf("delegateVolumes: %s, delegateVolumesMounts: %s, delegateInitVolumesMounts: %s", + delegateVolumes, delegateVolumesMounts, delegateInitVolumesMounts) + if delegateVolumes != "" { + createOptions[commonconstant.DelegateVolumesKey] = delegateVolumes + } + if delegateVolumesMounts != "" { + createOptions[commonconstant.DelegateVolumeMountKey] = delegateVolumesMounts + } + if delegateInitVolumesMounts != "" { + createOptions[commonconstant.DelegateInitVolumeMountKey] = delegateInitVolumesMounts + } + return nil +} + +func (s *SchedulerManager) createOrRetry(ctx context.Context, args []api.Arg, extraParams commonTypes.ExtraParams, + enableRetry bool) error { + // 只有首次拉起/扩容时insID为空,需要在拉起失败时防止进入失败回调中的重试逻辑 + if extraParams.DesignatedInstanceID == "" { + instanceID := uuid.New().String() + s.recreateInstanceIDMap.Store(instanceID, nil) + extraParams.DesignatedInstanceID = instanceID + defer s.recreateInstanceIDMap.Delete(instanceID) + } + if enableRetry { + err := s.CreateWithRetry(ctx, args, &extraParams) + if err != nil { + return err + } + } else { + instanceID := s.CreateInstance(ctx, types.FaaSSchedulerFunctionKey, args, &extraParams) + if instanceID == "" { + log.GetLogger().Errorf("failed to create scheduler manager instance") + return errors.New("failed to create scheduler manager instance") + } + } + return nil +} + +// CreateWithRetry - +func (s *SchedulerManager) CreateWithRetry(ctx context.Context, args []api.Arg, + extraParams *commonTypes.ExtraParams) error { + err := wait.ExponentialBackoffWithContext(ctx, createInstanceBackoff, func(context.Context) (done bool, err error) { + instanceID := s.CreateInstance(ctx, types.FaaSSchedulerFunctionKey, args, extraParams) + if instanceID == "" { + log.GetLogger().Errorf("failed to create scheduler instance") + return false, nil + } + if instanceID == "cancelled" { + return true, fmt.Errorf("create has been cancelled") + } + return true, nil + }) + return err +} + +func makePodLabel(createOptions map[string]string, tenantID string) { + podLabels := map[string]string{ + constant.SystemFuncName: constant.FuncNameFaasscheduler, + constant.SchedulerExclusivity: tenantID, + } + labels, err := json.Marshal(podLabels) + if err != nil { + log.GetLogger().Warnf("faasscheduler label json marshal error : %s", err.Error()) + return + } + if createOptions == nil { + return + } + createOptions[commonconstant.DelegatePodLabels] = string(labels) +} + +// CreateInstance create scheduler instance +func (s *SchedulerManager) CreateInstance(ctx context.Context, function string, args []api.Arg, + extraParams *commonTypes.ExtraParams) string { + instanceID := extraParams.DesignatedInstanceID + funcMeta := api.FunctionMeta{FuncID: function, Api: api.PosixApi, Name: &instanceID} + log.GetLogger().Infof("start to create scheduler instance(FuncID=%s)", funcMeta.FuncID) + invokeOpts := api.InvokeOptions{ + Cpu: int(extraParams.Resources[controllerutils.ResourcesCPU]), + Memory: int(extraParams.Resources[controllerutils.ResourcesMemory]), + ScheduleAffinities: extraParams.ScheduleAffinities, + CustomExtensions: extraParams.CustomExtensions, + CreateOpt: extraParams.CreateOpt, + Labels: extraParams.Label, + Timeout: 150, + } + createCh := make(chan api.ErrorInfo, 1) + go func() { + _, createErr := s.sdkClient.CreateInstance(funcMeta, args, invokeOpts) + if createErr != nil { + if errorInfo, ok := createErr.(api.ErrorInfo); ok { + createCh <- errorInfo + } else { + createCh <- api.ErrorInfo{Code: commonconstant.KernelInnerSystemErrCode, Err: createErr} + } + return + } + createCh <- api.ErrorInfo{Code: api.Ok} + }() + timer := time.NewTimer(types.CreatedTimeout) + select { + case notifyErr, ok := <-createCh: + defer timer.Stop() + if !ok { + log.GetLogger().Errorf("result channel of scheduler instance request is closed") + return "" + } + if !notifyErr.IsOk() { + log.GetLogger().Errorf("failed to bring up scheduler instance(id=%s), code:%d, err:%s", instanceID, + notifyErr.Code, notifyErr.Error()) + s.clearInstanceAfterError(instanceID) + return "" + } + case <-timer.C: + log.GetLogger().Errorf("time out waiting for instance(id=%s) creation", instanceID) + s.clearInstanceAfterError(instanceID) + return "" + case <-ctx.Done(): + log.GetLogger().Errorf("instance(id=%s) creation has been cancelled", instanceID) + s.clearInstanceAfterError(instanceID) + return "cancelled" + } + s.addInstance(instanceID) + log.GetLogger().Infof("succeed to create scheduler instance(id=%s)", instanceID) + return instanceID +} + +func genFunctionConfig(tenantID string) ([]api.Arg, *commonTypes.ExtraParams, error) { + config.SchedulerConfigLock.RLock() + schedulerConfig := config.GetFaaSSchedulerConfig() + extraParams, err := createExtraParams(schedulerConfig, tenantID) + if err != nil { + config.SchedulerConfigLock.RUnlock() + log.GetLogger().Errorf("failed to prepare faasscheduler createExtraParams, err:%s", err.Error()) + return nil, nil, err + } + schedulerConf, err := json.Marshal(schedulerConfig) + config.SchedulerConfigLock.RUnlock() + if err != nil { + log.GetLogger().Errorf("faaSScheduler config json marshal failed, err:%s", err.Error()) + return nil, nil, err + } + schedulerArgs := []api.Arg{ + { + Type: api.Value, + Data: schedulerConf, + }, + } + return schedulerArgs, extraParams, nil +} + +// SyncCreateInstanceByID - +func (s *SchedulerManager) SyncCreateInstanceByID(ctx context.Context, instanceID string) error { + var cancel context.CancelFunc + if ctx == nil { + ctx, cancel = context.WithCancel(context.Background()) + } + _, exist := s.recreateInstanceIDMap.LoadOrStore(instanceID, cancel) + if exist { + log.GetLogger().Warnf("scheduler instance[%s] is creating", instanceID) + return nil + } + defer s.recreateInstanceIDMap.Delete(instanceID) + log.GetLogger().Infof("start to sync create scheduler instance:%s", instanceID) + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + args, extraParams, err := genFunctionConfig(s.tenantID) + if err != nil { + return err + } + extraParams.DesignatedInstanceID = instanceID + err = s.createOrRetry(ctx, args, *extraParams, config.GetFaaSControllerConfig().EnableRetry) + if err != nil { + return err + } + return nil +} + +// GetInstanceCountFromEtcd get current instance count from etcd +func (s *SchedulerManager) GetInstanceCountFromEtcd() map[string]struct{} { + resp, err := s.etcdClient.Client.Get(context.TODO(), types.FaaSSchedulerPrefixKey, clientv3.WithPrefix()) + if err != nil { + log.GetLogger().Errorf("failed to search etcd key, prefixKey=%s, err=%s", types.FaaSSchedulerPrefixKey, + err.Error()) + return nil + } + instanceIDs := make(map[string]struct{}, resp.Count) + for _, kv := range resp.Kvs { + instanceID := controllerutils.ExtractInfoFromEtcdKey(string(kv.Key), commonconstant.InstanceIDIndexForInstance) + if instanceID != "" { + instanceIDs[instanceID] = struct{}{} + } + } + log.GetLogger().Infof("get etcd scheduler instance count=%d, %+v", resp.Count, instanceIDs) + return instanceIDs +} + +func (s *SchedulerManager) clearInstanceAfterError(instanceID string) { + if err := s.sdkClient.Kill(instanceID, types.KillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill scheduler instance: %s", instanceID) + } +} + +// addInstance add an instance to cache +func (s *SchedulerManager) addInstance(instanceID string) { + s.Lock() + _, exists := s.instanceCache[instanceID] + if exists { + log.GetLogger().Warnf("the instance(id=%s) already exist for %s", instanceID, s.tenantID) + s.Unlock() + return + } + log.GetLogger().Infof("add instance(id=%s) to cache for %s", instanceID, s.tenantID) + s.instanceCache[instanceID] = &types.InstanceSpecification{InstanceID: instanceID} + s.Unlock() + state.Update(instanceID, types.StateUpdate, types.FaasSchedulerInstanceState+s.tenantID) +} + +// GetInstanceCache acquire instance cache +func (s *SchedulerManager) GetInstanceCache() map[string]*types.InstanceSpecification { + return s.instanceCache +} + +// SyncKillAllInstance kill all instances of system function, faaS scheduler +func (s *SchedulerManager) SyncKillAllInstance() { + s.Lock() + cache := s.instanceCache + s.instanceCache = map[string]*types.InstanceSpecification{} + s.Unlock() + var wg sync.WaitGroup + for instanceID := range cache { + wg.Add(1) + go func(instanceID string) { + defer wg.Done() + if err := s.sdkClient.Kill(instanceID, types.SyncKillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill scheduler instance(id=%s), err:%s", instanceID, err.Error()) + return + } + log.GetLogger().Infof("success to kill scheduler instance(id=%s)", instanceID) + }(instanceID) + } + wg.Wait() +} + +// KillInstance kill a scheduler instance +func (s *SchedulerManager) KillInstance(instanceID string) error { + log.GetLogger().Infof("start to kill instance %s", instanceID) + return wait.ExponentialBackoffWithContext( + context.Background(), createInstanceBackoff, func(context.Context) (bool, error) { + var err error + err = s.sdkClient.Kill(instanceID, types.KillSignalVal, []byte{}) + if err != nil && !strings.Contains(err.Error(), "instance not found") { + log.GetLogger().Warnf("failed to kill instanceID: %s, err: %s", instanceID, err.Error()) + return false, nil + } + return true, nil + }) +} + +// KillExceptInstance - +func (s *SchedulerManager) KillExceptInstance(count int) error { + if len(s.instanceCache) < count { + return nil + } + for instanceID := range s.instanceCache { + if count <= 0 { + return nil + } + if err := s.KillInstance(instanceID); err != nil { + log.GetLogger().Errorf("kill frontend instance:%s, error:%s", instanceID, err.Error()) + return err + } + count-- + } + return nil +} + +// RecoverInstance recover a faaS scheduler instance when faults occur +func (s *SchedulerManager) RecoverInstance(info *types.InstanceSpecification) { + err := s.KillInstance(info.InstanceID) + if err != nil { + log.GetLogger().Warnf("failed to kill instanceID: %s, err: %s", info.InstanceID, err.Error()) + } +} + +// HandleInstanceUpdate handles function update +func (s *SchedulerManager) HandleInstanceUpdate(schedulerSpec *types.InstanceSpecification) { + // 有多种情况会调进此函数 + // 1. 创建实例回调, 新建实例会多次进入此方法 + // 2. controller首次监听etcd, etcd当前存量实例会进入此方法 + + // 理论上进入到这个方法的实例只有以下两种 + // 1. 版本正确的实例, 需要写入缓存, 并判断实例个数, 超出的需要删除, 未超出不做处理 + // 2. 版本不正确的实例有两种情况: + // (1) 该实例存在于terminating缓存中, 实例可能在进行滚动更新时, 这时应该不进行处理 + // (2) 该实例不在terminating缓存中, controller重启时, 更新了实例版本, 首次监听到了老的版本实例, 这时应该直接将实例删除 + log.GetLogger().Infof("handling scheduler instance %s update", schedulerSpec.InstanceID) + if schedulerSpec.InstanceSpecificationMeta.InstanceStatus.Code == int(commonconstant.KernelInstanceStatusExiting) { + log.GetLogger().Infof("scheduler instance %s is exiting,no need to update", schedulerSpec.InstanceID) + return + } + config.SchedulerConfigLock.RLock() + signature := controllerutils.GetSchedulerConfigSignature(config.GetFaaSSchedulerConfig()) + config.SchedulerConfigLock.RUnlock() + if isExpectInstance(&schedulerSpec.InstanceSpecificationMeta, signature) { + s.Lock() + currentNum := len(s.instanceCache) + _, exist := s.instanceCache[schedulerSpec.InstanceID] + if currentNum > s.count || (currentNum == s.count && !exist) { + log.GetLogger().Infof("current scheduler num is %d, kill the new instance %s for %s with cache: %v", + currentNum, schedulerSpec.InstanceID, s.tenantID, s.instanceCache) + delete(s.instanceCache, schedulerSpec.InstanceID) + s.Unlock() + if err := s.sdkClient.Kill(schedulerSpec.InstanceID, types.KillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to kill instance %s error:%s", schedulerSpec.InstanceID, err.Error()) + } + return + } + // add instance to cache if not exist, otherwise update the instance + if !exist { + log.GetLogger().Infof("instance %s has been added to cache", schedulerSpec.InstanceID) + s.instanceCache[schedulerSpec.InstanceID] = schedulerSpec + state.Update(schedulerSpec.InstanceID, types.StateUpdate, types.FaasSchedulerInstanceState+s.tenantID) + s.Unlock() + return + } + s.instanceCache[schedulerSpec.InstanceID].InstanceSpecificationMeta = schedulerSpec.InstanceSpecificationMeta + log.GetLogger().Infof("scheduler instance %s is updated, refresh instance cache", + schedulerSpec.InstanceID) + s.Unlock() + return + } + s.RLock() + _, exist := s.terminalCache[schedulerSpec.InstanceID] + s.RUnlock() + if !exist { + if err := s.KillInstance(schedulerSpec.InstanceID); err != nil { + log.GetLogger().Errorf("failed to kill instance %s error:%s", schedulerSpec.InstanceID, err.Error()) + } + } +} + +// HandleInstanceDelete handles function delete +func (s *SchedulerManager) HandleInstanceDelete(schedulerSpec *types.InstanceSpecification) { + // Here will depend on the kernel recover capability, not need to maintain the number of instances + log.GetLogger().Infof("handling scheduler instance %s delete", schedulerSpec.InstanceID) + config.SchedulerConfigLock.RLock() + signature := controllerutils.GetSchedulerConfigSignature(config.GetFaaSSchedulerConfig()) + config.SchedulerConfigLock.RUnlock() + s.Lock() + delete(s.instanceCache, schedulerSpec.InstanceID) + s.Unlock() + if isExpectInstance(&schedulerSpec.InstanceSpecificationMeta, signature) { + s.RLock() + if len(s.instanceCache) < s.count { + log.GetLogger().Infof("current faasscheduler instance num is %d, need to recreate instance: %s", + len(s.instanceCache), schedulerSpec.InstanceID) + go func() { + err := s.SyncCreateInstanceByID(nil, schedulerSpec.InstanceID) + if err != nil { + log.GetLogger().Errorf("failed to create instance: %s, err:%v", schedulerSpec.InstanceID, + err) + } + }() + s.RUnlock() + return + } else { + state.Update(schedulerSpec.InstanceID, types.StateDelete, types.FaasSchedulerInstanceState+s.tenantID) + } + s.RUnlock() + } + cancel, exist := s.recreateInstanceIDMap.Load(schedulerSpec.InstanceID) + if exist { + if cancelFunc, ok := cancel.(context.CancelFunc); ok { + cancelFunc() + log.GetLogger().Infof("instance %s bring up has been canceled", schedulerSpec.InstanceID) + return + } + log.GetLogger().Errorf("get cancel func failed from instanceIDMap, instanceID:%s", + schedulerSpec.InstanceID) + } +} + +// RollingUpdate rolling update +func (s *SchedulerManager) RollingUpdate(ctx context.Context, event *types.ConfigChangeEvent) { + // 1. 更新 预期实例个数 + // 2. 把不符合预期的instanceCache -> terminalCache + // 3. 从terminalCache随机删除一个实例 同步 + // 4. 创建新实例 同步 + // 5. terminalCache为空时将实例数调谐为预期实例数(同步), instanceCache到达预期数量时清空terminalCache(异步) + // !!!scheduler启动时需要等待hash环上的所有scheduler全部启动才能初始化完成,所以当正在运行的scheduler数和预期拉起的数量不符时,需要全量升级 + newSign := controllerutils.GetSchedulerConfigSignature(event.SchedulerCfg) + s.Lock() + runningNum := 0 + s.count = event.SchedulerCfg.SchedulerNum + // exclusivity tenant only need 1 scheduler instance + if s.tenantID != "" { + s.count = 1 + } + for _, ins := range s.instanceCache { + if ins.InstanceSpecificationMeta.InstanceStatus.Code == int(commonconstant.KernelInstanceStatusRunning) { + runningNum++ + } + if !isExpectInstance(&ins.InstanceSpecificationMeta, newSign) { + log.GetLogger().Infof("instance:%s is waiting for termination", ins.InstanceID) + s.terminalCache[ins.InstanceID] = ins + delete(s.instanceCache, ins.InstanceID) + } + } + s.Unlock() + if runningNum != s.count { + log.GetLogger().Infof("running scheduler:%d, expected:%d, need full update", runningNum, s.count) + err := s.fullUpdate(ctx) + event.Error = err + event.Done() + return + } + for { + select { + case <-ctx.Done(): + event.Error = fmt.Errorf("rolling update has stopped") + event.Done() + return + default: + } + s.RLock() + if len(s.instanceCache) == s.count { + s.RUnlock() + log.GetLogger().Infof("instance count arrive at expectation:%d, delete all terminating instance", + s.count) + go s.killAllTerminalInstance() + event.Done() + return + } + if len(s.terminalCache) == 0 { + s.RUnlock() + log.GetLogger().Infof("no instance need to terminate, pull up missing instance count:%d", + s.count-len(s.instanceCache)) + err := s.CreateExpectedInstanceCount(ctx) + if err != nil { + event.Error = err + } + event.Done() + return + } + insID := "" + for _, ins := range s.terminalCache { + insID = ins.InstanceID + break + } + s.RUnlock() + var err error + if err = s.sdkClient.Kill(insID, types.SyncKillSignalVal, []byte{}); err != nil { + log.GetLogger().Errorf("failed to sync kill scheduler instance(id=%s), err:%v", insID, err) + } + s.Lock() + delete(s.terminalCache, insID) + s.Unlock() + time.Sleep(constant.RecreateSleepTime) + err = s.SyncCreateInstanceByID(ctx, insID) + if err != nil { + event.Error = err + event.Done() + return + } + } +} + +func (s *SchedulerManager) killAllTerminalInstance() { + s.Lock() + insMap := s.terminalCache + s.terminalCache = map[string]*types.InstanceSpecification{} + s.Unlock() + var wg sync.WaitGroup + for _, ins := range insMap { + insID := ins.InstanceID + wg.Add(1) + go func() { + defer wg.Done() + err := s.KillInstance(insID) + if err != nil { + log.GetLogger().Errorf("Failed to kill instance %v,err: %v", insID, err) + } + }() + } + wg.Wait() +} + +func (s *SchedulerManager) configChangeProcessor(ctx context.Context, cancel context.CancelFunc) { + if ctx == nil || cancel == nil || s.ConfigChangeCh == nil { + return + } + for { + select { + case cfgEvent, ok := <-s.ConfigChangeCh: + if !ok { + cancel() + return + } + schedulerConfig, schedulerInsNum := s.ConfigDiff(cfgEvent) + if schedulerConfig != nil || schedulerInsNum != -1 { + log.GetLogger().Infof("scheduler config or instance num is changed," + + " need to update scheduler instance") + cancel() + ctx, cancel = context.WithCancel(context.Background()) + config.UpdateSchedulerConfig(cfgEvent.SchedulerCfg) + cfgEvent.Add(1) + go s.RollingUpdate(ctx, cfgEvent) + } else { + log.GetLogger().Infof("scheduler config is same as current, no need to update") + } + cfgEvent.Done() + } + } +} + +// ConfigDiff config diff +func (s *SchedulerManager) ConfigDiff(event *types.ConfigChangeEvent) (*types.SchedulerConfig, int) { + newSign := controllerutils.GetSchedulerConfigSignature(event.SchedulerCfg) + config.SchedulerConfigLock.RLock() + schedulerOldCfg := config.GetFaaSSchedulerConfig() + config.SchedulerConfigLock.RUnlock() + if strings.Compare(newSign, + controllerutils.GetSchedulerConfigSignature(schedulerOldCfg)) == 0 { + if event.SchedulerCfg.SchedulerNum != schedulerOldCfg.SchedulerNum { + config.SchedulerConfigLock.Lock() + schedulerOldCfg.SchedulerNum = event.SchedulerCfg.SchedulerNum + config.SchedulerConfigLock.Unlock() + return nil, event.SchedulerCfg.SchedulerNum + } + return nil, -1 + } + return event.SchedulerCfg, event.SchedulerCfg.SchedulerNum +} + +func (s *SchedulerManager) fullUpdate(ctx context.Context) error { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s.killAllTerminalInstance() + }() + wg.Add(1) + go func() { + defer wg.Done() + s.SyncKillAllInstance() + }() + wg.Wait() + time.Sleep(constant.RecreateSleepTime) + return s.CreateExpectedInstanceCount(ctx) +} diff --git a/yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager/schedulermanager_test.go b/yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager/schedulermanager_test.go new file mode 100644 index 0000000..2fbb11f --- /dev/null +++ b/yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager/schedulermanager_test.go @@ -0,0 +1,780 @@ +package faasschedulermanager + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "k8s.io/apimachinery/pkg/util/wait" + "reflect" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/api/v3/mvccpb" + "go.etcd.io/etcd/client/v3" + "yuanrong.org/kernel/runtime/libruntime/api" + + commonconstant "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + commontype "yuanrong/pkg/common/faas_common/types" + "yuanrong/pkg/common/faas_common/utils" + mockUtils "yuanrong/pkg/common/faas_common/utils" + stypes "yuanrong/pkg/functionscaler/types" + fcConfig "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/state" + "yuanrong/pkg/system_function_controller/types" +) + +type KvMock struct { +} + +func (k *KvMock) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + response := &clientv3.GetResponse{} + response.Count = 10 + return response, nil +} + +func (k *KvMock) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, + error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Txn(ctx context.Context) clientv3.Txn { + // TODO implement me + panic("implement me") +} + +func initConfig(configString string) { + fcConfig.InitConfig([]byte(configString)) + routerEtcdConfig := etcd3.EtcdConfig{ + Servers: []string{"1.2.3.4:1234"}, + User: "tom", + Password: "**", + } + + metaEtcdConfig := etcd3.EtcdConfig{ + Servers: []string{"1.2.3.4:5678"}, + User: "tom", + Password: "**", + } + + schedulerBasicConfig := types.SchedulerConfig{ + Configuration: stypes.Configuration{ + CPU: 999, + Memory: 999, + AutoScaleConfig: stypes.AutoScaleConfig{ + SLAQuota: 1000, + ScaleDownTime: 60000, + BurstScaleNum: 1000, + }, + LeaseSpan: 600000, + RouterETCDConfig: routerEtcdConfig, + MetaETCDConfig: metaEtcdConfig, + }, + SchedulerNum: 10, + } + fcConfig.UpdateSchedulerConfig(&schedulerBasicConfig) +} + +func newFaaSSchedulerManager(size int) (*SchedulerManager, error) { + stopCh := make(chan struct{}) + initConfig(`{ + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + }, + "alarmConfig":{"enableAlarm": true} + } + `) + + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + manager := NewFaaSSchedulerManager(&mockUtils.FakeLibruntimeSdkClient{}, etcdClient, stopCh, size, "") + time.Sleep(50 * time.Millisecond) + manager.count = size + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.terminalCache = map[string]*types.InstanceSpecification{} + return manager, nil +} + +func newFaaSSchedulerManagerWithRetry(size int) (*SchedulerManager, error) { + stopCh := make(chan struct{}) + initConfig(`{ + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + }, + "alarmConfig":{"enableAlarm": true}, + "enableRetry": true + } + `) + + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + manager := NewFaaSSchedulerManager(&mockUtils.FakeLibruntimeSdkClient{}, etcdClient, stopCh, size, "") + time.Sleep(50 * time.Millisecond) + manager.count = size + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.terminalCache = map[string]*types.InstanceSpecification{} + return manager, nil +} + +func TestNewInstanceManager(t *testing.T) { + Convey("Test NewInstanceManager", t, func() { + Convey("Test NewInstanceManager with correct size", func() { + got, err := newFaaSSchedulerManager(16) + So(err, ShouldBeNil) + So(got, ShouldNotBeNil) + }) + }) +} + +func Test_initInstanceCache(t *testing.T) { + Convey("initInstanceCache", t, func() { + manager, err := newFaaSSchedulerManager(1) + So(err, ShouldBeNil) + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + defer ApplyMethod(reflect.TypeOf(kv), "Get", + func(k *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + bytes, _ := json.Marshal(fcConfig.GetFaaSSchedulerConfig()) + meta := types.InstanceSpecificationMeta{Function: "test-function", InstanceID: "123", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}} + marshal, _ := json.Marshal(meta) + kvs := []*mvccpb.KeyValue{{Value: marshal}} + response := &clientv3.GetResponse{Kvs: kvs} + response.Count = 1 + return response, nil + }).Reset() + manager.initInstanceCache(etcdClient) + cache := manager.GetInstanceCache() + So(cache["123"], ShouldNotBeNil) + close(manager.stopCh) + }) +} + +func TestInstanceManager_CreateMultiInstances(t *testing.T) { + defer ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + Convey("Test CreateMultiInstances", t, func() { + Convey("Test CreateMultiInstances with retry", func() { + p := ApplyFunc((*mockUtils.FakeLibruntimeSdkClient).CreateInstance, func(_ *mockUtils.FakeLibruntimeSdkClient, _ api.FunctionMeta, _ []api.Arg, _ api.InvokeOptions) (string, error) { + time.Sleep(100 * time.Millisecond) + return "", api.ErrorInfo{ + Code: 10, + Err: fmt.Errorf("xxxxx"), + StackTracesInfo: api.StackTracesInfo{}, + } + }) + p2 := ApplyFunc(wait.ExponentialBackoffWithContext, func(ctx context.Context, backoff wait.Backoff, condition wait.ConditionWithContextFunc) error { + _, err := condition(ctx) + return err + }) + manager, err := newFaaSSchedulerManagerWithRetry(1) + So(err, ShouldBeNil) + ctx, cancel := context.WithCancel(context.TODO()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + err = manager.CreateMultiInstances(ctx, 1) + So(err, ShouldBeError) + + err = manager.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldBeNil) + p.Reset() + p2.Reset() + }) + + instanceMgr, err := newFaaSSchedulerManager(3) + So(err, ShouldBeNil) + defer ApplyMethod(reflect.TypeOf(instanceMgr), "CreateWithRetry", + func(ffm *SchedulerManager, ctx context.Context, args []api.Arg, + extraParams *commontype.ExtraParams) error { + if value := ctx.Value("err"); value != nil { + if s := value.(string); s == "canceled" { + return fmt.Errorf("create has been cancelled") + } + } + return nil + }).Reset() + Convey("Test CreateMultiInstances when passed different count", func() { + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldBeNil) + + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldBeNil) + }) + + Convey("Test CreateMultiInstances when failed to get scheduler config", func() { + patches := ApplyFunc(json.Marshal, func(_ interface{}) ([]byte, error) { + return nil, errors.New("json Marshal failed") + }) + defer patches.Reset() + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldNotBeNil) + }) + + Convey("Test CreateMultiInstances when failed to create instance", func() { + patches := ApplyMethod(reflect.TypeOf(instanceMgr), "CreateInstance", + func(_ *SchedulerManager, ctx context.Context, function string, args []api.Arg, + extraParams *commontype.ExtraParams) string { + return "" + }) + defer patches.Reset() + err = instanceMgr.CreateMultiInstances(context.TODO(), 1) + So(err, ShouldNotBeNil) + }) + }) +} + +func TestCreateWithRetry(t *testing.T) { + Convey("CreateWithRetry", t, func() { + Convey("retry", func() { + instanceMgr, err := newFaaSSchedulerManager(1) + So(err, ShouldBeNil) + + defer ApplyMethod(reflect.TypeOf(instanceMgr), "CreateInstance", + func(s *SchedulerManager, ctx context.Context, function string, args []api.Arg, + extraParams *commontype.ExtraParams) string { + select { + case <-ctx.Done(): + return "cancelled" + default: + return "" + } + }).Reset() + args, extraParams, _ := genFunctionConfig("") + ctx, cancelFunc := context.WithCancel(context.TODO()) + go func() { + time.Sleep(2 * time.Second) + cancelFunc() + }() + err = instanceMgr.CreateWithRetry(ctx, args, extraParams) + So(err, ShouldBeError) + }) + }) +} + +func TestInstanceManager_SyncKillAllInstance(t *testing.T) { + Convey("KillAllInstance", t, func() { + Convey("success", func() { + instanceMgr, err := newFaaSSchedulerManager(1) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&mockUtils.FakeLibruntimeSdkClient{}), "Kill", + func(_ *mockUtils.FakeLibruntimeSdkClient, instanceID string, signal int, payload []byte) error { + return nil + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + + instanceMgr.instanceCache = map[string]*types.InstanceSpecification{ + "123": &types.InstanceSpecification{}, + } + instanceMgr.SyncKillAllInstance() + So(len(instanceMgr.instanceCache), ShouldEqual, 0) + _, exist := instanceMgr.instanceCache["123"] + So(exist, ShouldEqual, false) + }) + }) +} + +func TestInstanceManager_GetInstanceCountFromEtcd(t *testing.T) { + Convey("GetInstanceCountFromEtcd", t, func() { + Convey("failed", func() { + instanceMgr, err := newFaaSSchedulerManager(1) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&KvMock{}), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, + error) { + return nil, errors.New("get etcd error") + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + count := instanceMgr.GetInstanceCountFromEtcd() + So(len(count), ShouldEqual, 0) + }) + }) + +} + +func TestInstanceManager_CreateExpectedInstanceCount(t *testing.T) { + Convey("CreateExpectedInstanceCount", t, func() { + Convey("success", func() { + instanceMgr, err := newFaaSSchedulerManager(2) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyMethod(reflect.TypeOf(&KvMock{}), "Get", + func(_ *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, + error) { + response := &clientv3.GetResponse{} + response.Count = 2 + response.Kvs = []*mvccpb.KeyValue{ + { + Key: []byte("/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasscheduler/version/$latest/defaultaz/task-66ccf050-50f6-4835-aa24-c1b15dddb00e/12996c08-0000-4000-8000-db6c3db0fcbf"), + }, + { + Key: []byte("/sn/instance/business/yrk/tenant/12345678901234561234567890123456/function/0-system-faasscheduler/version/$latest/defaultaz/task-66ccf050-50f6-4835-aa24-c1b15dddb00e/12996c08-0000-4000-8000-db6c3db0fcb2"), + }, + } + return response, nil + }), + ApplyFunc(state.GetState, func() state.ControllerState { + controllerState := state.ControllerState{FaasInstance: make(map[string]map[string]struct{})} + controllerState.FaasInstance[types.FaasSchedulerInstanceState] = map[string]struct{}{"test-2": {}} + return controllerState + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + count := instanceMgr.GetInstanceCountFromEtcd() + So(len(count), ShouldEqual, 2) + + schedulerConfig := fcConfig.GetFaaSSchedulerConfig() + bytes, _ := json.Marshal(schedulerConfig) + instanceMgr.instanceCache["test-1"] = &types.InstanceSpecification{ + FuncCtx: nil, + CancelFunc: nil, + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}, + }, + } + err = instanceMgr.CreateExpectedInstanceCount(context.TODO()) + So(err, ShouldBeNil) + }) + }) +} + +func TestInstanceManager_RecoverInstance(t *testing.T) { + Convey("RecoverInstance", t, func() { + Convey("create failed", func() { + instanceMgr, err := newFaaSSchedulerManager(1) + So(err, ShouldBeNil) + patches := []*Patches{ + ApplyFunc((*SchedulerManager).KillInstance, func(_ *SchedulerManager, _ string) error { + return nil + }), + ApplyFunc((*SchedulerManager).CreateMultiInstances, + func(_ *SchedulerManager, ctx context.Context, _ int) error { + return errors.New("failed to create instances") + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + + instanceMgr.RecoverInstance(&types.InstanceSpecification{}) + }) + }) +} + +func TestInstanceManager_HandleInstanceUpdate(t *testing.T) { + defer ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + instanceID := "1" + notFoundInstanceID := "2" + instanceMgr, err := newFaaSSchedulerManager(2) + assert.Nil(t, err) + instanceMgr.addInstance(instanceID) + schedulerConfig := fcConfig.GetFaaSSchedulerConfig() + bytes, _ := json.Marshal(schedulerConfig) + Convey("Test HandleInstanceEvent", t, func() { + Convey("Test HandleInstanceEvent when instance exists", func() { + specification := &types.InstanceSpecification{ + FuncCtx: nil, + CancelFunc: nil, + InstanceID: instanceID, + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}, + }, + } + instanceMgr.HandleInstanceUpdate(specification) + So(reflect.DeepEqual(instanceMgr.instanceCache[instanceID], specification), ShouldBeTrue) + specification.InstanceSpecificationMeta.Function = "testFunction" + instanceMgr.HandleInstanceUpdate(specification) + So(reflect.DeepEqual(instanceMgr.instanceCache[instanceID], specification), ShouldBeTrue) + }) + + Convey("Test HandleInstanceEvent when instance not exists", func() { + specification := &types.InstanceSpecification{ + FuncCtx: nil, + CancelFunc: nil, + InstanceID: notFoundInstanceID, + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}, + }, + } + instanceMgr.HandleInstanceUpdate(specification) + So(reflect.DeepEqual(instanceMgr.instanceCache[notFoundInstanceID], specification), ShouldBeTrue) + }) + + Convey("Test HandleInstanceDelete when create extra instance", func() { + instanceMgr, err = newFaaSSchedulerManager(1) + assert.Nil(t, err) + defer ApplyMethod(reflect.TypeOf(instanceMgr.etcdClient.Client.KV), "Get", + func(k *KvMock, ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, + error) { + bytes, _ := json.Marshal(fcConfig.GetFaaSSchedulerConfig()) + meta := types.InstanceSpecificationMeta{Function: "test-function", InstanceID: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}} + marshal, _ := json.Marshal(meta) + kvs := []*mvccpb.KeyValue{{Value: marshal}} + response := &clientv3.GetResponse{Kvs: kvs} + response.Count = 1 + return response, nil + }).Reset() + instanceMgr.initInstanceCache(instanceMgr.etcdClient) + So(instanceMgr.instanceCache["test-function"], ShouldNotBeNil) + specification := &types.InstanceSpecification{ + FuncCtx: nil, + CancelFunc: nil, + InstanceID: instanceID, + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}, + }, + } + instanceMgr.HandleInstanceUpdate(specification) + So(instanceMgr.instanceCache[instanceID], ShouldBeNil) + }) + + Convey("Test HandleInstanceUpdate when config change", func() { + instanceMgr, err = newFaaSSchedulerManager(1) + assert.Nil(t, err) + cfg := &types.SchedulerConfig{} + utils.DeepCopyObj(schedulerConfig, cfg) + cfg.CPU = 1000 + marshal, _ := json.Marshal(cfg) + specification := &types.InstanceSpecification{ + FuncCtx: nil, + CancelFunc: nil, + InstanceID: instanceID, + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(marshal)}}, + }, + } + instanceMgr.HandleInstanceUpdate(specification) + So(instanceMgr.instanceCache[instanceID], ShouldBeNil) + }) + }) +} + +func TestInstanceManager_HandleInstanceDelete(t *testing.T) { + defer ApplyFunc(state.Update, func(value interface{}, tags ...string) { + }).Reset() + Convey("Test HandleInstanceDelete", t, func() { + instanceID := "1" + notFoundInstanceID := "2" + instanceMgr, err := newFaaSSchedulerManager(2) + assert.Nil(t, err) + schedulerConfig := fcConfig.GetFaaSSchedulerConfig() + bytes, _ := json.Marshal(schedulerConfig) + Convey("Test HandleInstanceDelete when instance exist", func() { + instanceMgr.addInstance(instanceID) + specification := &types.InstanceSpecification{ + FuncCtx: nil, + CancelFunc: nil, + InstanceID: instanceID, + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}, + }, + } + instanceMgr.HandleInstanceDelete(specification) + So(len(instanceMgr.instanceCache), ShouldEqual, 0) + }) + + Convey("Test HandleInstanceDelete when instance not exist", func() { + specification := &types.InstanceSpecification{ + FuncCtx: nil, + CancelFunc: nil, + InstanceID: notFoundInstanceID, + InstanceSpecificationMeta: types.InstanceSpecificationMeta{}, + } + instanceMgr.HandleInstanceDelete(specification) + So(len(instanceMgr.instanceCache), ShouldEqual, 0) + }) + + Convey("Test HandleInstanceDelete when config change", func() { + instanceMgr, err = newFaaSSchedulerManager(1) + assert.Nil(t, err) + cfg := &types.SchedulerConfig{} + utils.DeepCopyObj(schedulerConfig, cfg) + cfg.CPU = 1000 + marshal, _ := json.Marshal(cfg) + specification := &types.InstanceSpecification{ + FuncCtx: nil, + CancelFunc: nil, + InstanceID: instanceID, + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(marshal)}}, + }, + } + instanceMgr.instanceCache[instanceID] = &types.InstanceSpecification{} + ctx, cancelFunc := context.WithCancel(context.TODO()) + instanceMgr.recreateInstanceIDMap.Store(instanceID, cancelFunc) + instanceMgr.HandleInstanceDelete(specification) + <-ctx.Done() + }) + }) +} + +func TestKillExceptInstance(t *testing.T) { + Convey("KillExceptInstance", t, func() { + instanceMgr, err := newFaaSSchedulerManager(1) + So(err, ShouldBeNil) + instanceMgr.instanceCache["testID"] = &types.InstanceSpecification{} + err = instanceMgr.KillExceptInstance(1) + So(err, ShouldBeNil) + }) +} + +func TestRollingUpdate(t *testing.T) { + Convey("test RollingUpdate", t, func() { + manager, err := newFaaSSchedulerManager(2) + So(err, ShouldBeNil) + Convey("same config", func() { + schedulerConfig := fcConfig.GetFaaSSchedulerConfig() + bytes, _ := json.Marshal(schedulerConfig) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + cfgEvent := &types.ConfigChangeEvent{ + SchedulerCfg: fcConfig.GetFaaSSchedulerConfig(), + TraceID: "traceID-123456789", + } + cfgEvent.SchedulerCfg.SchedulerNum = 2 + cfgEvent.Add(1) + manager.ConfigChangeCh <- cfgEvent + cfgEvent.Wait() + So(len(manager.instanceCache), ShouldEqual, cfgEvent.SchedulerCfg.SchedulerNum) + So(manager.instanceCache["test-1"], ShouldNotBeNil) + So(manager.instanceCache["test-2"], ShouldNotBeNil) + }) + + Convey("different config", func() { + schedulerConfig := fcConfig.GetFaaSSchedulerConfig() + bytes, _ := json.Marshal(schedulerConfig) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}, + InstanceStatus: types.InstanceStatus{Code: int(commonconstant.KernelInstanceStatusRunning)}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}, + InstanceStatus: types.InstanceStatus{Code: int(commonconstant.KernelInstanceStatusRunning)}, + }} + cfg := &types.SchedulerConfig{} + utils.DeepCopyObj(fcConfig.GetFaaSSchedulerConfig(), cfg) + cfgEvent := &types.ConfigChangeEvent{ + SchedulerCfg: cfg, + TraceID: "traceID-123456789", + } + cfgEvent.SchedulerCfg.CPU = 888 + cfgEvent.SchedulerCfg.SchedulerNum = 2 + cfgEvent.Add(1) + manager.ConfigChangeCh <- cfgEvent + cfgEvent.Wait() + time.Sleep(2 * time.Second) + close(manager.ConfigChangeCh) + So(len(manager.terminalCache), ShouldEqual, 0) + So(len(manager.instanceCache), ShouldEqual, cfgEvent.SchedulerCfg.SchedulerNum) + }) + + Convey("killAllTerminalInstance", func() { + schedulerConfig := fcConfig.GetFaaSSchedulerConfig() + bytes, _ := json.Marshal(schedulerConfig) + manager.instanceCache = make(map[string]*types.InstanceSpecification) + manager.instanceCache["test-1"] = &types.InstanceSpecification{ + InstanceID: "test-1", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + manager.instanceCache["test-2"] = &types.InstanceSpecification{ + InstanceID: "test-2", + InstanceSpecificationMeta: types.InstanceSpecificationMeta{Function: "test-function", + Args: []map[string]string{{"value": base64.StdEncoding.EncodeToString(bytes)}}}} + cfgEvent := &types.ConfigChangeEvent{ + SchedulerCfg: fcConfig.GetFaaSSchedulerConfig(), + TraceID: "traceID-123456789", + } + cfgEvent.SchedulerCfg.SchedulerNum = 2 + cfgEvent.Add(1) + manager.RollingUpdate(context.TODO(), cfgEvent) + cfgEvent.Wait() + So(len(manager.instanceCache), ShouldEqual, 2) + + cfgEvent.SchedulerCfg.SchedulerNum = 2 + cfgEvent.Add(1) + cancel, cancelFunc := context.WithCancel(context.TODO()) + cancelFunc() + manager.RollingUpdate(cancel, cfgEvent) + cfgEvent.Wait() + So(len(manager.instanceCache), ShouldEqual, 0) + + cfgEvent.SchedulerCfg.SchedulerNum = 0 + cfgEvent.Add(1) + manager.RollingUpdate(context.TODO(), cfgEvent) + cfgEvent.Wait() + So(len(manager.instanceCache), ShouldEqual, 0) + }) + }) +} + +func TestConfigDiff(t *testing.T) { + Convey("ConfigDiff", t, func() { + manager, err := newFaaSSchedulerManager(2) + So(err, ShouldBeNil) + Convey("same config ,different num", func() { + cfg := &types.SchedulerConfig{} + utils.DeepCopyObj(fcConfig.GetFaaSSchedulerConfig(), cfg) + cfgEvent := &types.ConfigChangeEvent{ + SchedulerCfg: cfg, + TraceID: "traceID-123456789", + } + cfgEvent.SchedulerCfg.SchedulerNum = 100 + _, num := manager.ConfigDiff(cfgEvent) + So(num, ShouldEqual, 100) + }) + }) +} + +func TestSchedulerManager_KillInstance(t *testing.T) { + Convey("kill instance test", t, func() { + Convey("baseline", func() { + manager, err := newFaaSSchedulerManager(2) + So(err, ShouldBeNil) + i := 0 + p := ApplyFunc((*mockUtils.FakeLibruntimeSdkClient).Kill, + func(_ *mockUtils.FakeLibruntimeSdkClient, instanceID string, _ int, _ []byte) error { + if i == 0 { + i = 1 + return fmt.Errorf("error") + } + return nil + }) + defer p.Reset() + err = manager.KillInstance("aaa") + So(err, ShouldBeNil) + }) + }) +} + +func Test_createExtraParams_for_InstanceLifeCycle(t *testing.T) { + Convey("createExtraParams for InstanceLifeCycle", t, func() { + conf := &types.SchedulerConfig{ + Configuration: stypes.Configuration{ + CPU: 1332, + Memory: 324134, + }, + SchedulerNum: 0, + } + newFaaSSchedulerManager(2) + params, err := createExtraParams(conf, "") + So(err, ShouldBeNil) + So(params.CreateOpt[commonconstant.InstanceLifeCycle], ShouldEqual, commonconstant.InstanceLifeCycleDetached) + }) +} + +func Test_prepareCreateOptions_for_NodeAffinity(t *testing.T) { + Convey("testInstanceNodeAffinity", t, func() { + tt := []struct { + name string + nodeAffinity string + nodeAffinityPolicy string + }{ + { + name: "case1", + nodeAffinity: "", + nodeAffinityPolicy: "", + }, + { + name: "case2", + nodeAffinity: "{\"requiredDuringSchedulingIgnoredDuringExecution\":{\"nodeSelectorTerms\":[{\"matchExpressions\":[{\"key\":\"node-role\",\"operator\":\"In\",\"values\":[\"edge\"]}]}]}}", + nodeAffinityPolicy: commonconstant.DelegateNodeAffinityPolicyCoverage, + }, + { + name: "case3", + nodeAffinity: "{\"requiredDuringSchedulingIgnoredDuringExecution\":{\"nodeSelectorTerms\":[{\"matchExpressions\":[{\"key\":\"node-role\",\"operator\":\"In\",\"values\":[\"edge-tagw\"]}]}]}}", + nodeAffinityPolicy: commonconstant.DelegateNodeAffinityPolicyCoverage, + }, + } + conf := &types.SchedulerConfig{ + Configuration: stypes.Configuration{ + CPU: 1332, + Memory: 324134, + }, + SchedulerNum: 0, + } + for _, ttt := range tt { + newFaaSSchedulerManager(2) + conf.NodeAffinityPolicy = ttt.nodeAffinityPolicy + conf.NodeAffinity = ttt.nodeAffinity + params, err := createExtraParams(conf, "") + So(err, ShouldBeNil) + So(params.CreateOpt[commonconstant.InstanceLifeCycle], ShouldEqual, + commonconstant.InstanceLifeCycleDetached) + So(params.CreateOpt[commonconstant.DelegateNodeAffinityPolicy], ShouldEqual, ttt.nodeAffinityPolicy) + So(params.CreateOpt[commonconstant.DelegateNodeAffinity], ShouldEqual, ttt.nodeAffinity) + } + }) +} diff --git a/yuanrong/pkg/system_function_controller/instancemanager/instancemanager.go b/yuanrong/pkg/system_function_controller/instancemanager/instancemanager.go new file mode 100644 index 0000000..5cc0fbc --- /dev/null +++ b/yuanrong/pkg/system_function_controller/instancemanager/instancemanager.go @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package instancemanager - +package instancemanager + +import ( + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/system_function_controller/constant" + "yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager" + "yuanrong/pkg/system_function_controller/types" +) + +// InstanceManager define InstanceManager management, which manage all system function instance +type InstanceManager struct { + CommonSchedulerManager *faasschedulermanager.SchedulerManager + ExclusivitySchedulerManagers map[string]*faasschedulermanager.SchedulerManager + FrontendManager *faasfrontendmanager.FrontendManager + FunctionManager *faasfunctionmanager.FunctionManager +} + +// HandleEventUpdate handle update event +func (im *InstanceManager) HandleEventUpdate(functionSpec *types.InstanceSpecification, kind types.EventKind) error { + switch kind { + case types.EventKindFrontend: + im.FrontendManager.HandleInstanceUpdate(functionSpec) + case types.EventKindScheduler: + tenantID := "" + if functionSpec.InstanceSpecificationMeta.CreateOptions != nil { + tenantID = functionSpec.InstanceSpecificationMeta.CreateOptions[constant.SchedulerExclusivity] + } + if tenantID == "" { + im.CommonSchedulerManager.HandleInstanceUpdate(functionSpec) + } else { + schedulerManager, ok := im.ExclusivitySchedulerManagers[tenantID] + if ok && schedulerManager != nil { + schedulerManager.HandleInstanceUpdate(functionSpec) + } else { + // delete scheduler when it has no scheduler manager + im.CommonSchedulerManager.KillInstance(functionSpec.InstanceID) + } + } + case types.EventKindManager: + im.FunctionManager.HandleInstanceUpdate(functionSpec) + default: + log.GetLogger().Errorf("unknown event kind: %s", kind) + } + return nil +} + +// HandleEventDelete handle delete event +func (im *InstanceManager) HandleEventDelete(functionSpec *types.InstanceSpecification, kind types.EventKind) error { + switch kind { + case types.EventKindFrontend: + im.FrontendManager.HandleInstanceDelete(functionSpec) + case types.EventKindScheduler: + tenantID := "" + if functionSpec.InstanceSpecificationMeta.CreateOptions != nil { + tenantID = functionSpec.InstanceSpecificationMeta.CreateOptions[constant.SchedulerExclusivity] + } + if tenantID == "" { + im.CommonSchedulerManager.HandleInstanceDelete(functionSpec) + } else { + schedulerManager, ok := im.ExclusivitySchedulerManagers[tenantID] + if ok && schedulerManager != nil { + schedulerManager.HandleInstanceDelete(functionSpec) + } + } + case types.EventKindManager: + im.FunctionManager.HandleInstanceDelete(functionSpec) + default: + log.GetLogger().Errorf("unknown event kind: %s", kind) + } + return nil +} + +// HandleEventRecover handle recover event +func (im *InstanceManager) HandleEventRecover(functionSpec *types.InstanceSpecification, kind types.EventKind) error { + switch kind { + case types.EventKindFrontend: + im.FrontendManager.RecoverInstance(functionSpec) + case types.EventKindScheduler: + tenantID := "" + if functionSpec.InstanceSpecificationMeta.CreateOptions != nil { + tenantID = functionSpec.InstanceSpecificationMeta.CreateOptions[constant.SchedulerExclusivity] + } + if tenantID == "" { + im.CommonSchedulerManager.RecoverInstance(functionSpec) + } else { + schedulerManager, ok := im.ExclusivitySchedulerManagers[tenantID] + if ok && schedulerManager != nil { + schedulerManager.RecoverInstance(functionSpec) + } + } + case types.EventKindManager: + im.FunctionManager.RecoverInstance(functionSpec) + default: + log.GetLogger().Errorf("unknown event kind: %s", kind) + } + return nil +} diff --git a/yuanrong/pkg/system_function_controller/instancemanager/instancemanager_test.go b/yuanrong/pkg/system_function_controller/instancemanager/instancemanager_test.go new file mode 100644 index 0000000..c125c92 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/instancemanager/instancemanager_test.go @@ -0,0 +1,337 @@ +package instancemanager + +import ( + "context" + "reflect" + "testing" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + "yuanrong.org/kernel/runtime/libruntime/api" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/constant" + "yuanrong/pkg/system_function_controller/instancemanager/faasfrontendmanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasfunctionmanager" + "yuanrong/pkg/system_function_controller/instancemanager/faasschedulermanager" + "yuanrong/pkg/system_function_controller/types" +) + +type KvMock struct { +} + +func (k *KvMock) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + response := &clientv3.GetResponse{} + response.Count = 10 + return response, nil +} + +func (k *KvMock) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + // TODO implement me + panic("implement me") +} + +func (k *KvMock) Txn(ctx context.Context) clientv3.Txn { + // TODO implement me + panic("implement me") +} + +func newInstanceManagerHelper(sdkClient api.LibruntimeAPI, stopCh chan struct{}, + frontendNum, schedulerNum int) (*InstanceManager, error) { + configString := `{ + "frontendInstanceNum": 100, + "schedulerInstanceNum": 100, + "faasschedulerConfig": { + "autoScaleConfig":{ + "slaQuota": 1000, + "scaleDownTime": 60000, + "burstScaleNum": 1000 + }, + "leaseSpan": 600000 + }, + "faasfrontendConfig": { + "slaQuota": 1000, + "functionCapability": 1, + "authenticationEnable": false, + "trafficLimitDisable": true, + "http": { + "resptimeout": 5, + "workerInstanceReadTimeOut": 5, + "maxRequestBodySize": 6 + } + }, + "routerEtcd": { + "servers": ["1.2.3.4:1234"], + "user": "tom", + "password": "**" + }, + "metaEtcd": { + "servers": ["1.2.3.4:5678"], + "user": "tom", + "password": "**" + } + } + ` + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&etcd3.EtcdInitParam{}), "InitClient", func(_ *etcd3.EtcdInitParam) error { + return nil + }), + gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + _ = config.InitConfig([]byte(configString)) + kv := &KvMock{} + client := &clientv3.Client{KV: kv} + etcdClient := &etcd3.EtcdClient{Client: client} + + return &InstanceManager{ + FrontendManager: faasfrontendmanager.NewFaaSFrontendManager(sdkClient, etcdClient, stopCh, 0, false), + FunctionManager: faasfunctionmanager.NewFaaSFunctionManager(sdkClient, etcdClient, stopCh, 0), + }, nil +} + +func TestHandleEventUpdate(t *testing.T) { + convey.Convey("HandleEventUpdate", t, func() { + mockTenantID := "mock-tenant-001" + manager, err := newInstanceManagerHelper(nil, make(chan struct{}), 1, 1) + convey.So(manager, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + manager.ExclusivitySchedulerManagers = map[string]*faasschedulermanager.SchedulerManager{ + mockTenantID: {}, + } + + convey.Convey("EventKindFrontend", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&faasfrontendmanager.FrontendManager{}), "HandleInstanceUpdate", + func(_ *faasfrontendmanager.FrontendManager, instanceSpec *types.InstanceSpecification) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + spec := &types.InstanceSpecification{} + manager.HandleEventUpdate(spec, types.EventKindFrontend) + }) + + convey.Convey("EventKindScheduler", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&faasschedulermanager.SchedulerManager{}), "HandleInstanceUpdate", + func(_ *faasschedulermanager.SchedulerManager, instanceSpec *types.InstanceSpecification) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + spec := &types.InstanceSpecification{} + manager.HandleEventUpdate(spec, types.EventKindScheduler) + }) + + convey.Convey("EventKindSchedulerWithNotEmptyTenantID", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&faasschedulermanager.SchedulerManager{}), "HandleInstanceUpdate", + func(_ *faasschedulermanager.SchedulerManager, instanceSpec *types.InstanceSpecification) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + spec := &types.InstanceSpecification{ + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + CreateOptions: map[string]string{ + constant.SchedulerExclusivity: mockTenantID, + }, + }, + } + manager.HandleEventUpdate(spec, types.EventKindScheduler) + }) + + convey.Convey("EventKindManager", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&faasfunctionmanager.FunctionManager{}), "HandleInstanceUpdate", + func(_ *faasfunctionmanager.FunctionManager, instanceSpec *types.InstanceSpecification) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + spec := &types.InstanceSpecification{} + manager.HandleEventUpdate(spec, types.EventKindManager) + }) + }) +} + +func TestHandleEventDelete(t *testing.T) { + convey.Convey("HandleEventDelete", t, func() { + mockTenantID := "mock-tenant-001" + manager, err := newInstanceManagerHelper(nil, make(chan struct{}), 1, 1) + convey.So(manager, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + manager.ExclusivitySchedulerManagers = map[string]*faasschedulermanager.SchedulerManager{ + mockTenantID: {}, + } + + convey.Convey("EventKindFrontend", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&faasfrontendmanager.FrontendManager{}), "HandleInstanceDelete", + func(_ *faasfrontendmanager.FrontendManager, instanceSpec *types.InstanceSpecification) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + spec := &types.InstanceSpecification{} + manager.HandleEventDelete(spec, types.EventKindFrontend) + }) + + convey.Convey("EventKindScheduler", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&faasschedulermanager.SchedulerManager{}), "HandleInstanceDelete", + func(_ *faasschedulermanager.SchedulerManager, instanceSpec *types.InstanceSpecification) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + spec := &types.InstanceSpecification{} + manager.HandleEventDelete(spec, types.EventKindScheduler) + }) + + convey.Convey("EventKindSchedulerWithNotEmptyTenantID", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&faasschedulermanager.SchedulerManager{}), "HandleInstanceDelete", + func(_ *faasschedulermanager.SchedulerManager, instanceSpec *types.InstanceSpecification) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + spec := &types.InstanceSpecification{ + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + CreateOptions: map[string]string{ + constant.SchedulerExclusivity: mockTenantID, + }, + }, + } + manager.HandleEventDelete(spec, types.EventKindScheduler) + }) + + convey.Convey("EventKindManager", func() { + patches := []*gomonkey.Patches{ + gomonkey.ApplyMethod(reflect.TypeOf(&faasfunctionmanager.FunctionManager{}), "HandleInstanceDelete", + func(_ *faasfunctionmanager.FunctionManager, instanceSpec *types.InstanceSpecification) { + return + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + spec := &types.InstanceSpecification{} + manager.HandleEventDelete(spec, types.EventKindManager) + }) + }) +} + +func TestHandleEventRecover(t *testing.T) { + convey.Convey("HandleEventRecover", t, func() { + mockTenantID := "mock-tenant-001" + manager, err := newInstanceManagerHelper(nil, make(chan struct{}), 1, 1) + convey.So(manager, convey.ShouldNotBeNil) + convey.So(err, convey.ShouldBeNil) + manager.ExclusivitySchedulerManagers = map[string]*faasschedulermanager.SchedulerManager{ + mockTenantID: {}, + } + + convey.Convey("EventKindFrontend", func() { + manager.FrontendManager = &faasfrontendmanager.FrontendManager{} + defer gomonkey.ApplyMethod(reflect.TypeOf(manager.FrontendManager), "KillInstance", + func(ffm *faasfrontendmanager.FrontendManager, instanceID string) error { + return nil + }).Reset() + spec := &types.InstanceSpecification{} + manager.HandleEventRecover(spec, types.EventKindFrontend) + }) + + convey.Convey("EventKindScheduler", func() { + manager.CommonSchedulerManager = &faasschedulermanager.SchedulerManager{} + defer gomonkey.ApplyMethod(reflect.TypeOf(manager.CommonSchedulerManager), "KillInstance", + func(ffm *faasschedulermanager.SchedulerManager, instanceID string) error { + return nil + }).Reset() + spec := &types.InstanceSpecification{} + manager.HandleEventRecover(spec, types.EventKindScheduler) + }) + + convey.Convey("EventKindSchedulerWithNotEmptyTenantID", func() { + manager.CommonSchedulerManager = &faasschedulermanager.SchedulerManager{} + defer gomonkey.ApplyMethod(reflect.TypeOf(manager.CommonSchedulerManager), "KillInstance", + func(ffm *faasschedulermanager.SchedulerManager, instanceID string) error { + return nil + }).Reset() + spec := &types.InstanceSpecification{ + InstanceSpecificationMeta: types.InstanceSpecificationMeta{ + CreateOptions: map[string]string{ + constant.SchedulerExclusivity: mockTenantID, + }, + }, + } + manager.HandleEventRecover(spec, types.EventKindScheduler) + }) + }) +} diff --git a/yuanrong/pkg/system_function_controller/registry/faasfrontendregistry.go b/yuanrong/pkg/system_function_controller/registry/faasfrontendregistry.go new file mode 100644 index 0000000..fbc8d28 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/registry/faasfrontendregistry.go @@ -0,0 +1,160 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "context" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/system_function_controller/types" + "yuanrong/pkg/system_function_controller/utils" +) + +// FaaSFrontendRegistry watche faasfrontend event of etcd +type FaaSFrontendRegistry struct { + watcher etcd3.Watcher + frontendSpecs map[string]*types.InstanceSpecification + subscriberChans []chan types.SubEvent + stopCh <-chan struct{} + sync.RWMutex +} + +// NewFrontendRegistry will create faaSFrontendRegistry +func NewFrontendRegistry(stopCh <-chan struct{}) *FaaSFrontendRegistry { + registry := &FaaSFrontendRegistry{ + frontendSpecs: make(map[string]*types.InstanceSpecification, defaultMapSize), + stopCh: stopCh, + } + return registry +} + +// InitWatcher init watcher +func (fr *FaaSFrontendRegistry) InitWatcher() { + fr.watcher = etcd3.NewEtcdWatcher(constant.InstancePathPrefix, + fr.watcherFilter, fr.watcherHandler, fr.stopCh, etcd3.GetRouterEtcdClient()) + fr.watcher.StartList() +} + +// RunWatcher will start etcd watch process +func (fr *FaaSFrontendRegistry) RunWatcher() { + go fr.watcher.StartWatch() +} + +func (fr *FaaSFrontendRegistry) getFunctionSpec(instanceID string) *types.InstanceSpecification { + fr.RLock() + schedulerSpec := fr.frontendSpecs[instanceID] + fr.RUnlock() + return schedulerSpec +} + +// The etcd key format to be filtered out +// /sn/instance/business/yrk/tenant/0/function/faasfrontend/version/$latest/defaultaz/frontend-xx.xx.xx.xx +func (fr *FaaSFrontendRegistry) watcherFilter(event *etcd3.Event) bool { + etcdKey := event.Key + items := strings.Split(etcdKey, constant.KeySeparator) + if len(items) != constant.ValidEtcdKeyLenForInstance || + items[constant.TenantIndexForInstance] != "tenant" || items[constant.FunctionIndexForInstance] != "function" { + return true + } + if !strings.Contains(items[constant.FunctionNameIndexForInstance], constant.FaasFrontendMark) { + return true + } + return false +} + +func (fr *FaaSFrontendRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling instance event type %s key %s", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("frontend registry ready to receive etcd kv") + return + } + instanceID := utils.ExtractInfoFromEtcdKey(event.Key, constant.InstanceIDIndexForInstance) + if len(instanceID) == 0 { + log.GetLogger().Errorf("ignoring invalid etcd key of key %s", event.Key) + return + } + fr.Lock() + defer fr.Unlock() + switch event.Type { + case etcd3.PUT, etcd3.HISTORYUPDATE: + faasFrontendSpec := utils.GetInstanceSpecFromEtcdValue(event.Value) + if faasFrontendSpec == nil { + log.GetLogger().Errorf("ignoring invalid etcd value of key %s", event.Key) + return + } + faasFrontendSpec.InstanceID = instanceID + targetFrontendSpec, needRecover := fr.createOrUpdateFrontendSpec(faasFrontendSpec) + fr.publishEvent(types.SubEventTypeUpdate, targetFrontendSpec) + if needRecover { + fr.publishEvent(types.SubEventTypeRecover, targetFrontendSpec) + } + return + case etcd3.DELETE, etcd3.HISTORYDELETE: + specification, exist := fr.frontendSpecs[instanceID] + if !exist { + log.GetLogger().Errorf("function faaS frontend %s does not exist in registry", instanceID) + return + } + delete(fr.frontendSpecs, instanceID) + fr.publishEvent(types.SubEventTypeDelete, specification) + default: + log.GetLogger().Errorf("unsupported event: %s", event.Key) + return + } +} + +func (fr *FaaSFrontendRegistry) createOrUpdateFrontendSpec( + faasFrontend *types.InstanceSpecification) (*types.InstanceSpecification, bool) { + specification, exist := fr.frontendSpecs[faasFrontend.InstanceID] + if !exist { + funcCtx, cancelFunc := context.WithCancel(context.TODO()) + faasFrontend.FuncCtx = funcCtx + faasFrontend.CancelFunc = cancelFunc + fr.frontendSpecs[faasFrontend.InstanceID] = faasFrontend + specification = faasFrontend + } else { + specification.InstanceSpecificationMeta = faasFrontend.InstanceSpecificationMeta + } + if utils.CheckNeedRecover(faasFrontend.InstanceSpecificationMeta) { + return faasFrontend, true + } + return specification, false +} + +// AddSubscriberChan add chan +func (fr *FaaSFrontendRegistry) AddSubscriberChan(subChan chan types.SubEvent) { + fr.Lock() + fr.subscriberChans = append(fr.subscriberChans, subChan) + fr.Unlock() +} + +func (fr *FaaSFrontendRegistry) publishEvent(eventType types.EventType, schedulerSpec *types.InstanceSpecification) { + for _, subChan := range fr.subscriberChans { + if subChan != nil { + subChan <- types.SubEvent{ + EventType: eventType, + EventKind: types.EventKindFrontend, + EventMsg: schedulerSpec, + } + } + } +} diff --git a/yuanrong/pkg/system_function_controller/registry/faasmanagerregistry.go b/yuanrong/pkg/system_function_controller/registry/faasmanagerregistry.go new file mode 100644 index 0000000..56f201c --- /dev/null +++ b/yuanrong/pkg/system_function_controller/registry/faasmanagerregistry.go @@ -0,0 +1,160 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "context" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/system_function_controller/types" + "yuanrong/pkg/system_function_controller/utils" +) + +// FaaSManagerRegistry watche faasmanager event of etcd +type FaaSManagerRegistry struct { + watcher etcd3.Watcher + managerSpecs map[string]*types.InstanceSpecification + subscriberChans []chan types.SubEvent + stopCh <-chan struct{} + sync.RWMutex +} + +// NewManagerRegistry will create faaSManagerRegistry +func NewManagerRegistry(stopCh <-chan struct{}) *FaaSManagerRegistry { + registry := &FaaSManagerRegistry{ + managerSpecs: make(map[string]*types.InstanceSpecification, defaultMapSize), + stopCh: stopCh, + } + return registry +} + +// InitWatcher init watcher +func (fm *FaaSManagerRegistry) InitWatcher() { + fm.watcher = etcd3.NewEtcdWatcher(constant.InstancePathPrefix, + fm.watcherFilter, fm.watcherHandler, fm.stopCh, etcd3.GetRouterEtcdClient()) + fm.watcher.StartList() +} + +// RunWatcher will start etcd watch process +func (fm *FaaSManagerRegistry) RunWatcher() { + go fm.watcher.StartWatch() +} + +func (fm *FaaSManagerRegistry) getFunctionSpec(instanceID string) *types.InstanceSpecification { + fm.RLock() + schedulerSpec := fm.managerSpecs[instanceID] + fm.RUnlock() + return schedulerSpec +} + +// The etcd key format to be filtered out +func (fm *FaaSManagerRegistry) watcherFilter(event *etcd3.Event) bool { + etcdKey := event.Key + log.GetLogger().Infof("watcherFilter get etcdKey=%s", etcdKey) + items := strings.Split(etcdKey, constant.KeySeparator) + if len(items) != constant.ValidEtcdKeyLenForInstance || + items[constant.TenantIndexForInstance] != "tenant" || items[constant.FunctionIndexForInstance] != "function" || + !strings.Contains(etcdKey, "0-system-faasmanager") { + log.GetLogger().Warnf("invalid faaS manager instance key: %s", etcdKey) + return true + } + return false +} + +func (fm *FaaSManagerRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling instance event type %s key %s", event.Type, event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("manager registry ready to receive etcd kv") + return + } + instanceID := utils.ExtractInfoFromEtcdKey(event.Key, constant.InstanceIDIndexForInstance) + if len(instanceID) == 0 { + log.GetLogger().Errorf("ignoring invalid etcd key of key %s", event.Key) + return + } + fm.Lock() + defer fm.Unlock() + switch event.Type { + case etcd3.PUT, etcd3.HISTORYUPDATE: + faasManagerSpec := utils.GetInstanceSpecFromEtcdValue(event.Value) + if faasManagerSpec == nil { + log.GetLogger().Errorf("ignoring invalid etcd value of key %s", event.Key) + return + } + faasManagerSpec.InstanceID = instanceID + targetFrontendSpec, needRecover := fm.createOrUpdateManagerSpec(faasManagerSpec) + if needRecover { + fm.publishEvent(types.SubEventTypeRecover, targetFrontendSpec) + } else { + fm.publishEvent(types.SubEventTypeUpdate, targetFrontendSpec) + } + return + case etcd3.DELETE, etcd3.HISTORYDELETE: + specification, exist := fm.managerSpecs[instanceID] + if !exist { + log.GetLogger().Errorf("function faaS manager %s does not exist in registry", instanceID) + return + } + delete(fm.managerSpecs, instanceID) + fm.publishEvent(types.SubEventTypeDelete, specification) + default: + log.GetLogger().Errorf("unsupported event: %s", event.Key) + return + } +} + +func (fm *FaaSManagerRegistry) createOrUpdateManagerSpec( + faasManager *types.InstanceSpecification) (*types.InstanceSpecification, bool) { + specification, exist := fm.managerSpecs[faasManager.InstanceID] + if !exist { + funcCtx, cancelFunc := context.WithCancel(context.TODO()) + faasManager.FuncCtx = funcCtx + faasManager.CancelFunc = cancelFunc + fm.managerSpecs[faasManager.InstanceID] = faasManager + specification = faasManager + } else { + specification.InstanceSpecificationMeta = faasManager.InstanceSpecificationMeta + } + if utils.CheckNeedRecover(faasManager.InstanceSpecificationMeta) { + return faasManager, true + } + return specification, false +} + +// AddSubscriberChan add chan +func (fm *FaaSManagerRegistry) AddSubscriberChan(subChan chan types.SubEvent) { + fm.Lock() + fm.subscriberChans = append(fm.subscriberChans, subChan) + fm.Unlock() +} + +func (fm *FaaSManagerRegistry) publishEvent(eventType types.EventType, schedulerSpec *types.InstanceSpecification) { + for _, subChan := range fm.subscriberChans { + if subChan != nil { + subChan <- types.SubEvent{ + EventType: eventType, + EventKind: types.EventKindManager, + EventMsg: schedulerSpec, + } + } + } +} diff --git a/yuanrong/pkg/system_function_controller/registry/faasschedulerregistry.go b/yuanrong/pkg/system_function_controller/registry/faasschedulerregistry.go new file mode 100644 index 0000000..003cbdb --- /dev/null +++ b/yuanrong/pkg/system_function_controller/registry/faasschedulerregistry.go @@ -0,0 +1,160 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "context" + "strings" + "sync" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/system_function_controller/types" + "yuanrong/pkg/system_function_controller/utils" +) + +// FaaSSchedulerRegistry watches scheduler event of etcd +type FaaSSchedulerRegistry struct { + schedulerSpecs map[string]*types.InstanceSpecification + watcher etcd3.Watcher + subscriberChans []chan types.SubEvent + stopCh <-chan struct{} + sync.RWMutex +} + +// NewSchedulerRegistry will create faaSSchedulerRegistry +func NewSchedulerRegistry(stopCh <-chan struct{}) *FaaSSchedulerRegistry { + schedulerRegistry := &FaaSSchedulerRegistry{ + schedulerSpecs: make(map[string]*types.InstanceSpecification, defaultMapSize), + stopCh: stopCh, + } + return schedulerRegistry +} + +// InitWatcher init watcher +func (sr *FaaSSchedulerRegistry) InitWatcher() { + sr.watcher = etcd3.NewEtcdWatcher(constant.InstancePathPrefix, + sr.watcherFilter, sr.watcherHandler, sr.stopCh, etcd3.GetRouterEtcdClient()) + sr.watcher.StartList() +} + +// RunWatcher will start etcd watch process +func (sr *FaaSSchedulerRegistry) RunWatcher() { + go sr.watcher.StartWatch() +} + +func (sr *FaaSSchedulerRegistry) getFunctionSpec(instanceID string) *types.InstanceSpecification { + sr.RLock() + schedulerSpec := sr.schedulerSpecs[instanceID] + sr.RUnlock() + return schedulerSpec +} + +func (sr *FaaSSchedulerRegistry) watcherFilter(event *etcd3.Event) bool { + log.GetLogger().Infof("watcherFilter get etcdKey=%s", event.Key) + items := strings.Split(event.Key, constant.KeySeparator) + if len(items) != constant.ValidEtcdKeyLenForInstance { + return true + } + if items[constant.FunctionsIndexForInstance] != "instance" || items[constant.TenantIndexForInstance] != "tenant" || + items[constant.FunctionIndexForInstance] != "function" { + return true + } + if !strings.Contains(items[constant.FunctionNameIndexForInstance], constant.FaasSchedulerMark) { + return true + } + return false +} + +func (sr *FaaSSchedulerRegistry) watcherHandler(event *etcd3.Event) { + log.GetLogger().Infof("handling function event type %s key %s", event.Type, event.Key) + instanceID := utils.ExtractInstanceIDFromEtcdKey(event.Key) + if event.Type == etcd3.SYNCED { + log.GetLogger().Infof("scheduler registry ready to receive etcd kv") + return + } + if len(instanceID) == 0 { + log.GetLogger().Warnf("ignore invalid etcd key of key %s", event.Key) + return + } + sr.Lock() + defer sr.Unlock() + switch event.Type { + case etcd3.PUT, etcd3.HISTORYUPDATE: + targetSchedulerSpec := utils.GetInstanceSpecFromEtcdValue(event.Value) + if targetSchedulerSpec == nil { + log.GetLogger().Errorf("ignoring invalid etcd value of key %s", event.Key) + return + } + targetSchedulerSpec.InstanceID = instanceID + schedulerSpec, needRecover := sr.createOrUpdateSchedulerSpec(targetSchedulerSpec) + sr.publishEvent(types.SubEventTypeUpdate, schedulerSpec) + if needRecover { + sr.publishEvent(types.SubEventTypeRecover, schedulerSpec) + } + case etcd3.DELETE, etcd3.HISTORYDELETE: + schedulerSpec, exist := sr.schedulerSpecs[instanceID] + if !exist { + log.GetLogger().Errorf("function faaS scheduler %s doesn't exist in registry", instanceID) + return + } + delete(sr.schedulerSpecs, instanceID) + sr.publishEvent(types.SubEventTypeDelete, schedulerSpec) + default: + log.GetLogger().Errorf("unsupported event: %s", event.Key) + return + } +} + +func (sr *FaaSSchedulerRegistry) createOrUpdateSchedulerSpec( + targetSchedulerSpec *types.InstanceSpecification) (*types.InstanceSpecification, bool) { + schedulerSpec, exist := sr.schedulerSpecs[targetSchedulerSpec.InstanceID] + if !exist { + funcCtx, cancelFunc := context.WithCancel(context.TODO()) + targetSchedulerSpec.FuncCtx = funcCtx + targetSchedulerSpec.CancelFunc = cancelFunc + sr.schedulerSpecs[targetSchedulerSpec.InstanceID] = targetSchedulerSpec + schedulerSpec = targetSchedulerSpec + } else { + schedulerSpec.InstanceSpecificationMeta = targetSchedulerSpec.InstanceSpecificationMeta + } + if utils.CheckNeedRecover(targetSchedulerSpec.InstanceSpecificationMeta) { + return targetSchedulerSpec, true + } + return schedulerSpec, false +} + +// AddSubscriberChan add chan +func (sr *FaaSSchedulerRegistry) AddSubscriberChan(subChan chan types.SubEvent) { + sr.Lock() + sr.subscriberChans = append(sr.subscriberChans, subChan) + sr.Unlock() +} + +func (sr *FaaSSchedulerRegistry) publishEvent(eventType types.EventType, schedulerSpec *types.InstanceSpecification) { + for _, subChan := range sr.subscriberChans { + if subChan != nil { + subChan <- types.SubEvent{ + EventType: eventType, + EventKind: types.EventKindScheduler, + EventMsg: schedulerSpec, + } + } + } +} diff --git a/yuanrong/pkg/system_function_controller/registry/registry.go b/yuanrong/pkg/system_function_controller/registry/registry.go new file mode 100644 index 0000000..a7c4e80 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/registry/registry.go @@ -0,0 +1,88 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/system_function_controller/types" +) + +const ( + defaultMapSize = 16 + functionRegistryNum = 2 +) + +// FunctionRegistry defines the interface for kind of registry +type FunctionRegistry interface { + AddSubscriberChan(subChan chan types.SubEvent) + publishEvent(eventType types.EventType, schedulerSpec *types.InstanceSpecification) + watcherFilter(event *etcd3.Event) bool + watcherHandler(event *etcd3.Event) + InitWatcher() + RunWatcher() + getFunctionSpec(instanceID string) *types.InstanceSpecification +} + +var ( + // GlobalRegistry is the global registry + GlobalRegistry *Registry +) + +// Registry watches etcd and builds registry cache based on etcd watch +type Registry struct { + functionRegistry map[types.EventKind]FunctionRegistry +} + +// InitRegistry will initialize registry +func InitRegistry() { + GlobalRegistry = &Registry{ + functionRegistry: make(map[types.EventKind]FunctionRegistry, functionRegistryNum), + } +} + +// AddFunctionRegistry add function registry +func (r *Registry) AddFunctionRegistry(registry FunctionRegistry, eventType types.EventKind) { + if registry == nil || GlobalRegistry == nil { + return + } + GlobalRegistry.functionRegistry[eventType] = registry +} + +// ProcessETCDList ETCD List +func (r *Registry) ProcessETCDList() { + for _, funcRegistry := range r.functionRegistry { + funcRegistry.InitWatcher() + } +} + +// RegisterSubscriberChan function registry can subscriber channel +func (r *Registry) RegisterSubscriberChan(subChan chan types.SubEvent) { + for _, funcRegistry := range r.functionRegistry { + funcRegistry.AddSubscriberChan(subChan) + } + log.GetLogger().Infof("all registry subscriber channel") +} + +// RunFunctionWatcher watch events and publish to subscribed channel +func (r *Registry) RunFunctionWatcher() { + for _, funcRegistry := range r.functionRegistry { + funcRegistry.RunWatcher() + } + log.GetLogger().Infof("all registry run watcher") +} diff --git a/yuanrong/pkg/system_function_controller/registry/registry_test.go b/yuanrong/pkg/system_function_controller/registry/registry_test.go new file mode 100644 index 0000000..f40c2e6 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/registry/registry_test.go @@ -0,0 +1,711 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package registry - +package registry + +import ( + "errors" + "reflect" + "testing" + "time" + + . "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/client/v3" + + "yuanrong/pkg/common/faas_common/etcd3" + fcConfig "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/types" + "yuanrong/pkg/system_function_controller/utils" +) + +var ( + validSchedulerEtcdKey = "/sn/instance/business/yrk/tenant/12/function/0-system-faasscheduler/version/$latest/defaultaz/requestID/123" + validFrontendEtcdKey = "/sn/instance/business/yrk/tenant/0/function/0-system-faasfrontend/version/$latest/defaultaz/requestID/123" + validManagerEtcdKey = "/sn/instance/business/yrk/tenant/0/function/0-system-faasmanager/version/$latest/defaultaz/requestID/123" + invalidEtcdKey = "/instance/business/yrk/tenant/12/function/0-system-faasscheduler/version/$latest/defaultaz/123" + + validEtcdValue = `{ + "runtimeID":"16444dbc", + "deployedIP":"10.244.136.113", + "deployedNode":"dggphis36581", + "runtimeIP":"10.42.0.37", + "runtimePort":"32065", + "funcKey":"12345678901234561234567890123456/0-system-faasscheduler/$latest", + "resource":{"cpu":"500","memory":"500","customresources":{}}, + "concurrency":1, + "status":3, + "labels":null} + ` + + configString = `{ + "frontendInstanceNum": 10, + "schedulerInstanceNum": 10, + "faasschedulerConfig": { + "autoScaleConfig":{ + "slaQuota": 1000, + "scaleDownTime": 60000, + "burstScaleNum": 1000 + }, + "leaseSpan": 600000 + }, + "etcd": { + "url": ["1.2.3.4:1234"], + "username": "tom", + "password": "**" + } + } + ` +) + +func initConfig() { + fcConfig.InitConfig([]byte(configString)) +} + +func generateETCDevent(eventType int, key string) *etcd3.Event { + event := &etcd3.Event{ + Type: eventType, + Key: key, + Value: []byte(validEtcdValue), + } + return event +} + +func initRegistry() { + initConfig() + + patches := [...]*Patches{ + ApplyFunc(etcd3.GetSharedEtcdClient, + func(_ *etcd3.EtcdConfig) (*clientv3.Client, error) { + return nil, nil + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", + func(_ *etcd3.EtcdWatcher) { + return + }), + } + + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + InitRegistry() +} + +func TestInitRegistryPublish(t *testing.T) { + initRegistry() + Convey("Test RegistryPublish", t, func() { + Convey("test scheduler registry publish", func() { + funCh := make(chan types.SubEvent, 16) + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindScheduler] = NewSchedulerRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler].AddSubscriberChan(funCh) + + GlobalRegistry.functionRegistry[types.EventKindScheduler].publishEvent(types.SubEventTypeUpdate, + &types.InstanceSpecification{InstanceID: "test123"}) + event := <-funCh + functionSpec, ok := event.EventMsg.(*types.InstanceSpecification) + if !ok { + assert.Error(t, errors.New("event assert types.InstanceSpecification failed")) + } + So(functionSpec.InstanceID, ShouldEqual, "test123") + So(event.EventKind, ShouldEqual, types.EventKindScheduler) + So(event.EventType, ShouldEqual, types.SubEventTypeUpdate) + }) + + Convey("test manager registry publish", func() { + funCh := make(chan types.SubEvent, 16) + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindManager] = NewManagerRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindManager].AddSubscriberChan(funCh) + + GlobalRegistry.functionRegistry[types.EventKindManager].publishEvent(types.SubEventTypeUpdate, + &types.InstanceSpecification{InstanceID: "test123"}) + event := <-funCh + functionSpec, ok := event.EventMsg.(*types.InstanceSpecification) + if !ok { + assert.Error(t, errors.New("event assert types.InstanceSpecification failed")) + } + So(functionSpec.InstanceID, ShouldEqual, "test123") + So(event.EventKind, ShouldEqual, types.EventKindManager) + So(event.EventType, ShouldEqual, types.SubEventTypeUpdate) + }) + + Convey("test frontend registry publish", func() { + funCh := make(chan types.SubEvent, 16) + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindFrontend] = NewFrontendRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindFrontend].AddSubscriberChan(funCh) + + GlobalRegistry.functionRegistry[types.EventKindFrontend].publishEvent(types.SubEventTypeUpdate, + &types.InstanceSpecification{InstanceID: "test123"}) + event := <-funCh + functionSpec, ok := event.EventMsg.(*types.InstanceSpecification) + if !ok { + assert.Error(t, errors.New("event assert types.InstanceSpecification failed")) + } + So(functionSpec.InstanceID, ShouldEqual, "test123") + So(event.EventKind, ShouldEqual, types.EventKindFrontend) + So(event.EventType, ShouldEqual, types.SubEventTypeUpdate) + }) + + Convey("test all registry publish", func() { + funCh := make(chan types.SubEvent, 16) + GlobalRegistry.RegisterSubscriberChan(funCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler].publishEvent(types.SubEventTypeUpdate, + &types.InstanceSpecification{InstanceID: "test123"}) + GlobalRegistry.functionRegistry[types.EventKindFrontend].publishEvent(types.SubEventTypeUpdate, + &types.InstanceSpecification{InstanceID: "test456"}) + GlobalRegistry.functionRegistry[types.EventKindManager].publishEvent(types.SubEventTypeUpdate, + &types.InstanceSpecification{InstanceID: "test789"}) + event := <-funCh + functionSpec, ok := event.EventMsg.(*types.InstanceSpecification) + if !ok { + assert.Error(t, errors.New("event assert types.InstanceSpecification failed")) + } + So(functionSpec.InstanceID, ShouldEqual, "test123") + So(event.EventKind, ShouldEqual, types.EventKindScheduler) + So(event.EventType, ShouldEqual, types.SubEventTypeUpdate) + + event = <-funCh + functionSpec, ok = event.EventMsg.(*types.InstanceSpecification) + if !ok { + assert.Error(t, errors.New("event assert types.InstanceSpecification failed")) + } + So(functionSpec.InstanceID, ShouldEqual, "test456") + So(event.EventKind, ShouldEqual, types.EventKindFrontend) + So(event.EventType, ShouldEqual, types.SubEventTypeUpdate) + + event = <-funCh + functionSpec, ok = event.EventMsg.(*types.InstanceSpecification) + if !ok { + assert.Error(t, errors.New("event assert types.InstanceSpecification failed")) + } + So(functionSpec.InstanceID, ShouldEqual, "test789") + So(event.EventKind, ShouldEqual, types.EventKindManager) + So(event.EventType, ShouldEqual, types.SubEventTypeUpdate) + }) + }) +} + +func TestRegistry_runWatcher(t *testing.T) { + var event *etcd3.Event + initConfig() + InitRegistry() + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindFrontend] = NewFrontendRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler] = NewSchedulerRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindManager] = NewManagerRegistry(stopCh) + patches := [...]*Patches{ + ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartList", func(_ *etcd3.EtcdWatcher) { + }), + } + defer func() { + for _, patch := range patches { + patch.Reset() + } + }() + Convey("test registry runWatcher", t, func() { + Convey("test scheduler registry runWatcher", func() { + c := make(chan *etcd3.Event) + patches := [...]*Patches{ + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", + func(_ *etcd3.EtcdWatcher) { + if !GlobalRegistry.functionRegistry[types.EventKindScheduler].watcherFilter(event) { + GlobalRegistry.functionRegistry[types.EventKindScheduler].watcherHandler(event) + c <- event + return + } + }), + } + + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + + event = generateETCDevent(etcd3.HISTORYUPDATE, validSchedulerEtcdKey) + GlobalRegistry.functionRegistry[types.EventKindScheduler].InitWatcher() + GlobalRegistry.functionRegistry[types.EventKindScheduler].RunWatcher() + So(<-c, ShouldEqual, event) + }) + + Convey("test frontend registry runWatcher", func() { + c := make(chan *etcd3.Event) + patches := [...]*Patches{ + ApplyFunc(etcd3.GetSharedEtcdClient, + func(_ *etcd3.EtcdConfig) (*clientv3.Client, error) { + return &clientv3.Client{}, nil + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", + func(_ *etcd3.EtcdWatcher) { + if !GlobalRegistry.functionRegistry[types.EventKindFrontend].watcherFilter(event) { + GlobalRegistry.functionRegistry[types.EventKindFrontend].watcherHandler(event) + c <- event + return + } + }), + } + + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + event = generateETCDevent(etcd3.PUT, validFrontendEtcdKey) + GlobalRegistry.functionRegistry[types.EventKindFrontend].InitWatcher() + GlobalRegistry.functionRegistry[types.EventKindFrontend].RunWatcher() + So(<-c, ShouldEqual, event) + }) + + Convey("test manager registry runWatcher", func() { + c := make(chan *etcd3.Event) + patches := [...]*Patches{ + ApplyFunc(etcd3.GetSharedEtcdClient, + func(_ *etcd3.EtcdConfig) (*clientv3.Client, error) { + return &clientv3.Client{}, nil + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", + func(_ *etcd3.EtcdWatcher) { + if !GlobalRegistry.functionRegistry[types.EventKindManager].watcherFilter(event) { + GlobalRegistry.functionRegistry[types.EventKindManager].watcherHandler(event) + c <- event + return + } + }), + } + + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + event = generateETCDevent(etcd3.PUT, validManagerEtcdKey) + GlobalRegistry.functionRegistry[types.EventKindManager].InitWatcher() + GlobalRegistry.functionRegistry[types.EventKindManager].RunWatcher() + So(<-c, ShouldEqual, event) + }) + + Convey("test all registry runWatcher", func() { + eventCh := make(chan *etcd3.Event, 1) + frontendEventCh := make(chan *etcd3.Event, 1) + patches := [...]*Patches{ + ApplyFunc(etcd3.GetSharedEtcdClient, + func(_ *etcd3.EtcdConfig) (*clientv3.Client, error) { + return &clientv3.Client{}, nil + }), + ApplyMethod(reflect.TypeOf(&etcd3.EtcdWatcher{}), "StartWatch", + func(_ *etcd3.EtcdWatcher) { + event = generateETCDevent(etcd3.PUT, validSchedulerEtcdKey) + event1 := generateETCDevent(etcd3.HISTORYUPDATE, validFrontendEtcdKey) + event2 := generateETCDevent(etcd3.PUT, validManagerEtcdKey) + if !GlobalRegistry.functionRegistry[types.EventKindScheduler].watcherFilter(event) { + GlobalRegistry.functionRegistry[types.EventKindScheduler].watcherHandler(event) + eventCh <- event + } + if !GlobalRegistry.functionRegistry[types.EventKindFrontend].watcherFilter(event1) { + GlobalRegistry.functionRegistry[types.EventKindFrontend].watcherHandler(event1) + frontendEventCh <- event1 + } + if !GlobalRegistry.functionRegistry[types.EventKindManager].watcherFilter(event1) { + GlobalRegistry.functionRegistry[types.EventKindManager].watcherHandler(event1) + frontendEventCh <- event2 + } + }), + } + defer func() { + for _, patch := range patches { + time.Sleep(100 * time.Millisecond) + patch.Reset() + } + }() + GlobalRegistry.ProcessETCDList() + GlobalRegistry.RunFunctionWatcher() + e1, _ := <-eventCh + So(e1.Key, ShouldEqual, validSchedulerEtcdKey) + e2, _ := <-frontendEventCh + So(e2.Key, ShouldEqual, validFrontendEtcdKey) + }) + }) + +} + +func TestSchedulerRegistry_watcherFilter(t *testing.T) { + initConfig() + InitRegistry() + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindFrontend] = NewFrontendRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler] = NewSchedulerRegistry(stopCh) + tests := []struct { + name string + etcdKey string + want bool + }{ + { + "test valid scheduler instance etcd key", + validSchedulerEtcdKey, + false, + }, + { + "test invalid length of scheduler instance etcd key", + invalidEtcdKey, + true, + }, + { + "test invalid content of scheduler instance etcd key", + "/sn/instance/business/yrk/tenant/12/functionxxx/0-system-faasscheduler/version/$latest/defaultaz/requestID/123", + true, + }, + { + "test invalid content of scheduler instance etcd key", + "/sn/instance/business/yrk/tenant/12/function/0-system-faascontroller/version/$latest/defaultaz/requestID/123", + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := generateETCDevent(etcd3.PUT, tt.etcdKey) + assert.Equalf(t, tt.want, GlobalRegistry.functionRegistry[types.EventKindScheduler].watcherFilter(event), "watcherFilter(%v)", event) + }) + } +} + +func TestFrontendRegistry_watcherFilter(t *testing.T) { + initConfig() + InitRegistry() + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindFrontend] = NewFrontendRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler] = NewSchedulerRegistry(stopCh) + tests := []struct { + name string + etcdKey string + want bool + }{ + { + "test valid frontend instance etcd key", + validFrontendEtcdKey, + false, + }, + { + "test invalid length of frontend instance etcd key", + invalidEtcdKey, + true, + }, + { + "test invalid content of frontend instance etcd key", + "/sn/instance/business/yrk/tenant/0/functionxxx/0-system-faasfrontend/version/$latest/defaultaz/requestID/123", + true, + }, + { + "test invalid content of frontend instance etcd key", + "/sn/instance/business/yrk/tenant/0/function/0-system-faascontroller/version/$latest/defaultaz/requestID/123", + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := generateETCDevent(etcd3.PUT, tt.etcdKey) + assert.Equalf(t, tt.want, GlobalRegistry.functionRegistry[types.EventKindFrontend].watcherFilter(event), "watcherFilter(%v)", event) + }) + } +} + +func TestManagerRegistry_watcherFilter(t *testing.T) { + initConfig() + InitRegistry() + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindFrontend] = NewFrontendRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler] = NewSchedulerRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindManager] = NewManagerRegistry(stopCh) + tests := []struct { + name string + etcdKey string + want bool + }{ + { + "test valid frontend instance etcd key", + validFrontendEtcdKey, + true, + }, + { + "test invalid length of frontend instance etcd key", + invalidEtcdKey, + true, + }, + { + "test invalid content of frontend instance etcd key", + "/sn/instance/business/yrk/tenant/0/functionxxx/0-system-faasfrontend/version/$latest/defaultaz/requestID/123/xxx", + true, + }, + { + "test invalid content of frontend instance etcd key", + "/sn/instance/business/yrk/tenant/0/function/0-system-faascontroller/version/$latest/defaultaz/requestID/123/xxx", + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := generateETCDevent(etcd3.PUT, tt.etcdKey) + assert.Equalf(t, tt.want, GlobalRegistry.functionRegistry[types.EventKindManager].watcherFilter(event), "watcherFilter(%v)", event) + }) + } +} + +func TestSchedulerRegistry_watcherHandler(t *testing.T) { + initConfig() + InitRegistry() + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindFrontend] = NewFrontendRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler] = NewSchedulerRegistry(stopCh) + tests := []struct { + name string + etcdKey string + eventType int + wantNil bool + }{ + { + "test invalid instanceID", + "", + etcd3.PUT, + true, + }, + { + "test watcherHandler with put event", + validSchedulerEtcdKey, + etcd3.PUT, + false, + }, + { + "test watcherHandler with duplicated put event", + validSchedulerEtcdKey, + etcd3.PUT, + false, + }, + { + "test watcherHandler with delete event", + validSchedulerEtcdKey, + etcd3.DELETE, + true, + }, + { + "test watcherHandler with delete event", + validSchedulerEtcdKey, + etcd3.DELETE, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := generateETCDevent(tt.eventType, tt.etcdKey) + GlobalRegistry.functionRegistry[types.EventKindScheduler].watcherHandler(event) + + instanceID := utils.ExtractInstanceIDFromEtcdKey(tt.etcdKey) + ans := GlobalRegistry.functionRegistry[types.EventKindScheduler].getFunctionSpec(instanceID) + if !tt.wantNil { + assert.NotNil(t, ans) + } else { + assert.Nil(t, ans) + } + }) + } +} + +func TestFrontendRegistry_watcherHandler(t *testing.T) { + initConfig() + InitRegistry() + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindFrontend] = NewFrontendRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler] = NewSchedulerRegistry(stopCh) + tests := []struct { + name string + etcdKey string + eventType int + wantNil bool + }{ + { + "test invalid instanceID", + "", + etcd3.PUT, + true, + }, + { + "test watcherHandler with put event", + validFrontendEtcdKey, + etcd3.PUT, + false, + }, + { + "test watcherHandler with duplicated put event", + validFrontendEtcdKey, + etcd3.PUT, + false, + }, + { + "test watcherHandler with delete event", + validFrontendEtcdKey, + etcd3.DELETE, + true, + }, + { + "test watcherHandler with delete event", + validFrontendEtcdKey, + etcd3.DELETE, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := generateETCDevent(tt.eventType, tt.etcdKey) + GlobalRegistry.functionRegistry[types.EventKindFrontend].watcherHandler(event) + + instanceID := utils.ExtractInstanceIDFromEtcdKey(tt.etcdKey) + ans := GlobalRegistry.functionRegistry[types.EventKindFrontend].getFunctionSpec(instanceID) + if !tt.wantNil { + assert.NotNil(t, ans) + } else { + assert.Nil(t, ans) + } + }) + } +} + +func TestManagerRegistry_watcherHandler(t *testing.T) { + initConfig() + InitRegistry() + stopCh := make(chan struct{}) + GlobalRegistry.functionRegistry[types.EventKindFrontend] = NewFrontendRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindScheduler] = NewSchedulerRegistry(stopCh) + GlobalRegistry.functionRegistry[types.EventKindManager] = NewManagerRegistry(stopCh) + tests := []struct { + name string + etcdKey string + eventType int + wantNil bool + }{ + { + "test invalid instanceID", + "", + etcd3.PUT, + true, + }, + { + "test watcherHandler with put event", + validFrontendEtcdKey, + etcd3.PUT, + false, + }, + { + "test watcherHandler with duplicated put event", + validFrontendEtcdKey, + etcd3.PUT, + false, + }, + { + "test watcherHandler with delete event", + validFrontendEtcdKey, + etcd3.DELETE, + true, + }, + { + "test watcherHandler with delete event", + validFrontendEtcdKey, + etcd3.DELETE, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := generateETCDevent(tt.eventType, tt.etcdKey) + GlobalRegistry.functionRegistry[types.EventKindManager].watcherHandler(event) + + instanceID := utils.ExtractInstanceIDFromEtcdKey(tt.etcdKey) + ans := GlobalRegistry.functionRegistry[types.EventKindManager].getFunctionSpec(instanceID) + if !tt.wantNil { + assert.NotNil(t, ans) + } else { + assert.Nil(t, ans) + } + }) + } +} + +func Test_extractInstanceIDFromEtcdKey(t *testing.T) { + type args struct { + etcdKey string + } + tests := []struct { + name string + args args + want string + }{ + { + "test valid etcd key", + args{etcdKey: validSchedulerEtcdKey}, + "123", + }, + { + "test invalid etcd key", + args{etcdKey: invalidEtcdKey}, + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := utils.ExtractInstanceIDFromEtcdKey(tt.args.etcdKey); got != tt.want { + t.Errorf("extractInstanceIDFromEtcdKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getSchedulerSpecFromEtcdValue(t *testing.T) { + type args struct { + etcdValue []byte + } + tests := []struct { + name string + args args + wantNil bool + }{ + { + "test valid etcd value", + args{etcdValue: []byte(validEtcdValue)}, + false, + }, + { + "test invalid etcd value", + args{etcdValue: []byte("123")}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := utils.GetInstanceSpecFromEtcdValue(tt.args.etcdValue) + assert.Equal(t, got == nil, tt.wantNil) + }) + } +} diff --git a/yuanrong/pkg/system_function_controller/service/frontendservice.go b/yuanrong/pkg/system_function_controller/service/frontendservice.go new file mode 100644 index 0000000..bbccb2a --- /dev/null +++ b/yuanrong/pkg/system_function_controller/service/frontendservice.go @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package service - +package service + +import ( + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/retry" + + "yuanrong/pkg/common/faas_common/k8sclient" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/constant" +) + +var backoff = wait.Backoff{ + Steps: 5, + Duration: 10000000, // 10 ms + Factor: 1.5, + Jitter: 0.3, +} + +// CreateFrontendService - +func CreateFrontendService() error { + nameSpace := config.GetFaaSControllerConfig().NameSpace + if nameSpace == "" { + nameSpace = constant.NamespaceDefault + } + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: constant.FuncNameFaasfrontend, + Namespace: nameSpace, + }, + Spec: v1.ServiceSpec{ + Selector: map[string]string{ + constant.SystemFuncName: constant.FuncNameFaasfrontend, + }, + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: "TCP", + Port: constant.ServiceFrontendPort, + TargetPort: intstr.IntOrString{ + Type: intstr.Int, + IntVal: constant.ServiceFrontendTargetPort, + }, + NodePort: constant.ServiceFrontendNodePort, + }, + }, + Type: v1.ServiceTypeNodePort, + }, + } + + createService := func() error { + // create Service + return k8sclient.GetkubeClient().CreateK8sService(service) + } + // Used to determine whether an error can be retried + isRetriable := func(err error) bool { + // always retry + return true + } + return retry.OnError(backoff, isRetriable, createService) +} diff --git a/yuanrong/pkg/system_function_controller/service/frontendservice_test.go b/yuanrong/pkg/system_function_controller/service/frontendservice_test.go new file mode 100644 index 0000000..ed17e67 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/service/frontendservice_test.go @@ -0,0 +1,28 @@ +package service + +import ( + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + v1 "k8s.io/api/core/v1" + + "yuanrong/pkg/common/faas_common/k8sclient" + "yuanrong/pkg/system_function_controller/config" + "yuanrong/pkg/system_function_controller/types" +) + +func TestCreateFrontendService(t *testing.T) { + convey.Convey("CreateFrontendService", t, func() { + defer gomonkey.ApplyMethod(reflect.TypeOf(&k8sclient.KubeClient{}), "CreateK8sService", + func(_ *k8sclient.KubeClient, service *v1.Service) error { + return nil + }).Reset() + defer gomonkey.ApplyFunc(config.GetFaaSControllerConfig, func() types.Config { + return types.Config{NameSpace: "default"} + }).Reset() + err := CreateFrontendService() + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/system_function_controller/state/state.go b/yuanrong/pkg/system_function_controller/state/state.go new file mode 100644 index 0000000..86becfa --- /dev/null +++ b/yuanrong/pkg/system_function_controller/state/state.go @@ -0,0 +1,178 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package state - +package state + +import ( + "encoding/json" + "fmt" + "os" + "sync" + + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/state" + "yuanrong/pkg/system_function_controller/types" +) + +// ControllerState add the status to be saved here. +type ControllerState struct { + FaaSControllerConfig types.Config `json:"FaaSControllerConfig" valid:"optional"` + // FaasInstance: type(frontend/scheduler/manager) -- instanceID + FaasInstance map[string]map[string]struct{} `json:"FaasInstance" valid:"optional"` +} + +const ( + defaultHandlerQueueSize = 100 + faasInstanceTypeNum = 3 + + faasInstanceTagsLen = 2 + faasInstanceTagsOptIndex = 0 + faasInstanceTagsTypeIndex = 1 +) + +var ( + // controllerState - + controllerState = &ControllerState{ + FaasInstance: make(map[string]map[string]struct{}, faasInstanceTypeNum), + } + controllerStateLock sync.RWMutex + controllerHandlerQueue *state.Queue + stateKey = "" + allFaaSStateKey = "/faas/state/recover" +) + +func init() { + stateKey = fmt.Sprintf("/faas/state/recover/faascontroller/%s", os.Getenv("INSTANCE_ID")) +} + +// InitState - +func InitState(schedulerExclusivity []string) { + if controllerHandlerQueue != nil { + return + } + controllerState.FaasInstance[types.FaasFrontendInstanceState] = map[string]struct{}{} + controllerState.FaasInstance[types.FaasSchedulerInstanceState] = map[string]struct{}{} + for _, tenantID := range schedulerExclusivity { + controllerState.FaasInstance[types.FaasSchedulerInstanceState+tenantID] = map[string]struct{}{} + } + controllerState.FaasInstance[types.FaasManagerInstanceState] = map[string]struct{}{} + controllerHandlerQueue = state.NewStateQueue(defaultHandlerQueueSize) + if controllerHandlerQueue == nil { + return + } + go controllerHandlerQueue.Run(updateState) + controllerStateLock.Lock() + defer controllerStateLock.Unlock() + stateBytes, err := controllerHandlerQueue.GetState(stateKey) + if err != nil { + log.GetLogger().Errorf("Failed to get state from etcd err: %v", err) + return + } + err = json.Unmarshal(stateBytes, controllerState) + if err != nil { + log.GetLogger().Errorf("unmarshal controller state error %s", err.Error()) + return + } +} + +// GetState - +func GetState() ControllerState { + controllerStateLock.RLock() + defer controllerStateLock.RUnlock() + return *controllerState +} + +// SetState - +func SetState(byte []byte) error { + return json.Unmarshal(byte, controllerState) +} + +// GetStateByte is used to obtain the local state +func GetStateByte() ([]byte, error) { + if controllerHandlerQueue == nil { + return nil, fmt.Errorf("controllerHandlerQueue is not initialized") + } + controllerStateLock.RLock() + defer controllerStateLock.RUnlock() + stateBytes, err := controllerHandlerQueue.GetState(stateKey) + if err != nil { + return nil, err + } + log.GetLogger().Debugf("get state from etcd controllerState: %v", string(stateBytes)) + return stateBytes, nil +} + +// DeleteStateByte - +func DeleteStateByte() error { + if controllerHandlerQueue == nil { + return fmt.Errorf("controllerHandlerQueue is not initialized") + } + controllerStateLock.Lock() + defer controllerStateLock.Unlock() + return controllerHandlerQueue.DeleteState(allFaaSStateKey) +} + +func updateState(value interface{}, tags ...string) { + if controllerHandlerQueue == nil { + log.GetLogger().Errorf("controller state controllerHandlerQueue is nil") + return + } + controllerStateLock.Lock() + defer controllerStateLock.Unlock() + switch v := value.(type) { + case types.Config: + controllerState.FaaSControllerConfig = v + log.GetLogger().Infof("update controller state for controller config") + case string: + if len(tags) != faasInstanceTagsLen { + log.GetLogger().Errorf("failed to operate the FaasInstance, tags length: %d", len(tags)) + return + } + if tags[faasInstanceTagsOptIndex] == types.StateUpdate { + controllerState.FaasInstance[tags[faasInstanceTagsTypeIndex]][v] = struct{}{} + } else if tags[faasInstanceTagsOptIndex] == types.StateDelete { + delete(controllerState.FaasInstance[tags[faasInstanceTagsTypeIndex]], v) + } else { + log.GetLogger().Errorf("failed to operate the FaasInstance, opt is error %s", tags[0]) + return + } + default: + log.GetLogger().Warnf("unknown data type for ControllerState") + return + } + + state, err := json.Marshal(controllerState) + if err != nil { + log.GetLogger().Errorf("get controller state error %s", err.Error()) + return + } + if err = controllerHandlerQueue.SaveState(state, stateKey); err != nil { + log.GetLogger().Errorf("save controller state error %s", err.Error()) + } + log.GetLogger().Infof("update controller state successfully") +} + +// Update is used to write controller state to the cache queue +func Update(value interface{}, tags ...string) { + if controllerHandlerQueue == nil { + log.GetLogger().Errorf("controller state controllerHandlerQueue is nil") + return + } + if err := controllerHandlerQueue.Push(value, tags...); err != nil { + log.GetLogger().Errorf("failed to push state to state queue: %s", err.Error()) + } +} diff --git a/yuanrong/pkg/system_function_controller/state/state_test.go b/yuanrong/pkg/system_function_controller/state/state_test.go new file mode 100644 index 0000000..953af4d --- /dev/null +++ b/yuanrong/pkg/system_function_controller/state/state_test.go @@ -0,0 +1,130 @@ +package state + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/state" + "yuanrong/pkg/system_function_controller/types" +) + +func TestGetState(t *testing.T) { + conf := []byte(`{"FaaSControllerConfig":{"frontendInstanceNum":100,"schedulerInstanceNum":100,"faasschedulerConfig":{ "autoScaleConfig":{"slaQuota": 1000,"scaleDownTime": 60000,"burstScaleNum": 1000},"leaseSpan":600000},"faasfrontendConfig":{"slaQuota":1000,"functionCapability":1,"authenticationEnable":false,"trafficLimitDisable":true,"instanceNum":0,"http":{"resptimeout":5,"workerInstanceReadTimeOut":5,"maxRequestBodySize":6},"metaEtcd":{"servers":null,"user":"","password":"","sslEnable":false,"CaFile":"","CertFile":"","KeyFile":""},"routerEtcd":{"servers":null,"user":"","password":"","sslEnable":false,"CaFile":"","CertFile":"","KeyFile":""}},"routerEtcd":{"servers":["1.2.3.4:1234"],"user":"tom","password":"**","sslEnable":false,"CaFile":"","CertFile":"","KeyFile":""},"metaEtcd":{"servers":["1.2.3.4:5678"],"user":"tom","password":"**","sslEnable":false,"CaFile":"","CertFile":"","KeyFile":""},"tlsConfig":{"caContent":"","keyContent":"","certContent":""}},"FaasInstance":{}}`) + SetState(conf) + _ = json.Unmarshal(conf, controllerState) + tests := []struct { + name string + want *ControllerState + wantErr bool + }{ + { + name: "get state", + want: controllerState, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetState() + if reflect.DeepEqual(got, tt.want) { + t.Errorf("GetState() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_updateState(t *testing.T) { + convey.Convey("updateState", t, func() { + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyFunc(etcd3.GetRouterEtcdClient, func() *etcd3.EtcdClient { + return &etcd3.EtcdClient{} + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&state.Queue{}), "SaveState", + func(q *state.Queue, state []byte, key string) error { + return nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(controllerHandlerQueue), "GetState", + func(q *state.Queue, key string) ([]byte, error) { + return nil, fmt.Errorf("etcd not init") + }).Reset() + config := types.Config{} + controllerState.FaaSControllerConfig = config + convey.Convey("nil controllerHandlerQueue", func() { + updateState(types.Config{ManagerInstanceNum: 1}) + convey.So(GetState().FaaSControllerConfig.ManagerInstanceNum, convey.ShouldEqual, 0) + }) + InitState(nil) + convey.Convey("config", func() { + method := gomonkey.ApplyMethod(reflect.TypeOf(&state.Queue{}), "SaveState", func(_ *state.Queue, state []byte) error { + return nil + }) + defer method.Reset() + updateState(types.Config{ManagerInstanceNum: 1}) + + state := GetState() + convey.So(state.FaaSControllerConfig.ManagerInstanceNum, convey.ShouldEqual, 1) + }) + + convey.Convey("string", func() { + updateState("instanceID", "wrong tag") + state := GetState() + convey.So(state.FaaSControllerConfig.ManagerInstanceNum, convey.ShouldEqual, 0) + + Update("instanceID", types.StateUpdate, types.FaasFrontendInstanceState) + time.Sleep(50 * time.Millisecond) + convey.So(len(GetState().FaasInstance[types.FaasFrontendInstanceState]), convey.ShouldEqual, 1) + + Update("instanceID", types.StateDelete, types.FaasFrontendInstanceState) + time.Sleep(50 * time.Millisecond) + convey.So(len(GetState().FaasInstance[types.FaasFrontendInstanceState]), convey.ShouldEqual, 0) + + Update("instanceID", "wrong opt", types.FaasFrontendInstanceState) + time.Sleep(50 * time.Millisecond) + convey.So(len(GetState().FaasInstance[types.FaasFrontendInstanceState]), convey.ShouldEqual, 0) + Update(123) + }) + }) + +} + +func TestGetStateByte(t *testing.T) { + convey.Convey("GetStateByte", t, func() { + stateBytes := []byte("123") + defer gomonkey.ApplyFunc(state.NewStateQueue, func(size int) *state.Queue { + return &state.Queue{} + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(&state.Queue{}), "GetState", + func(q *state.Queue, key string) ([]byte, error) { + return stateBytes, nil + }).Reset() + InitState(nil) + stateByte, err := GetStateByte() + convey.So(err, convey.ShouldBeNil) + convey.So(string(stateByte), convey.ShouldEqual, "123") + }) +} + +func TestDeleteStateByte(t *testing.T) { + convey.Convey("delete state byte test", t, func() { + p := gomonkey.ApplyFunc(state.NewStateQueue, func(size int) *state.Queue { + return &state.Queue{} + }) + defer p.Reset() + p2 := gomonkey.ApplyFunc((*state.Queue).DeleteState, func(_ *state.Queue, key string) error { + return nil + }) + defer p2.Reset() + InitState(nil) + err := DeleteStateByte() + convey.So(err, convey.ShouldBeNil) + }) +} diff --git a/yuanrong/pkg/system_function_controller/types/types.go b/yuanrong/pkg/system_function_controller/types/types.go new file mode 100644 index 0000000..5af9862 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/types/types.go @@ -0,0 +1,282 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types - +package types + +import ( + "context" + "sync" + "time" + + "yuanrong.org/kernel/runtime/libruntime/api" + + "yuanrong/pkg/common/faas_common/alarm" + "yuanrong/pkg/common/faas_common/config" + "yuanrong/pkg/common/faas_common/crypto" + "yuanrong/pkg/common/faas_common/etcd3" + "yuanrong/pkg/common/faas_common/sts/raw" + "yuanrong/pkg/common/faas_common/types" + faasfrontendconf "yuanrong/pkg/frontend/types" + faasschedulerconf "yuanrong/pkg/functionscaler/types" +) + +// Config defines FaaS Controller configuration +type Config struct { + ManagerInstanceNum int `json:"managerInstanceNum"` + EnableRetry bool `json:"enableRetry"` + NameSpace string `json:"nameSpace"` + ClusterID string `json:"clusterID" valid:"optional"` + ClusterName string `json:"clusterName" valid:"optional"` + RegionName string `json:"regionName" valid:"optional"` + AlarmConfig alarm.Config `json:"alarmConfig" valid:"optional"` + RouterETCD etcd3.EtcdConfig `json:"routerEtcd" valid:"required"` + MetaETCD etcd3.EtcdConfig `json:"metaEtcd"` + TLSConfig config.TLSConfig `json:"tlsConfig"` + RawStsConfig raw.StsConfig `json:"rawStsConfig,omitempty"` + SccConfig crypto.SccConfig `json:"sccConfig"` + SchedulerExclusivity []string `json:"schedulerExclusivity" valid:"optional"` +} + +// SchedulerConfig defines configuration faas scheduler needs +type SchedulerConfig struct { + faasschedulerconf.Configuration + SchedulerNum int `json:"schedulerNum"` +} + +// FrontendConfig defines configuration faas frontend needs +type FrontendConfig struct { + faasfrontendconf.Config +} + +// ManagerConfig defines configuration faas manager needs +type ManagerConfig struct { + CPU float64 `json:"cpu"` + Memory float64 `json:"memory"` + Image string `json:"image"` + Version string `json:"version"` + NodeSelector map[string]string `json:"nodeSelector,omitempty"` + ManagerInstanceNum int `json:"managerInstanceNum"` + LeaseRenewMinute int `json:"leaseRenewMinute" valid:"optional"` + RouterEtcd etcd3.EtcdConfig `json:"routerEtcd" valid:"optional"` + MetaEtcd etcd3.EtcdConfig `json:"metaEtcd" valid:"optional"` + SccConfig crypto.SccConfig `json:"sccConfig" valid:"optional"` + AlarmConfig alarm.Config `json:"alarmConfig" valid:"optional"` + Affinity string `json:"affinity"` + NodeAffinity string `json:"nodeAffinity" valid:"optional"` + NodeAffinityPolicy string `json:"nodeAffinityPolicy" valid:"optional"` +} + +const ( + // CreatedTimeout max timeout of creating system function, kernel timeouts 120s + CreatedTimeout = 15*time.Minute - 10*time.Second + // KillSignalVal signal of kill + KillSignalVal = 1 + // SyncKillSignalVal Synchronize signal of kill + SyncKillSignalVal = 3 + // PreserveMetaKillSignalVal Preserve instance/app MeataData after sending the signal of kill + PreserveMetaKillSignalVal = 5 + // FaaSSchedulerFunctionKey etcd key of faaS scheduler + FaaSSchedulerFunctionKey = "0/0-system-faasscheduler/$latest" + // FaaSSchedulerPrefixKey prefix etcd key of faaS scheduler + FaaSSchedulerPrefixKey = "/sn/instance/business/yrk/tenant/0/function/" + + "0-system-faasscheduler/version" + // FasSFrontendFunctionKey etcd key of faaS frontend + FasSFrontendFunctionKey = "0/0-system-faasfrontend/$latest" + // FasSFrontendPrefixKey prefix etcd key of faaS frontend + FasSFrontendPrefixKey = "/sn/instance/business/yrk/tenant/0/function/" + + "0-system-faasfrontend/version/" + // FasSManagerFunctionKey etcd key of faaS manager + FasSManagerFunctionKey = "0/0-system-faasmanager/$latest" + // FasSManagerPrefixKey prefix etcd key of faaS manager + FasSManagerPrefixKey = "/sn/instance/business/yrk/tenant/0/function/" + + "0-system-faasmanager/version/$latest/" +) + +const ( + // SubEventTypeUpdate is update type of subscribe event + SubEventTypeUpdate EventType = "update" + // SubEventTypeDelete is delete type of subscribe event + SubEventTypeDelete EventType = "delete" + // SubEventTypeRecover is recover type of subscribe event + SubEventTypeRecover EventType = "recover" +) + +const ( + // EventKindInvalid is the wrong kind of function registry + EventKindInvalid EventKind = iota + // EventKindFrontend is the type of frontend registry + EventKindFrontend + // EventKindScheduler is the type of scheduler registry + EventKindScheduler + // EventKindManager is the type of function registry + EventKindManager +) + +type ( + // EventType defines registry event type + EventType string + // EventKind defines registry event kind + EventKind uint8 +) + +// SubEvent contains event published to subscribers +type SubEvent struct { + EventType + EventKind + EventMsg interface{} +} + +// SystemFunctionCreator is the creator interface of system function +type SystemFunctionCreator interface { + CreateExpectedInstanceCount(ctx context.Context) error + CreateMultiInstances(ctx context.Context, count int) error + CreateWithRetry(ctx context.Context, args []*api.Arg, extraParams *types.ExtraParams) error + CreateInstance(ctx context.Context, function string, args []*api.Arg, extraParams *types.ExtraParams) string + RollingUpdate(ctx context.Context, event *ConfigChangeEvent) +} + +// SystemFunctionGetter is the getter interface of system function +type SystemFunctionGetter interface { + GetInstanceCountFromEtcd() map[string]struct{} + GetInstanceCache() map[string]*InstanceSpecification +} + +// SystemFunctionKiller is the kill interface of system function +type SystemFunctionKiller interface { + SyncKillAllInstance() + KillInstance(instanceID string) error +} + +// SystemFunctionRestorer is the recover interface of system function +type SystemFunctionRestorer interface { + RecoverInstance(info *InstanceSpecification) +} + +// SystemFunctionHandler is the handle interface of system function +type SystemFunctionHandler interface { + HandleInstanceUpdate(instanceSpec *InstanceSpecification) + HandleInstanceDelete(instanceSpec *InstanceSpecification) +} + +// SystemFunction is group by system function interfaces +type SystemFunction interface { + SystemFunctionCreator + SystemFunctionGetter + SystemFunctionKiller + SystemFunctionRestorer + SystemFunctionHandler +} + +// InstanceSpecification contains specification of an instance +type InstanceSpecification struct { + FuncCtx context.Context + CancelFunc context.CancelFunc + InstanceID string + InstanceSpecificationMeta InstanceSpecificationMeta +} + +// InstanceSpecificationMeta contains specification meta of a faas scheduler +type InstanceSpecificationMeta struct { + InstanceID string `json:"instanceID"` + RequestID string `json:"requestID"` + RuntimeID string `json:"runtimeID"` + RuntimeAddress string `json:"runtimeAddress"` + FunctionAgentID string `json:"functionAgentID"` + FunctionProxyID string `json:"functionProxyID"` + Function string `json:"function"` + RestartPolicy string `json:"restartPolicy"` + Resources Resources `json:"resources" valid:",optional"` + ScheduleOption ScheduleOption `json:"scheduleOption"` + CreateOptions map[string]string `json:"createOptions"` + StartTime string `json:"startTime"` + InstanceStatus InstanceStatus `json:"instanceStatus"` + Labels []string `json:"labels"` + JobID string `json:"jobID"` + SchedulerChain []string `json:"schedulerChain"` + Args []map[string]string `json:"args"` +} + +// Resources contains resource specification of a scheduler instance in etcd +type Resources struct { + Resources map[string]Resource `json:"resources"` +} + +// Resource is system function resource +type Resource struct { + name string `json:"name"` + Scalar Scalar `json:"scalar"` +} + +// Scalar is system function scalar +type Scalar struct { + Value int `json:"value"` + Limit int `json:"limit"` +} + +// ScheduleOption is system function scheduleOption +type ScheduleOption struct { + SchedPolicyName string `json:"schedPolicyName"` + Priority int `json:"priority"` + Affinity Affinity `json:"affinity"` +} + +// Affinity is system function Affinity +type Affinity struct { + InstanceAffinity InstanceAffinity `json:"instanceAffinity"` + InstanceAntiAffinity InstanceAffinity `json:"instanceAntiAffinity"` + NodeAffinity NodeAffinity `json:"nodeAffinity"` +} + +// NodeAffinity is system function nodeAffinity +type NodeAffinity struct { + Affinity map[string]string `json:"affinity"` +} + +// InstanceAffinity is system function instanceAffinity +type InstanceAffinity struct { + Affinity map[string]string `json:"affinity"` +} + +// InstanceStatus is system function InstanceStatus +type InstanceStatus struct { + Code int `json:"code"` + Msg string `json:"msg"` +} + +// ConfigChangeEvent config change event +type ConfigChangeEvent struct { + FrontendCfg *FrontendConfig + SchedulerCfg *SchedulerConfig + ManagerCfg *ManagerConfig + Error error + sync.WaitGroup + TraceID string +} + +const ( + // StateUpdate - + StateUpdate = "update" + // StateDelete - + StateDelete = "delete" + + // FaasFrontendInstanceState - + FaasFrontendInstanceState = "FaasFrontend" + // FaasSchedulerInstanceState - + FaasSchedulerInstanceState = "FaasScheduler" + // FaasManagerInstanceState - + FaasManagerInstanceState = "FaasManager" +) diff --git a/yuanrong/pkg/system_function_controller/utils/utils.go b/yuanrong/pkg/system_function_controller/utils/utils.go new file mode 100644 index 0000000..dc75256 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/utils/utils.go @@ -0,0 +1,139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package utils - +package utils + +import ( + "encoding/json" + "strings" + "time" + + "yuanrong/pkg/common/faas_common/constant" + "yuanrong/pkg/common/faas_common/logger/log" + "yuanrong/pkg/common/faas_common/utils" + "yuanrong/pkg/system_function_controller/types" +) + +const ( + // ResourcesCPU - + ResourcesCPU = "CPU" + // ResourcesMemory - + ResourcesMemory = "Memory" +) + +// ExtractInfoFromEtcdKey extracts infomation from etcdKey +func ExtractInfoFromEtcdKey(etcdKey string, index int) string { + items := strings.Split(etcdKey, constant.KeySeparator) + if len(items) != constant.ValidEtcdKeyLenForInstance { + return "" + } + if index > len(items)-1 { + log.GetLogger().Errorf("index is out of range") + return "" + } + return items[index] +} + +// ExtractInstanceIDFromEtcdKey will extract instanceID from etcdKey +func ExtractInstanceIDFromEtcdKey(etcdKey string) string { + items := strings.Split(etcdKey, constant.KeySeparator) + if len(items) != constant.ValidEtcdKeyLenForInstance { + return "" + } + return items[len(items)-1] +} + +// GetInstanceSpecFromEtcdValue will get schedulerSpec from etcd value +func GetInstanceSpecFromEtcdValue(etcdValue []byte) *types.InstanceSpecification { + specMeta := &types.InstanceSpecificationMeta{} + err := json.Unmarshal(etcdValue, specMeta) + if err != nil { + log.GetLogger().Errorf("failed to unmarshal etcd value to function specification %s", err.Error()) + return nil + } + spec := &types.InstanceSpecification{InstanceSpecificationMeta: *specMeta} + return spec +} + +// GetMinSleepTime - +func GetMinSleepTime(defaultSleepTime, maxSleepTime time.Duration) time.Duration { + if defaultSleepTime.Seconds() <= maxSleepTime.Seconds() { + return defaultSleepTime + } + return maxSleepTime +} + +// CheckNeedRecover - +func CheckNeedRecover(newInstanceMeta types.InstanceSpecificationMeta) bool { + // instance status from running to fatal + return newInstanceMeta.InstanceStatus.Code == int(constant.KernelInstanceStatusFatal) || + newInstanceMeta.InstanceStatus.Code == int(constant.KernelInstanceStatusEvicting) || + newInstanceMeta.InstanceStatus.Code == int(constant.KernelInstanceStatusScheduleFailed) || + newInstanceMeta.InstanceStatus.Code == int(constant.KernelInstanceStatusEvicted) +} + +// GetSchedulerConfigSignature - +func GetSchedulerConfigSignature(schedulerCfg *types.SchedulerConfig) string { + var cfg types.SchedulerConfig + err := utils.DeepCopyObj(schedulerCfg, &cfg) + if err != nil { + log.GetLogger().Warnf("Failed to copy scheduler config: %v", err) + return "" + } + cfg.SchedulerNum = 0 + cfgStr, err := json.Marshal(cfg) + if err != nil { + log.GetLogger().Errorf("Failed to marshal scheduler config: %v", err) + return "" + } + h := utils.FnvHash(string(cfgStr)) + return h +} + +// GetFrontendConfigSignature - +func GetFrontendConfigSignature(frontendCfg *types.FrontendConfig) string { + var cfg types.FrontendConfig + err := utils.DeepCopyObj(frontendCfg, &cfg) + if err != nil { + log.GetLogger().Warnf("Failed to copy frontend config: %v", err) + return "" + } + cfg.InstanceNum = 0 + cfgStr, err := json.Marshal(cfg) + if err != nil { + log.GetLogger().Errorf("Failed to marshal frontend config: %v", err) + return "" + } + return utils.FnvHash(string(cfgStr)) +} + +// GetManagerConfigSignature - +func GetManagerConfigSignature(managerCfg *types.ManagerConfig) string { + var cfg types.ManagerConfig + err := utils.DeepCopyObj(managerCfg, &cfg) + if err != nil { + log.GetLogger().Warnf("Failed to copy manager config: %v", err) + return "" + } + cfg.ManagerInstanceNum = 0 + cfgStr, err := json.Marshal(cfg) + if err != nil { + log.GetLogger().Errorf("Failed to marshal manager config: %v", err) + return "" + } + return utils.FnvHash(string(cfgStr)) +} diff --git a/yuanrong/pkg/system_function_controller/utils/utils_test.go b/yuanrong/pkg/system_function_controller/utils/utils_test.go new file mode 100644 index 0000000..7cfa772 --- /dev/null +++ b/yuanrong/pkg/system_function_controller/utils/utils_test.go @@ -0,0 +1,92 @@ +package utils + +import ( + "encoding/json" + "errors" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "yuanrong/pkg/system_function_controller/types" +) + +func TestCheckNeedRecover(t *testing.T) { + convey.Convey("checkNeedRecover", t, func() { + var newInstanceMeta types.InstanceSpecificationMeta + newInstanceMeta.InstanceStatus.Code = 6 + convey.So(CheckNeedRecover(newInstanceMeta), convey.ShouldBeTrue) + newInstanceMeta.InstanceStatus.Code = 7 + convey.So(CheckNeedRecover(newInstanceMeta), convey.ShouldBeTrue) + newInstanceMeta.InstanceStatus.Code = 10 + convey.So(CheckNeedRecover(newInstanceMeta), convey.ShouldBeTrue) + newInstanceMeta.InstanceStatus.Code = 5 + convey.So(CheckNeedRecover(newInstanceMeta), convey.ShouldBeFalse) + }) +} + +func TestExtractInfoFromEtcdKey(t *testing.T) { + convey.Convey("extractInfoFromEtcdKey", t, func() { + convey.So(ExtractInfoFromEtcdKey("/a/b/c", 10), convey.ShouldBeEmpty) + }) +} + +func TestGetMinSleepTime(t *testing.T) { + convey.Convey("getMinSleepTime", t, func() { + time1 := time.Duration(9) * time.Second + time2 := time.Duration(11) * time.Second + maxTime := time.Duration(10) * time.Second + convey.So(GetMinSleepTime(time1, maxTime), convey.ShouldEqual, time1) + convey.So(GetMinSleepTime(time2, maxTime), convey.ShouldEqual, maxTime) + }) +} + +func TestGetSchedulerConfigSignature(t *testing.T) { + convey.Convey("getSchedulerConfigSignature", t, func() { + outputs := []gomonkey.OutputCell{ + {Values: gomonkey.Params{nil, errors.New("err")}}, + {Values: gomonkey.Params{[]byte{'{', '}'}, nil}}, + {Values: gomonkey.Params{nil, errors.New("err")}}} + patch := gomonkey.ApplyFuncSeq(json.Marshal, outputs) + defer patch.Reset() + convey.So(GetSchedulerConfigSignature(nil), convey.ShouldBeEmpty) + faaSSchedulerConfig := &types.SchedulerConfig{} + convey.So(GetSchedulerConfigSignature(faaSSchedulerConfig), convey.ShouldBeEmpty) + }) +} + +func TestGetFrontendConfigSignature(t *testing.T) { + convey.Convey("getFrontendConfigSignature", t, func() { + outputs := []gomonkey.OutputCell{ + {Values: gomonkey.Params{nil, errors.New("err")}}, + {Values: gomonkey.Params{[]byte{'{', '}'}, nil}}, + {Values: gomonkey.Params{nil, errors.New("err")}}} + patch := gomonkey.ApplyFuncSeq(json.Marshal, outputs) + defer patch.Reset() + convey.So(GetFrontendConfigSignature(nil), convey.ShouldBeEmpty) + frontendConfig := &types.FrontendConfig{} + convey.So(GetFrontendConfigSignature(frontendConfig), convey.ShouldBeEmpty) + }) +} + +func TestGetManagerConfigSignature(t *testing.T) { + convey.Convey("getManagerConfigSignature", t, func() { + outputs := []gomonkey.OutputCell{ + {Values: gomonkey.Params{nil, errors.New("err")}}, + {Values: gomonkey.Params{[]byte{'{', '}'}, nil}}, + {Values: gomonkey.Params{nil, errors.New("err")}}} + patch := gomonkey.ApplyFuncSeq(json.Marshal, outputs) + defer patch.Reset() + convey.So(GetManagerConfigSignature(nil), convey.ShouldBeEmpty) + frontendConfig := &types.ManagerConfig{} + convey.So(GetManagerConfigSignature(frontendConfig), convey.ShouldBeEmpty) + }) +} + +func TestGetInstanceSpecFromEtcdValue(t *testing.T) { + convey.Convey("getInstanceSpecFromEtcdValue", t, func() { + spec := GetInstanceSpecFromEtcdValue([]byte("{\"instanceID\":\"bfc2c7a6-0f26-42b1-9f36-72c0e47b8daf\",\"requestID\":\"113f09375917969b00\",\"functionAgentID\":\"function-agent-72c0e47b8daf-500m-500mi-faasscheduler-dc310000ef\",\"functionProxyID\":\"dggphis190720\",\"function\":\"0/0-system-faasscheduler/$latest\",\"resources\":{\"resources\":{\"Memory\":{\"name\":\"Memory\",\"scalar\":{\"value\":500}},\"CPU\":{\"name\":\"CPU\",\"scalar\":{\"value\":500}}}},\"scheduleOption\":{\"schedPolicyName\":\"monopoly\",\"affinity\":{\"instanceAffinity\":{},\"resource\":{},\"instance\":{\"topologyKey\":\"node\"}},\"resourceSelector\":{\"resource.owner\":\"08d513ba-14f1-4cf0-b400-00000000009a\"},\"extension\":{\"schedule_policy\":\"monopoly\",\"DELEGATE_DIRECTORY_QUOTA\":\"512\"},\"range\":{}},\"createOptions\":{\"resource.owner\":\"system\",\"tenantId\":\"\",\"lifecycle\":\"detached\",\"DELEGATE_POD_LABELS\":\"{\\\"systemFuncName\\\":\\\"faasscheduler\\\"}\",\"RecoverRetryTimes\":\"0\",\"DELEGATE_RUNTIME_MANAGER\":\"{\\\"image\\\":\\\"\\\"}\",\"DELEGATE_DIRECTORY_QUOTA\":\"512\",\"DELEGATE_ENCRYPT\":\"{\\\"metaEtcdPwd\\\":\\\"\\\"}\",\"schedule_policy\":\"monopoly\",\"ConcurrentNum\":\"32\",\"DATA_AFFINITY_ENABLED\":\"false\",\"DELEGATE_NODE_AFFINITY\":\"{\\\"preferredDuringSchedulingIgnoredDuringExecution\\\":[{\\\"preference\\\":{\\\"matchExpressions\\\":[{\\\"key\\\":\\\"node-type\\\",\\\"operator\\\":\\\"In\\\",\\\"values\\\":[\\\"system\\\"]}]},\\\"weight\\\":1}]}\"},\"labels\":[\"faasscheduler\"],\"instanceStatus\":{\"code\":2,\"msg\":\"creating\"},\"jobID\":\"job-12345678\",\"schedulerChain\":[\"function-agent-72c0e47b8daf-500m-500mi-faasscheduler-dc310000ef\"],\"parentID\":\"0-system-faascontroller-0\",\"parentFunctionProxyAID\":\"dggphis190721-LocalSchedInstanceCtrlActor@10.28.83.232:22772\",\"storageType\":\"local\",\"scheduleTimes\":1,\"deployTimes\":1,\"args\":[{\"value\":\"eyJzY2VuYXJpbyI6IiIsImNwdSI6MTAwMCwibWVtb3J5Ijo0MDAwLCJwcmVkaWN0R3JvdXBXaW5kb3ciOjAsInNsYVF1b3RhIjoxMDAwLCJzY2FsZURvd25UaW1lIjo2MDAwMCwiYnVyc3RTY2FsZU51bSI6MTAwMCwibGVhc2VTcGFuIjoxMDAwLCJmdW5jdGlvbkxpbWl0UmF0ZSI6NTAwMCwicm91dGVyRXRjZCI6eyJzZXJ2ZXJzIjpbImRzLWNvcmUtZXRjZC5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsOjIzNzkiXSwidXNlciI6IiIsInBhc3N3b3JkIjoiIiwic3NsRW5hYmxlIjpmYWxzZSwiYXV0aFR5cGUiOiJOb2F1dGgiLCJ1c2VTZWNyZXQiOmZhbHNlLCJzZWNyZXROYW1lIjoiZXRjZC1jbGllbnQtc2VjcmV0IiwiQ2FGaWxlIjoiIiwiQ2VydEZpbGUiOiIiLCJLZXlGaWxlIjoiIiwiUGFzc3BocmFzZUZpbGUiOiIifSwibWV0YUV0Y2QiOnsic2VydmVycyI6WyJkcy1jb3JlLWV0Y2QuZGVmYXVsdC5zdmMuY2x1c3Rlci5sb2NhbDoyMzc5Il0sInVzZXIiOiIiLCJwYXNzd29yZCI6IiIsInNzbEVuYWJsZSI6ZmFsc2UsImF1dGhUeXBlIjoiTm9hdXRoIiwidXNlU2VjcmV0IjpmYWxzZSwic2VjcmV0TmFtZSI6ImV0Y2QtY2xpZW50LXNlY3JldCIsIkNhRmlsZSI6IiIsIkNlcnRGaWxlIjoiIiwiS2V5RmlsZSI6IiIsIlBhc3NwaHJhc2VGaWxlIjoiIn0sImRvY2tlclJvb3RQYXRoIjoiL3Zhci9saWIvZG9ja2VyIiwicmF3U3RzQ29uZmlnIjp7InNlbnNpdGl2ZUNvbmZpZ3MiOnsic2hhcmVLZXlzIjpudWxsfSwic2VydmVyQ29uZmlnIjp7InBhdGgiOiIvb3B0L2h1YXdlaS9jZXJ0cy9ITVNDbGllbnRDbG91ZEFjY2VsZXJhdGVTZXJ2aWNlL0hNU0NhYVNZdWFuUm9uZ1dvcmtlci9ITVNDYWFzWXVhblJvbmdXb3JrZXIuaW5pIn0sIm1nbXRTZXJ2ZXJDb25maWciOnt9fSwiY2x1c3RlcklEIjoiY2x1c3RlcjAwMSIsImNsdXN0ZXJOYW1lIjoiZHN3ZWJfY2NldHVyYm9fYmo0X2F1dG9fYXoxIiwiZGlza01vbml0b3JFbmFibGUiOmZhbHNlLCJyZWdpb25OYW1lIjoiYmVpamluZzQiLCJhbGFybUNvbmZpZyI6eyJlbmFibGVBbGFybSI6ZmFsc2UsImFsYXJtTG9nQ29uZmlnIjp7ImZpbGVwYXRoIjoiL29wdC9odWF3ZWkvbG9ncy9hbGFybXMiLCJsZXZlbCI6IkluZm8iLCJ0aWNrIjowLCJmaXJzdCI6MCwidGhlcmVhZnRlciI6MCwidHJhY2luZyI6ZmFsc2UsImRpc2FibGUiOmZhbHNlLCJzaW5nbGVzaXplIjo1MDAsInRocmVzaG9sZCI6M30sInhpYW5nWXVuRm91ckNvbmZpZyI6eyJzaXRlIjoiY25fZGV2X2RlZmF1bHQiLCJ0ZW5hbnRJRCI6IlQwMTQiLCJhcHBsaWNhdGlvbklEIjoiY29tLmh1YXdlaS5jbG91ZF9lbmhhbmNlX2RldmljZSIsInNlcnZpY2VJRCI6ImNvbS5odWF3ZWkuaG1zY29yZWNhbWVyYWNsb3VkZW5oYW5jZXNlcnZpY2UifSwibWluSW5zU3RhcnRJbnRlcnZhbCI6MTUsIm1pbkluc0NoZWNrSW50ZXJ2YWwiOjE1fSwiZXBoZW1lcmFsU3RvcmFnZSI6NTEyLCJob3N0YWxpYXNlc2hvc3RuYW1lIjpudWxsLCJmdW5jdGlvbkNvbmZpZyI6bnVsbCwibG9jYWxBdXRoIjp7ImFLZXkiOiIiLCJzS2V5IjoiIiwiZHVyYXRpb24iOjB9LCJjb25jdXJyZW50TnVtIjowLCJ2ZXJzaW9uIjoiIiwiaW1hZ2UiOiIiLCJuYW1lU3BhY2UiOiJkZWZhdWx0Iiwic2NjQ29uZmlnIjp7ImVuYWJsZSI6ZmFsc2UsInNlY3JldE5hbWUiOiJzY2Mta3Mtc2VjcmV0IiwiYWxnb3JpdGhtIjoiQUVTMjU2X0dDTSJ9LCJub2RlQWZmaW5pdHkiOiJ7XCJwcmVmZXJyZWREdXJpbmdTY2hlZHVsaW5nSWdub3JlZER1cmluZ0V4ZWN1dGlvblwiOlt7XCJwcmVmZXJlbmNlXCI6e1wibWF0Y2hFeHByZXNzaW9uc1wiOlt7XCJrZXlcIjpcIm5vZGUtdHlwZVwiLFwib3BlcmF0b3JcIjpcIkluXCIsXCJ2YWx1ZXNcIjpbXCJzeXN0ZW1cIl19XX0sXCJ3ZWlnaHRcIjoxfV19Iiwic2NoZWR1bGVyTnVtIjoyfQ==\"}],\"version\":\"2\",\"dataSystemHost\":\"7.185.111.125\",\"detached\":true,\"gracefulShutdownTime\":\"600\",\"tenantID\":\"0\",\"isSystemFunc\":true}")) + convey.So(spec, convey.ShouldNotBeNil) + }) +} diff --git a/yuanrong/proto/CMakeLists.txt b/yuanrong/proto/CMakeLists.txt new file mode 100644 index 0000000..9cd384a --- /dev/null +++ b/yuanrong/proto/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + +set(GRPC_PROTO_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/posix/common.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/core_service.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/runtime_rpc.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/runtime_service.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/affinity.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/inner_service.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/bus_service.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/message.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/resource.proto" + "${CMAKE_CURRENT_SOURCE_DIR}/posix/bus_adapter.proto") +GENERATE_GRPC_CPP(GRPCPB_SRCS GRPCPB_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/pb/posix + PROTO_FILES ${GRPC_PROTO_SRCS} + SOURCE_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/posix) +add_library(posix_pb STATIC ${GRPCPB_SRCS}) +set_target_properties(posix_pb PROPERTIES UNITY_BUILD ON) +set_target_properties(posix_pb PROPERTIES UNITY_BUILD_BATCH_SIZE 2) +add_dependencies(posix_pb protobuf grpc) +target_link_libraries( + posix_pb + PRIVATE + dl + ${protobuf_LIB} + ${grpcpp_LIB} + ${grpc_LIB} + ${gpr_LIB}) \ No newline at end of file diff --git a/yuanrong/proto/pb/message_pb.h b/yuanrong/proto/pb/message_pb.h new file mode 100644 index 0000000..6f560c2 --- /dev/null +++ b/yuanrong/proto/pb/message_pb.h @@ -0,0 +1,177 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MESSAGE_PB_H +#define MESSAGE_PB_H + +#include "posix/message.pb.h" + +namespace functionsystem::messages { +using ResourceView = ::messages::ResourceView; +using ScheduleRequest = ::messages::ScheduleRequest; +using PluginContext = ::messages::PluginContext; +using ScheduleResponse = ::messages::ScheduleResponse; +using ScheduleResult = ::messages::ScheduleResult; + +using StaticFunctionChangeRequest = ::messages::StaticFunctionChangeRequest; +using StaticFunctionChangeResponse= ::messages::StaticFunctionChangeResponse; + +using Register = ::messages::Register; +using Registered = ::messages::Registered; +using ScheduleTopology = ::messages::ScheduleTopology; +using NotifySchedAbnormalRequest = ::messages::NotifySchedAbnormalRequest; +using NotifySchedAbnormalResponse = ::messages::NotifySchedAbnormalResponse; + +using Register = ::messages::Register; +using Registered = ::messages::Registered; +using ScheduleTopology = ::messages::ScheduleTopology; +using NotifySchedAbnormalRequest = ::messages::NotifySchedAbnormalRequest; + +using NotifyWorkerStatusRequest = ::messages::NotifyWorkerStatusRequest; +using NotifyWorkerStatusResponse = ::messages::NotifyWorkerStatusResponse; +using UpdateNodeTaintRequest = ::messages::UpdateNodeTaintRequest; +using UpdateNodeTaintResponse = ::messages::UpdateNodeTaintResponse; + +using CreateAgentRequest = ::messages::CreateAgentRequest; +using CreateAgentResponse = ::messages::CreateAgentResponse; + +using DeployInstanceRequest = ::messages::DeployInstanceRequest; +using DeployInstanceResponse = ::messages::DeployInstanceResponse; +using StartInstanceRequest = ::messages::StartInstanceRequest; +using StartInstanceResponse = ::messages::StartInstanceResponse; +using Layer = ::messages::Layer; +using FuncDeploySpec = ::messages::FuncDeploySpec; +using FuncMount = ::messages::FuncMount; +using FuncMountUser = ::messages::FuncMountUser; +using FuncMountConfig = ::messages::FuncMountConfig; +using DeployRequest = ::messages::DeployRequest; +using KillInstanceRequest = ::messages::KillInstanceRequest; +using KillInstanceResponse = ::messages::KillInstanceResponse; +using UpdateCredRequest = ::messages::UpdateCredRequest; +using UpdateCredResponse = ::messages::UpdateCredResponse; + +using RegisterRuntimeManagerRequest = ::messages::RegisterRuntimeManagerRequest; +using RegisterRuntimeManagerResponse = ::messages::RegisterRuntimeManagerResponse; + +using RuntimeInstanceInfo = ::messages::RuntimeInstanceInfo; +using RuntimeConfig = ::messages::RuntimeConfig; +using DeploymentConfig = ::messages::DeploymentConfig; +using CodePackageThresholds = ::messages::CodePackageThresholds; + +using DeployRequest = ::messages::DeployRequest; +using DeployResult = ::messages::DeployResult; +using DeployDuration = ::messages::DeployDuration; + +using StopInstanceRequest = ::messages::StopInstanceRequest; +using StopInstanceResponse = ::messages::StopInstanceResponse; + +using UpdateResourcesRequest = ::messages::UpdateResourcesRequest; +using UpdateInstanceStatusRequest = ::messages::UpdateInstanceStatusRequest; +using UpdateInstanceStatusResponse = ::messages::UpdateInstanceStatusResponse; + +using QueryInstanceStatusRequest = ::messages::QueryInstanceStatusRequest; +using QueryInstanceStatusResponse = ::messages::QueryInstanceStatusResponse; + +using InstanceStatusInfo = ::messages::InstanceStatusInfo; + +using UpdateAgentStatusRequest = ::messages::UpdateAgentStatusRequest; +using UpdateAgentStatusResponse = ::messages::UpdateAgentStatusResponse; + +using UpdateRuntimeStatusRequest = ::messages::UpdateRuntimeStatusRequest; +using UpdateRuntimeStatusResponse = ::messages::UpdateRuntimeStatusResponse; + +using SchedulerNode = ::messages::SchedulerNode; +using FuncAgentRegisInfo = ::messages::FuncAgentRegisInfo; +using FuncAgentRegisInfoCollection = ::messages::FuncAgentRegisInfoCollection; + +using CreateAgentRequest = ::messages::CreateAgentRequest; +using CreateAgentResponse = ::messages::CreateAgentResponse; + +using CleanStatusRequest = ::messages::CleanStatusRequest; +using CleanStatusResponse = ::messages::CleanStatusResponse; + +using ForwardKillRequest = ::messages::ForwardKillRequest; +using ForwardKillResponse = ::messages::ForwardKillResponse; + +using EvictAgentRequest = ::messages::EvictAgentRequest; +using EvictAgentAck = ::messages::EvictAgentAck; +using EvictAgentResult = ::messages::EvictAgentResult; +using EvictAgentResultAck = ::messages::EvictAgentResultAck; +using UpdateLocalStatusRequest = ::messages::UpdateLocalStatusRequest; +using UpdateLocalStatusResponse = ::messages::UpdateLocalStatusResponse; + +using QueryAgentInfoRequest = ::messages::QueryAgentInfoRequest; +using QueryAgentInfoResponse = ::messages::QueryAgentInfoResponse; +using ExternalAgentInfo = ::messages::ExternalAgentInfo; +using ExternalQueryAgentInfoResponse = ::messages::ExternalQueryAgentInfoResponse; +using FunctionSystemStatus = ::messages::FunctionSystemStatus; +using RuleType = ::messages::NetworkIsolationRuleType; +using SetNetworkIsolationRequest = ::messages::SetNetworkIsolationRequest; +using SetNetworkIsolationResponse = ::messages::SetNetworkIsolationResponse; +using QueryResourcesInfoRequest = ::messages::QueryResourcesInfoRequest; +using QueryResourcesInfoResponse = ::messages::QueryResourcesInfoResponse; + +using MetaStoreRequest = ::messages::MetaStoreRequest; +using MetaStoreResponse = ::messages::MetaStoreResponse; + +using GetAndWatchResponse = ::messages::GetAndWatchResponse; + +using GroupInfo = ::messages::GroupInfo; +using GroupResponse = ::messages::GroupResponse; +using KillGroup = ::messages::KillGroup; +using KillGroupResponse = ::messages::KillGroupResponse; +using DeletePodRequest = ::messages::DeletePodRequest; +using DeletePodResponse = ::messages::DeletePodResponse; + +using ResourceGroupInfo = ::messages::ResourceGroupInfo; +using BundleInfo = ::messages::BundleInfo; +using BundleCollection = ::messages::BundleCollection; +using RemoveBundleRequest = ::messages::RemoveBundleRequest; +using RemoveBundleResponse = ::messages::RemoveBundleResponse; + +using GetTokenRequest = ::messages::GetTokenRequest; +using GetTokenResponse = ::messages::GetTokenResponse; + +using GetAKSKByTenantIDRequest = ::messages::GetAKSKByTenantIDRequest; +using GetAKSKByAKRequest = ::messages::GetAKSKByAKRequest; +using GetAKSKResponse = ::messages::GetAKSKResponse; + +using CancelType = ::messages::CancelType; +using CancelSchedule = ::messages::CancelSchedule; +using CancelScheduleResponse = ::messages::CancelScheduleResponse; + +using QueryInstancesInfoRequest = ::messages::QueryInstancesInfoRequest; +using QueryInstancesInfoResponse = ::messages::QueryInstancesInfoResponse; + +using DebugInstanceInfo = ::messages::DebugInstanceInfo; +using QueryDebugInstanceInfosRequest = ::messages::QueryDebugInstanceInfosRequest; +using QueryDebugInstanceInfosResponse = ::messages::QueryDebugInstanceInfosResponse; + +using ReportAgentAbnormalRequest = ::messages::ReportAgentAbnormalRequest; +using ReportAgentAbnormalResponse = ::messages::ReportAgentAbnormalResponse; + +using CheckInstanceStateRequest = ::messages::CheckInstanceStateRequest; +using CheckInstanceStateResponse = ::messages::CheckInstanceStateResponse; + +namespace MetaStore { +using PutRequest = ::messages::PutRequest; +using PutResponse = ::messages::PutResponse; +using Lease = ::messages::Lease; +using ObserveResponse = ::messages::ObserveResponse; +using ObserveCancelRequest = ::messages::ObserveCancelRequest; +using ForwardWatchRequest = ::messages::ForwardWatchRequest; +} // namespace MetaStore +} // namespace functionsystem::messages +#endif \ No newline at end of file diff --git a/yuanrong/proto/pb/posix_pb.h b/yuanrong/proto/pb/posix_pb.h new file mode 100644 index 0000000..905b0c3 --- /dev/null +++ b/yuanrong/proto/pb/posix_pb.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POSIX_PB_H +#define POSIX_PB_H +#include "posix/core_service.pb.h" +#include "posix/runtime_service.pb.h" +#include "posix/inner_service.pb.h" +#include "posix/runtime_rpc.pb.h" +#include "posix/common.pb.h" +#include "posix/affinity.pb.h" + +namespace functionsystem { + using InvokeRequest = core_service::InvokeRequest; + using InvokeResponse = core_service::InvokeResponse; + using CreateRequest = core_service::CreateRequest; + using CreateResponse = core_service::CreateResponse; + using CallResult = core_service::CallResult; + using CallResultAck = core_service::CallResultAck; + using StateSaveRequest = core_service::StateSaveRequest; + using StateSaveResponse = core_service::StateSaveResponse; + using StateLoadRequest = core_service::StateLoadRequest; + using StateLoadResponse = core_service::StateLoadResponse; + using KillRequest = core_service::KillRequest; + using KillResponse = core_service::KillResponse; + using ExitRequest = core_service::ExitRequest; + using ExitResponse = core_service::ExitResponse; + using GroupOptions = core_service::GroupOptions; + using CreateRequests = core_service::CreateRequests; + using CreateResponses = core_service::CreateResponses; + using SharedStreamMsg = std::shared_ptr; + using InstanceRange = core_service::InstanceRange; + using CreateResourceGroupRequest = core_service::CreateResourceGroupRequest; + using CreateResourceGroupResponse = core_service::CreateResourceGroupResponse; +} + +namespace internal { + using ForwardCallRequest = inner_service::ForwardCallRequest; + using ForwardCallResponse = inner_service::ForwardCallResponse; + using ForwardCallResultRequest = inner_service::ForwardCallResultRequest; + using ForwardCallResultResponse = inner_service::ForwardCallResultResponse; + using ForwardKillRequest = inner_service::ForwardKillRequest; + using ForwardKillResponse = inner_service::ForwardKillResponse; + using RouteCallRequest = inner_service::RouteCallRequest; +} + +namespace runtime { + using CallRequest = runtime_service::CallRequest; + using CallResponse = runtime_service::CallResponse; + using NotifyRequest = runtime_service::NotifyRequest; + using NotifyResponse = runtime_service::NotifyResponse; + using SignalRequest = runtime_service::SignalRequest; + using SignalResponse = runtime_service::SignalResponse; + using ShutdownRequest = runtime_service::ShutdownRequest; + using ShutdownResponse = runtime_service::ShutdownResponse; + using HeartbeatRequest = runtime_service::HeartbeatRequest; + using HeartbeatResponse = runtime_service::HeartbeatResponse; + using CheckpointRequest = runtime_service::CheckpointRequest; + using CheckpointResponse = runtime_service::CheckpointResponse; + using RecoverRequest = runtime_service::RecoverRequest; + using RecoverResponse = runtime_service::RecoverResponse; +} + +namespace common { + using ErrorCode = common::ErrorCode; + using HealthCheckCode = common::HealthCheckCode; + using Arg = common::Arg; + using HeteroDeviceInfo = common::HeteroDeviceInfo; + using ServerInfo = common::ServerInfo; + using FunctionGroupRunningInfo = common::FunctionGroupRunningInfo; + using ResourceGroupSpec = common::ResourceGroupSpec; +} + +namespace affinity { + using LabelIn = affinity::LabelIn; + using LabelNotIn = affinity::LabelNotIn; + using LabelExists = affinity::LabelExists; + using LabelDoesNotExist = affinity::LabelDoesNotExist; + using LabelOperator = affinity::LabelOperator; + using SubCondition = affinity::SubCondition; + using Condition = affinity::Condition; + using Selector = affinity::Selector; + using AffinityType = affinity::AffinityType; + using AffinityScope = affinity::AffinityScope; + using ResourceAffinity = affinity::ResourceAffinity; + using InstanceAffinity = affinity::InstanceAffinity; + using Affinity = affinity::Affinity; +} + +#endif \ No newline at end of file diff --git a/yuanrong/proto/posix/affinity.proto b/yuanrong/proto/posix/affinity.proto new file mode 100644 index 0000000..0109b68 --- /dev/null +++ b/yuanrong/proto/posix/affinity.proto @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package affinity; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/affinity;affinity"; + + +// IN: True if the candinate has label with desired label key and its label value is in the specified value set +message LabelIn { + repeated string values = 1; +} + +// NOT_IN: True if the candinate has label with desired label key and its label value is NOT in the specified value set +message LabelNotIn { + repeated string values = 1; +} + +// EXISTS: True if the candinate has label with desired label key +message LabelExists {} + +// DOES_NOT_EXIST: True if the candinate does't have any label with desired label key +message LabelDoesNotExist {} + +message LabelOperator { + oneof LabelOperator { + LabelIn in = 1; + LabelNotIn notIn = 2; + LabelExists exists = 3; + LabelDoesNotExist notExist = 4; + } +} + +message LabelExpression { + string key = 1; // label key + LabelOperator op = 2; +} + +message SubCondition { + repeated LabelExpression expressions = 1; // AND between expressions + int64 weight = 2; // weight of this sub condition for ranking +} + +message Condition { + repeated SubCondition subConditions = 1; // OR between sub conditions + bool orderPriority = 2; // in order of priority instead of weights rank +} + +message Selector { + Condition condition = 1; +} + +enum AffinityType { + PreferredAffinity = 0; + PreferredAntiAffinity = 1; + RequiredAffinity = 2; + RequiredAntiAffinity = 3; +} + +enum AffinityScope { + POD = 0; + NODE = 1; +} + +message ResourceAffinity { + Selector preferredAffinity = 1; + Selector preferredAntiAffinity = 2; + Selector requiredAffinity = 3; + Selector requiredAntiAffinity = 4; +} + +message InstanceAffinity { + Selector preferredAffinity = 1; + Selector preferredAntiAffinity = 2; + Selector requiredAffinity = 3; + Selector requiredAntiAffinity = 4; + AffinityScope scope = 5; +} + +message Affinity { + ResourceAffinity resource = 1; + InstanceAffinity instance = 2; +} diff --git a/yuanrong/proto/posix/bus_adapter.proto b/yuanrong/proto/posix/bus_adapter.proto new file mode 100644 index 0000000..b21d6f6 --- /dev/null +++ b/yuanrong/proto/posix/bus_adapter.proto @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package bus_adapter; +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/busadapter;busadapter"; + +message QueryTerminatingLongRunPod { + string nodeIP = 1; +} + +message WorkerRouter { + string workerID = 1; + string workerAddr = 2; + string nodeIP = 3; + string functionKeyWithResource = 4; + bool isMinInstance = 5; + bool isDirectFunc = 6; + bool hasInitializer = 7; +} + +message QueryTerminatingLongRunPodResponse { + map workers = 1; + string nodeIP = 2; +} + +message AsyncResponse { + int32 statusCode = 1; + bytes body = 2; + map headers = 3; +} + +message CallResponse { + // Success: 0, Failed for others + uint32 ErrorCode = 1; + // Message for error + string ErrorMessage = 2; + // response data + bytes RawData = 3; + // logs + string Logs = 4; + // request ID + uint64 RequestID = 5; + // this summary counts init duration + string Summary = 6; + string FunctionKey = 7; + string InstanceID = 8; + map Labels = 9; + // DcCallStack distributed convergence call stack + bytes DcCallStack = 10; +} + +message PosixCallResult { + // Success: 0, Failed for others + uint32 ErrorCode = 1; + // Message for error + string ErrorMessage = 2; + // response data + bytes RawData = 3; + // logs + string Logs = 4; + // request ID + string RequestID = 5; + // this summary counts init duration + string Summary = 6; + string FunctionKey = 7; + string InstanceID = 8; + map Labels = 9; + // DcCallStack distributed convergence call stack + bytes DcCallStack = 10; +} + +message Message { + enum MessageType { + REGISTER = 0; + CALL_RESPONSE = 2; + POSIX_CALL_RESPONSE = 42; + } + MessageType type = 1; + CallResponse callResponse = 4; + PosixCallResult posixCallResult = 41; +} + +message HttpResponse { + uint32 code = 1; + string message = 2; +} diff --git a/yuanrong/proto/posix/bus_service.proto b/yuanrong/proto/posix/bus_service.proto new file mode 100644 index 0000000..831dfbb --- /dev/null +++ b/yuanrong/proto/posix/bus_service.proto @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package bus_service; + +import "common.proto"; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/bus;bus"; + +// bus service provides APIs to runtime, +service BusService { + // notify bus to connect frontend + rpc DiscoverFrontend (DiscoverFrontendRequest) returns (DiscoverFrontendResponse) {} + // query instance info from frontend + rpc QueryInstance (QueryInstanceRequest) returns (QueryInstanceResponse) {} + // notify bus to connect driver + rpc DiscoverDriver (DiscoverDriverRequest) returns (DiscoverDriverResponse) {} +} + +message DiscoverDriverRequest { + string driverIP = 1; + string driverPort = 2; + string jobID = 3; + string instanceID = 4; + string functionName = 5; +} + +message DiscoverDriverResponse { + string serverVersion = 1; +} + +message DiscoverFrontendRequest { + string frontendIP = 1; + string frontendPort = 2; +} + +message DiscoverFrontendResponse {} + +message QueryInstanceRequest { + string instanceID = 1; +} + +message QueryInstanceResponse { + common.ErrorCode code = 1; + string message = 2; + string status = 3; +} diff --git a/yuanrong/proto/posix/common.proto b/yuanrong/proto/posix/common.proto new file mode 100644 index 0000000..db31530 --- /dev/null +++ b/yuanrong/proto/posix/common.proto @@ -0,0 +1,249 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package common; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/common;common"; + +message Arg { + enum ArgType { + VALUE = 0; + OBJECT_REF = 1; + } + ArgType type = 1; + bytes value = 2; + repeated string nested_refs = 3; +} + +enum ErrorCode { + ERR_NONE = 0; + ERR_PARAM_INVALID = 1001; + ERR_RESOURCE_NOT_ENOUGH = 1002; + ERR_INSTANCE_NOT_FOUND = 1003; + ERR_INSTANCE_DUPLICATED = 1004; + ERR_INVOKE_RATE_LIMITED = 1005; + ERR_RESOURCE_CONFIG_ERROR = 1006; + ERR_INSTANCE_EXITED = 1007; + ERR_EXTENSION_META_ERROR = 1008; + ERR_INSTANCE_SUB_HEALTH = 1009; + ERR_GROUP_SCHEDULE_FAILED = 1010; + ERR_GROUP_EXIT_TOGETHER = 1011; + ERR_CREATE_RATE_LIMITED = 1012; + ERR_INSTANCE_EVICTED = 1013; + ERR_AUTHORIZE_FAILED = 1014; + ERR_FUNCTION_META_NOT_FOUND = 1015; + ERR_INSTANCE_INFO_INVALID = 1016; + ERR_SCHEDULE_CANCELED = 1017; + ERR_SCHEDULE_PLUGIN_CONFIG = 1018; + ERR_SUB_STATE_INVALID = 1019; + ERR_USER_CODE_LOAD = 2001; + ERR_USER_FUNCTION_EXCEPTION = 2002; + ERR_REQUEST_BETWEEN_RUNTIME_BUS = 3001; + ERR_INNER_COMMUNICATION = 3002; + ERR_INNER_SYSTEM_ERROR = 3003; + ERR_DISCONNECT_FRONTEND_BUS = 3004; + ERR_ETCD_OPERATION_ERROR = 3005; + ERR_BUS_DISCONNECTION = 3006; + ERR_REDIS_OPERATION_ERROR = 3007; + ERR_K8S_UNAVAILABLE = 3008; + ERR_FUNCTION_AGENT_OPERATION_ERROR = 3009; + ERR_STATE_MACHINE_ERROR = 3010; + ERR_LOCAL_SCHEDULER_OPERATION_ERROR = 3011; + ERR_RUNTIME_MANAGER_OPERATION_ERROR = 3012; + ERR_INSTANCE_MANAGER_OPERATION_ERROR= 3013; + ERR_LOCAL_SCHEDULER_ABNORMAL = 3014; + ERR_DS_UNAVAILABLE = 3015; + ERR_NPU_FAULT_ERROR = 3016; +} + +enum HealthCheckCode { + HEALTHY = 0; + HEALTH_CHECK_FAILED = 1; + SUB_HEALTH = 2; +} + +message SmallObject { + string id = 1; + bytes value = 2; // sbuffer +} + +message StackTraceInfo { + string type = 1; // type of exception thrown by user code + string message = 2; // message in user code thrown exception + repeated StackTraceElement stackTraceElements = 3; // stack trace elements in user code thrown exception + string language = 4; // language of user code +} + +message StackTraceElement { + string className = 1; // class name of user code exception + string methodName = 2; // method name of user code exception + string fileName = 3; // file name of user code exception + int64 lineNumber = 4; // line number of user code exception + map extensions = 5; // extensions for different language +} + +message TLSConfig { + bool dsAuthEnable = 1; + bool dsEncryptEnable = 2; + bytes dsClientPublicKey = 3; + bytes dsClientPrivateKey = 4; + bytes dsServerPublicKey = 5; + bool serverAuthEnable = 6; + bytes rootCertData = 7; + bytes moduleCertData = 8; + bytes moduleKeyData = 9; + string token = 10; + bool enableServerMode = 11; + string serverNameOverride = 12; + string posixPort = 13; + string salt = 14; + string accessKey = 15; // component-level access key + string securityKey = 16; // component-level security key +} + +message HeteroDeviceInfo +{ + int64 deviceId = 1; + string deviceIp = 2; + int64 rankId = 3; +} + +message ServerInfo +{ + repeated HeteroDeviceInfo devices = 1; + string serverId = 2; +} + +message FunctionGroupRunningInfo +{ + repeated ServerInfo serverList = 1; + int64 instanceRankId = 2; + int64 worldSize = 3; + string deviceName = 4; +} + +// message used in unix domain socket +message SocketMessage { + string magicNumber = 1; // header info(magicNumber/version/packetType/packetID) used to check + string version = 2; + string packetType = 3; + string packetID = 4; + BusinessMessage businessMsg = 5; +} + +message BusinessMessage { + MessageType type = 1; + oneof payload { + FunctionLog functionLog = 2; + } +} + +// Used in domain socket between runtime and runtime manager +enum MessageType { + LogProcess = 0; +} + +// user function log, one kind of businessMessage payload +message FunctionLog { + string level = 1; // log level + string timestamp = 2; + string content = 3; // log content + string invokeID = 4; + string traceID = 5; + string stage = 6; // log stage + bool isStart = 7; // first log sign + bool isFinish = 8; // last log sign + string logType = 9; // "tail": return log to user when invoke finishes, "": do not return log + int32 errorCode = 10; + string functionInfo = 11; // user function version urn + string instanceId = 12; + string logSource = 13; // std or logger + string logGroupId = 14; // used in FG + string logStreamId = 15; // used in FG +} + +message RuntimeInfo { + string serverIpAddr = 1; + int32 serverPort = 2; + string route = 3; // for low-reliability instance, format is "ip:port" +} + +message Bundle { + map resources = 1; + // custom label for reserved unit + repeated string labels = 2; // "key:value" or "key2" +} + +enum GroupPolicy { + None = 0; + Spread = 1; + StrictSpread = 2; + Pack = 3; + StrictPack = 4; +} + +message ResourceGroupSpec { + string name = 1; + // indicated which rg is the resource group was created from, default is primary + string owner = 2; + // indicated which app submitted(job/app) + string appID = 3; + string tenantID = 4; + // multiple units which is reserved defined in a resource group + repeated Bundle bundles = 5; + message Option { + // resource group schedule priority + int64 priority = 1; + GroupPolicy groupPolicy = 2; + // etc: + // "lifetime" : "detached" + map extension = 100; + } + Option opt = 6; +} + +message InstanceTermination { + string instanceID = 1; +} + +message FunctionMasterObserve {} + +message FunctionMasterEvent { + string address = 1; +} + +message SubscriptionPayload { + oneof Content { + InstanceTermination instanceTermination = 1; // Subscribe to instance termination event + FunctionMasterObserve functionMaster = 2; // Subscribe to function-master election changed + } +} + +message UnsubscriptionPayload { + oneof Content { + InstanceTermination instanceTermination = 1; // Unsubscribe specified instance's termination event + FunctionMasterObserve functionMaster = 2; // UnSubscribe to function-master election changed + } +} + +message NotificationPayload { + oneof Content { + InstanceTermination instanceTermination = 1; // Instance termination event notification + FunctionMasterEvent functionMasterEvent = 2; // function-master election changed event + } +} \ No newline at end of file diff --git a/yuanrong/proto/posix/core_service.proto b/yuanrong/proto/posix/core_service.proto new file mode 100644 index 0000000..5f00406 --- /dev/null +++ b/yuanrong/proto/posix/core_service.proto @@ -0,0 +1,200 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package core_service; + +import "common.proto"; +import "affinity.proto"; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/core;core"; + +enum AffinityType { + PreferredAffinity = 0; + PreferredAntiAffinity = 1; + RequiredAffinity = 2; + RequiredAntiAffinity = 3; +} + +message SchedulingOptions { + int32 priority = 1; + map resources = 2; + map extension = 3; + // will deprecate in future + map affinity = 4; + affinity.Affinity scheduleAffinity = 5; + InstanceRange range = 6; + int64 scheduleTimeoutMs = 7; + bool preemptedAllowed = 8; + // indicated which rgroup submit to + string rGroupName = 9; +} + +message InstanceRange { + int32 min = 1; + int32 max = 2; + int32 step = 3; +} + +message CreateRequest { + string function = 1; + repeated common.Arg args = 2; + SchedulingOptions schedulingOps = 3; + string requestID = 4; + string traceID = 5; + repeated string labels = 6; // "key:value" or "key2" + // optional. if designated instanceID is not empty, the created instance id will be assigned designatedInstanceID + string designatedInstanceID = 7; + map createOptions = 8; +} + +message CreateResourceGroupRequest { + common.ResourceGroupSpec rGroupSpec = 1; + string requestID = 2; + string traceID = 3; +} + +message CreateResourceGroupResponse { + common.ErrorCode code = 1; + string message = 2; + string requestID = 3; +} + +message CreateResponse { + common.ErrorCode code = 1; + string message = 2; + string instanceID = 3; +} + +message GroupOptions { + // group schedule timeout (sec) + int64 timeout = 1; + // group alias name, this field cannot be used for life cycle management. + string groupName = 2; + bool sameRunningLifecycle = 3; + // indicated which rgroup submit to + string rGroupName = 4; + common.GroupPolicy groupPolicy = 5; +} + +// gang scheduling +message CreateRequests { + repeated CreateRequest requests = 1; + string tenantID = 2; + string requestID = 3; + string traceID = 4; + GroupOptions groupOpt = 5; +} + +// gang scheduling +message CreateResponses { + common.ErrorCode code = 1; + string message = 2; + repeated string instanceIDs = 3; + // used for life cycle management and the unique ID of the corresponding group. + // when you want to recycle a group, use signal 4 to send a kill request for the ID. + string groupID = 4; +} + +message InvokeOptions { + map customTag = 1; +} + +message InvokeRequest { + string function = 1; + repeated common.Arg args = 2; + string instanceID = 3; + string requestID = 4; + string traceID = 5; + repeated string returnObjectIDs = 6; + string spanID = 7; + InvokeOptions invokeOptions = 8; +} + +message InvokeResponse { + common.ErrorCode code = 1; + string message = 2; + string returnObjectID = 3; +} + +message CallResult { + common.ErrorCode code = 1; + string message = 2; + string instanceID = 3; + string requestID = 4; + repeated common.SmallObject smallObjects = 5; + repeated common.StackTraceInfo stackTraceInfos = 6; + common.RuntimeInfo runtimeInfo = 7; +} + +message CallResultAck { + common.ErrorCode code = 1; + string message = 2; +} + +message TerminateRequest { + string instanceID = 1; +} + +message TerminateResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ExitRequest { + common.ErrorCode code = 1; + string message = 2; +} + +message ExitResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message StateSaveRequest { + bytes state = 1; + string requestID = 2; +} + +message StateSaveResponse { + common.ErrorCode code = 1; + string message = 2; + string checkpointID = 3; +} + +message StateLoadRequest { + string checkpointID = 1; + string requestID = 2; +} + +message StateLoadResponse { + common.ErrorCode code = 1; + string message = 2; + bytes state = 3; +} + +message KillRequest { + string instanceID = 1; + int32 signal = 2; + bytes payload = 3; + string requestID = 4; +} + +message KillResponse { + common.ErrorCode code = 1; + string message = 2; +} diff --git a/yuanrong/proto/posix/inner_service.proto b/yuanrong/proto/posix/inner_service.proto new file mode 100644 index 0000000..7efcef1 --- /dev/null +++ b/yuanrong/proto/posix/inner_service.proto @@ -0,0 +1,118 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package inner_service; + +import "common.proto"; +import "core_service.proto"; +import "bus_service.proto"; +import "runtime_service.proto"; +import "resource.proto"; +import "runtime_rpc.proto"; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/inner;inner"; + +// Inner service provides APIs for bus to bus interaction +service InnerService { + // forward a recovery request to rebuild instance owned by other bus-proxy + rpc ForwardRecover (ForwardRecoverRequest) returns (ForwardRecoverResponse) {} + // notify the result of forward by other proxy request + rpc NotifyResult (NotifyRequest) returns (NotifyResponse) {} + // forward a killing request to signal instance owned by other bus-proxy + rpc ForwardKill (ForwardKillRequest) returns (ForwardKillResponse) {} + // forward a calling result request to other bus-proxy + rpc ForwardCallResult (ForwardCallResultRequest) returns (ForwardCallResultResponse) {} + // forward a call request to other bus-proxy + rpc ForwardCall (ForwardCallRequest) returns (ForwardCallResponse) {} + // forward a queryInstance request to other bus-proxy + rpc QueryInstance (bus_service.QueryInstanceRequest) returns (bus_service.QueryInstanceResponse) {} +} + +message NotifyRequest { + string requestID = 1; + common.ErrorCode code = 2; + string message = 3; +} + +message NotifyResponse {} + +message ForwardRecoverRequest { + string instanceID = 1; + string runtimeIP = 2; + string runtimePort = 3; + string runtimeID = 4; + string function = 5; +} + +message ForwardRecoverResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ForwardKillRequest { + string requestID = 1; + string srcInstanceID = 2; + core_service.KillRequest req = 3; + string instanceID = 4; + string runtimeIP = 5; + string runtimePort = 6; + string runtimeID = 7; + string instanceRequestID = 8; + string srcTenantID = 9; +} + +message ForwardKillResponse { + string requestID = 1; + common.ErrorCode code = 2; + string message = 3; +} + +// Used to forward the result of an initcall across nodes +message ForwardCallResultRequest { + core_service.CallResult req = 1; + string instanceID = 2; + string runtimeID = 3; + string functionProxyID = 4; + // Fast channel for the caller node to obtain the status update of the created instance. + resources.InstanceInfo readyInstance = 5; +} + +message ForwardCallResultResponse { + common.ErrorCode code = 1; + string message = 2; + string requestID = 3; + string instanceID = 4; +} + +message ForwardCallRequest { + runtime_service.CallRequest req = 1; + string instanceID = 2; + string srcIP = 3; + string srcNode = 4; +} + +message ForwardCallResponse { + common.ErrorCode code = 1; + string message = 2; + string requestID = 3; +} + +message RouteCallRequest { + runtime_rpc.StreamingMessage req = 1; + string instanceID = 2; +} \ No newline at end of file diff --git a/yuanrong/proto/posix/log_service.proto b/yuanrong/proto/posix/log_service.proto new file mode 100644 index 0000000..f74e156 --- /dev/null +++ b/yuanrong/proto/posix/log_service.proto @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package log_service; + +import "common.proto"; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/logservice;logservice"; + +service LogManagerService { + rpc Register (RegisterRequest) returns (RegisterResponse) {} + rpc ReportLog (ReportLogRequest) returns (ReportLogResponse) {} +} + +service LogCollectorService { + rpc ReadLog (ReadLogRequest) returns (stream ReadLogResponse) {} + rpc StopLogStream (StopLogStreamRequest) returns (StopLogStreamResponse) {} + rpc StartLogStream (StartLogStreamRequest) returns (StartLogStreamResponse) {} + rpc QueryLogStream (QueryLogStreamRequest) returns (QueryLogStreamResponse) {} +} + +message StartLogStreamRequest { + string streamName = 1; + LogItem item = 2; +} + +message StartLogStreamResponse { + int32 code = 1; + string message = 2; +} + +message StopLogStreamRequest { + string streamName = 1; +} + +message StopLogStreamResponse { + int32 code = 1; + string message = 2; +} + +message QueryLogStreamRequest {} + +message QueryLogStreamResponse { + int32 code = 1; + repeated string streams = 2; +} + +message RegisterRequest { + string collectorID = 1; + string address = 2; +} + +message RegisterResponse { + int32 code = 1; + string message = 2; +} + +enum LogTarget { + USER_STD = 0; + LIB_RUNTIME = 1; + RUNTIME_API = 2; +} + +message LogItem { + string filename = 1; + string collectorID = 2; + LogTarget target = 3; + string runtimeID = 4; // optional +} + +message ReadLogRequest { + LogItem item = 1; + uint32 startLine = 2; + uint32 endLine = 3; +} + +message ReadLogResponse { + int32 code = 1; + string message = 2; + bytes content = 3; +} + +message ReportLogRequest { + repeated LogItem items = 1; +} + +message ReportLogResponse { + int32 code = 1; + string message = 2; +} diff --git a/yuanrong/proto/posix/message.proto b/yuanrong/proto/posix/message.proto new file mode 100644 index 0000000..2e91219 --- /dev/null +++ b/yuanrong/proto/posix/message.proto @@ -0,0 +1,978 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +import "resource.proto"; +import "common.proto"; +import "core_service.proto"; + +package messages; +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/message;message"; + +message ResourceView {} + + +// common message +message ScheduleTopology { + message Scheduler { + string name = 1; + string address = 2; + } + Scheduler leader = 1; + repeated Scheduler members = 2; +} + +message Register { + string name = 1; + string address = 2; + resources.ResourceUnit resource = 3; + string message = 4; + // multiple resource view is add for primary and virtual + // 3 was reserve for former compatible + map resources = 5; +} + +message Registered { + int32 code = 1; + string message = 2; + ScheduleTopology topo = 3; +} + +message AffinityContext { + map scheduledScore = 1; + map scheduledResult = 2; + int64 maxScore = 3; + bool isTopDownScheduling = 4; +} + +message DefaultPluginContext { + map filterCtx = 1; +} + +message GroupScheduleContext { + string reserved = 1; +} + +message PluginContext { + oneof PluginContext { + AffinityContext affinityCtx = 1; + DefaultPluginContext defaultCtx = 2; + GroupScheduleContext groupSchedCtx = 3; + } +} + +message ScheduleRequest { + resources.InstanceInfo instance = 1; + // initRequest should be deprecated + bytes initRequest = 2; + string requestID = 3; + string traceID = 4; + map updateResources = 5; + string sourceName = 6; + uint32 scheduleRound = 7; + // context of scheduler result to avoid + // duplicate scheduling resource conflict or fragment. + // key is plugin name + map contexts = 8; + // Indicates whether the current request is a range schedule. + bool isInsRangeScheduler = 9; + // store the range schedule infos + RangeOptions rangeOpts = 10; + // candidate result of Scheduler + CandidateResult candidateResult = 11; + // domain send current revision to local so that local use revision to update resourceUnitChange + PullResourceRequest unitSnapshot = 12; +} + +message CandidateResult { + // store candidate resource unit id + string unitID = 1; +} + +message RangeOptions { + core_service.InstanceRange range = 1; + int32 curRangeInstanceNum = 2; +} + +message ScheduleResult +{ + string nodeId = 1; + repeated common.HeteroDeviceInfo devices = 2; +} + +message ScheduleResponse { + int32 code = 1; + string message = 2; + string requestID = 3; + string instanceID = 4; + map updateResources = 5; + // context of scheduler result to avoid + // duplicate scheduling resource conflict or fragment. + // key is plugin name + map contexts = 6; + ScheduleResult scheduleResult = 7; + string traceID = 8; +} + +message NotifySchedAbnormalRequest { + string schedName = 1; + string ip = 2; +} + +message NotifySchedAbnormalResponse { + string schedName = 1; + string ip = 2; +} + +message NotifyWorkerStatusRequest { + bool healthy = 1; + string workerIP = 2; +} + +message NotifyWorkerStatusResponse { + bool healthy = 1; + string workerIP = 2; +} + +message UpdateNodeTaintRequest { + string requestID = 1; + bool healthy = 2; + string ip = 3; + string key = 4; +} + +message UpdateNodeTaintResponse { + string requestID = 1; + int32 code = 2; + string message = 3; +} + +// common end + +// domain scheduler and scaler message begin +message CreateAgentRequest { + resources.InstanceInfo instanceInfo = 1; + map labels = 2; +} + +message CreateAgentResponse { + string requestID = 1; + int32 code = 2; + string message = 3; + map updatedCreateOptions = 4; +} + +message DeletePodRequest { + string requestID = 1; + string functionAgentID = 2; + string message = 3; +} + +message DeletePodResponse { + string requestID = 1; + int32 code = 2; + string message = 3; +} + +// domain scheduler and scaler message end + +// local scheduler and function agent message + +message Layer { + string appID = 1; + string bucketID = 2; + string objectID = 3; + string bucketURL = 4; + string sha256 = 5; + string hostName = 6; + string securityToken = 7; + string temporaryAccessKey = 8; + string temporarySecretKey = 9; + string sha512 = 10; +} + +// function meta information +message FuncDeploySpec { + string bucketID = 1; + string objectID = 2; + repeated Layer layers = 3; + string deployDir = 4; + string storageType = 5; + string token = 6; + string accessKey = 7; + string secretAccessKey = 8; + string bucketURL = 9; +} + +message FuncMountUser { + int32 userID = 1; + int32 groupID = 2; +} + +message FuncMount { + string mountType = 1; + string mountResource = 2; + string mountSharePath = 3; + string localMountPath = 4; + string status = 5; +} + +message FuncMountConfig { + FuncMountUser funcMountUser = 1; + repeated FuncMount funcMounts = 2; +} + +message DeployInstanceRequest { + string instanceID = 1; + string traceID = 2; + string requestID = 3; + string entryFile = 4; + string envKey = 5; + string envInfo = 6; + string encryptedUserData = 7; + string language = 8; + string codeSha256 = 9; + resources.Resources resources = 10; + FuncDeploySpec funcDeploySpec = 11; + map hookHandler = 12; + int32 instanceLevel = 13; + map createOptions = 14; + FuncMountConfig funcMountConfig = 15; + resources.ScheduleOption scheduleOption = 16; + bool enableServerMode = 17; + bool enableAuthServerCert = 18; + string serverRootCertData = 19; + string serverAuthToken = 20; // Tenant's token used together with salt. + string serverNameOverride = 21; + int64 gracefulShutdownTime = 22; + string cryptoAlgorithm = 23; + bool runtimeDsAuthEnable = 24; + bool runtimeDsEncryptEnable = 25; + string runtimeDsClientPublicKey = 26; + string runtimeDsClientPrivateKey = 27; + string runtimeDsServerPublicKey = 28; + string posixPort = 29; + string salt = 30; // Tenant's salt used together with serverAuthToken. + string accessKey = 31; // component-level access key + string securityKey = 32; // component-level security key + string codeSha512 = 33; + TenantCredentials tenantCredentials = 34; + string tenantID = 35; +} + +message TenantCredentials { + bytes accessKey = 1; + bytes secretKey = 2; + bytes dataKey = 3; + bool isCredential = 4; // is credential enabled +} + +message DeployInstanceResponse { + int32 code = 1; + string message = 2; + string requestID = 3; + string timeInfo = 4; + string instanceID = 5; + string runtimeID = 6; + string address = 7; + string cpuType = 8; + int64 pid = 9; +} + +message KillInstanceRequest { + string instanceID = 1; + string runtimeID = 2; + string requestID = 3; + string traceID = 4; + string storageType = 5; + bool isMonopoly = 6; +} + +message KillInstanceResponse { + int32 code = 1; + string message = 2; + string requestID = 3; + string instanceID = 4; +} + +message UpdateCredRequest { + string requestID = 1; + string instanceID = 2; + string runtimeID = 3; + string token = 4; + string salt = 5; + TenantCredentials tenantCredentials = 6; +} + +message UpdateCredResponse { + string requestID = 1; + int32 code = 2; + string message = 3; +} + +message UpdateResourcesRequest { + resources.ResourceUnit resourceUnit = 1; +} + +message InstanceStatusInfo { + string instanceID = 1; + int32 status = 2; + string requestID = 3; + string instanceMsg = 4; + int32 type = 5; +} + +message UpdateInstanceStatusRequest { + InstanceStatusInfo instanceStatusInfo = 1; + string requestID = 2; +} + +message UpdateInstanceStatusResponse { + int32 status = 1; + string message = 2; + string requestID = 3; +} + +message QueryInstanceStatusRequest { + string instanceID = 1; + string runtimeID = 2; + string requestID = 3; +} + +message QueryInstanceStatusResponse { + InstanceStatusInfo instanceStatusInfo = 1; + string requestID = 2; +} + +message UpdateAgentStatusRequest { + string requestID = 1; + int32 status = 2; + string message = 3; +} + +message UpdateAgentStatusResponse { + string requestID = 1; + int32 status = 2; + string message = 3; +} +// local scheduler and function agent end + +message RuntimeConfig { + TLSConfig tlsConfig = 1; + string entryfile = 2; + string language = 3; + map hookHandler = 4; + map userEnvs = 5; + resources.Resources resources = 6; + FuncMountConfig funcMountConfig = 7; + map posixEnvs = 8; + CreateSubDirectoryConfig subDirectoryConfig = 9; + string debugServerPort = 10; +} + +message CreateSubDirectoryConfig { + bool isEnable = 1; + string parentDirectory = 2; + int64 quota = 3; +} + +message TLSConfig { + bool dsAuthEnable = 1; // enable ds component auth + bool dsEncryptEnable = 2; // enable ds zmq encrypt + bytes dsClientPublicKey = 3; + bytes dsClientPrivateKey = 4; + bytes dsServerPublicKey = 5; + bool serverAuthEnable = 6; + bytes rootCertData = 7; + bytes moduleCertData = 8; + bytes moduleKeyData = 9; + string token = 10; // Tenant token used together with the salt. + bool enableServerMode = 11; + string serverNameOverride = 12; + string posixPort = 13; + string salt = 14; // Tenant salt used together with the token. + string accessKey = 15; // component-level access key + string securityKey = 16; // component-level security key + TenantCredentials tenantCredentials = 17; +} + +message DeploymentConfig { + string bucketID = 1; + string objectID = 2; + repeated Layer layers = 3; + string deployDir = 4; + string storageType = 5; + string sha256 = 6; + string hostName = 7; + string securityToken = 8; + string temporaryAccessKey = 9; + string temporarySecretKey = 10; + string sha512 = 11; + string bucketURL = 12; + map deployOptions = 13; +} + +message CodePackageThresholds { + int32 fileCountsMax = 1; + int32 zipFileSizeMaxMB = 2; + int32 unzipFileSizeMaxMB = 3; + int32 dirDepthMax = 4; + int32 codeAgingTime = 5; +} + +message DeployRequest { + RuntimeConfig runtimeConfig = 1; + DeploymentConfig deploymentConfig = 2; + string instanceID = 3; + string schedPolicyName = 4; +} + +message RuntimeInstanceInfo { + RuntimeConfig runtimeConfig = 1; + DeploymentConfig deploymentConfig = 2; + string instanceID = 3; + string runtimeID = 4; + string traceID = 5; + string requestID = 6; + string address = 7; + int64 gracefulShutdownTime = 8; +} + +message DeployDuration { + double deployFuncTime = 1; + double deployLayerTime = 2; +} + +message DeployResult { + int32 code = 1; + string message = 2; + DeployDuration deployDuration = 3; + string errorMessage = 4; + string runtimePkgDir = 5; + string entryFile = 6; +} + +message RegisterRuntimeManagerRequest { + string name = 1; + string address = 2; + string id = 3; + resources.ResourceUnit resourceUnit = 4; + map runtimeInstanceInfos = 5; +} + +message RegisterRuntimeManagerResponse { + int32 code = 1; + string message = 2; +} + +message StartInstanceRequest { + RuntimeInstanceInfo runtimeInstanceInfo = 1; + int32 type = 2; + resources.ScheduleOption scheduleOption = 3; + string logPrefix = 4; +} + +message StartRuntimeInstanceResponse { + string runtimeID = 1; + string address = 2; + string port = 3; + int64 pid = 4; + string cpuType = 5; +} + +message StartInstanceResponse { + int32 code = 1; + string message = 2; + string requestID = 3; + StartRuntimeInstanceResponse startRuntimeInstanceResponse = 4; +} + +message StopInstanceRequest { + string runtimeID = 1; + string requestID = 2; + string traceID = 3; + int32 type = 4; +} + +message StopInstanceResponse { + int32 code = 1; + string message = 2; + string runtimeID = 3; + string requestID = 4; + string traceID = 5; + string instanceID = 6; +} + +message SchedulerNode { + string name = 1; + string address = 2; + int32 level = 3; + repeated SchedulerNode children = 4; +} + +message UpdateRuntimeStatusRequest { + string requestID = 1; + int32 status = 2; + string message = 3; +} + +message UpdateRuntimeStatusResponse { + string requestID = 1; + int32 status = 2; + string message = 3; +} + +// used for function proxy to persistent function agent registration information +message FuncAgentRegisInfo { + string agentAIDName = 1; // function-agent AID name + string agentAddress = 2; // function-agent address + string runtimeMgrAID = 3; // runtime-manager AID + string runtimeMgrID = 4; // runtime-manager RandomID + int32 statusCode = 5; // SUCCESS = 1, FAILED = 0, registration status of function-agent and runtime-manager + uint32 evictTimeoutSec = 6; + map extensionInfo = 7; +} + +message FuncAgentRegisInfoCollection { + map funcAgentRegisInfoMap = 1; + int32 localStatus = 2; +} + +message CleanStatusRequest { + string name = 1; // RuntimeManagerID +} + +message CleanStatusResponse {} + +message ForwardKillRequest { + string requestID = 1; + resources.InstanceInfo instance = 2; + core_service.KillRequest req = 3; +} + +message ForwardKillResponse { + string requestID = 1; + int32 code = 2; + string message = 3; + string instanceID = 4; +} + +message PullResourceRequest { + uint64 version = 1; + string localViewInitTime = 2; +} + +message UpdateLocalStatusRequest { + uint32 status = 1; +} + +message UpdateLocalStatusResponse { + uint32 status = 1; + bool healthy = 2; +} + +message EvictAgentRequest { + string agentID = 1; + uint32 timeoutSec = 2; + string requestID = 3; + string localID = 4; + repeated string instances = 5; +} + +message EvictAgentAck { + string agentID = 1; + int32 code = 2; + string message = 3; + string requestID = 4; +} + +message EvictAgentResult { + string agentID = 1; + int32 code = 2; + string message = 3; + string requestID = 4; +} + +message EvictAgentResultAck { + string agentID = 1; + string requestID = 4; +} + +message QueryAgentInfoRequest { + string requestID = 1; +} + +message QueryAgentInfoResponse { + string requestID = 1; + repeated resources.AgentInfo agentInfos = 2; +} + +message QueryResourcesInfoRequest { + string requestID = 1; +} + +message QueryResourcesInfoResponse { + string requestID = 1; + resources.ResourceUnit resource = 2; +} + +message QueryInstancesInfoRequest { + string requestID = 1; +} + +message QueryInstancesInfoResponse { + string requestID = 1; + common.ErrorCode code = 2; + repeated resources.InstanceInfo instanceInfos = 3; +} + +message ExternalAgentInfo { + string id = 1; + string alias = 2; +} + +message ExternalQueryAgentInfoResponse { + repeated ExternalAgentInfo data = 1; +} + +message FunctionSystemStatus { + common.ErrorCode code = 1; + string message = 2; +} + +enum NetworkIsolationRuleType { + IPSET_ADD = 0; + IPSET_DELETE = 1; + IPSET_FLUSH = 2; + IPTABLES_COMMAND = 3; +} + +message SetNetworkIsolationRequest { + string requestID = 1; + NetworkIsolationRuleType ruleType = 2; + // IPs to add to podIp-whitelist when ruleType = IPSET_ADD + // IPs to delete from podIp-whitelist when ruleType = IPSET_DELETE + repeated string rules = 3; +} + +message SetNetworkIsolationResponse { + string requestID = 1; + int32 code = 2; + string message = 3; +} + +message MetaStoreRequest { + string requestID = 1; + bytes requestMsg = 2; + bool asyncBackup = 3; +} + +message MetaStoreResponse { + string responseID = 1; + bytes responseMsg = 2; + int32 status = 3; + string errorMsg = 4; +} + +message ForwardWatchRequest { + string requestID = 1; + bytes requestMsg = 2; + string originAID = 3; +} + +message GroupInfo { + string requestID = 1; + string traceID = 2; + string groupID = 3; + string parentID = 4; + string ownerProxy = 5; + core_service.GroupOptions groupOpts = 6; + repeated ScheduleRequest requests = 7; + int32 status = 8; + string message = 9; + bool insRangeScheduler = 10; + core_service.InstanceRange insRange = 11; + repeated ScheduleRequest rangeRequests = 12; + resources.CreateTarget target = 13; + string rGroupName = 14; +} + +message GroupResponse { + string requestID = 1; + string traceID = 2; + int32 code = 4; + string message = 5; + int32 rangeSuccessNum = 6; + map updateResources = 7; + map scheduleResults = 8; +} + +message KillGroup { + string srcInstanceID = 1; + string groupID = 2; + // master send request to local, local clear groupCtx + string groupRequestID = 3; +} + +message KillGroupResponse { + string groupID = 1; + int32 code = 2; + string message = 3; +} + +message CommonStatus { + int32 code = 1; + string message = 2; +} + +message BundleInfo { + string bundleID = 1; + string rGroupName = 2; + // upper resource group, empty if it belongs to top resource group + string parentRGroupName = 3; + string functionProxyID = 4; + string functionAgentID = 5; + string tenantID = 6; + // bundle resource capacity + resources.Resources resources = 7; + // bundle labels + repeated string labels = 8; + CommonStatus status = 9; + // indicate which resource unit this bundle belongs to + string parentId = 10; +} + +message BundleCollection { + // bundleID: bundleInfo + map bundles = 1; +} + +message RemoveBundleRequest { + string requestID = 1; + string srcInstanceID = 2; + string rGroupName = 3; + string tenantId = 4; +} + +message RemoveBundleResponse { + string requestID = 1; + string rGroupName = 2; + CommonStatus status = 3; +} + +message ResourceGroupInfo { + string name = 1; + // indicated which rg is the resource group was created from, default is primary + string owner = 2; + // indicated which app submitted(job/app) + string appID = 3; + string tenantID = 4; + // multiple units which is reserved defined in a resource group + repeated BundleInfo bundles = 5; + CommonStatus status = 6; + string parentID = 7; + string requestID = 8; + string traceID = 9; + message Option { + // resource group schedule priority + int64 priority = 1; + common.GroupPolicy groupPolicy = 2; + // etc: + // "lifetime" : "detached" + map extension = 100; + } + Option opt = 10; +} + +message PutRequest { + // Used to request deduplication. + string requestID = 1; + + bytes key = 2; + + bytes value = 3; + + int64 lease = 4; + + // if true, the response contains previous key-value. + bool prevKv = 5; + + bool asyncBackup = 6; +} + +message PutResponse { + // Used to request deduplication. + string requestID = 1; + + int64 revision = 2; + + bytes prevKv = 3; + + int32 status = 4; + + string errorMsg = 5; +} + +message Lease { + int64 id = 1; + + int64 ttl = 2; + + int64 expiry = 3; + + repeated string items = 4; +} + +message GetAndWatchResponse { + bytes getResponseMsg = 1; + bytes watchResponseMsg = 2; +} + +message ObserveCancelRequest { + int64 cancelObserveID = 1; +} + +message ObserveResponse { + string name = 1; + int64 observeID = 2; + bytes responseMsg = 3; + bool isCreate = 4; + bool isCancel = 5; + string cancelMsg = 6; +} + +message GetTokenRequest { + string requestID = 1; + string tenantID = 2; + bool isCreate = 3; +} + +message GetTokenResponse { + string requestID = 1; + string tenantID = 2; + string newToken = 3; + string oldToken = 4; + string salt = 5; + int32 code = 6; + string message = 7; +} + +message GetAKSKByTenantIDRequest { + string requestID = 1; + string tenantID = 2; + bool isCreate = 3; + bool isPermanentValid = 4; +} + +message GetAKSKByAKRequest { + string requestID = 1; + string accessKey = 2; +} + +message GetAKSKResponse { + string requestID = 1; + string tenantID = 2; + string newAccessKey = 3; + string newSecretKey = 4; + string newDataKey = 5; + string newExpiredTimeStamp = 6; + string oldAccessKey = 7; + string oldSecretKey = 8; + string oldDataKey = 9; + string oldExpiredTimeStamp = 10; + string role = 11; + int32 code = 12; + string message = 13; +} + +// Indicates the cancel type +enum CancelType { + REQUEST = 0; + JOB = 1; + PARENT = 2; + GROUP = 3; + FUNCTION = 4; +} + +message CancelSchedule { + CancelType type = 1; + string id = 2; + string reason = 3; + string msgID = 4; +} + +message CancelScheduleResponse { + string msgID = 1; + FunctionSystemStatus status = 2; +} + +message ResourceInfo { + string requestID = 1; + resources.ResourceUnit resource = 2; +} + +message DebugInstanceInfo { + string instanceID = 1; + int32 pid = 2; + string debugServer = 3; + string status = 4; + string language = 5; +} + +message QueryDebugInstanceInfosRequest { + string requestID = 1; +} + +message QueryDebugInstanceInfosResponse { + string requestID = 1; + int32 code = 2; + repeated DebugInstanceInfo debugInstanceInfos = 3; +} + +message ReportAgentAbnormalRequest { + string requestID = 1; + repeated string bundleIDs = 2; +} + +message ReportAgentAbnormalResponse { + string requestID = 1; + int32 code = 2; + string message = 3; +} + +message CheckInstanceStateRequest { + string requestID = 1; + string instanceID = 2; +} + +message CheckInstanceStateResponse { + string requestID = 1; + int32 code = 2; +} + +message StaticFunctionChangeRequest { + string requestID = 1; + string instanceID = 2; + int32 status = 4; +} + +message StaticFunctionChangeResponse { + int32 code = 1; + string message = 2; + string requestID = 3; + string instanceID = 4; +} \ No newline at end of file diff --git a/yuanrong/proto/posix/resource.proto b/yuanrong/proto/posix/resource.proto new file mode 100644 index 0000000..174fd62 --- /dev/null +++ b/yuanrong/proto/posix/resource.proto @@ -0,0 +1,567 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package resources; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/resource;resource"; + +import "common.proto"; +import "affinity.proto"; +import "core_service.proto"; + +message Value { + enum Type { + SCALAR = 0; + RANGES = 1; + SET = 2; + COUNTER = 3; + VECTORS = 4; + END = 5; + } + + message Scalar { + double value = 1; + double limit = 2; + } + + message Range { + uint64 begin = 1; + uint64 end = 2; + } + + message Ranges { + repeated Range range = 1; + } + + message Set { + repeated string items = 1; + } + + message Vectors { + message Vector { + repeated double values = 1; + } + + message Category { + // {uuid : Vector} + map vectors = 1; + } + + // CPU | Memory | HBM | Stream ... + map values = 1; + } + + message Counter { + map items = 1; + } +} + +message Volume { + enum VolumeSourceType { + UNKNOWN = 0; + LOCAL_DIR = 1; + CONFIG_MAP = 2; + EMPTY_DIR = 3; + ELARA = 4; + } + + enum VolumeMode { + RW = 0; + RO = 1; + } + + VolumeMode mode = 1; + VolumeSourceType sourceType = 2; + string hostPath = 3; + string containerPath = 4; + string configMapPath = 5; + string emptyDir = 6; + string elaraPath = 7; +} + +message DiskInfo { + Volume volume = 1; + string type = 2; + string devPath = 3; + string mountPath = 4; +} + +enum AffinityType { + PreferredAffinity = 0; + PreferredAntiAffinity = 1; + RequiredAffinity = 2; + RequiredAntiAffinity = 3; +} + +message NodeAffinity { + map affinity = 1; +} + +message InstanceAffinity { + map affinity = 1; +} + +message ObjAffinity { + affinity.Selector preferredAffinity = 1; +} + +message TenantAffinity { + affinity.Selector preferredAffinity = 1; + affinity.Selector requiredAntiAffinity = 2; +} + +message PreemptedAffinity { + affinity.Selector preferredAffinity = 1; + affinity.Selector preferredAntiAffinity = 2; +} + +message PendingAffinity { + repeated affinity.ResourceAffinity resources = 1; +} + +message ResourceGroupAffinity { + affinity.Selector requiredAffinity = 1; +} + +message GroupPolicyAffinity { + // for pack + affinity.Selector preferredAffinity = 1; + // for spread + affinity.Selector preferredAntiAffinity = 2; + // for strict spread + affinity.Selector requiredAntiAffinity = 3; +} + +message InnerSystemAffinity { + ObjAffinity data = 1; + TenantAffinity tenant = 2; + PreemptedAffinity preempt = 3; + // Used in fairness scheduling to filter resources similar to those required by pending requests. + PendingAffinity pending = 4; + ResourceGroupAffinity rgroup = 5; + GroupPolicyAffinity grouplb = 6; +} + +message Affinity { + NodeAffinity nodeAffinity = 1; + InstanceAffinity instanceAffinity = 2; + InstanceAffinity instanceAntiAffinity = 3; + affinity.ResourceAffinity resource = 4; + affinity.InstanceAffinity instance = 5; + InnerSystemAffinity inner = 6; +} + +message DiskContent { + string name = 1; // for dev name: vmachine-vda barametal-sda nvme0n1 + uint64 size = 2; // for disk size, like 40, unit: GB + string mountPoints = 3; // moutpoint path +} + +message Resource { + message Extension { + oneof Content { + DiskContent disk = 1; + } + } + + string name = 1; + + // Resource type, maybe scalar/ranges/set + Value.Type type = 2; + + // Resource of memory、cpu... + Value.Scalar scalar = 3; + + // Resource of ip address、network port... + Value.Ranges ranges = 4; + + // Resource of volume or disk + Value.Set set = 5; + + // Used for ordered resources, such as NUMA nodes, NPU devices, GPU devices. + Value.Vectors vectors = 6; + + string runtime = 7; + string driver = 8; + + DiskInfo disk = 9; + + // Resource info of NPU/GPU + // { "vendor": "huawei.com", + // "product_model", "Ascend910B4", + // "HBM": "1000,1000,1000" + // ... }: + map heterogeneousInfo = 10; + + // If the resource is expired, do delete when deduction. + // Used with Resource::vectors. + bool expired = 11; + + // for repeated resource type : Set Vectors + // The order of the array must correspond strictly to the index of the Set or Vectors. + repeated Extension extensions = 12; +} + +// resource map for resource common operater, such as +, - +message Resources { + map resources = 1; +} + +enum CreateTarget { + INSTANCE = 0; // indicate instance schedule + RESOURCE_GROUP = 1; // indicate resource group schedule +} + +message ScheduleOption { + string schedPolicyName = 1; + int32 priority = 2; + Affinity affinity = 3; + uint32 initCallTimeOut = 4; + + // resource_owner:uuid + map resourceSelector = 5; + // this field will be passed to k8s when create new pod + map nodeSelector = 6; + map extension = 7; + core_service.InstanceRange range = 8; + int64 scheduleTimeoutMs = 9; + bool preemptedAllowed = 10; + // while target is RESOURCE_GROUP, the instanceID should be formed like {rGroup}_{tenantID}_bundle_{index} + CreateTarget target = 11; + // which rGroup resource was specified + string rGroupName = 12; +} + +message InstanceStatus { + // code in state-machine + int32 code = 1; + // process exit code reported by runtime manager + int32 exitCode = 2; + // reason why the instance in this status, developer can understand + string msg = 3; + // process exit type + int32 type = 4; + // instance err code, defined in function system + int32 errCode = 5; +} + +message InstanceInfo { + // podname in K8S BCM, InstanceID in YuanRong system. + string instanceID = 1; + + // which request to create this instance + string requestID = 2; + + // hostname while be set to /etc/hostname when K8S BCM, runtime in YuanRong system + string runtimeID = 3; + + // runtime ip:port in YuanRong system + string runtimeAddress = 4; + + // functionAgentID in YuanRong system + string functionAgentID = 5; + + // K8S BCM is nodeName; + string functionProxyID = 6; + + // container image in K8S BCM, function name in YuanRong system + string function = 7; + + // the restart policy when instance running failed + string restartPolicy = 8; + Resources resources = 9; + + Resources actualUse = 10; + + // special option for scheduler + ScheduleOption scheduleOption = 11; + + // create options (eg.concurrency) + map createOptions = 12; + + // instance labels + repeated string labels = 13; + + // Instance start time + string startTime = 14; + + InstanceStatus instanceStatus = 15; + + string jobID = 16; + + // the topology is local->domain1->domain2 + repeated string schedulerChain = 17; + + // parentID is the instanceID of creator + string parentID = 18; + + // parentFunctionProxyAID is functionProxyAID of creator + string parentFunctionProxyAID = 19; + + // the storage type of the function corresponding to this instance. + string storageType = 20; + + // schedule retry times + int32 scheduleTimes = 21; + + // local redeploy times (in original local scheduler), default is 1 + int32 deployTimes = 22; + + // args in creating request + repeated common.Arg args = 23; + + bool isCheckpointed = 24; + + // version indicates the number of times that instance information is modified in etcd. + int64 version = 25; + + string dataSystemHost = 26; + + bool detached = 27; + + int64 gracefulShutdownTime = 28; + + string tenantID = 29; + + bool isSystemFunc = 30; + + string groupID = 31; + + // indicate an instance whether is a low reliability instance + bool lowReliability = 32; + // extension field + map extensions = 33; + // the instance was scheduled on this resource unit + string unitID = 34; +} + +message RouteInfo { + string instanceID = 1; + string runtimeAddress = 2; + string functionAgentID = 3; + string function = 4; + string functionProxyID = 5; + InstanceStatus instanceStatus = 6; + string jobID = 7; + string parentID = 8; + string requestID = 9; // Need while Update Instance + string tenantID = 10; // Need while iam + bool isSystemFunc = 11; // Need while iam + int64 version = 12; // indicates the number of times that instance information is modified in etcd +} + +message SystemInfo { + string architecture = 1; + string systemUUID = 2; + string machineID = 3; + string kernelVersion = 4; + string osImage = 5; + string agentVersion = 6; + string runtimeVersion = 7; + string bootId = 8; +} + +message AgentInfo { + string localID = 1; + string agentID = 2; + string alias = 3; +} + +message BucketIndex { + message Bucket { + message Info { + // the number of pod is shared by instances + int32 sharedNum = 1; + // the number of pod is monopolized by an instance + int32 monopolyNum = 2; + } + Info total = 1; + // key is scheduler name or function agent name + // value is Info + map allocatable = 2; + } + // key is mem of agent + map buckets = 1; +} + +message StatusChange { + uint32 status = 1; +} + +message InstanceChange { + enum InstanceChangeType { + ADD = 0; + DELETE = 1; + } + InstanceChangeType changeType = 1; + string instanceId = 2; + InstanceInfo instance = 3; +} + +message Modification { + StatusChange statusChange = 1; + repeated InstanceChange instanceChanges = 3; +} + +message Addition { + ResourceUnit resourceUnit = 1; +} + +message Deletion { +} + +message ResourceUnitChange { + string resourceUnitId = 1; + oneof Changed { + Addition addition = 2; + Deletion deletion = 3; + Modification modification = 4; + } +} + +message ResourceUnitChanges { + repeated ResourceUnitChange changes = 1; + uint64 startRevision = 2; + uint64 endRevision = 3; + string localId = 4; + string localViewInitTime = 5; +} + +message ResourceUnit { + // NodeName in K8S BCM, FunctionAgentID/DomainSchedulerID in YuanRong system + string id = 1; + + // Total Resource of this Unit, key is ResourceName + Resources capacity = 2; + + // Allocatable Resource of this Unit, key is ResourceName + Resources allocatable = 3; + + Resources actualUse = 4; + + // now only using in YuanRong system; key is FunctionAgentID/DomainSchedulerID + map fragment = 5; + + // pod in K8S BCM, instance in YuanRong system + map instances = 6; + + // nodeLabel(s) of resource unit + // map k is different label name + // map v is the occuring times counter for every different value + // eg. agent01:{x:y} agent02:{x:z} agent03:{x:z}, will be joined to {x:{y:1,z:2}} in upper level resource unit + // this representation supports add/sub operation, the key can be erased iff when counter decrease to 0 + map nodeLabels = 7; + + // node SystemInfo in K8S BCM + SystemInfo systemInfo = 8; + + // node support max instance num; + int32 maxInstanceNum = 9; + + // fragment index + // key is the cpu/mem proportion + // value is the index which including a map of mem-bucket + map bucketIndexs = 10; + + uint64 revision = 11; + // unit status. + // 0 - normal + // 1 - evicting, means the unit is unavailable + // 2 - recovering, means the unit is unavailable + // 3 - to be deleted by Scaler using K8S API, means the unit is unavailable + uint32 status = 12; + + // alias for the current resource unit; + // useful for users to evict resource + string alias = 13; + + // On domain, each ownerID of fragment is the specified local id + // On local, ownerID is real agent id. because of the fragment of bundle split from the agent id(1) is bundle id + // which is different from agent id. + string ownerId = 14; + + // Use the identifier at initialization as the initialization time of resourceunit + string viewInitTime = 15; +} + +enum InvokeType { + CreateInstance = 0; +} + +enum ApiType { + Actor = 0; + Faas = 1; + Posix = 2; + Serve = 3; +} + +enum LanguageType { + Cpp = 0; + Python =1; + Java = 2; + Golang = 3; + NodeJS = 4; + CSharp = 5; + Php = 6; +} + +message FunctionMeta { + string applicationName = 1; + string moduleName = 2; + string functionName = 3; + string className = 4; + LanguageType language = 5; + string codeID = 6; + string signature = 7; + ApiType apiType = 8; + string name = 9; // The designated actor name + string ns = 10; // The designated actor namespace + string functionID = 11; + string initializerCodeID = 12; + bool isAsync = 13; + bool isGenerator = 14; +} + +message FunctionID { + LanguageType language = 1; + string functionID = 2; +} + +message MetaConfig { + string jobID = 1; + repeated string codePaths = 2; + int64 recycleTime = 3; + int64 maxTaskInstanceNum = 4; + int64 maxConcurrencyCreateNum = 5; + bool enableMetrics = 6; + int64 threadPoolSize = 7; + repeated FunctionID functionIDs = 8; + string ns = 9; + repeated string schedulerInstanceIds = 10; + map customEnvs = 11; + string tenantID = 12; + int64 localThreadPoolSize = 13; + repeated string functionMasters = 14; + bool isLowReliabilityTask = 15; +} + +message MetaData { + InvokeType invokeType = 1; + FunctionMeta functionMeta = 2; + MetaConfig config = 3; +} diff --git a/yuanrong/proto/posix/runtime_rpc.proto b/yuanrong/proto/posix/runtime_rpc.proto new file mode 100644 index 0000000..e051db4 --- /dev/null +++ b/yuanrong/proto/posix/runtime_rpc.proto @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package runtime_rpc; + +import "core_service.proto"; +import "runtime_service.proto"; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb;api"; + +// RuntimeRPC provide bidirectional streaming RPC interface +service RuntimeRPC { + // build bidirection grpc communication channel, different message body type specify different api handler + rpc MessageStream (stream StreamingMessage) returns (stream StreamingMessage) {} +} + +message StreamingMessage { + string messageID = 1; + oneof body { + + // Create an instance for specify function + // handle by core + core_service.CreateRequest createReq = 2; + core_service.CreateResponse createRsp = 3; + + // invoke the created instance + // handle by core + core_service.InvokeRequest invokeReq = 4; + core_service.InvokeResponse invokeRsp = 5; + + // exit the created instance + // only support to be called by instance itself + // handle by core + core_service.ExitRequest exitReq = 6; + core_service.ExitResponse exitRsp = 7; + + // save state of the created instance + // handle by core + core_service.StateSaveRequest saveReq = 8; + core_service.StateSaveResponse saveRsp = 9; + + // load state of the created instance + // handle by core + core_service.StateLoadRequest loadReq = 10; + core_service.StateLoadResponse loadRsp = 11; + + // send the signal to instance or core + // 1 ~ 63: core defined signal + // 64 ~ 1024: custom runtime defined signal + // handle by core + core_service.KillRequest killReq = 12; + core_service.KillResponse killRsp = 13; + + // send call request result to sender + // handle by core + core_service.CallResult callResultReq = 14; + core_service.CallResultAck callResultAck = 15; + + // Call a method or init state of instance + // handle by runtime + runtime_service.CallRequest callReq = 16; + runtime_service.CallResponse callRsp = 17; + + // NotifyResult is applied to async notify result of create or invoke request invoked by runtime + // handle by runtime + runtime_service.NotifyRequest notifyReq = 18; + runtime_service.NotifyResponse notifyRsp = 19; + + // Checkpoint request a state to save for failure recovery and state migration + // handle by runtime + runtime_service.CheckpointRequest checkpointReq = 20; + runtime_service.CheckpointResponse checkpointRsp = 21; + + // Recover state + // handle by runtime + runtime_service.RecoverRequest recoverReq = 22; + runtime_service.RecoverResponse recoverRsp = 23; + + // request an instance to shutdown + // handle by runtime + runtime_service.ShutdownRequest shutdownReq = 24; + runtime_service.ShutdownResponse shutdownRsp = 25; + + // receive the signal send by other runtime or driver + // handle by runtime + runtime_service.SignalRequest signalReq = 26; + runtime_service.SignalResponse signalRsp = 27; + + // check whether the runtime is alive + // handle by runtime + runtime_service.HeartbeatRequest heartbeatReq = 28; + runtime_service.HeartbeatResponse heartbeatRsp = 29; + + // Create group instance for specify function + // handle by core + core_service.CreateRequests createReqs = 30; + core_service.CreateResponses createRsps = 31; + + // Create resource group to reserve multiple bundle of resource + core_service.CreateResourceGroupRequest rGroupReq = 32; + core_service.CreateResourceGroupResponse rGroupRsp = 33; + } + map metaData = 100; +} \ No newline at end of file diff --git a/yuanrong/proto/posix/runtime_service.proto b/yuanrong/proto/posix/runtime_service.proto new file mode 100644 index 0000000..e2bd3f1 --- /dev/null +++ b/yuanrong/proto/posix/runtime_service.proto @@ -0,0 +1,111 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package runtime_service; + +import "common.proto"; + +option go_package = "yuanrong/pkg/common/faas_common/grpc/pb/runtime;runtime"; + +message CallRequest { + string function = 1; + repeated common.Arg args = 2; + string traceID = 3; + string returnObjectID = 4; + // isCreate specify the request whether initialization or runtime invoke + bool isCreate = 5; + // senderID specify the caller identity + // while process done, it should be send back to core by CallResult.instanceID + string senderID = 6; + // while process done, it should be send back to core by CallResult.requestID + string requestID = 7; + repeated string returnObjectIDs = 8; + map createOptions = 9; + string spanID = 10; +} + +message CallResponse { + common.ErrorCode code = 1; + string message = 2; + +} + +message CheckpointRequest { + string checkpointID = 1; +} + +message CheckpointResponse { + common.ErrorCode code = 1; + string message = 2; + bytes state = 3; +} + +message RecoverRequest { + bytes state = 1; + map createOptions = 2; +} + +message RecoverResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message GracefulExitRequest { + uint64 gracePeriodSecond = 1; +} + +message GracefulExitResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message ShutdownRequest { + uint64 gracePeriodSecond = 1; +} + +message ShutdownResponse { + common.ErrorCode code = 1; + string message = 2; +} + +message NotifyRequest { + string requestID = 1; + common.ErrorCode code = 2; + string message = 3; + repeated common.SmallObject smallObjects = 4; + repeated common.StackTraceInfo stackTraceInfos = 5; + common.RuntimeInfo runtimeInfo = 7; +} + +message NotifyResponse {} + +message HeartbeatRequest {} + +message HeartbeatResponse { + common.HealthCheckCode code = 1; +} + +message SignalRequest { + int32 signal = 1; + bytes payload = 2; +} + +message SignalResponse { + common.ErrorCode code = 1; + string message = 2; +} \ No newline at end of file diff --git a/yuanrong/test/collector/test.sh b/yuanrong/test/collector/test.sh new file mode 100644 index 0000000..5a74030 --- /dev/null +++ b/yuanrong/test/collector/test.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +# global environment +CUR_DIR=$(dirname "$(readlink -f "$0")") +ROOT_PATH=${CUR_DIR}/../../ +SRC_PATH=${ROOT_PATH}/pkg/collector +CMD_PATH=${ROOT_PATH}/cmd/collector +OUTPUT_PATH=${CUR_DIR}/output + +# run go test and report +run_gocover_report() +{ + rm -rf "${OUTPUT_PATH}" + mkdir -p "${OUTPUT_PATH}" + + local proto_output=$(realpath ${ROOT_PATH}/..) + local posix_proto_path=$(realpath ${ROOT_PATH}/../posix/proto) + local fs_proto_path=${ROOT_PATH}/src/common/proto/posix/ + cp -f ${posix_proto_path}/*.proto ${fs_proto_path}/ + protoc --proto_path="${fs_proto_path}" --go_out="${proto_output}" --go-grpc_out="${proto_output}" "${fs_proto_path}"/*.proto + + cd ${CMD_PATH} + go test -v -gcflags=all=-l -covermode="${GOCOVER_MODE}" -coverpkg="./..." "./..." + + cd ${SRC_PATH} + go test -v -gcflags=all=-l -covermode="${GOCOVER_MODE}" -coverprofile="$OUTPUT_PATH/collector.cover" -coverpkg="./..." "./..." + + if [ $? -ne 0 ]; then + log_error "failed to go test collector" + exit 1 + fi + + # export llt coverage result + cd "$OUTPUT_PATH" + echo "mode: ${GOCOVER_MODE}" > coverage.out && cat ./*.cover | grep -v mode: | grep -v pb.go | sort -r | \ + awk '{if($1 != last) {print $0;last=$1}}' >> coverage.out + + gocov convert coverage.out > coverage.json + gocov report coverage.json > CoverResult.txt + gocov-html coverage.json > coverage.html +} + +run_gocover_report +exit 0 \ No newline at end of file diff --git a/yuanrong/test/common/test.sh b/yuanrong/test/common/test.sh new file mode 100644 index 0000000..b66410c --- /dev/null +++ b/yuanrong/test/common/test.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +# global environment +CUR_DIR=$(dirname "$(readlink -f "$0")") +ROOT_PATH=${CUR_DIR}/../../ +SRC_PATH=${ROOT_PATH}/pkg/common/faas_common +OUTPUT_PATH=${CUR_DIR}/output +echo LD_LIBRARY_PATH=$LD_LIBRARY_PATH + +# run go test and report +run_gocover_report() +{ + rm -rf "${OUTPUT_PATH}" + mkdir -p "${OUTPUT_PATH}" + + cd ${SRC_PATH} + go test -gcflags=all=-l -covermode="${GOCOVER_MODE}" -coverprofile="$OUTPUT_PATH/common.cover" -coverpkg="./..." "./..." + + if [ $? -ne 0 ]; then + log_error "failed to go test common" + exit 1 + fi + + # export llt coverage result + cd "$OUTPUT_PATH" + echo "mode: ${GOCOVER_MODE}" > coverage.out && cat ./*.cover | grep -v mode: | grep -v pb.go | sort -r | \ + awk '{if($1 != last) {print $0;last=$1}}' >> coverage.out + + gocov convert coverage.out > coverage.json + gocov report coverage.json > CoverResult.txt + gocov-html coverage.json > coverage.html +} + +run_gocover_report +exit 0 \ No newline at end of file diff --git a/yuanrong/test/dashboard/test.sh b/yuanrong/test/dashboard/test.sh new file mode 100644 index 0000000..6f3a9a8 --- /dev/null +++ b/yuanrong/test/dashboard/test.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +# global environment +CUR_DIR=$(dirname "$(readlink -f "$0")") +ROOT_PATH=${CUR_DIR}/../../ +SRC_PATH=${ROOT_PATH}/pkg/dashboard +CMD_PATH=${ROOT_PATH}/cmd/dashboard +OUTPUT_PATH=${CUR_DIR}/output + +# run go test and report +run_gocover_report() +{ + rm -rf "${OUTPUT_PATH}" + mkdir -p "${OUTPUT_PATH}" + + cd ${CMD_PATH} + go test -gcflags=all=-l -covermode="${GOCOVER_MODE}" -coverpkg="./..." "./..." + + cd ${SRC_PATH} + go test -gcflags=all=-l -covermode="${GOCOVER_MODE}" -coverprofile="$OUTPUT_PATH/dashboard.cover" -coverpkg="./..." "./..." + + if [ $? -ne 0 ]; then + log_error "failed to go test dashboard" + exit 1 + fi + + # export llt coverage result + cd "$OUTPUT_PATH" + echo "mode: ${GOCOVER_MODE}" > coverage.out && cat ./*.cover | grep -v mode: | grep -v pb.go | sort -r | \ + awk '{if($1 != last) {print $0;last=$1}}' >> coverage.out + + gocov convert coverage.out > coverage.json + gocov report coverage.json > CoverResult.txt + gocov-html coverage.json > coverage.html +} + +run_gocover_report +exit 0 \ No newline at end of file diff --git a/yuanrong/test/test.sh b/yuanrong/test/test.sh new file mode 100644 index 0000000..3372ae4 --- /dev/null +++ b/yuanrong/test/test.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +CUR_DIR=$(dirname "$(readlink -f "$0")") +PROJECT_DIR=$(cd "${CUR_DIR}/.."; pwd) +ROOT_PATH=$PROJECT_DIR + +# go module prepare +export GO111MODULE=on +export GONOSUMDB=* +export CGO_ENABLED=1 + +# resolve missing go.sum entry +go env -w "GOFLAGS"="-mod=mod" + +# coverage mode +# set: 每个语句是否执行? +# count: 每个语句执行了几次? +# atomic: 类似于count, 但表示的是并行程序中的精确计数 +export GOCOVER_MODE="set" + +# test module name +MODULE_LIST=(\ +"dashboard" +) + +. "${PROJECT_DIR}"/build/compile_functions.sh + +# $1: source file name, In the format of xxx.go +# $2: target file name, In the format of xxx_mock.go +function generate_mock() +{ + if ! mockgen -destination "$2" -source "$1" -package mock; then + log_error "Failed to generate mock file." + return 1; + fi +} +export -f generate_mock + +# create source code link, go cover report dependent on GOPATH src +link_source_code() +{ + rm -rf "${GOPATH}/pkg" + rm -rf "${GOPATH}/src/dashboard" + + mkdir -p "${GOPATH}"/src/ + ln -s "${ROOT_PATH}" "${GOPATH}"/src/dashboard +} + +link_source_code + +if [[ -z "${1}" ]]; then + for i in "${!MODULE_LIST[@]}"; + do + if ! sh -x "${CUR_DIR}/${MODULE_LIST[$i]}/test.sh"; then + echo "Failed to test ${MODULE_LIST[$i]}" + exit 1 + fi + echo "Succeed to test ${MODULE_LIST[$i]}" + done + echo "Succeed to test all module" +elif [[ "${MODULE_LIST[@]}" =~ "${1}" ]]; then + if ! sh -x "${CUR_DIR}/${1}/test.sh"; then + echo "Failed to test ${1}" + exit 1 + fi + echo "Succeed to test ${1}" +else + echo "Please input parameters 'module name: ${MODULE_LIST[@]}'" + exit 1 +fi + +exit 0 \ No newline at end of file -- Gitee From 13bbc2f8e923fa69a1e8fb784e8c2602b4d5255e Mon Sep 17 00:00:00 2001 From: mayuehit Date: Mon, 17 Nov 2025 22:00:28 +0800 Subject: [PATCH 2/3] sync ds code --- api/cpp/example/stream_example.cpp | 1 - api/cpp/include/yr/api/stream.h | 2 - api/cpp/src/stream_pubsub.cpp | 15 +---- api/cpp/src/stream_pubsub.h | 9 --- api/go/example/stream_example.go | 3 +- api/go/libruntime/api/types.go | 1 - api/go/libruntime/clibruntime/clibruntime.go | 10 ---- .../clibruntime/clibruntime_test.go | 6 -- api/go/libruntime/cpplibruntime/clibruntime.h | 1 - .../cpplibruntime/cpplibruntime.cpp | 7 --- .../cpplibruntime/mock/mock_cpplibruntime.cpp | 5 -- api/go/yr/stream.go | 7 --- api/go/yr/stream_test.go | 6 -- .../main/cpp/com_yuanrong_jni_Producer.cpp | 25 ++------- .../src/main/cpp/com_yuanrong_jni_Producer.h | 15 ++--- .../src/main/cpp/jni_types.cpp | 55 +++++++------------ .../function-common/src/main/cpp/jni_types.h | 25 ++++----- .../main/java/com/yuanrong/CreateParam.java | 12 ---- .../java/com/yuanrong/jni/JniProducer.java | 8 --- .../java/com/yuanrong/stream/Producer.java | 7 --- .../com/yuanrong/stream/ProducerImpl.java | 12 ---- .../com/yuanrong/stream/TestProducerImpl.java | 24 -------- api/python/yr/fnruntime.pyx | 11 ---- api/python/yr/runtime.py | 6 -- deploy/data_system/install.sh | 4 +- .../generator/stream_generator_notifier.cpp | 18 ------ .../objectstore/datasystem_object_store.cpp | 2 - .../streamstore/stream_producer_consumer.cpp | 8 --- .../streamstore/stream_producer_consumer.h | 2 - test/api/stream_pub_sub_test.cpp | 27 +-------- test/clibruntime/clibruntime_test.cpp | 7 +-- test/libruntime/generator_test.cpp | 6 -- test/libruntime/mock/mock_datasystem.h | 3 +- .../mock/mock_datasystem_client.cpp | 37 ++++--------- test/libruntime/stream_store_test.cpp | 6 +- 35 files changed, 68 insertions(+), 325 deletions(-) diff --git a/api/cpp/example/stream_example.cpp b/api/cpp/example/stream_example.cpp index 56918d8..ab39b5b 100644 --- a/api/cpp/example/stream_example.cpp +++ b/api/cpp/example/stream_example.cpp @@ -38,7 +38,6 @@ int main() std::string str = "hello"; YR::Element element((uint8_t *)(str.c_str()), str.size()); producer->Send(element); - producer->Flush(); //! [producer send] //! [consumer recv] // consumer receive data diff --git a/api/cpp/include/yr/api/stream.h b/api/cpp/include/yr/api/stream.h index 5a78dd4..70d1a67 100644 --- a/api/cpp/include/yr/api/stream.h +++ b/api/cpp/include/yr/api/stream.h @@ -190,8 +190,6 @@ public: virtual void Send(const Element &element, int64_t timeoutMs) = 0; - virtual void Flush() = 0; - virtual void Close() = 0; }; diff --git a/api/cpp/src/stream_pubsub.cpp b/api/cpp/src/stream_pubsub.cpp index dd3bca5..6f2ed18 100644 --- a/api/cpp/src/stream_pubsub.cpp +++ b/api/cpp/src/stream_pubsub.cpp @@ -43,16 +43,6 @@ void StreamProducer::Send(const Element &element, int64_t timeoutMs) } } -void StreamProducer::Flush() -{ - YR::Libruntime::ErrorInfo err = producer_->Flush(); - if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("Flush err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), - err.Msg()); - throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); - } -} - void StreamProducer::Close() { auto err = YR::Libruntime::LibruntimeManager::Instance().GetLibRuntime()->SetTraceId(traceId_); @@ -113,9 +103,8 @@ void StreamConsumer::Close() } err = consumer_->Close(); if (err.Code() != YR::Libruntime::ErrorCode::ERR_OK) { - YRLOG_ERROR("Close err: Code:{}, MCode:{}, Msg:{}", - fmt::underlying(err.Code()), - fmt::underlying(err.MCode()), err.Msg()); + YRLOG_ERROR("Close err: Code:{}, MCode:{}, Msg:{}", fmt::underlying(err.Code()), fmt::underlying(err.MCode()), + err.Msg()); throw YR::Exception(static_cast(err.Code()), static_cast(err.MCode()), err.Msg()); } } diff --git a/api/cpp/src/stream_pubsub.h b/api/cpp/src/stream_pubsub.h index 5cd5917..aef7068 100644 --- a/api/cpp/src/stream_pubsub.h +++ b/api/cpp/src/stream_pubsub.h @@ -53,15 +53,6 @@ public: */ void Send(const Element &element, int64_t timeoutMs); - /** - * @brief Manually flushes the buffer to make the data visible to consumers. - * @throws Exception - * - **4299**: producer failed to flush. - * - * @snippet{trimleft} stream_example.cpp producer send - */ - void Flush(); - /** * @brief Closes the producer, triggering an automatic flush of the buffer and indicating that the buffer will no * longer be used. Once closed, the producer cannot be used again. diff --git a/api/go/example/stream_example.go b/api/go/example/stream_example.go index 6a83b0c..4633888 100644 --- a/api/go/example/stream_example.go +++ b/api/go/example/stream_example.go @@ -52,8 +52,7 @@ func StreamExample() { } fmt.Println(producer.Send(data)) - fmt.Println(producer.Flush()) - + subDatas, err := consumer.Receive(1, 30000) if err != nil { fmt.Println("receive failed, err: ", err) diff --git a/api/go/libruntime/api/types.go b/api/go/libruntime/api/types.go index 17e4c31..4560981 100644 --- a/api/go/libruntime/api/types.go +++ b/api/go/libruntime/api/types.go @@ -62,7 +62,6 @@ type Element struct { type StreamProducer interface { Send(element Element) error SendWithTimeout(element Element, timeoutMs int64) error - Flush() error Close() error } diff --git a/api/go/libruntime/clibruntime/clibruntime.go b/api/go/libruntime/clibruntime/clibruntime.go index 29d5414..3839713 100644 --- a/api/go/libruntime/clibruntime/clibruntime.go +++ b/api/go/libruntime/clibruntime/clibruntime.go @@ -499,16 +499,6 @@ func (p *StreamProducerImpl) SendWithTimeout(element api.Element, timeoutMs int6 return nil } -// Flush ensure flush buffered data so that it is visible to the consumer. -func (p *StreamProducerImpl) Flush() error { - cErr := C.CProducerFlush(p.producer) - code := int(cErr.code) - if code != 0 { - return codeNotZeroErr(code, cErr, "stream producer flush: ") - } - return nil -} - // Close signals the producer to stop accepting new data and automatically flushes // any pending data in the buffer. Once closed, the producer is no longer available. func (p *StreamProducerImpl) Close() error { diff --git a/api/go/libruntime/clibruntime/clibruntime_test.go b/api/go/libruntime/clibruntime/clibruntime_test.go index 4c46c90..a12a375 100644 --- a/api/go/libruntime/clibruntime/clibruntime_test.go +++ b/api/go/libruntime/clibruntime/clibruntime_test.go @@ -406,12 +406,6 @@ func TestProducerSendAndFlush(t *testing.T) { convey.So(err, convey.ShouldBeNil) }, ) - convey.Convey( - "producer flush", func() { - err = producer.Flush() - convey.So(err, convey.ShouldBeNil) - }, - ) convey.Convey( "producer close", func() { err = producer.Close() diff --git a/api/go/libruntime/cpplibruntime/clibruntime.h b/api/go/libruntime/cpplibruntime/clibruntime.h index aa758ec..2a6c587 100644 --- a/api/go/libruntime/cpplibruntime/clibruntime.h +++ b/api/go/libruntime/cpplibruntime/clibruntime.h @@ -432,7 +432,6 @@ CErrorInfo CQueryGlobalConsumersNum(char *streamName, uint64_t *num); CErrorInfo CProducerSend(Producer_p producerPtr, uint8_t *ptr, uint64_t size, uint64_t id); CErrorInfo CProducerSendWithTimeout(Producer_p producerPtr, uint8_t *ptr, uint64_t size, uint64_t id, int64_t timeoutMs); -CErrorInfo CProducerFlush(Producer_p producerPtr); CErrorInfo CProducerClose(Producer_p producerPtr); CErrorInfo CConsumerReceive(Consumer_p consumerPtr, uint32_t timeoutMs, CElement **elements, uint64_t *count); CErrorInfo CConsumerReceiveExpectNum(Consumer_p consumerPtr, uint32_t expectNum, uint32_t timeoutMs, diff --git a/api/go/libruntime/cpplibruntime/cpplibruntime.cpp b/api/go/libruntime/cpplibruntime/cpplibruntime.cpp index 026017b..bcfb9c7 100644 --- a/api/go/libruntime/cpplibruntime/cpplibruntime.cpp +++ b/api/go/libruntime/cpplibruntime/cpplibruntime.cpp @@ -1335,13 +1335,6 @@ CErrorInfo CProducerSendWithTimeout(Producer_p producerPtr, uint8_t *ptr, uint64 return ErrorInfoToCError(err); } -CErrorInfo CProducerFlush(Producer_p producerPtr) -{ - auto producer = *reinterpret_cast *>(producerPtr); - auto err = producer->Flush(); - return ErrorInfoToCError(err); -} - CErrorInfo CProducerClose(Producer_p producerPtr) { auto producer = reinterpret_cast *>(producerPtr); diff --git a/api/go/libruntime/cpplibruntime/mock/mock_cpplibruntime.cpp b/api/go/libruntime/cpplibruntime/mock/mock_cpplibruntime.cpp index a478ae0..e0f2577 100644 --- a/api/go/libruntime/cpplibruntime/mock/mock_cpplibruntime.cpp +++ b/api/go/libruntime/cpplibruntime/mock/mock_cpplibruntime.cpp @@ -293,11 +293,6 @@ CErrorInfo CProducerSendWithTimeout(Producer_p producerPtr, uint8_t *ptr, uint64 return ErrorInfoToCError(ErrorInfo()); } -CErrorInfo CProducerFlush(Producer_p producerPtr) -{ - return ErrorInfoToCError(ErrorInfo()); -} - CErrorInfo CProducerClose(Producer_p producerPtr) { return ErrorInfoToCError(ErrorInfo()); diff --git a/api/go/yr/stream.go b/api/go/yr/stream.go index d3b37ce..9dbc979 100644 --- a/api/go/yr/stream.go +++ b/api/go/yr/stream.go @@ -69,13 +69,6 @@ func (producer *Producer) Send(data []byte) error { return producer.producer.Send(element) } -// Flush ensure flush buffered data so that it is visible to the consumer. -func (producer *Producer) Flush() error { - producer.mutex.Lock() - defer producer.mutex.Unlock() - return producer.producer.Flush() -} - // Close signals the producer to stop accepting new data and automatically flushes // any pending data in the buffer. Once closed, the producer is no longer available. func (producer *Producer) Close() error { diff --git a/api/go/yr/stream_test.go b/api/go/yr/stream_test.go index 8765f15..03fb309 100644 --- a/api/go/yr/stream_test.go +++ b/api/go/yr/stream_test.go @@ -79,12 +79,6 @@ func TestStream(t *testing.T) { convey.So(err, convey.ShouldBeNil) }, ) - convey.Convey( - "Flush success", func() { - err = producer.Flush() - convey.So(err, convey.ShouldBeNil) - }, - ) }, ) diff --git a/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.cpp b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.cpp index f1141b0..24ccb01 100644 --- a/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.cpp +++ b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.cpp @@ -28,9 +28,8 @@ extern "C" { #endif JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBufferDefaultTimeout(JNIEnv *env, jclass, - jlong handle, - jbyteArray bytes, - jlong len) + jlong handle, jbyteArray bytes, + jlong len) { auto producer = reinterpret_cast *>(handle); jbyte *bytekey = env->GetByteArrayElements(bytes, 0); @@ -51,8 +50,7 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBufferDefaul } JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBufferDefaultTimeout(JNIEnv *env, jclass, - jlong handle, - jobject buf) + jlong handle, jobject buf) { auto producer = reinterpret_cast *>(handle); auto body = env->GetDirectBufferAddress(buf); @@ -73,8 +71,7 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBufferDefa } JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBuffer(JNIEnv *env, jclass, jlong handle, - jbyteArray bytes, jlong len, - jint timeoutMs) + jbyteArray bytes, jlong len, jint timeoutMs) { auto producer = reinterpret_cast *>(handle); jbyte *bytekey = env->GetByteArrayElements(bytes, 0); @@ -95,7 +92,7 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBuffer(JNIEn } JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBuffer(JNIEnv *env, jclass, jlong handle, - jobject buf, jint timeoutMs) + jobject buf, jint timeoutMs) { auto producer = reinterpret_cast *>(handle); auto body = env->GetDirectBufferAddress(buf); @@ -115,18 +112,6 @@ JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBuffer(JNI return jerr; } -JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_flush(JNIEnv *env, jclass, jlong handle) -{ - auto producer = reinterpret_cast *>(handle); - auto err = (*producer)->Flush(); - jobject jerr = YR::jni::JNIErrorInfo::FromCc(env, err); - if (jerr == nullptr) { - YR::jni::JNILibruntimeException::ThrowNew(env, "failed to convert jerr when Producer_flush, get null"); - return nullptr; - } - return jerr; -} - JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_close(JNIEnv *env, jclass, jlong handle) { auto producer = reinterpret_cast *>(handle); diff --git a/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.h b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.h index 8b1d099..e7fbb29 100644 --- a/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.h +++ b/api/java/function-common/src/main/cpp/com_yuanrong_jni_Producer.h @@ -22,18 +22,15 @@ extern "C" { #endif JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBufferDefaultTimeout(JNIEnv *, jclass, jlong, - jbyteArray, jlong); + jbyteArray, jlong); -JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBufferDefaultTimeout(JNIEnv *, jclass, - jlong, jobject); +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBufferDefaultTimeout(JNIEnv *, jclass, jlong, + jobject); -JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBuffer(JNIEnv *, jclass, jlong, jbyteArray, - jlong, jint); +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendHeapBuffer(JNIEnv *, jclass, jlong, jbyteArray, jlong, + jint); -JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBuffer(JNIEnv *, jclass, jlong, jobject, - jint); - -JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_flush(JNIEnv *, jclass, jlong); +JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_sendDirectBuffer(JNIEnv *, jclass, jlong, jobject, jint); JNIEXPORT jobject JNICALL Java_com_yuanrong_jni_JniProducer_close(JNIEnv *, jclass, jlong); diff --git a/api/java/function-common/src/main/cpp/jni_types.cpp b/api/java/function-common/src/main/cpp/jni_types.cpp index a8b0e7e..f3469f4 100644 --- a/api/java/function-common/src/main/cpp/jni_types.cpp +++ b/api/java/function-common/src/main/cpp/jni_types.cpp @@ -110,8 +110,7 @@ void JNILibruntimeException::Throw(JNIEnv *env, const YR::Libruntime::ErrorCode const YR::Libruntime::ModuleCode &moduleCode, const std::string &msg) { jmethodID constructorId = env->GetMethodID( - clz_, "", - "(Lcom/yuanrong/errorcode/ErrorCode;Lcom/yuanrong/errorcode/ModuleCode;Ljava/lang/String;)V"); + clz_, "", "(Lcom/yuanrong/errorcode/ErrorCode;Lcom/yuanrong/errorcode/ModuleCode;Ljava/lang/String;)V"); jobject jerrorCode = JNIErrorCode::FromCc(env, errorCode); if (jerrorCode == nullptr) { @@ -372,8 +371,8 @@ void JNIInvokeType::Init(JNIEnv *env) { clz_ = LoadClass(env, "com/yuanrong/libruntime/generated/Libruntime$InvokeType"); jmGetNumber_ = GetJMethod(env, clz_, "getNumber", "()I"); - jmForNumber_ = GetStaticMethodID(env, clz_, "forNumber", - "(I)Lcom/yuanrong/libruntime/generated/Libruntime$InvokeType;"); + jmForNumber_ = + GetStaticMethodID(env, clz_, "forNumber", "(I)Lcom/yuanrong/libruntime/generated/Libruntime$InvokeType;"); } void JNIInvokeType::Recycle(JNIEnv *env) @@ -398,8 +397,8 @@ void JNILanguageType::Init(JNIEnv *env) { clz_ = LoadClass(env, "com/yuanrong/libruntime/generated/Libruntime$LanguageType"); jmGetNumber_ = GetJMethod(env, clz_, "getNumber", "()I"); - jmForNumber_ = GetStaticMethodID(env, clz_, "forNumber", - "(I)Lcom/yuanrong/libruntime/generated/Libruntime$LanguageType;"); + jmForNumber_ = + GetStaticMethodID(env, clz_, "forNumber", "(I)Lcom/yuanrong/libruntime/generated/Libruntime$LanguageType;"); } void JNILanguageType::Recycle(JNIEnv *env) @@ -487,8 +486,7 @@ void JNICodeExecutor::Init(JNIEnv *env) env, clz_, "execute", "(Lcom/yuanrong/libruntime/generated/Libruntime$FunctionMeta;Lcom/yuanrong/libruntime/generated/" "Libruntime$InvokeType;Ljava/util/List;)Lcom/yuanrong/executor/ReturnType;"); - jmDumpInstance_ = - GetStaticMethodID(env, clz_, "dumpInstance", "(Ljava/lang/String;)Lcom/yuanrong/errorcode/Pair;"); + jmDumpInstance_ = GetStaticMethodID(env, clz_, "dumpInstance", "(Ljava/lang/String;)Lcom/yuanrong/errorcode/Pair;"); jmLoadInstance_ = GetStaticMethodID(env, clz_, "loadInstance", "([B[B)V"); @@ -608,8 +606,8 @@ YR::Libruntime::ErrorInfo JNICodeExecutor::DumpInstance(JNIEnv *env, const std:: size_t clzNameSize = env->GetArrayLength(clzNameBytes); // data buffer format: [uint_8(size of buf1)|buf1(instanceBuf)|buf2(clsName)] // nativeBuffer is the combination of instanceBuf and clsName - if (instanceBufSize > (std::numeric_limits::max() - sizeof(size_t)) - || (sizeof(size_t) + instanceBufSize) > (std::numeric_limits::max() - clzNameSize)) { + if (instanceBufSize > (std::numeric_limits::max() - sizeof(size_t)) || + (sizeof(size_t) + instanceBufSize) > (std::numeric_limits::max() - clzNameSize)) { return YR::Libruntime::ErrorInfo(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "nativeBufferSize exceeds maximum allowed size"); } @@ -885,10 +883,11 @@ void JNIHashMap::Recycle(JNIEnv *env) template jobject JNIHashMap::FromCc(JNIEnv *env, const std::unordered_map &map, - std::function converterK, std::function converterV) + std::function converterK, + std::function converterV) { jobject hashMap = env->NewObject(clz_, init_); - for (const auto& kv : map) { + for (const auto &kv : map) { auto tmpKey = converterK(env, kv.first); auto tmpValue = converterV(env, kv.second); env->CallObjectMethod(hashMap, jmPut_, tmpKey, tmpValue); @@ -1823,8 +1822,7 @@ void JNIWriteMode::Init(JNIEnv *env) { clz_ = LoadClass(env, "com/yuanrong/WriteMode"); j_field_NONE_L2_CACHE_ = GetJStaticField(env, clz_, "NONE_L2_CACHE", "Lcom/yuanrong/WriteMode;"); - j_field_WRITE_THROUGH_L2_CACHE_ = - GetJStaticField(env, clz_, "WRITE_THROUGH_L2_CACHE", "Lcom/yuanrong/WriteMode;"); + j_field_WRITE_THROUGH_L2_CACHE_ = GetJStaticField(env, clz_, "WRITE_THROUGH_L2_CACHE", "Lcom/yuanrong/WriteMode;"); j_field_WRITE_BACK_L2_CACHE_ = GetJStaticField(env, clz_, "WRITE_BACK_L2_CACHE", "Lcom/yuanrong/WriteMode;"); j_object_field_NONE_L2_CACHE_ = GetJStaticObjectField(env, clz_, j_field_NONE_L2_CACHE_); @@ -2034,7 +2032,6 @@ YR::Libruntime::CacheType JNIMSetParam::GetCacheType(JNIEnv *env, jobject o) void JNICreateParam::Init(JNIEnv *env) { clz_ = LoadClass(env, "com/yuanrong/CreateParam"); - jGetWriteMode_ = GetJMethod(env, clz_, "getWriteMode", "()Lcom/yuanrong/WriteMode;"); jGetConsistencyType_ = GetJMethod(env, clz_, "getConsistencyType", "()Lcom/yuanrong/ConsistencyType;"); jGetCacheType_ = GetJMethod(env, clz_, "getCacheType", "()Lcom/yuanrong/CacheType;"); } @@ -2045,7 +2042,6 @@ YR::Libruntime::CreateParam JNICreateParam::FromJava(JNIEnv *env, jobject o) RETURN_IF_NULL(o, createParam); return YR::Libruntime::CreateParam{ - .writeMode = GetWriteMode(env, o), .consistencyType = GetConsistencyType(env, o), .cacheType = GetCacheType(env, o), }; @@ -2059,12 +2055,6 @@ void JNICreateParam::Recycle(JNIEnv *env) } } -YR::Libruntime::WriteMode JNICreateParam::GetWriteMode(JNIEnv *env, jobject o) -{ - jobject writeMode = env->CallObjectMethod(o, JNICreateParam::jGetWriteMode_); - return JNIWriteMode::FromJava(env, writeMode); -} - YR::Libruntime::ConsistencyType JNICreateParam::GetConsistencyType(JNIEnv *env, jobject o) { jobject consistencyType = env->CallObjectMethod(o, JNICreateParam::jGetConsistencyType_); @@ -2505,25 +2495,20 @@ void JNINode::Recycle(JNIEnv *env) jobject JNINode::GetResourcesFromResourceUnit(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit) { - return JNIHashMap::FromCc(env, resourceUnit.capacity, - [](JNIEnv *env, const std::string &key) { return env->NewStringUTF(key.c_str()); }, - [](JNIEnv *env, const float &value) { return JNIFloat::FromCc(env, value); } - ); + return JNIHashMap::FromCc( + env, resourceUnit.capacity, [](JNIEnv *env, const std::string &key) { return env->NewStringUTF(key.c_str()); }, + [](JNIEnv *env, const float &value) { return JNIFloat::FromCc(env, value); }); } - jobject JNINode::GetLabelsFromResourceUnit(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit) { - return JNIHashMap::FromCc>(env, resourceUnit.nodeLabels, - [](JNIEnv *env, const std::string &key) { - return env->NewStringUTF(key.c_str()); - }, + return JNIHashMap::FromCc>( + env, resourceUnit.nodeLabels, + [](JNIEnv *env, const std::string &key) { return env->NewStringUTF(key.c_str()); }, [](JNIEnv *env, const std::vector &value) { return JNIArrayList::FromCc( - env, value, [](JNIEnv *env, const std::string &s) { return env->NewStringUTF(s.c_str()); } - ); - } - ); + env, value, [](JNIEnv *env, const std::string &s) { return env->NewStringUTF(s.c_str()); }); + }); } jobject JNINode::FromCc(JNIEnv *env, const YR::Libruntime::ResourceUnit &resourceUnit) diff --git a/api/java/function-common/src/main/cpp/jni_types.h b/api/java/function-common/src/main/cpp/jni_types.h index d8a5aae..ff19dc5 100644 --- a/api/java/function-common/src/main/cpp/jni_types.h +++ b/api/java/function-common/src/main/cpp/jni_types.h @@ -121,18 +121,18 @@ using FunctionLog = ::libruntime::FunctionLog; } \ } while (false) -#define CHECK_NULL_THROW_NEW_AND_RETURN(env, ptr, returnValue, msg) \ - if ((ptr) == nullptr) { \ - YR::jni::JNILibruntimeException::Throw(env, YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, \ - YR::Libruntime::ModuleCode::RUNTIME, std::string(msg)); \ - return returnValue; \ +#define CHECK_NULL_THROW_NEW_AND_RETURN(env, ptr, returnValue, msg) \ + if ((ptr) == nullptr) { \ + YR::jni::JNILibruntimeException::Throw(env, YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, \ + YR::Libruntime::ModuleCode::RUNTIME, std::string(msg)); \ + return returnValue; \ } -#define CHECK_NULL_THROW_NEW_AND_RETURN_VOID(env, ptr, msg) \ - if ((ptr) == nullptr) { \ - YR::jni::JNILibruntimeException::Throw(env, YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, \ - YR::Libruntime::ModuleCode::RUNTIME, std::string(msg)); \ - return; \ +#define CHECK_NULL_THROW_NEW_AND_RETURN_VOID(env, ptr, msg) \ + if ((ptr) == nullptr) { \ + YR::jni::JNILibruntimeException::Throw(env, YR::Libruntime::ErrorCode::ERR_INNER_SYSTEM_ERROR, \ + YR::Libruntime::ModuleCode::RUNTIME, std::string(msg)); \ + return; \ } inline jclass LoadClass(JNIEnv *env, const std::string &className) @@ -409,7 +409,8 @@ public: static void Recycle(JNIEnv *env); template static jobject FromCc(JNIEnv *env, const std::unordered_map &map, - std::function converterK, std::function converterV); + std::function converterK, + std::function converterV); private: inline static jclass clz_ = nullptr; @@ -842,13 +843,11 @@ public: static void Init(JNIEnv *env); static void Recycle(JNIEnv *env); static YR::Libruntime::CreateParam FromJava(JNIEnv *env, jobject o); - static YR::Libruntime::WriteMode GetWriteMode(JNIEnv *env, jobject o); static YR::Libruntime::ConsistencyType GetConsistencyType(JNIEnv *env, jobject o); static YR::Libruntime::CacheType GetCacheType(JNIEnv *env, jobject o); private: inline static jclass clz_ = nullptr; - inline static jmethodID jGetWriteMode_ = nullptr; inline static jmethodID jGetConsistencyType_ = nullptr; inline static jmethodID jGetCacheType_ = nullptr; }; diff --git a/api/java/function-common/src/main/java/com/yuanrong/CreateParam.java b/api/java/function-common/src/main/java/com/yuanrong/CreateParam.java index 2ce38bd..00907f3 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/CreateParam.java +++ b/api/java/function-common/src/main/java/com/yuanrong/CreateParam.java @@ -25,7 +25,6 @@ import lombok.Data; */ @Data public class CreateParam { - private WriteMode writeMode = WriteMode.NONE_L2_CACHE; private ConsistencyType consistencyType = ConsistencyType.PRAM; @@ -53,17 +52,6 @@ public class CreateParam { createParam = new CreateParam(); } - /** - * set the writeMode - * - * @param writeMode the writeMode - * @return CreateParam Builder class object. - */ - public Builder writeMode(WriteMode writeMode) { - createParam.writeMode = writeMode; - return this; - } - /** * set the consistencyType * diff --git a/api/java/function-common/src/main/java/com/yuanrong/jni/JniProducer.java b/api/java/function-common/src/main/java/com/yuanrong/jni/JniProducer.java index be323b6..3e26870 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/jni/JniProducer.java +++ b/api/java/function-common/src/main/java/com/yuanrong/jni/JniProducer.java @@ -66,14 +66,6 @@ public class JniProducer { */ public static native ErrorInfo sendDirectBuffer(long producerPtr, ByteBuffer buffers, int timeoutMs); - /** - * flush - * - * @param producerPtr producerPtr - * @return ErrorInfo - */ - public static native ErrorInfo flush(long producerPtr); - /** * close * diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/Producer.java b/api/java/function-common/src/main/java/com/yuanrong/stream/Producer.java index 6cd8d84..4cd0580 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/stream/Producer.java +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/Producer.java @@ -49,13 +49,6 @@ public interface Producer { */ void send(Element element, int timeoutMs) throws YRException; - /** - * Manually flush the buffer data to make it visible to consumers. - * - * @throws YRException Unified exception types thrown. - */ - void flush() throws YRException; - /** * Closing a producer triggers an automatic flush of the data buffer and indicates that the data buffer is no longer * in use. Once closed, the producer can no longer be used. diff --git a/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerImpl.java b/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerImpl.java index 158902f..7fee6fc 100644 --- a/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerImpl.java +++ b/api/java/function-common/src/main/java/com/yuanrong/stream/ProducerImpl.java @@ -89,18 +89,6 @@ public class ProducerImpl implements Producer { } } - @Override - public void flush() throws YRException { - rLock.lock(); - try { - ensureOpen(); - ErrorInfo err = JniProducer.flush(this.producerPtr); - StackTraceUtils.checkErrorAndThrowForInvokeException(err, err.getErrorMessage()); - } finally { - rLock.unlock(); - } - } - /** * Checks to make sure that producer has not been closed. * diff --git a/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerImpl.java b/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerImpl.java index a222390..0cc1edb 100644 --- a/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerImpl.java +++ b/api/java/function-common/src/test/java/com/yuanrong/stream/TestProducerImpl.java @@ -102,30 +102,6 @@ public class TestProducerImpl { Assert.assertFalse(isException); } - @Test - public void testFlush() { - ProducerImpl producer = new ProducerImpl(10L); - ProducerImpl producer1 = new ProducerImpl(0L); - ErrorInfo errorInfo = new ErrorInfo(ErrorCode.ERR_OK, ModuleCode.CORE, ""); - PowerMockito.mockStatic(JniProducer.class); - when(JniProducer.flush(anyLong())).thenReturn(errorInfo); - boolean isException = false; - try { - producer1.flush(); - } catch (Exception e) { - isException = true; - } - Assert.assertTrue(isException); - - isException = false; - try { - producer.flush(); - } catch (Exception e) { - isException = true; - } - Assert.assertFalse(isException); - } - @Test public void testClose() { ProducerImpl producer = new ProducerImpl(10L); diff --git a/api/python/yr/fnruntime.pyx b/api/python/yr/fnruntime.pyx index dda9e3f..d24b43c 100644 --- a/api/python/yr/fnruntime.pyx +++ b/api/python/yr/fnruntime.pyx @@ -1030,17 +1030,6 @@ cdef class Producer: f"failed to send, " f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") - def flush(self) -> None: - """ - Manually flushing the buffer makes the data visible to the consumer. - """ - cdef CErrorInfo ret - ret = self.producer.get().Flush() - if not ret.OK(): - raise RuntimeError( - f"failed to flush, " - f"code: {ret.Code()}, module code {ret.MCode()}, msg: {ret.Msg().decode()}") - def close(self) -> None: """ Closing the producer will trigger an automatic flush of the data buffer and diff --git a/api/python/yr/runtime.py b/api/python/yr/runtime.py index 6489726..1021743 100644 --- a/api/python/yr/runtime.py +++ b/api/python/yr/runtime.py @@ -128,12 +128,6 @@ class CreateParam: """ pass - #: Configure the reliability of the data. - #: When the server is configured to support a secondary cache for ensuring reliability, - #: such as a Redis service, this configuration can ensure the reliability of the data. - #: Defaults to WriteMode.NONE_L2_CACHE. - write_mode: WriteMode = WriteMode.NONE_L2_CACHE - #: Data consistency configuration. #: In a distributed scenario, different levels of consistency semantics can be configured. #: The optional parameters are ConsistencyType.PRAM (asynchronous) and ConsistencyType.CAUSAL (causal consistency). diff --git a/deploy/data_system/install.sh b/deploy/data_system/install.sh index dc3434b..64c5fd8 100644 --- a/deploy/data_system/install.sh +++ b/deploy/data_system/install.sh @@ -49,7 +49,7 @@ function install_ds_master() { -node_timeout_s=${DS_NODE_TIMEOUT_S} \ -rocksdb_store_dir="${data_system_install_dir}/rocksdb" \ -etcd_address="${ETCD_CLUSTER_ADDRESS}" \ - -az_name="${ETCD_TABLE_PREFIX}" \ + -cluster_name="${ETCD_TABLE_PREFIX}" \ -etcd_target_name_override="${ETCD_TARGET_NAME_OVERRIDE}" \ -enable_etcd_auth=${ENABLE_ETCD_AUTH} \ -arena_per_tenant=${DS_ARENA_PER_TENANT} \ @@ -137,7 +137,7 @@ function install_ds_worker() { -enable_huge_tlb=${DS_ENABLE_HUGE_TLB} \ -enable_thp=${DS_ENABLE_THP} \ -etcd_address="${ETCD_CLUSTER_ADDRESS}" \ - -az_name="${ETCD_TABLE_PREFIX}" \ + -cluster_name="${ETCD_TABLE_PREFIX}" \ -etcd_target_name_override="${ETCD_TARGET_NAME_OVERRIDE}" \ -enable_etcd_auth=${ENABLE_ETCD_AUTH} \ -etcd_ca=${ETCD_SSL_BASE_PATH}/${ETCD_CA_FILE} \ diff --git a/src/libruntime/generator/stream_generator_notifier.cpp b/src/libruntime/generator/stream_generator_notifier.cpp index 55ffb98..88ec641 100644 --- a/src/libruntime/generator/stream_generator_notifier.cpp +++ b/src/libruntime/generator/stream_generator_notifier.cpp @@ -158,12 +158,6 @@ ErrorInfo StreamGeneratorNotifier::NotifyResultByStream(const std::string &gener if (!err.OK()) { YRLOG_ERROR("failed to send notify result to stream, err code: {}, err message: {}", fmt::underlying(err.Code()), err.Msg()); - } else { - err = producer->Flush(); - if (!err.OK()) { - YRLOG_ERROR("failed to flush stream, err code: {}, err message: {}", fmt::underlying(err.Code()), - err.Msg()); - } } if (!resultErr.OK()) { @@ -197,12 +191,6 @@ ErrorInfo StreamGeneratorNotifier::NotifyFinishedByStream(const std::string &gen if (!err.OK()) { YRLOG_ERROR("failed to send notify finished to stream, err code: {}, err message: {}", fmt::underlying(err.Code()), err.Msg()); - } else { - err = producer->Flush(); - if (!err.OK()) { - YRLOG_ERROR("failed to flush stream, err code: {}, err message: {}", fmt::underlying(err.Code()), - err.Msg()); - } } if (!DecreaseProducerReference(topic)) { @@ -234,12 +222,6 @@ ErrorInfo StreamGeneratorNotifier::NotifyHeartbeatByStream(const std::string &ge if (!err.OK()) { YRLOG_ERROR("failed to send notify heartbeat to stream, err code: {}, err message: {}", fmt::underlying(err.Code()), err.Msg()); - } else { - err = producer->Flush(); - if (!err.OK()) { - YRLOG_ERROR("failed to flush stream, err code: {}, err message: {}", fmt::underlying(err.Code()), - err.Msg()); - } } if (!DecreaseProducerReference(topic)) { diff --git a/src/libruntime/objectstore/datasystem_object_store.cpp b/src/libruntime/objectstore/datasystem_object_store.cpp index 5e27296..6135628 100644 --- a/src/libruntime/objectstore/datasystem_object_store.cpp +++ b/src/libruntime/objectstore/datasystem_object_store.cpp @@ -217,7 +217,6 @@ ErrorInfo DSCacheObjectStore::CreateBuffer(const std::string &objectId, size_t d OBJ_STORE_INIT_ONCE(); std::shared_ptr dataBuffer; ds::CreateParam param; - param.writeMode = static_cast(createParam.writeMode); param.consistencyType = static_cast(createParam.consistencyType); param.cacheType = static_cast(createParam.cacheType); ds::Status status = dsClient->Create(objectId, dataSize, param, dataBuffer); @@ -272,7 +271,6 @@ ErrorInfo DSCacheObjectStore::Put(std::shared_ptr data, const std::strin std::string msg; std::shared_ptr dataBuffer; ds::CreateParam param; - param.writeMode = static_cast(createParam.writeMode); param.consistencyType = static_cast(createParam.consistencyType); param.cacheType = static_cast(createParam.cacheType); ds::Status status = dsClient->Create(objId, static_cast(data->GetSize()), param, dataBuffer); diff --git a/src/libruntime/streamstore/stream_producer_consumer.cpp b/src/libruntime/streamstore/stream_producer_consumer.cpp index b9f7ef4..5c1f04c 100644 --- a/src/libruntime/streamstore/stream_producer_consumer.cpp +++ b/src/libruntime/streamstore/stream_producer_consumer.cpp @@ -38,14 +38,6 @@ ErrorInfo StreamProducer::Send(const Element &element, int64_t timeoutMs) return ErrorInfo(); } -ErrorInfo StreamProducer::Flush() -{ - datasystem::Status status = dsProducer->Flush(); - auto msg = "producer failed to Flush, errMsg: " + status.ToString(); - RETURN_ERR_NOT_OK(status.IsOk(), status.GetCode(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, msg); - return ErrorInfo(); -} - ErrorInfo StreamProducer::Close() { datasystem::Status status = dsProducer->Close(); diff --git a/src/libruntime/streamstore/stream_producer_consumer.h b/src/libruntime/streamstore/stream_producer_consumer.h index 8a314c8..9dc0dd8 100644 --- a/src/libruntime/streamstore/stream_producer_consumer.h +++ b/src/libruntime/streamstore/stream_producer_consumer.h @@ -30,8 +30,6 @@ public: virtual ErrorInfo Send(const Element &element, int64_t timeoutMs); - virtual ErrorInfo Flush(); - virtual ErrorInfo Close(); std::shared_ptr &GetProducer(); diff --git a/test/api/stream_pub_sub_test.cpp b/test/api/stream_pub_sub_test.cpp index 5c1a557..7a9034e 100644 --- a/test/api/stream_pub_sub_test.cpp +++ b/test/api/stream_pub_sub_test.cpp @@ -28,7 +28,6 @@ class MockStreamProducer : public YR::Libruntime::StreamProducer { public: MOCK_METHOD1(Send, YR::Libruntime::ErrorInfo(const YR::Libruntime::Element &element)); MOCK_METHOD2(Send, YR::Libruntime::ErrorInfo(const YR::Libruntime::Element &element, int64_t timeoutMs)); - MOCK_METHOD0(Flush, YR::Libruntime::ErrorInfo()); MOCK_METHOD0(Close, YR::Libruntime::ErrorInfo()); }; @@ -49,8 +48,8 @@ using namespace testing; class StreamPubSubTest : public testing::Test { public: - StreamPubSubTest(){}; - ~StreamPubSubTest(){}; + StreamPubSubTest() {}; + ~StreamPubSubTest() {}; void SetUp() override { Mkdir("/tmp/log"); @@ -117,28 +116,6 @@ TEST_F(StreamPubSubTest, SendSuccessfullyTest) EXPECT_NO_THROW(streamProducer->Send(ele, 1000)); } -TEST_F(StreamPubSubTest, FlushFailedTest) -{ - YR::Libruntime::ErrorInfo err(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, - YR::Libruntime::ModuleCode::DATASYSTEM, "Flush failed."); - EXPECT_CALL(*(this->producer), Flush()).WillOnce(testing::Return(err)); - bool isThrow = false; - try { - this->streamProducer->Flush(); - } catch (YR::Exception &err) { - ASSERT_EQ(err.Code(), YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED); - EXPECT_THAT(err.Msg(), testing::HasSubstr("Flush failed.")); - isThrow = true; - } - EXPECT_TRUE(isThrow); -} - -TEST_F(StreamPubSubTest, FlushSuccessfullyTest) -{ - EXPECT_CALL(*(this->producer), Flush()).WillOnce(testing::Return(YR::Libruntime::ErrorInfo())); - EXPECT_NO_THROW(this->streamProducer->Flush()); -} - TEST_F(StreamPubSubTest, ProducerCloseFailedTest) { YR::Libruntime::ErrorInfo err(YR::Libruntime::ErrorCode::ERR_DATASYSTEM_FAILED, diff --git a/test/clibruntime/clibruntime_test.cpp b/test/clibruntime/clibruntime_test.cpp index ec3f0ca..8416799 100644 --- a/test/clibruntime/clibruntime_test.cpp +++ b/test/clibruntime/clibruntime_test.cpp @@ -20,8 +20,8 @@ #include #include "common/mock_libruntime.h" #define private public -#include "datasystem/kv_client.h" #include "api/go/libruntime/cpplibruntime/clibruntime.h" +#include "datasystem/kv_client.h" #include "src/libruntime/libruntime_manager.h" using namespace YR::utility; @@ -118,8 +118,8 @@ void freeCErrorIds(CErrorObject **errorIds, int size_errorIds) class CLibruntimeTest : public testing::Test { public: - CLibruntimeTest(){}; - ~CLibruntimeTest(){}; + CLibruntimeTest() {}; + ~CLibruntimeTest() {}; void SetUp() override { Mkdir("/tmp/log"); @@ -919,7 +919,6 @@ TEST_F(CLibruntimeTest, CProducerTest) ASSERT_EQ(cErr.code, 0); SafeFreeCErr(cErr); - cErr = CProducerFlush(producer); ASSERT_EQ(cErr.code, 0); SafeFreeCErr(cErr); diff --git a/test/libruntime/generator_test.cpp b/test/libruntime/generator_test.cpp index 993ca6d..9c35d87 100644 --- a/test/libruntime/generator_test.cpp +++ b/test/libruntime/generator_test.cpp @@ -271,8 +271,6 @@ TEST_F(GeneratorTest, GeneratorNotifierTest_NotifyResult) return ErrorInfo(); }); - EXPECT_CALL(*p, Flush()).Times(2).WillRepeatedly([]() { return ErrorInfo(); }); - { auto n = std::make_shared(streamStore, map); for (int i = 0; i < 2; i++) { @@ -320,8 +318,6 @@ TEST_F(GeneratorTest, GeneratorNotifierTest_NotifyError) return ErrorInfo(); }); - EXPECT_CALL(*p, Flush()).WillOnce([]() { return ErrorInfo(); }); - { auto n = std::make_shared(streamStore, map); GeneratorIdRecorder r(genId, rtId, map); @@ -362,8 +358,6 @@ TEST_F(GeneratorTest, GeneratorNotifierTest_NotifyFinished) return ErrorInfo(); }); - EXPECT_CALL(*p, Flush()).WillOnce([]() { return ErrorInfo(); }); - { auto n = std::make_shared(streamStore, map); GeneratorIdRecorder r(genId, rtId, map); diff --git a/test/libruntime/mock/mock_datasystem.h b/test/libruntime/mock/mock_datasystem.h index baaef48..6308a91 100644 --- a/test/libruntime/mock/mock_datasystem.h +++ b/test/libruntime/mock/mock_datasystem.h @@ -124,7 +124,7 @@ public: class MockHeretoStore : public HeteroStore { public: - MOCK_METHOD1(Init, ErrorInfo(datasystem::ConnectOptions & options)); + MOCK_METHOD1(Init, ErrorInfo(datasystem::ConnectOptions &options)); MOCK_METHOD0(Shutdown, void()); MOCK_METHOD2(DevDelete, ErrorInfo(const std::vector &objectIds, std::vector &failedObjectIds)); @@ -146,7 +146,6 @@ class MockStreamProducer : public StreamProducer { public: MOCK_METHOD1(Send, ErrorInfo(const Element &element)); MOCK_METHOD2(Send, ErrorInfo(const Element &element, int64_t timeoutMs)); - MOCK_METHOD0(Flush, ErrorInfo()); MOCK_METHOD0(Close, ErrorInfo()); }; diff --git a/test/libruntime/mock/mock_datasystem_client.cpp b/test/libruntime/mock/mock_datasystem_client.cpp index f3a4651..311d929 100644 --- a/test/libruntime/mock/mock_datasystem_client.cpp +++ b/test/libruntime/mock/mock_datasystem_client.cpp @@ -22,21 +22,15 @@ #define private public #include "datasystem/hetero_client.h" -#include "datasystem/object_client.h" #include "datasystem/kv_client.h" +#include "datasystem/object_client.h" #include "datasystem/stream_client.h" namespace datasystem { -class ThreadPool { -}; +class ThreadPool {}; -class StreamClientImpl { -}; -StreamClient::StreamClient(std::string ip, int port, const std::string &clientPublicKey, - const SensitiveValue &clientPrivateKey, const std::string &serverPublicKey, - const std::string &accessKey, const SensitiveValue &secretKey) -{ -} +class StreamClientImpl {}; +StreamClient::StreamClient(ConnectionOpts options) {} Status StreamClient::Init(bool reportWorkerLost) { @@ -86,11 +80,6 @@ Status Producer::Send(const Element &element, int64_t timeoutMs) return Status::OK(); } -Status Producer::Flush() -{ - return Status::OK(); -} - Status Producer::Close() { return Status::OK(); @@ -126,8 +115,7 @@ Status Consumer::Ack(uint64_t elementId) return Status::OK(); } -class ObjectClientImpl { -}; +class ObjectClientImpl {}; ObjectClient::ObjectClient(const ConnectOptions &connectOptions) {} @@ -265,10 +253,9 @@ Status Buffer::Publish(const std::unordered_set &nestedIds) return Status::OK(); } -class KVClientImpl { -}; +class KVClientImpl {}; -KVClient::KVClient(const ConnectOptions &connectOptions){}; +KVClient::KVClient(const ConnectOptions &connectOptions) {}; Status KVClient::Init() { @@ -309,7 +296,7 @@ Status KVClient::Get(const std::string &key, std::string &val, int32_t timeoutMs } Status KVClient::Get(const std::vector &keys, std::vector> &readOnlyBuffers, - int32_t timeoutMs) + int32_t timeoutMs) { // To test the if branch of partial get, // if a vector of len = 1, successfully get @@ -435,8 +422,7 @@ Buffer::Buffer(Buffer &&other) noexcept {} Buffer::~Buffer() {} -class HeteroClientImpl { -}; +class HeteroClientImpl {}; HeteroClient::HeteroClient(const ConnectOptions &connectOptions) {} @@ -453,7 +439,7 @@ Status HeteroClient::ShutDown() } Status HeteroClient::MGetH2D(const std::vector &objectIds, const std::vector &devBlobList, - std::vector &failList, int32_t timeoutMs) + std::vector &failList, int32_t timeoutMs) { return Status::OK(); } @@ -463,7 +449,8 @@ Status HeteroClient::DevDelete(const std::vector &objectIds, std::v return Status::OK(); } -Status HeteroClient::DevLocalDelete(const std::vector &objectIds, std::vector &failedObjectIds) +Status HeteroClient::DevLocalDelete(const std::vector &objectIds, + std::vector &failedObjectIds) { return Status::OK(); } diff --git a/test/libruntime/stream_store_test.cpp b/test/libruntime/stream_store_test.cpp index 7f1d00d..48f8f96 100644 --- a/test/libruntime/stream_store_test.cpp +++ b/test/libruntime/stream_store_test.cpp @@ -33,8 +33,8 @@ namespace YR { namespace test { class StreamStoreTest : public testing::Test { public: - StreamStoreTest(){}; - ~StreamStoreTest(){}; + StreamStoreTest() {}; + ~StreamStoreTest() {}; void SetUp() override { Mkdir("/tmp/log"); @@ -128,8 +128,6 @@ TEST_F(StreamStoreTest, TestProducer) ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); err = streamProducer->Send(element, 1000); ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); - err = streamProducer->Flush(); - ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); err = streamProducer->Close(); ASSERT_EQ(err.Code(), ErrorCode::ERR_OK); } -- Gitee From f90d688d319612ce4dcb7588d61726831e93dcd2 Mon Sep 17 00:00:00 2001 From: mayuehit Date: Wed, 19 Nov 2025 15:24:01 +0800 Subject: [PATCH 3/3] fix --- api/python/yr/fnruntime.pyx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/python/yr/fnruntime.pyx b/api/python/yr/fnruntime.pyx index d24b43c..30efb85 100644 --- a/api/python/yr/fnruntime.pyx +++ b/api/python/yr/fnruntime.pyx @@ -69,6 +69,8 @@ CErrorCode, CErrorInfo, CFunctionMeta, CInternalWaitResult, CInvokeArg, CInvokeOptions, CInvokeType, CModuleCode, CLanguageType, CLibruntimeConfig, CLibruntimeManager,move,CLibruntime, +CStreamProducer, CSubscriptionConfig, +CSubscriptionType, CExistenceOpt, CSetParam, CMSetParam, CCreateParam, CStackTraceInfo, CWriteMode, CCacheType, CConsistencyType, CGetParam, CGetParams, CMultipleReadResult, CDevice, CMultipleDelResult, CUInt64CounterData, CDoubleCounterData, NativeBuffer, StringNativeBuffer, CInstanceOptions, CGaugeData, CTensor, CDataType, CResourceUnit, CAlarmInfo, CAlarmSeverity, CFunctionGroupOptions, CBundleAffinity, CFunctionGroupRunningInfo, CFiberEvent, -- Gitee